From 53132054b71bffc7e81d7491d79b620c3546b5ce Mon Sep 17 00:00:00 2001 From: Estrella-xx Date: Fri, 27 Mar 2026 17:42:04 +0800 Subject: [PATCH 01/42] [NPU] Support GLM-4.7-Flash on NPU (#153) --- .../npu/attention/ascend_backend.py | 354 +++++++++++++----- .../srt/layers/rotary_embedding/base.py | 9 + 2 files changed, 266 insertions(+), 97 deletions(-) diff --git a/python/sglang/srt/hardware_backend/npu/attention/ascend_backend.py b/python/sglang/srt/hardware_backend/npu/attention/ascend_backend.py index 93b44de927b7..e6e7dd5ccbad 100644 --- a/python/sglang/srt/hardware_backend/npu/attention/ascend_backend.py +++ b/python/sglang/srt/hardware_backend/npu/attention/ascend_backend.py @@ -21,6 +21,7 @@ ) from sglang.srt.layers.attention.base_attn_backend import AttentionBackend from sglang.srt.layers.attention.nsa.utils import is_nsa_enable_prefill_cp +from sglang.srt.layers.dp_attention import get_attention_tp_size from sglang.srt.layers.radix_attention import AttentionType from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode from sglang.srt.speculative.spec_info import SpecInput @@ -213,6 +214,7 @@ def __init__(self, model_runner: ModelRunner): self.forward_metadata = None self.device = model_runner.device self.page_size = model_runner.page_size + self.model_dtype = model_runner.model_config.dtype self.use_mla = model_runner.model_config.attention_arch == AttentionArch.MLA if self.use_mla: self.kv_lora_rank = model_runner.model_config.kv_lora_rank @@ -261,6 +263,18 @@ def __init__(self, model_runner: ModelRunner): model_runner.token_to_kv_pool.full_to_swa_index_mapping ) + # head num padding + self.padding_size_list = [1, 2, 4, 8, 16, 32, 64, 128] + self.q_head_num_padding = None + if hasattr(model_runner.model_config, "num_attention_heads") and self.use_mla: + self.tp_q_head_num = ( + model_runner.model_config.num_attention_heads // get_attention_tp_size() + ) + for num in self.padding_size_list: + if num >= self.tp_q_head_num: + self.q_head_num_padding = num + break + # dllm model config self.dllm_config = DllmConfig.from_server_args(model_runner.server_args) self.is_dllm_model = False @@ -435,6 +449,35 @@ def init_forward_metadata_capture_cuda_graph( torch.cumsum(extend_seq_lens_cpu_int, dim=0).int().tolist() ) + if ( + self.q_head_num_padding is not None + and self.q_head_num_padding > self.tp_q_head_num + ): + metadata.nope_padding = torch.empty( + [ + bs, + 1, + self.q_head_num_padding - self.tp_q_head_num, + self.kv_lora_rank, + ], + dtype=( + self.model_dtype if self.model_dtype is not None else torch.bfloat16 + ), + device=seq_lens.device, + ) + metadata.rope_padding = torch.empty( + [ + bs, + 1, + self.q_head_num_padding - self.tp_q_head_num, + self.qk_rope_head_dim, + ], + dtype=( + self.model_dtype if self.model_dtype is not None else torch.bfloat16 + ), + device=seq_lens.device, + ) + self.graph_metadata[bs] = metadata self.forward_metadata = metadata @@ -989,110 +1032,209 @@ def forward_extend( -1, layer.tp_q_head_num * layer.v_head_dim ) elif sum(forward_batch.extend_prefix_lens_cpu) > 0: - num_token_padding = q.shape[0] - q, k, v = [ - data[: forward_batch.num_token_non_padded_cpu] for data in [q, k, v] - ] - q_nope, q_rope = q.split([layer.v_head_dim, self.qk_rope_head_dim], dim=-1) - k_nope, k_rope = k.split([layer.v_head_dim, self.qk_rope_head_dim], dim=-1) - - # 1st, compute extend tokens to get attn_output and attn_lse - num_tokens = q_nope.size(0) - attn_output = torch.zeros( - num_tokens, - layer.tp_q_head_num, - layer.v_head_dim, - dtype=q_nope.dtype, - device=q_nope.device, - ) - attn_lse = torch.zeros( - layer.tp_q_head_num, - num_tokens, - dtype=torch.float32, - device=q_nope.device, - ) - torch_npu.atb.npu_ring_mla( - q_nope=q_nope, - q_rope=q_rope, - k_nope=k_nope, - k_rope=k_rope, - value=v, - mask=self.ringmla_mask, - seqlen=self.forward_metadata.extend_seq_lens_cpu_int, - head_num=layer.tp_q_head_num, - kv_head_num=layer.tp_k_head_num, - pre_out=None, - prev_lse=None, - qk_scale=layer.scaling, - kernel_type="kernel_type_high_precision", - mask_type="mask_type_triu", - calc_type="calc_type_first_ring", - output=attn_output, - softmax_lse=attn_lse, - ) + if layer.qk_head_dim == layer.v_head_dim: + q = q.reshape(-1, layer.tp_q_head_num, layer.qk_head_dim) - # 2nd, load history kvcache(kv_a and k_pe) and calculate k_nope - k_buffer = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id) - v_buffer = forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id) - kv_cached = torch.index_select( - k_buffer, 0, self.forward_metadata.flatten_prefix_block_tables - ) - k_rope_cached = torch.index_select( - v_buffer, 0, self.forward_metadata.flatten_prefix_block_tables - ).flatten(0, 1) + k_buffer = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id) + v_buffer = forward_batch.token_to_kv_pool.get_value_buffer( + layer.layer_id + ) + kv_cached = torch.index_select( + k_buffer, 0, self.forward_metadata.flatten_prefix_block_tables + ) + k_rope_cached = torch.index_select( + v_buffer, 0, self.forward_metadata.flatten_prefix_block_tables + ).flatten(0, 1) - assert layer.kv_b_proj is not None - kv = layer.kv_b_proj(kv_cached)[0].view( - -1, layer.tp_k_head_num, self.qk_nope_head_dim + layer.v_head_dim - ) - k_nope, v = kv.split([self.qk_nope_head_dim, layer.v_head_dim], dim=-1) + assert layer.kv_b_proj is not None + kv = layer.kv_b_proj(kv_cached)[0].view( + -1, layer.tp_k_head_num, self.qk_nope_head_dim + layer.v_head_dim + ) + k_nope, v_pre = kv.split( + [self.qk_nope_head_dim, layer.v_head_dim], dim=-1 + ) - # 3rd, compute history kv to attn_out - k_rope = k_rope_cached.expand(-1, layer.tp_k_head_num, -1) - seq_len = torch.stack( - [ + k_rope = k_rope_cached.expand(-1, layer.tp_k_head_num, -1) + k_pre = torch.cat([k_nope, k_rope], dim=-1) + + attn_output = torch.empty( + (q.size(0), layer.tp_q_head_num, layer.v_head_dim), + device=q.device, + dtype=q.dtype, + ) + q_len_offset = 0 + prefix_len_offset = 0 + for q_len, prefix_len in zip( self.forward_metadata.extend_seq_lens_cpu_int, self.forward_metadata.prefix_lens, + ): + k_cur_slice = k[None, q_len_offset : q_len_offset + q_len] + v_cur_slice = v[None, q_len_offset : q_len_offset + q_len] + k_pre_slice = k_pre[ + None, prefix_len_offset : prefix_len_offset + prefix_len + ] + v_pre_slice = v_pre[ + None, prefix_len_offset : prefix_len_offset + prefix_len + ] + + k_full = torch.cat([k_pre_slice, k_cur_slice], dim=1) + v_full = torch.cat([v_pre_slice, v_cur_slice], dim=1) + + attn_output[q_len_offset : q_len_offset + q_len] = ( + torch.ops.npu.npu_fused_infer_attention_score( + q[None, q_len_offset : q_len_offset + q_len], + k_full, + v_full, + num_heads=layer.tp_q_head_num, + num_key_value_heads=layer.tp_k_head_num, + input_layout="BSND", # todo, TND not supports q_heads!=k_heads + atten_mask=self.fia_mask, + sparse_mode=3, + scale=layer.scaling, + next_tokens=0, + )[0] + ) + q_len_offset += q_len + prefix_len_offset += prefix_len + attn_output = attn_output.view( + -1, layer.tp_q_head_num * layer.v_head_dim + ) + else: + num_token_padding = q.shape[0] + q, k, v = [ + data[: forward_batch.num_token_non_padded_cpu] for data in [q, k, v] ] - ) - torch_npu.atb.npu_ring_mla( - q_nope=q_nope, - q_rope=q_rope, - k_nope=k_nope, - k_rope=k_rope, - value=v, - mask=self.ringmla_mask, - seqlen=seq_len, - head_num=layer.tp_q_head_num, - kv_head_num=layer.tp_k_head_num, - pre_out=attn_output, - prev_lse=attn_lse, - qk_scale=layer.scaling, - kernel_type="kernel_type_high_precision", - mask_type="no_mask", - calc_type="calc_type_default", - output=attn_output, - softmax_lse=attn_lse, - ) - attn_output = attn_output.reshape( - [-1, layer.tp_q_head_num, layer.v_head_dim] - ) - if num_token_padding != forward_batch.num_token_non_padded_cpu: - attn_output = torch.cat( + q_nope, q_rope = q.split( + [layer.v_head_dim, self.qk_rope_head_dim], dim=-1 + ) + k_nope, k_rope = k.split( + [layer.v_head_dim, self.qk_rope_head_dim], dim=-1 + ) + + # 1st, compute extend tokens to get attn_output and attn_lse + num_tokens = q_nope.size(0) + attn_output = torch.zeros( + num_tokens, + layer.tp_q_head_num, + layer.v_head_dim, + dtype=q_nope.dtype, + device=q_nope.device, + ) + attn_lse = torch.zeros( + layer.tp_q_head_num, + num_tokens, + dtype=torch.float32, + device=q_nope.device, + ) + torch_npu.atb.npu_ring_mla( + q_nope=q_nope, + q_rope=q_rope, + k_nope=k_nope, + k_rope=k_rope, + value=v, + mask=self.ringmla_mask, + seqlen=self.forward_metadata.extend_seq_lens_cpu_int, + head_num=layer.tp_q_head_num, + kv_head_num=layer.tp_k_head_num, + pre_out=None, + prev_lse=None, + qk_scale=layer.scaling, + kernel_type="kernel_type_high_precision", + mask_type="mask_type_triu", + calc_type="calc_type_first_ring", + output=attn_output, + softmax_lse=attn_lse, + ) + + # 2nd, load history kvcache(kv_a and k_pe) and calculate k_nope + k_buffer = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id) + v_buffer = forward_batch.token_to_kv_pool.get_value_buffer( + layer.layer_id + ) + kv_cached = torch.index_select( + k_buffer, 0, self.forward_metadata.flatten_prefix_block_tables + ) + k_rope_cached = torch.index_select( + v_buffer, 0, self.forward_metadata.flatten_prefix_block_tables + ).flatten(0, 1) + + assert layer.kv_b_proj is not None + kv = layer.kv_b_proj(kv_cached)[0].view( + -1, layer.tp_k_head_num, self.qk_nope_head_dim + layer.v_head_dim + ) + k_nope, v = kv.split([self.qk_nope_head_dim, layer.v_head_dim], dim=-1) + + # 3rd, compute history kv to attn_out + k_rope = k_rope_cached.expand(-1, layer.tp_k_head_num, -1) + seq_len = torch.stack( [ - attn_output, - attn_output.new_zeros( - num_token_padding - attn_output.shape[0], - *attn_output.shape[1:], - ), - ], - dim=0, + self.forward_metadata.extend_seq_lens_cpu_int, + self.forward_metadata.prefix_lens, + ] ) + torch_npu.atb.npu_ring_mla( + q_nope=q_nope, + q_rope=q_rope, + k_nope=k_nope, + k_rope=k_rope, + value=v, + mask=self.ringmla_mask, + seqlen=seq_len, + head_num=layer.tp_q_head_num, + kv_head_num=layer.tp_k_head_num, + pre_out=attn_output, + prev_lse=attn_lse, + qk_scale=layer.scaling, + kernel_type="kernel_type_high_precision", + mask_type="no_mask", + calc_type="calc_type_default", + output=attn_output, + softmax_lse=attn_lse, + ) + attn_output = attn_output.reshape( + [-1, layer.tp_q_head_num, layer.v_head_dim] + ) + if num_token_padding != forward_batch.num_token_non_padded_cpu: + attn_output = torch.cat( + [ + attn_output, + attn_output.new_zeros( + num_token_padding - attn_output.shape[0], + *attn_output.shape[1:], + ), + ], + dim=0, + ) else: - assert ( - layer.qk_head_dim != layer.v_head_dim - ), "FIA only supports qk_head_dim != v_head_dim" - if layer.v_head_dim in [256]: + if layer.qk_head_dim == layer.v_head_dim: + """FIA will support multi-bs in the later version of CANN""" + q = q.reshape(-1, layer.tp_q_head_num, layer.qk_head_dim) + attn_output = torch.empty( + (q.size(0), layer.tp_q_head_num, layer.v_head_dim), + device=q.device, + dtype=q.dtype, + ) + q_len_offset = 0 + for q_len in forward_batch.extend_seq_lens_cpu: + attn_output[q_len_offset : q_len_offset + q_len] = ( + torch.ops.npu.npu_fused_infer_attention_score( + q[None, q_len_offset : q_len_offset + q_len], + k[None, q_len_offset : q_len_offset + q_len], + v[None, q_len_offset : q_len_offset + q_len], + num_heads=layer.tp_q_head_num, + num_key_value_heads=layer.tp_k_head_num, + input_layout="BSND", # todo, TND not supports q_heads!=k_heads + atten_mask=self.fia_mask.unsqueeze(0), + sparse_mode=3 if q_len != 1 else 0, + scale=layer.scaling, + next_tokens=0, + )[0] + ) + q_len_offset += q_len + attn_output = attn_output.view( + -1, layer.tp_q_head_num * layer.v_head_dim + ) + elif layer.v_head_dim in [256]: """Currently, in NO_QUANT situation, qk_nope_head_dim == v_head_dim, and rope exists, v_head_dim only support 512 and 128""" kv_lora_rank = k.shape[-1] - self.qk_rope_head_dim kv_c, k_rope = k.split([kv_lora_rank, self.qk_rope_head_dim], dim=-1) @@ -1525,6 +1667,22 @@ def forward_decode_graph( q_nope = q.view(-1, 1, layer.tp_q_head_num, self.kv_lora_rank).contiguous() q_rope = q_rope.view(-1, 1, layer.tp_q_head_num, self.qk_rope_head_dim) + assert ( + self.q_head_num_padding is None + or self.q_head_num_padding >= layer.tp_q_head_num + ) + + if ( + self.q_head_num_padding is not None + and self.q_head_num_padding > layer.tp_q_head_num + ): + q_nope = torch.cat( + [q_nope, self.forward_metadata.nope_padding], dim=2 + ).contiguous() + q_rope = torch.cat( + [q_rope, self.forward_metadata.rope_padding], dim=2 + ).contiguous() + if self.forward_metadata.seq_lens_cpu_int is None: actual_seq_len_kv = self.forward_metadata.seq_lens_cpu_list else: @@ -1538,7 +1696,7 @@ def forward_decode_graph( c_kv_cache, query_rope=q_rope, key_rope=k_rope_cache, - num_heads=layer.tp_q_head_num, + num_heads=self.q_head_num_padding, num_key_value_heads=layer.tp_k_head_num, block_table=self.forward_metadata.block_tables, block_size=self.page_size, @@ -1558,7 +1716,7 @@ def forward_decode_graph( c_kv_cache, query_rope=q_rope, key_rope=k_rope_cache, - num_heads=layer.tp_q_head_num, + num_heads=self.q_head_num_padding, num_key_value_heads=layer.tp_k_head_num, block_table=self.forward_metadata.block_tables, block_size=self.page_size, @@ -1571,6 +1729,8 @@ def forward_decode_graph( workspace=workspace, out=[output, softmax_lse], ) + + output = output[:, :, : layer.tp_q_head_num, :] return output.view(-1, layer.tp_q_head_num * self.kv_lora_rank) def forward_decode( diff --git a/python/sglang/srt/layers/rotary_embedding/base.py b/python/sglang/srt/layers/rotary_embedding/base.py index 943fe8558f4f..1d17a1ca7f90 100644 --- a/python/sglang/srt/layers/rotary_embedding/base.py +++ b/python/sglang/srt/layers/rotary_embedding/base.py @@ -242,7 +242,13 @@ def forward_npu( rotary_mode = "half" else: rotary_mode = "interleave" + mrope_section = [0, 0, 0] + query_shape = query.shape + key_shape = key.shape + query = query.reshape(query.shape[0], -1) + key = key.reshape(key.shape[0], -1) + query_out, key_out = torch_npu.npu_mrope( positions, query, @@ -252,6 +258,9 @@ def forward_npu( mrope_section=mrope_section, rotary_mode=rotary_mode, ) + + query_out = query_out.reshape(query_shape) + key_out = key_out.reshape(key_shape) return query_out, key_out def forward_cpu( From 70301c531fc488acc3b8d8982d2a2e59a4b3906f Mon Sep 17 00:00:00 2001 From: McZyWu Date: Sat, 28 Mar 2026 16:34:49 +0800 Subject: [PATCH 02/42] [NPU] recover accuracy for gemma3-4b-it from 54% to 72% (reduced by transformer5.3) (#155) * [NPU] recover accuracy for gemma3-4b-it for transformer5.3 --- python/sglang/srt/models/gemma3_causal.py | 24 ++++++++++++++++++----- 1 file changed, 19 insertions(+), 5 deletions(-) diff --git a/python/sglang/srt/models/gemma3_causal.py b/python/sglang/srt/models/gemma3_causal.py index 0481cae0eeba..db243ea3cd88 100644 --- a/python/sglang/srt/models/gemma3_causal.py +++ b/python/sglang/srt/models/gemma3_causal.py @@ -42,9 +42,16 @@ default_weight_loader, maybe_remap_kv_scale_name, ) -from sglang.srt.utils import add_prefix, cpu_has_amx_support, is_cpu, make_layers +from sglang.srt.utils import ( + add_prefix, + cpu_has_amx_support, + is_cpu, + is_npu, + make_layers, +) _is_cpu = is_cpu() +_is_npu = is_npu() _is_cpu_amx_available = cpu_has_amx_support() @@ -573,10 +580,17 @@ def __init__( local_theta = getattr(config, "rope_local_base_freq", 10000.0) global_config = copy.deepcopy(config) - global_config.rope_parameters = { - "rope_type": "default", - "rope_theta": global_theta, - } + if not _is_npu: + global_config.rope_parameters = { + "rope_type": "default", + "rope_theta": global_theta, + } + else: + global_config.rope_parameters = { + "rope_theta": global_theta, + "factor": 8, + "rope_type": "linear", + } self.rotary_emb = Gemma3RotaryEmbedding(config=global_config) self.gradient_checkpointing = False From 822138b6d67d2b576da647ca46318896147b34cf Mon Sep 17 00:00:00 2001 From: amote-i <49533125+amote-i@users.noreply.github.com> Date: Sat, 28 Mar 2026 16:55:59 +0800 Subject: [PATCH 03/42] Br fix qwen2 5 ascend (#159) * fix qwen2_5_math_rm_72b --- python/sglang/srt/models/qwen2.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/models/qwen2.py b/python/sglang/srt/models/qwen2.py index 9574186e1caa..85a510fc7854 100644 --- a/python/sglang/srt/models/qwen2.py +++ b/python/sglang/srt/models/qwen2.py @@ -576,7 +576,9 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): continue if name == "model.embed_tokens.weight": - if self.pp_group.is_last_rank and self.config.tie_word_embeddings: + if ( + not hasattr(self, "pp_group") or self.pp_group.is_last_rank + ) and self.config.tie_word_embeddings: if "lm_head.weight" in params_dict: param = params_dict["lm_head.weight"] weight_loader = getattr( From 2720785bae06084d92215abf22fd29eac77aef17 Mon Sep 17 00:00:00 2001 From: Estrella-xx Date: Sat, 28 Mar 2026 17:55:08 +0800 Subject: [PATCH 04/42] add documentation for GLM-4.7-Flash on Ascend (#162) * add documentation for GLM-4.7-Flash * update --- .../ascend_npu_glm4_7_flash_examples.md | 177 ++++++++++++++++++ 1 file changed, 177 insertions(+) create mode 100644 docs/platforms/ascend_npu_glm4_7_flash_examples.md diff --git a/docs/platforms/ascend_npu_glm4_7_flash_examples.md b/docs/platforms/ascend_npu_glm4_7_flash_examples.md new file mode 100644 index 000000000000..8604efb78d47 --- /dev/null +++ b/docs/platforms/ascend_npu_glm4_7_flash_examples.md @@ -0,0 +1,177 @@ +# GLM-4.7-Flash examples + +## Environment Preparation + +### Model Weight + +- `GLM-4.7-Flash`(BF16 version): [Download model weight](https://www.modelscope.cn/models/ZhipuAI/GLM-4.7-Flash). + +### Installation + +The dependencies required for the NPU runtime environment have been integrated into a Docker image and uploaded to the quay.io platform. You can directly pull it. + +```bash +#Atlas 800 A3 +docker pull quay.io/ascend/sglang:main-cann8.5.0-a3 +#Atlas 800 A2 +docker pull quay.io/ascend/sglang:main-cann8.5.0-910b + +#start container +docker run -itd --shm-size=16g --privileged=true --name ${NAME} \ +--privileged=true --net=host \ +-v /var/queue_schedule:/var/queue_schedule \ +-v /etc/ascend_install.info:/etc/ascend_install.info \ +-v /usr/local/sbin:/usr/local/sbin \ +-v /usr/local/Ascend/driver:/usr/local/Ascend/driver \ +-v /usr/local/Ascend/firmware:/usr/local/Ascend/firmware \ +--device=/dev/davinci0:/dev/davinci0 \ +--device=/dev/davinci1:/dev/davinci1 \ +--device=/dev/davinci2:/dev/davinci2 \ +--device=/dev/davinci3:/dev/davinci3 \ +--device=/dev/davinci4:/dev/davinci4 \ +--device=/dev/davinci5:/dev/davinci5 \ +--device=/dev/davinci6:/dev/davinci6 \ +--device=/dev/davinci7:/dev/davinci7 \ +--device=/dev/davinci8:/dev/davinci8 \ +--device=/dev/davinci9:/dev/davinci9 \ +--device=/dev/davinci10:/dev/davinci10 \ +--device=/dev/davinci11:/dev/davinci11 \ +--device=/dev/davinci12:/dev/davinci12 \ +--device=/dev/davinci13:/dev/davinci13 \ +--device=/dev/davinci14:/dev/davinci14 \ +--device=/dev/davinci15:/dev/davinci15 \ +--device=/dev/davinci_manager:/dev/davinci_manager \ +--device=/dev/hisi_hdc:/dev/hisi_hdc \ +--entrypoint=bash \ +quay.io/ascend/sglang:${tag} +``` + +Note: When using this image, you need to update Transformers to version 5.3.0. + +``` shell +# reinstall transformers +pip install transformers==5.3.0 +``` + +## Running GLM-4.7-Flash + +### Running GLM-4.7-Flash on 1 x Atlas 800I A3. + +Run the following script to execute online inference. + +```shell +# high performance cpu +echo performance | tee /sys/devices/system/cpu/cpu*/cpufreq/scaling_governor +sysctl -w vm.swappiness=0 +sysctl -w kernel.numa_balancing=0 +sysctl -w kernel.sched_migration_cost_ns=50000 +# bind cpu +export SGLANG_SET_CPU_AFFINITY=1 + +unset https_proxy +unset http_proxy +unset HTTPS_PROXY +unset HTTP_PROXY +unset ASCEND_LAUNCH_BLOCKING +# cann +source /usr/local/Ascend/ascend-toolkit/set_env.sh +source /usr/local/Ascend/nnal/atb/set_env.sh + +export PYTORCH_NPU_ALLOC_CONF=expandable_segments:True +export STREAMS_PER_DEVICE=32 +export HCCL_BUFFSIZE=1000 +export HCCL_OP_EXPANSION_MODE=AIV +export HCCL_SOCKET_IFNAME=lo +export GLOO_SOCKET_IFNAME=lo + +python3 -m sglang.launch_server \ + --model-path $MODEL_PATH \ + --tp-size 2 \ + --attention-backend ascend \ + --device npu \ + --chunked-prefill-size 16384 \ + --max-prefill-tokens 150000 \ + --dtype bfloat16 \ + --max-running-requests 32 \ + --trust-remote-code \ + --host 127.0.0.1 \ + --mem-fraction-static 0.75 \ + --port 8000 \ + --cuda-graph-bs 1 2 4 8 16 32 \ + --watchdog-timeout 9000 +``` + +Note: TP size is currently limited to 2 or 4. + +### Running GLM-4.7-Flash on 1 x Atlas 800I A3 in slime-ascend. + +#### Preparation + +- [slime-ascend](https://gitcode.com/Ascend/slime-ascend) code + +#### Installation + +Run the following commands to install sglang. (Please replace '' with the path to the root directory of the slime codebase.') + +```bash +git clone -b v0.5.8 https://github.com/sgl-project/sglang.git +cd sglang +mv python/pyproject_other.toml python/pyproject.toml +pip install -e python[srt_npu] +git checkout . && git checkout sglang-slime +git am /docker/npu_patch/v0.2.2/sglang/* +``` + +Note: Make sure you are using Transformers 5.3.0. + +#### Execution + +Run the following script to execute online **inference**. + +```shell +# high performance cpu +echo performance | tee /sys/devices/system/cpu/cpu*/cpufreq/scaling_governor +sysctl -w vm.swappiness=0 +sysctl -w kernel.numa_balancing=0 +sysctl -w kernel.sched_migration_cost_ns=50000 +# bind cpu +export SGLANG_SET_CPU_AFFINITY=1 + +unset https_proxy +unset http_proxy +unset HTTPS_PROXY +unset HTTP_PROXY +unset ASCEND_LAUNCH_BLOCKING +# cann +source /usr/local/Ascend/ascend-toolkit/set_env.sh +source /usr/local/Ascend/nnal/atb/set_env.sh + +export PYTORCH_NPU_ALLOC_CONF=expandable_segments:True +export STREAMS_PER_DEVICE=32 +export HCCL_BUFFSIZE=1000 +export HCCL_OP_EXPANSION_MODE=AIV +export HCCL_SOCKET_IFNAME=lo +export GLOO_SOCKET_IFNAME=lo + +python3 -m sglang.launch_server \ + --model-path $MODEL_PATH \ + --tp-size 2 \ + --attention-backend ascend \ + --device npu \ + --chunked-prefill-size 16384 \ + --max-prefill-tokens 150000 \ + --dtype bfloat16 \ + --max-running-requests 32 \ + --trust-remote-code \ + --host 127.0.0.1 \ + --mem-fraction-static 0.75 \ + --port 8000 \ + --cuda-graph-bs 1 2 4 8 16 32 \ + --watchdog-timeout 9000 +``` + +Refer to [Training and Deployment Example](https://gitcode.com/Ascend/slime-ascend/blob/main/docs/ascend_tutorial/examples/glm4.7-30B-A3B.md) for training and deployment. + +### Using Benchmark + +Refer to [Benchmark and Profiling](../developer_guide/benchmark_and_profiling.md) for details. From 6c54ed020e24204555417429d90c196f8f368a08 Mon Sep 17 00:00:00 2001 From: longxin9715 <59550463+longxin9715@users.noreply.github.com> Date: Sat, 28 Mar 2026 18:02:35 +0800 Subject: [PATCH 05/42] Revert "Use LazyValue for routed_experts_weights_of_layer initialization" (#161) Co-authored-by: j30065060 --- python/sglang/srt/models/qwen3_moe.py | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/python/sglang/srt/models/qwen3_moe.py b/python/sglang/srt/models/qwen3_moe.py index 010a73074759..1373099fdf5a 100644 --- a/python/sglang/srt/models/qwen3_moe.py +++ b/python/sglang/srt/models/qwen3_moe.py @@ -1182,15 +1182,11 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): logger.warning(f"Parameter {name} not found in params_dict") if not hasattr(self, "routed_experts_weights_of_layer"): - self.routed_experts_weights_of_layer = LazyValue( - lambda: { - layer_id: self.model.layers[layer_id].mlp.get_moe_weights() - for layer_id in range(self.start_layer, self.end_layer) - if isinstance( - self.model.layers[layer_id].mlp, Qwen3MoeSparseMoeBlock - ) - } - ) + self.routed_experts_weights_of_layer = { + layer_id: self.model.layers[layer_id].mlp.get_moe_weights() + for layer_id in range(self.start_layer, self.end_layer) + if isinstance(self.model.layers[layer_id].mlp, Qwen3MoeSparseMoeBlock) + } @classmethod def get_model_config_for_expert_location(cls, config): From 2b186af02602eecc093d6167de1ae22d85847c55 Mon Sep 17 00:00:00 2001 From: jianzhao-xu <978716854@qq.com> Date: Mon, 30 Mar 2026 14:18:46 +0800 Subject: [PATCH 06/42] fix(grok): fallback to standard weight loading when no presharded files found (#156) Co-authored-by: Jianzhao Xu --- python/sglang/srt/models/grok.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/python/sglang/srt/models/grok.py b/python/sglang/srt/models/grok.py index 3cc6a48e79f9..8a6a5f0d5c88 100644 --- a/python/sglang/srt/models/grok.py +++ b/python/sglang/srt/models/grok.py @@ -981,6 +981,9 @@ def _prepare_presharded_weights( for pattern in allow_patterns: hf_weights_files += glob.glob(os.path.join(hf_folder, pattern)) + if not hf_weights_files: + return old_prepare_weights(self, model_name_or_path, revision, fall_back_to_pt) + if hf_weights_files[0].endswith("safetensors"): use_safetensors = True else: From 34f3a2cd35f76a5421f60994ac0bf92acd6b669d Mon Sep 17 00:00:00 2001 From: iridiumine <42236072+iridiumine@users.noreply.github.com> Date: Mon, 30 Mar 2026 15:42:11 +0800 Subject: [PATCH 07/42] revert: revert qwen3_5.py, use separate layers (#172) --- python/sglang/srt/models/qwen3_5.py | 350 ++++++---------------------- 1 file changed, 65 insertions(+), 285 deletions(-) diff --git a/python/sglang/srt/models/qwen3_5.py b/python/sglang/srt/models/qwen3_5.py index 45b55fa3bd09..107b73378c72 100644 --- a/python/sglang/srt/models/qwen3_5.py +++ b/python/sglang/srt/models/qwen3_5.py @@ -20,11 +20,6 @@ import torch import torch.nn as nn -import triton - -from sglang.jit_kernel.triton.gdn_fused_proj import ( - fused_qkvzba_split_reshape_cat_contiguous, -) # Configs from sglang.srt.configs.qwen3_5 import ( @@ -59,10 +54,6 @@ RowParallelLinear, ) from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE -from sglang.srt.layers.parameter import ( - BlockQuantScaleParameter, - PerTensorScaleParameter, -) from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_linear_attention import RadixLinearAttention @@ -79,14 +70,11 @@ # Models from sglang.srt.models.qwen3_vl import Qwen3VLForConditionalGeneration -from sglang.srt.server_args import get_global_server_args # Utils from sglang.srt.utils import ( LazyValue, add_prefix, - cpu_has_amx_support, - is_cpu, is_cuda, is_npu, make_layers, @@ -97,9 +85,6 @@ logger = logging.getLogger(__name__) _is_cuda = is_cuda() _is_npu = is_npu() -_is_cpu = is_cpu() -_is_amx_available = cpu_has_amx_support() - cached_get_processor = lru_cache(get_processor) @@ -144,47 +129,63 @@ def __init__( ) self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1) - # projection of the input hidden states - self.in_proj_qkvz = self.create_qkvz_proj( - hidden_size=self.hidden_size, - key_dim=self.key_dim, - value_dim=self.value_dim, + # Split projection layers (following vLLM's implementation) + # Instead of fused in_proj_qkvz and in_proj_ba, use separate layers + self.in_proj_qkv = MergedColumnParallelLinear( + input_size=self.hidden_size, + output_sizes=[self.key_dim, self.key_dim, self.value_dim], + bias=False, quant_config=quant_config, - prefix=add_prefix("in_proj_qkvz", prefix), tp_rank=self.attn_tp_rank, tp_size=self.attn_tp_size, + prefix=add_prefix("in_proj_qkv", prefix), ) - - self.in_proj_ba = self.create_ba_proj( - hidden_size=self.hidden_size, - num_v_heads=self.num_v_heads, + self.in_proj_z = ColumnParallelLinear( + input_size=self.hidden_size, + output_size=self.value_dim, + bias=False, quant_config=quant_config, - prefix=add_prefix("in_proj_ba", prefix), tp_rank=self.attn_tp_rank, tp_size=self.attn_tp_size, + prefix=add_prefix("in_proj_z", prefix), + ) + self.in_proj_b = ColumnParallelLinear( + input_size=self.hidden_size, + output_size=self.num_v_heads, + bias=False, + quant_config=quant_config, + tp_rank=self.attn_tp_rank, + tp_size=self.attn_tp_size, + prefix=add_prefix("in_proj_b", prefix), + ) + self.in_proj_a = ColumnParallelLinear( + input_size=self.hidden_size, + output_size=self.num_v_heads, + bias=False, + quant_config=quant_config, + tp_rank=self.attn_tp_rank, + tp_size=self.attn_tp_size, + prefix=add_prefix("in_proj_a", prefix), ) - - # Override weight loaders for packed checkpoint format. - # Important: for FP8, this must cover not only `.weight` but also - # `weight_scale_inv` / `weight_scale` / `input_scale` if present. - self._bind_packed_weight_loaders(self.in_proj_qkvz) - self._bind_packed_weight_loaders(self.in_proj_ba) # Conv1d weight loader setup query_key_settings = (self.key_dim, 0, False) value_settings = (self.value_dim, 0, False) - self._override_weight_loader( + delattr(self.conv1d.weight, "weight_loader") + set_weight_attrs( self.conv1d.weight, - mamba_v2_sharded_weight_loader( - [ - query_key_settings, - query_key_settings, - value_settings, - ], - self.attn_tp_size, - self.attn_tp_rank, - ), + { + "weight_loader": mamba_v2_sharded_weight_loader( + [ + query_key_settings, + query_key_settings, + value_settings, + ], + self.attn_tp_size, + self.attn_tp_rank, + ) + }, ) # State parameters @@ -201,6 +202,7 @@ def __init__( conv_weights = self.conv1d.weight.view( self.conv1d.weight.size(0), self.conv1d.weight.size(2) ) + # RadixLinearAttention layer self.attn = RadixLinearAttention( layer_id=layer_id, num_q_heads=self.num_k_heads // self.attn_tp_size, @@ -216,6 +218,7 @@ def __init__( dt_bias=self.dt_bias, ) + # Normalization layer self.norm = RMSNormGated( self.head_v_dim, eps=self.layer_norm_epsilon, @@ -225,6 +228,7 @@ def __init__( dtype=config.torch_dtype, ) + # Output projection self.out_proj = RowParallelLinear( self.value_dim, self.hidden_size, @@ -237,190 +241,16 @@ def __init__( prefix=add_prefix("out_proj", prefix), ) - @staticmethod - def _override_weight_loader(param, loader): - """Robustly override loader for: - 1) BasevLLMParameter subclasses: real storage is `_weight_loader` - 2) regular Parameters that already have mutable `weight_loader` - 3) regular Parameters without `weight_loader` yet - """ - if hasattr(param, "_weight_loader"): - # FP8 / quantized BasevLLMParameter path - param._weight_loader = loader - return - - if hasattr(param, "weight_loader"): - # Regular parameter/tensor that already has a mutable attr. - # Do NOT call set_weight_attrs here, because it asserts when - # overwriting an existing attribute. - param.weight_loader = loader - return - - # Fresh attribute on a normal tensor/Parameter - set_weight_attrs(param, {"weight_loader": loader}) - - def _bind_packed_weight_loaders(self, module): - """Bind packed-checkpoint-aware loaders to all relevant params of a merged module.""" - for attr_name in ("weight", "weight_scale_inv", "weight_scale", "input_scale"): - param = getattr(module, attr_name, None) - if param is None: - continue - original_loader = getattr(param, "weight_loader", None) - if original_loader is None: - continue - wrapped_loader = self._make_packed_weight_loader(module, original_loader) - self._override_weight_loader(param, wrapped_loader) - - @staticmethod - def _get_split_sizes_for_param(module, param, loaded_shard_id): - """Return checkpoint-side split sizes for this param type.""" - if isinstance(param, BlockQuantScaleParameter): - # Split by output blocks, not raw output sizes. - block_n, _ = module.quant_method.quant_config.weight_block_size - block_n = 1 if getattr(param, "format_ue8m0", False) else block_n - return [ - (module.output_sizes[idx] + block_n - 1) // block_n - for idx in loaded_shard_id - ] - - if isinstance(param, PerTensorScaleParameter): - # One logical scale per logical shard. - return [1 for _ in loaded_shard_id] - - # Normal weight / non-block quant tensor - return [module.output_sizes[idx] for idx in loaded_shard_id] - - @classmethod - def _make_packed_weight_loader(cls, module, original_weight_loader): - """Wrap the param's original loader so split checkpoints: - - in_proj_qkv + in_proj_z -> merged in_proj_qkvz - - in_proj_b + in_proj_a -> merged in_proj_ba - can load correctly for both normal and FP8 params. - """ - - def weight_loader(param, loaded_weight, loaded_shard_id=None): - # Only intercept split-checkpoint tuple shards. - # int shard_id and None should preserve original behavior. - if isinstance(loaded_shard_id, tuple): - split_sizes = cls._get_split_sizes_for_param( - module, param, loaded_shard_id - ) - - if len(loaded_weight.shape) == 0: - # Scalar only makes sense for a single logical shard. - assert len(split_sizes) == 1 and split_sizes[0] == 1, ( - f"Unexpected scalar for tuple shard load: " - f"{loaded_shard_id=}, {split_sizes=}" - ) - chunks = [loaded_weight.reshape(1)] - else: - split_dim = getattr(param, "output_dim", 0) - chunks = loaded_weight.split(split_sizes, dim=split_dim) - - assert len(chunks) == len(loaded_shard_id), ( - f"Chunk/shard mismatch: {len(chunks)=}, " - f"{len(loaded_shard_id)=}, {split_sizes=}" - ) - - for idx, chunk in zip(loaded_shard_id, chunks): - # Delegate each chunk to the param's original int-shard loader. - original_weight_loader(param, chunk, idx) - return - - return original_weight_loader(param, loaded_weight, loaded_shard_id) - - return weight_loader - - def create_qkvz_proj( - self, - hidden_size: int, - key_dim: int, - value_dim: int, - quant_config: QuantizationConfig | None, - prefix: str, - tp_rank: Optional[int] = None, - tp_size: Optional[int] = None, - ) -> MergedColumnParallelLinear: - return MergedColumnParallelLinear( - input_size=hidden_size, - output_sizes=[key_dim, key_dim, value_dim, value_dim], - bias=False, - quant_config=quant_config, - prefix=prefix, - tp_rank=tp_rank, - tp_size=tp_size, - ) - - def create_ba_proj( - self, - hidden_size: int, - num_v_heads: int, - quant_config: QuantizationConfig | None, - prefix: str, - tp_rank: Optional[int] = None, - tp_size: Optional[int] = None, - ) -> MergedColumnParallelLinear: - # Qwen3.5 has separate in_proj_b and in_proj_a weights in the - # checkpoint, which are loaded into the fused in_proj_ba parameter - # via stacked_params_mapping with shard_id 0 and 1 respectively. - return MergedColumnParallelLinear( - input_size=hidden_size, - output_sizes=[num_v_heads, num_v_heads], - bias=False, - quant_config=quant_config, - prefix=prefix, - tp_rank=tp_rank, - tp_size=tp_size, - ) - def fix_query_key_value_ordering( self, - mixed_qkvz: torch.Tensor, - mixed_ba: torch.Tensor, + mixed_qkv, + z, + b, + a, ): - """ - Derives `query`, `key` and `value` tensors from `mixed_qkvzba`. - """ - k_tp = self.key_dim // self.attn_tp_size - v_tp = self.value_dim // self.attn_tp_size - nv_tp = self.num_v_heads // self.attn_tp_size - - # Directly split, no head group reshape - query, key, value, z = mixed_qkvz.split([k_tp, k_tp, v_tp, v_tp], dim=-1) - b, a = mixed_ba.split([nv_tp, nv_tp], dim=-1) - - # value / z reshape to (seq, num_v_heads/tp, head_v_dim) - value = value.reshape(value.size(0), -1, self.head_v_dim) - z = z.reshape(z.size(0), -1, self.head_v_dim) - - return query, key, value, z, b, a - - def _forward_input_proj(self, hidden_states: torch.Tensor): - if ( - _is_cpu - or _is_npu - or not get_global_server_args().disable_piecewise_cuda_graph - ): - DUAL_STREAM_TOKEN_THRESHOLD = 0 - else: - DUAL_STREAM_TOKEN_THRESHOLD = 1024 - - seq_len, _ = hidden_states.shape - if ( - self.alt_stream is not None - and get_is_capture_mode() - and seq_len < DUAL_STREAM_TOKEN_THRESHOLD - ): - current_stream = torch.cuda.current_stream() - self.alt_stream.wait_stream(current_stream) - projected_states_qkvz, _ = self.in_proj_qkvz(hidden_states) - with torch.cuda.stream(self.alt_stream): - projected_states_ba, _ = self.in_proj_ba(hidden_states) - current_stream.wait_stream(self.alt_stream) - else: - projected_states_qkvz, _ = self.in_proj_qkvz(hidden_states) - projected_states_ba, _ = self.in_proj_ba(hidden_states) - return projected_states_qkvz, projected_states_ba + raise NotImplementedError( + "Qwen3.5 Series dont need to fix query key value ordering" + ) def forward( self, @@ -433,60 +263,30 @@ def forward( 2. Core attention (custom op) 3. Output projection """ - projected_states_qkvz, projected_states_ba = self._forward_input_proj( - hidden_states - ) + seq_len, _ = hidden_states.shape + + mixed_qkv, _ = self.in_proj_qkv(hidden_states) + z, _ = self.in_proj_z(hidden_states) + z = z.reshape(z.size(0), -1, self.head_v_dim) + b, _ = self.in_proj_b(hidden_states) + a, _ = self.in_proj_a(hidden_states) + + b = b.contiguous() + a = a.contiguous() - if self.num_v_heads // self.num_k_heads in [1, 2, 4] and not _is_cpu: - mixed_qkv, z, b, a = fused_qkvzba_split_reshape_cat_contiguous( - projected_states_qkvz, - projected_states_ba, - triton.cdiv(self.num_k_heads, self.attn_tp_size), - triton.cdiv(self.num_v_heads, self.attn_tp_size), - self.head_k_dim, - self.head_v_dim, - ) - elif _is_cpu and _is_amx_available: - mixed_qkv, z, b, a = ( - torch.ops.sgl_kernel.fused_qkvzba_split_reshape_cat_cpu( - projected_states_qkvz, - projected_states_ba, - self.num_k_heads // self.attn_tp_size, - self.num_v_heads // self.attn_tp_size, - self.head_k_dim, - self.head_v_dim, - ) - ) - else: - query, key, value, z, b, a = self.fix_query_key_value_ordering( - projected_states_qkvz, projected_states_ba - ) - query, key, value = map( - lambda x: x.reshape(x.shape[0], -1), (query, key, value) - ) - mixed_qkv = torch.cat((query, key, value), dim=-1) core_attn_out = self.attn( - forward_batch, + forward_batch=forward_batch, mixed_qkv=mixed_qkv, a=a, b=b, ) z_shape_og = z.shape - # reshape input data into 2D tensor core_attn_out = core_attn_out.reshape(-1, core_attn_out.shape[-1]) z = z.reshape(-1, z.shape[-1]) - - # Add padding for DP-Attn - if core_attn_out.shape != z.shape: - core_attn_out_pad = torch.zeros_like(z) - core_attn_out_pad[: core_attn_out.shape[0], :] = core_attn_out - core_attn_out = core_attn_out_pad - core_attn_out = self.norm(core_attn_out, z) core_attn_out = core_attn_out.reshape(z_shape_og) - core_attn_out = core_attn_out.reshape(*core_attn_out.shape[:-2], -1) - + core_attn_out = core_attn_out.flatten(-2) # ... h d -> ... (h d) output, _ = self.out_proj(core_attn_out) return output @@ -1018,11 +818,6 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): ("qkv_proj", "v_proj", "v"), ("gate_up_proj", "gate_proj", 0), ("gate_up_proj", "up_proj", 1), - # GDN - ("in_proj_qkvz.", "in_proj_qkv.", (0, 1, 2)), - ("in_proj_qkvz.", "in_proj_z.", 3), - ("in_proj_ba.", "in_proj_b.", 0), - ("in_proj_ba.", "in_proj_a.", 1), ] loaded_params: Set[str] = set() @@ -1099,11 +894,6 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): ("qkv_proj", "v_proj", "v"), ("gate_up_proj", "gate_proj", 0), ("gate_up_proj", "up_proj", 1), - # GDN - ("in_proj_qkvz.", "in_proj_qkv.", (0, 1, 2)), - ("in_proj_qkvz.", "in_proj_z.", 3), - ("in_proj_ba.", "in_proj_b.", 0), - ("in_proj_ba.", "in_proj_a.", 1), ] # Params for weights, fp8 weight scales, fp8 activation scales @@ -1337,11 +1127,6 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): ("qkv_proj", "v_proj", "v"), ("gate_up_proj", "gate_proj", 0), ("gate_up_proj", "up_proj", 1), - # GDN fused projections - ("in_proj_qkvz.", "in_proj_qkv.", (0, 1, 2)), - ("in_proj_qkvz.", "in_proj_z.", 3), - ("in_proj_ba.", "in_proj_b.", 0), - ("in_proj_ba.", "in_proj_a.", 1), ] loaded_params: Set[str] = set() @@ -1438,11 +1223,6 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): ("qkv_proj", "v_proj", "v"), ("gate_up_proj", "gate_proj", 0), ("gate_up_proj", "up_proj", 1), - # GDN fused projections - ("in_proj_qkvz.", "in_proj_qkv.", (0, 1, 2)), - ("in_proj_qkvz.", "in_proj_z.", 3), - ("in_proj_ba.", "in_proj_b.", 0), - ("in_proj_ba.", "in_proj_a.", 1), ] # Params for weights, fp8 weight scales, fp8 activation scales From 8df71a849a9e51fedc05c42046a43b702a68cb13 Mon Sep 17 00:00:00 2001 From: jianzhao-xu <978716854@qq.com> Date: Mon, 30 Mar 2026 17:11:51 +0800 Subject: [PATCH 08/42] [bugfix]GLM-4V model (#176) Co-authored-by: Jianzhao Xu --- python/sglang/srt/models/glm4v.py | 2 ++ python/sglang/srt/multimodal/processors/base_processor.py | 4 +++- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/models/glm4v.py b/python/sglang/srt/models/glm4v.py index 7cfa1e71c1d7..e2ff4da96d31 100644 --- a/python/sglang/srt/models/glm4v.py +++ b/python/sglang/srt/models/glm4v.py @@ -414,6 +414,7 @@ def __init__( num_heads=self.num_heads, quant_config=quant_config, prefix=add_prefix(f"blocks.{layer_idx}", prefix), + num_dummy_heads=vision_config.num_dummy_heads, rms_norm_eps=vision_config.rms_norm_eps, attn_qkv_bias=vision_config.attention_bias, use_data_parallel=use_data_parallel, @@ -553,6 +554,7 @@ def __init__( self.pp_group = get_pp_group() self.config = config self.use_data_parallel = get_global_server_args().mm_enable_dp_encoder + vision_utils.update_vit_attn_dummy_heads_config(self.config) self.visual = Glm4vVisionModel( config.vision_config, quant_config=quant_config, diff --git a/python/sglang/srt/multimodal/processors/base_processor.py b/python/sglang/srt/multimodal/processors/base_processor.py index 859c225be13d..9ce169024570 100644 --- a/python/sglang/srt/multimodal/processors/base_processor.py +++ b/python/sglang/srt/multimodal/processors/base_processor.py @@ -407,7 +407,9 @@ def process_mm_data( kwargs["device"] = "xpu" elif not _is_npu: kwargs["device"] = "cuda" - else: + elif processor.__class__.__name__ not in { + "Glm4vProcessor", + }: # Note: for qwen-vl, processor has some reshape issue because of dims restriction on Ascend. from sglang.srt.hardware_backend.npu.modules.qwen_vl_processor import ( npu_apply_qwen_image_preprocess_patch, From 6d57e7c78ff779f6a9209013b9aca3f2c64ff0d7 Mon Sep 17 00:00:00 2001 From: chx96642264 Date: Mon, 30 Mar 2026 17:29:18 +0800 Subject: [PATCH 09/42] NPU can use piece cuda graph when the piece cuda graph is explicitly declared (#169) Co-authored-by: chx96642264 --- python/sglang/srt/server_args.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index c770f3d161f4..e7db9bf9f579 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -629,6 +629,7 @@ class ServerArgs: enable_single_batch_overlap: bool = False tbo_token_distribution_threshold: float = 0.48 enable_torch_compile: bool = False + enable_piecewise_cuda_graph: bool = False disable_piecewise_cuda_graph: bool = False enforce_piecewise_cuda_graph: bool = False enable_torch_compile_debug_mode: bool = False @@ -1117,6 +1118,10 @@ def _handle_piecewise_cuda_graph(self): if self.enable_eplb or self.expert_distribution_recorder_mode is not None: self.disable_piecewise_cuda_graph = True + # NPU can use this function when the piece cuda graph is explicitly declared + if self.enable_piecewise_cuda_graph: + self.disable_piecewise_cuda_graph = False + def _handle_gpu_memory_settings(self, gpu_mem): """ Configure GPU memory-dependent settings including @@ -5391,8 +5396,8 @@ def add_cli_args(parser: argparse.ArgumentParser): ) parser.add_argument( "--enable-piecewise-cuda-graph", - action=DeprecatedAction, - help="Deprecated: Piecewise cuda graph is enabled by default. Use --enforce-piecewise-cuda-graph to skip auto-disable conditions.", + action="store_true", + help="Optimize the model with piecewise cuda graph for extend/prefill only.", ) parser.add_argument( "--enforce-piecewise-cuda-graph", From dbcfc313f76feb0ae9fed1da2a33aa455f4910fe Mon Sep 17 00:00:00 2001 From: khalilzhk Date: Mon, 30 Mar 2026 18:57:34 +0800 Subject: [PATCH 10/42] Bug fix for llama eagle3 (#177) --- python/sglang/srt/models/llama.py | 9 +++++++-- python/sglang/srt/models/llama_eagle3.py | 8 ++++++-- 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/python/sglang/srt/models/llama.py b/python/sglang/srt/models/llama.py index d8810c508c48..f955ac750d34 100644 --- a/python/sglang/srt/models/llama.py +++ b/python/sglang/srt/models/llama.py @@ -252,8 +252,13 @@ def __init__( ) -> None: super().__init__() self.hidden_size = config.hidden_size - rope_theta = config.rope_parameters["rope_theta"] - rope_scaling = config.rope_parameters + rope_parameters = getattr(config, "rope_parameters", None) + if rope_parameters is not None: + rope_theta = rope_parameters.get("rope_theta", 10000) + rope_scaling = rope_parameters + else: + rope_theta = getattr(config, "rope_theta", 10000) + rope_scaling = getattr(config, "rope_scaling", None) if rope_scaling is not None and getattr( config, "original_max_position_embeddings", None ): diff --git a/python/sglang/srt/models/llama_eagle3.py b/python/sglang/srt/models/llama_eagle3.py index a4022dba0c95..e9a383ddcfd7 100644 --- a/python/sglang/srt/models/llama_eagle3.py +++ b/python/sglang/srt/models/llama_eagle3.py @@ -111,13 +111,17 @@ def __init__( super().__init__() self.config = config - rope_scaling = config.rope_parameters + rope_parameters = getattr(config, "rope_parameters", None) + if rope_parameters is not None: + rope_scaling = rope_parameters + else: + rope_scaling = getattr(config, "rope_scaling", None) self.is_mrope_enabled = ( rope_scaling is not None and "mrope_section" in rope_scaling ) # fix rope_scaling for qwen2.5-vl if self.is_mrope_enabled: - config.rope_parameters["rope_type"] = "default" + rope_scaling["rope_type"] = "default" self.vocab_size = config.vocab_size self.embed_tokens = VocabParallelEmbedding( From 30aa7925e317b0a1a004cb576dd6d06415c27656 Mon Sep 17 00:00:00 2001 From: heziiop <1624120705@qq.com> Date: Mon, 30 Mar 2026 20:28:37 +0800 Subject: [PATCH 11/42] fix eagle3 accept rate (#179) --- .../npu/attention/ascend_backend.py | 26 ++++++++++++++++--- .../eagle_draft_npu_graph_runner.py | 18 ++++++++----- 2 files changed, 35 insertions(+), 9 deletions(-) diff --git a/python/sglang/srt/hardware_backend/npu/attention/ascend_backend.py b/python/sglang/srt/hardware_backend/npu/attention/ascend_backend.py index e6e7dd5ccbad..1ac339d0bf22 100644 --- a/python/sglang/srt/hardware_backend/npu/attention/ascend_backend.py +++ b/python/sglang/srt/hardware_backend/npu/attention/ascend_backend.py @@ -209,10 +209,14 @@ def get_splitfuse_attn_mask( class AscendAttnBackend(AttentionBackend): - def __init__(self, model_runner: ModelRunner): + def __init__(self, model_runner: ModelRunner, speculative_step_id: int = 0): super().__init__() self.forward_metadata = None self.device = model_runner.device + self.speculative_step_id = speculative_step_id + self.speculative_step_offset_npu = torch.tensor( + speculative_step_id + 1, device="npu" + ) self.page_size = model_runner.page_size self.model_dtype = model_runner.model_config.dtype self.use_mla = model_runner.model_config.attention_arch == AttentionArch.MLA @@ -301,6 +305,11 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): seq_lens_max = forward_batch.seq_lens.max() if forward_batch.forward_mode.is_target_verify(): seq_lens_max += self.speculative_num_draft_tokens + elif ( + forward_batch.forward_mode.is_decode_or_idle() + and forward_batch.spec_info is not None + ): + seq_lens_max += self.speculative_step_id + 1 self.forward_metadata.block_tables = ( forward_batch.req_to_token_pool.req_to_token[ forward_batch.req_pool_indices, :seq_lens_max @@ -343,6 +352,11 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): if forward_batch.forward_mode.is_target_verify(): self.forward_metadata.seq_lens_cpu_int += self.speculative_num_draft_tokens + elif ( + forward_batch.forward_mode.is_decode_or_idle() + and forward_batch.spec_info is not None + ): + self.forward_metadata.seq_lens_cpu_int += self.speculative_step_id + 1 if ( self.use_mla @@ -498,6 +512,8 @@ def init_forward_metadata_replay_cuda_graph( max_len = seq_lens_cpu[:bs].max().item() if forward_mode.is_target_verify(): max_len += self.speculative_num_draft_tokens + elif forward_mode.is_decode_or_idle() and spec_info is not None: + max_len += self.speculative_step_id + 1 max_seq_pages = (max_len + self.page_size - 1) // self.page_size if self.is_hybrid_swa: @@ -519,6 +535,8 @@ def init_forward_metadata_replay_cuda_graph( if forward_mode.is_target_verify(): seq_lens = seq_lens + self.speculative_num_draft_tokens + elif forward_mode.is_decode_or_idle() and spec_info is not None: + seq_lens = seq_lens + self.speculative_step_offset_npu metadata.seq_lens[:bs].copy_(seq_lens[:bs]) self.forward_metadata = metadata @@ -2062,8 +2080,10 @@ def __init__( self.speculative_num_steps = speculative_num_steps self.attn_backends = [] - for _ in range(self.speculative_num_steps): - self.attn_backends.append(AscendAttnBackend(model_runner)) + for step_id in range(self.speculative_num_steps): + self.attn_backends.append( + AscendAttnBackend(model_runner, speculative_step_id=step_id) + ) def common_template(self, forward_batch: ForwardBatch, call_fn: int): assert forward_batch.spec_info is not None diff --git a/python/sglang/srt/hardware_backend/npu/graph_runner/eagle_draft_npu_graph_runner.py b/python/sglang/srt/hardware_backend/npu/graph_runner/eagle_draft_npu_graph_runner.py index 4ffc5fdd2d57..77c5d4f2405a 100644 --- a/python/sglang/srt/hardware_backend/npu/graph_runner/eagle_draft_npu_graph_runner.py +++ b/python/sglang/srt/hardware_backend/npu/graph_runner/eagle_draft_npu_graph_runner.py @@ -83,22 +83,28 @@ def _get_update_attr_name(self): def _get_update_attr_type(self): return self.attr_type[AttentionArch.MLA] - def _replay_update(self, seq_lens): + def _replay_update(self, seq_lens_list): if isinstance(self.update_attr_type, torch.Tensor): - seq_lens = torch.from_numpy(np.array(seq_lens).astype(np.int32)) + seq_lens = torch.from_numpy(np.array(seq_lens_list).astype(np.int32)) self.graphs[self.bs].update( - cpu_update_input=[{self.update_attr_name: seq_lens}] + cpu_update_input=[ + {self.update_attr_name: seq_lens} for seq_lens in seq_lens_list + ] ) def _replay(self, forward_batch: ForwardBatch): self.update_attr_name = self._get_update_attr_name() self.update_attr_type = self._get_update_attr_type() if not is_deepseek_nsa(self.model_runner.model_config.hf_config): - seq_lens = forward_batch.seq_lens_cpu.tolist() + [0] * ( - self.bs - self.raw_bs + seq_lens_for_each_draft_step = [] + for speculative_step_id in range(self.speculative_num_steps - 1): + seq_lens_cpu = forward_batch.seq_lens_cpu + speculative_step_id + 1 + seq_lens = seq_lens_cpu.tolist() + [0] * (self.bs - self.raw_bs) + seq_lens_for_each_draft_step.append(seq_lens) + thread = threading.Thread( + target=self._replay_update, args=(seq_lens_for_each_draft_step,) ) - thread = threading.Thread(target=self._replay_update, args=(seq_lens,)) thread.start() self.graphs[self.bs].replay() thread.join() From 841f4cdd34be8875e900eba37aa0265547c73255 Mon Sep 17 00:00:00 2001 From: iridiumine <42236072+iridiumine@users.noreply.github.com> Date: Mon, 30 Mar 2026 20:56:06 +0800 Subject: [PATCH 12/42] Support MTP for Qwen3.5 (#154) * feat: support Ascend NPU MTP adaptation for GDN attention backend --- .../npu/attention/ascend_gdn_backend.py | 423 ++++++++++++++++++ .../hardware_backend/npu/memory_pool_npu.py | 24 + .../layers/attention/attention_registry.py | 8 +- .../layers/attention/fla/fused_gdn_gating.py | 66 +++ .../attention/hybrid_linear_attn_backend.py | 72 ++- .../layers/attention/mamba/mamba2_metadata.py | 1 + python/sglang/srt/layers/layernorm.py | 9 +- python/sglang/srt/mem_cache/memory_pool.py | 9 + python/sglang/srt/models/qwen3_5_mtp.py | 15 +- python/sglang/srt/models/qwen3_next_mtp.py | 12 + 10 files changed, 629 insertions(+), 10 deletions(-) create mode 100644 python/sglang/srt/hardware_backend/npu/attention/ascend_gdn_backend.py diff --git a/python/sglang/srt/hardware_backend/npu/attention/ascend_gdn_backend.py b/python/sglang/srt/hardware_backend/npu/attention/ascend_gdn_backend.py new file mode 100644 index 000000000000..b5fc8445cfd5 --- /dev/null +++ b/python/sglang/srt/hardware_backend/npu/attention/ascend_gdn_backend.py @@ -0,0 +1,423 @@ +from typing import Optional, Tuple, Union + +import torch +from sgl_kernel_npu.fla.fused_gdn_gating import fused_gdn_gating_npu +from sgl_kernel_npu.mamba.causal_conv1d import ( + causal_conv1d_fn_npu, + causal_conv1d_update_npu, +) + +from sglang.srt.layers.attention.fla.fused_gdn_gating import ( + fused_gdn_gating_kernel_without_sigmoid, +) +from sglang.srt.layers.attention.linear.gdn_backend import ( + GDNAttnBackend, + GDNKernelDispatcher, +) +from sglang.srt.layers.attention.linear.utils import ( + get_linear_attn_decode_backend, + get_linear_attn_prefill_backend, +) +from sglang.srt.layers.radix_linear_attention import RadixLinearAttention +from sglang.srt.mem_cache.memory_pool import MambaPool +from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode +from sglang.srt.model_executor.model_runner import ModelRunner +from sglang.srt.speculative.eagle_info import EagleDraftInput, EagleVerifyInput +from sglang.srt.utils import is_cpu + +fused_gdn_gating = fused_gdn_gating_npu +causal_conv1d_fn = causal_conv1d_fn_npu +causal_conv1d_update = causal_conv1d_update_npu + + +class AscendGDNKernelDispatcher(GDNKernelDispatcher): + pass + + +class AscendGDNAttnBackend(GDNAttnBackend): + + def __init__(self, model_runner: ModelRunner): + super().__init__(model_runner) + decode_backend = get_linear_attn_decode_backend() + prefill_backend = get_linear_attn_prefill_backend() + self.kernel_dispatcher = AscendGDNKernelDispatcher( + decode_backend, prefill_backend + ) + + def prepare_gdn_inputs( + self, + bs: int, + forward_mode: ForwardMode, + spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]], + ): + cache_indices = self.forward_metadata.mamba_cache_indices + self.num_accepted_tokens = torch.ones( + [bs], dtype=torch.int32, device=cache_indices.device + ) + self.actual_seq_lengths = torch.ones( + [bs], dtype=torch.int32, device=cache_indices.device + ) + if forward_mode.is_target_verify(): + seq_len = spec_info.draft_token_num + self.actual_seq_lengths = self.actual_seq_lengths * seq_len + start_indices = cache_indices * seq_len + offset = torch.arange(seq_len, device=start_indices.device) + ranges = start_indices.unsqueeze(1) + offset + self.ssm_state_indices = ranges.flatten().to(torch.int32) + else: + self.ssm_state_indices = cache_indices + + def init_forward_metadata(self, forward_batch: ForwardBatch): + if forward_batch.forward_mode.is_draft_extend(True): + return + super().init_forward_metadata(forward_batch) + self.prepare_gdn_inputs( + forward_batch.batch_size, + forward_batch.forward_mode, + forward_batch.spec_info, + ) + self.graph_mode = False + + def init_forward_metadata_capture_cuda_graph( + self, + bs: int, + num_tokens: int, + req_pool_indices: torch.Tensor, + seq_lens: torch.Tensor, + encoder_lens: Optional[torch.Tensor], + forward_mode: ForwardMode, + spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]], + ): + if forward_mode.is_draft_extend(True): + return + super().init_forward_metadata_capture_cuda_graph( + bs, + num_tokens, + req_pool_indices, + seq_lens, + encoder_lens, + forward_mode, + spec_info, + ) + self.prepare_gdn_inputs(bs, forward_mode, spec_info) + self.graph_mode = True + + def init_forward_metadata_replay_cuda_graph( + self, + bs: int, + req_pool_indices: torch.Tensor, + seq_lens: torch.Tensor, + seq_lens_sum: int, + encoder_lens: Optional[torch.Tensor], + forward_mode: ForwardMode, + spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]], + seq_lens_cpu: Optional[torch.Tensor], + ): + if forward_mode.is_draft_extend(True): + return + super().init_forward_metadata_replay_cuda_graph( + bs, + req_pool_indices, + seq_lens, + seq_lens_sum, + encoder_lens, + forward_mode, + spec_info, + seq_lens_cpu, + ) + self.prepare_gdn_inputs(bs, forward_mode, spec_info) + self.graph_mode = True + + def forward_decode( + self, + layer: RadixLinearAttention, + forward_batch: ForwardBatch, + mixed_qkv: Union[torch.Tensor, Tuple[torch.Tensor, ...]], + a: torch.Tensor, + b: torch.Tensor, + **kwargs, + ): + layer_cache = self.req_to_token_pool.mamba2_layer_cache(layer.layer_id) + conv_states = layer_cache.conv[0] + ssm_states = layer_cache.temporal + query_start_loc = self.forward_metadata.query_start_loc + cache_indices = self.forward_metadata.mamba_cache_indices + + assert isinstance(mixed_qkv, torch.Tensor) + conv_states_tmp = conv_states.transpose(1, 2).clone() + mixed_qkv = causal_conv1d_update( + mixed_qkv, + conv_states_tmp, + layer.conv_weights, + layer.bias, + layer.activation, + conv_state_indices=cache_indices, + ) + conv_states[:] = conv_states_tmp.transpose(1, 2) + + query, key, value = torch.split( + mixed_qkv, + [layer.q_dim, layer.k_dim, layer.v_dim], + dim=-1, + ) + bs = forward_batch.batch_size + query = query.view(1, bs, layer.num_q_heads, layer.head_q_dim) + key = key.view(1, bs, layer.num_k_heads, layer.head_k_dim) + value = value.view(1, bs, layer.num_v_heads, layer.head_v_dim) + + core_attn_out = self.kernel_dispatcher.decode( + q=query, + k=key, + v=value, + a=a, + b=b, + A_log=layer.A_log, + dt_bias=layer.dt_bias, + ssm_states=ssm_states, + cache_indices=cache_indices, + query_start_loc=query_start_loc, + ) + + self._track_mamba_state_decode( + forward_batch, conv_states, ssm_states, cache_indices + ) + return core_attn_out + + def forward_extend( + self, + layer: RadixLinearAttention, + forward_batch: ForwardBatch, + mixed_qkv: Union[torch.Tensor, Tuple[torch.Tensor, ...]], + a: torch.Tensor, + b: torch.Tensor, + **kwargs, + ): + assert isinstance(mixed_qkv, torch.Tensor) + seq_len = mixed_qkv.shape[0] + is_target_verify = forward_batch.forward_mode.is_target_verify() + forward_metadata = self.forward_metadata + + query_start_loc = forward_metadata.query_start_loc + cache_indices = forward_metadata.mamba_cache_indices + retrieve_next_token = forward_metadata.retrieve_next_token + retrieve_next_sibling = forward_metadata.retrieve_next_sibling + retrieve_parent_token = forward_metadata.retrieve_parent_token + + mamba_cache_params = self.req_to_token_pool.mamba2_layer_cache(layer.layer_id) + conv_states = mamba_cache_params.conv[0] + ssm_states = mamba_cache_params.temporal + if is_target_verify: + assert isinstance(mamba_cache_params, MambaPool.SpeculativeState) + intermediate_state_cache = mamba_cache_params.intermediate_ssm + intermediate_conv_window_cache = ( + mamba_cache_params.intermediate_conv_window[0] + ) + has_initial_states = torch.ones( + seq_len // forward_batch.spec_info.draft_token_num, + dtype=torch.bool, + device=forward_batch.input_ids.device, + ) + else: + has_initial_states = forward_batch.extend_prefix_lens > 0 + if is_target_verify: + batch_size = seq_len // forward_batch.spec_info.draft_token_num + draft_token_num = forward_batch.spec_info.draft_token_num + num_token_padding = mixed_qkv.shape[0] + batch_size = cache_indices.shape[0] + if ( + not self.graph_mode + and forward_batch.num_token_non_padded_cpu != num_token_padding + ): + mixed_qkv = mixed_qkv[: forward_batch.num_token_non_padded_cpu] + a = a[: forward_batch.num_token_non_padded_cpu] + b = b[: forward_batch.num_token_non_padded_cpu] + seq_len = forward_batch.num_token_non_padded_cpu + + mixed_qkv_reshaped = mixed_qkv.view(batch_size, draft_token_num, -1) + num_accepted_tokens = torch.full( + (batch_size,), + draft_token_num, + dtype=torch.int32, + device=mixed_qkv.device, + ) + mixed_qkv = torch.ops.npu.causal_conv1d_update( + mixed_qkv_reshaped, + layer.conv_weights.transpose(0, 1).contiguous(), + conv_states, + cache_indices, + layer.bias, + num_accepted_tokens, + None, + layer.activation == "silu", + self.pad_slot_id, + ).view(seq_len, -1) + else: + mixed_qkv = mixed_qkv.transpose(0, 1) + if ( + forward_batch.mamba_track_mask is not None + and forward_batch.mamba_track_mask.any() + ): + conv_dst = forward_batch.mamba_track_indices + mixed_qkv_to_track = mixed_qkv[ + :, forward_metadata.track_conv_indices + ].transpose(0, 1) + mask_indices = forward_batch.mamba_track_mask.nonzero(as_tuple=True)[0] + conv_states[conv_dst[mask_indices]] = mixed_qkv_to_track + kernel_size = layer.conv_weights.shape[-1] + conv_states_for_prefill = conv_states[:, -(kernel_size - 1) :, :] + conv_states_tmp = conv_states_for_prefill.transpose(1, 2).contiguous() + + mixed_qkv = causal_conv1d_fn( + mixed_qkv, + layer.conv_weights, + layer.bias, + activation=layer.activation, + conv_states=conv_states_tmp, + has_initial_state=has_initial_states, + cache_indices=cache_indices, + query_start_loc=query_start_loc, + seq_lens_cpu=forward_batch.extend_seq_lens_cpu, + ).transpose(0, 1)[:seq_len] + conv_states[:, -(kernel_size - 1) :, :] = conv_states_tmp.transpose( + 1, 2 + ).contiguous() + + if is_target_verify: + g, beta = fused_gdn_gating_kernel_without_sigmoid( + layer.A_log, a, b, layer.dt_bias + ) + beta = beta.unsqueeze(0) + num_heads, head_k_dim = layer.num_q_heads, layer.head_q_dim + num_value_heads, head_v_dim = layer.num_v_heads, layer.head_v_dim + + mixed_qkv_last_dim = mixed_qkv.shape[-1] + + mixed_qkv = mixed_qkv.view(batch_size, -1, mixed_qkv_last_dim) + beta = beta.view(batch_size, -1, num_value_heads) + g = g.view(batch_size, -1, num_value_heads) + + core_attn_out = self.fused_recurrent_gated_delta_rule_update( + mixed_qkv, + num_heads, + num_value_heads, + head_k_dim, + head_v_dim, + recurrent_state=ssm_states, + beta=beta, + g=g, + cache_indices=cache_indices, + intermediate_state=intermediate_state_cache, + ) + core_attn_out = core_attn_out.view(-1, num_value_heads, head_v_dim) + if (not self.graph_mode) and core_attn_out.shape[0] < num_token_padding: + core_attn_out = torch.cat( + [ + core_attn_out, + core_attn_out.new_zeros( + num_token_padding - core_attn_out.shape[0], + *core_attn_out.shape[1:], + ), + ], + dim=0, + ) + else: + query, key, value = torch.split( + mixed_qkv, + [layer.q_dim, layer.k_dim, layer.v_dim], + dim=-1, + ) + + actual_seq_len = query.shape[0] + query = query.view(1, actual_seq_len, layer.num_q_heads, layer.head_q_dim) + key = key.view(1, actual_seq_len, layer.num_k_heads, layer.head_k_dim) + value = value.view(1, actual_seq_len, layer.num_v_heads, layer.head_v_dim) + + g, beta = fused_gdn_gating(layer.A_log, a, b, layer.dt_bias) + core_attn_out, last_recurrent_state, h = self.kernel_dispatcher.extend( + q=query, + k=key, + v=value, + g=g, + beta=beta, + ssm_states=ssm_states, + cache_indices=cache_indices, + query_start_loc=query_start_loc, + ) + if is_cpu() and last_recurrent_state is not None: + last_recurrent_state = last_recurrent_state.to( + ssm_states.dtype, copy=False + ) + ssm_states[cache_indices] = last_recurrent_state + if not forward_batch.spec_algorithm.is_none(): + last_recurrent_state = last_recurrent_state.transpose(-1, -2).to( + ssm_states.dtype, copy=False + ) + else: + last_recurrent_state = last_recurrent_state.to( + ssm_states.dtype, copy=False + ) + ssm_states[cache_indices] = last_recurrent_state + if h is not None: + self._track_mamba_state_extend( + forward_batch, h, ssm_states, forward_metadata + ) + + return core_attn_out + + def fused_recurrent_gated_delta_rule_update( + self, + mix_qkv: torch.Tensor, + num_heads, + num_value_heads, + head_k_dim, + head_v_dim, + recurrent_state: torch.Tensor, + beta: torch.Tensor, + g: torch.Tensor, + cache_indices: torch.Tensor, + intermediate_state: Optional[torch.Tensor] = None, + ): + beta = beta.to(torch.bfloat16) + g = g.to(torch.float32) + batch_size = mix_qkv.shape[0] + seq_len = mix_qkv.shape[1] + scale = 1 / (head_k_dim**0.5) + + if intermediate_state is not None: + intermediate_state = intermediate_state.view( + -1, num_value_heads, head_k_dim, head_v_dim + ) + + if self.graph_mode: + num_accepted_tokens = torch.full( + [batch_size], 1, dtype=torch.int32, device=cache_indices.device + ) + actual_seq_lengths = torch.full( + [batch_size], seq_len, dtype=torch.int32, device=cache_indices.device + ) + ssm_state_indices = self.forward_metadata.mamba_cache_indices_gdn + else: + num_accepted_tokens = self.num_accepted_tokens + actual_seq_lengths = self.actual_seq_lengths + ssm_state_indices = self.ssm_state_indices + + attn_core_out = torch.ops.npu.recurrent_gated_delta_rule( + mix_qkv, + recurrent_state, + beta=beta, + scale=scale, + actual_seq_lengths=actual_seq_lengths, + ssm_state_indices=ssm_state_indices.view(batch_size, seq_len), + nk=num_heads, + nv=num_value_heads, + intermediate_state=intermediate_state, + cache_indices=cache_indices, + num_accepted_tokens=num_accepted_tokens, + g=g, + ) + + if intermediate_state is not None: + intermediate_state = intermediate_state.view( + -1, seq_len, num_value_heads, head_k_dim, head_v_dim + ) + return attn_core_out diff --git a/python/sglang/srt/hardware_backend/npu/memory_pool_npu.py b/python/sglang/srt/hardware_backend/npu/memory_pool_npu.py index ea81f4e589e6..e4f319fa5859 100644 --- a/python/sglang/srt/hardware_backend/npu/memory_pool_npu.py +++ b/python/sglang/srt/hardware_backend/npu/memory_pool_npu.py @@ -15,6 +15,30 @@ from sglang.srt.layers.radix_attention import RadixAttention +def _init_npu_conv_state( + conv_state_in, conv_state_shape, speculative_num_draft_tokens: Optional[int] = None +): + extra_conv_len = 0 + if speculative_num_draft_tokens is not None: + extra_conv_len = speculative_num_draft_tokens - 1 + + # conv_state shape (layers, pool_size, conv_wind + draft_step, dim) for conv1d ascendc ops require dim as last dim + conv_state = [ + torch.zeros( + size=( + conv_state_in.shape[0], + conv_state_in.shape[1], + conv_shape[1] + extra_conv_len, + conv_shape[0], + ), + dtype=conv_state_in.dtype, + device=conv_state_in.device, + ) + for conv_shape in conv_state_shape + ] + return conv_state + + class NPUMHATokenToKVPool(MHATokenToKVPool): def __init__( diff --git a/python/sglang/srt/layers/attention/attention_registry.py b/python/sglang/srt/layers/attention/attention_registry.py index 2353c15993fd..81ea048a91ef 100644 --- a/python/sglang/srt/layers/attention/attention_registry.py +++ b/python/sglang/srt/layers/attention/attention_registry.py @@ -192,7 +192,6 @@ def attn_backend_wrapper(runner: "ModelRunner", full_attn_backend: "AttentionBac HybridLinearAttnBackend, Mamba2AttnBackend, ) - from sglang.srt.layers.attention.linear.gdn_backend import GDNAttnBackend from sglang.srt.layers.attention.linear.kda_backend import KDAAttnBackend from sglang.srt.layers.attention.linear.lightning_backend import ( LightningAttentionBackend, @@ -202,6 +201,13 @@ def attn_backend_wrapper(runner: "ModelRunner", full_attn_backend: "AttentionBac ) from sglang.srt.utils import is_blackwell, is_npu + if is_npu(): + from sglang.srt.hardware_backend.npu.attention.ascend_gdn_backend import ( + AscendGDNAttnBackend as GDNAttnBackend, + ) + else: + from sglang.srt.layers.attention.linear.gdn_backend import GDNAttnBackend + check_environments() initialize_linear_attn_config(runner.server_args) if runner.hybrid_gdn_config is not None: diff --git a/python/sglang/srt/layers/attention/fla/fused_gdn_gating.py b/python/sglang/srt/layers/attention/fla/fused_gdn_gating.py index 6e92208ec130..a82c18ad9abb 100644 --- a/python/sglang/srt/layers/attention/fla/fused_gdn_gating.py +++ b/python/sglang/srt/layers/attention/fla/fused_gdn_gating.py @@ -67,3 +67,69 @@ def fused_gdn_gating( num_warps=1, ) return g, beta_output + + +@triton.jit +def fused_gdn_gating_kernel_without_sigmoid_kernel( + g, + A_log, + a, + dt_bias, + batch, + seq_len, + NUM_HEADS: tl.constexpr, + beta: tl.constexpr, + threshold: tl.constexpr, + BLK_BATCHES: tl.constexpr, + BLK_HEADS: tl.constexpr, +): + i_b, i_s, i_d = tl.program_id(0), tl.program_id(1), tl.program_id(2) + batch_off = i_b * BLK_BATCHES + tl.arange(0, BLK_BATCHES) + head_off = i_d * BLK_HEADS + tl.arange(0, BLK_HEADS) + head_mask = head_off < NUM_HEADS + a_off = ( + batch_off[:, None] * seq_len * NUM_HEADS + i_s * NUM_HEADS + head_off[None, :] + ) + a_mask = (batch_off[:, None] < batch) & head_mask[None, :] + blk_A_log = tl.load(A_log + head_off, mask=head_mask) + blk_bias = tl.load(dt_bias + head_off, mask=head_mask) + blk_a = tl.load(a + a_off, mask=a_mask) + x = blk_a.to(tl.float32) + blk_bias.to(tl.float32) + softplus_x = tl.where( + beta * x <= threshold, (1 / beta) * tl.log(1 + tl.exp(beta * x)), x + ) + blk_g = -tl.exp(blk_A_log.to(tl.float32)) * softplus_x + tl.store(g + a_off, blk_g.to(g.dtype.element_ty), mask=a_mask) + + +def fused_gdn_gating_kernel_without_sigmoid( + A_log: torch.Tensor, + a: torch.Tensor, + b: torch.Tensor, + dt_bias: torch.Tensor, + beta: float = 1.0, + threshold: float = 20.0, +) -> Tuple[torch.Tensor, torch.Tensor]: + batch, num_heads = a.shape + seq_len = 1 + g = torch.empty_like(a, dtype=torch.float32) + num_cores = 48 # num_vectorcore of NPU + NUM_BLK_BATCHES = triton.cdiv(num_cores, triton.cdiv(num_heads, 8)) + BLK_BATCHES = triton.cdiv(batch, NUM_BLK_BATCHES) + grid = (NUM_BLK_BATCHES, seq_len, triton.cdiv(num_heads, 8)) + fused_gdn_gating_kernel_without_sigmoid_kernel[grid]( + g, + A_log, + a, + dt_bias, + batch, + seq_len, + num_heads, + beta, + threshold, + BLK_BATCHES=BLK_BATCHES, + BLK_HEADS=8, + num_warps=1, + ) + g = g.unsqueeze(0) + return g, b diff --git a/python/sglang/srt/layers/attention/hybrid_linear_attn_backend.py b/python/sglang/srt/layers/attention/hybrid_linear_attn_backend.py index 91194c494396..4643147b27c5 100644 --- a/python/sglang/srt/layers/attention/hybrid_linear_attn_backend.py +++ b/python/sglang/srt/layers/attention/hybrid_linear_attn_backend.py @@ -22,13 +22,19 @@ from sglang.srt.server_args import get_global_server_args from sglang.srt.speculative.eagle_info import EagleDraftInput, EagleVerifyInput from sglang.srt.speculative.spec_info import SpecInput -from sglang.srt.utils import is_cpu +from sglang.srt.utils import is_cpu, is_npu if not is_cpu(): from sglang.srt.layers.attention.fla.chunk_delta_h import ( CHUNK_SIZE as FLA_CHUNK_SIZE, ) +if is_npu(): + from sgl_kernel_npu.mamba.mamba_state_update_triton import ( + conv_state_rollback, + move_intermediate_cache_dynamic_h_block, + ) + logger = logging.getLogger(__name__) @@ -142,6 +148,7 @@ def __init__(self, model_runner: ModelRunner): self.req_to_token_pool: HybridReqToTokenPool = model_runner.req_to_token_pool self.forward_metadata: ForwardMetadata = None self.state_indices_list = [] + self.state_indices_list_gdn = [] self.query_start_loc_list = [] self.retrieve_next_token_list = [] self.retrieve_next_sibling_list = [] @@ -409,6 +416,14 @@ def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int): (i + 1,), self.pad_slot_id, dtype=torch.int32, device=self.device ) ) + self.state_indices_list_gdn.append( + torch.full( + ((i + 1) * draft_token_num,), + self.pad_slot_id, + dtype=torch.int32, + device=self.device, + ) + ) self.query_start_loc_list.append( torch.zeros((i + 2,), dtype=torch.int32, device=self.device) ) @@ -462,6 +477,8 @@ def _capture_metadata( forward_mode: ForwardMode, spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]], ): + mamba_indices = self.req_to_token_pool.get_mamba_indices(req_pool_indices) + self.state_indices_list[bs - 1][: len(mamba_indices)].copy_(mamba_indices) if forward_mode.is_decode_or_idle(): self.query_start_loc_list[bs - 1].copy_( self.cached_cuda_graph_decode_query_start_loc[: bs + 1] @@ -470,10 +487,17 @@ def _capture_metadata( self.query_start_loc_list[bs - 1].copy_( self.cached_cuda_graph_verify_query_start_loc[: bs + 1] ) + start_indices = mamba_indices * spec_info.draft_token_num + offset = torch.arange( + spec_info.draft_token_num, device=start_indices.device + ) + ranges = start_indices.unsqueeze(1) + offset + ssm_state_indices = ranges.flatten().to(torch.int32) + self.state_indices_list_gdn[bs - 1][ + : len(mamba_indices) * spec_info.draft_token_num + ].copy_(ssm_state_indices) else: raise ValueError(f"Invalid forward mode: {forward_mode=}") - mamba_indices = self.req_to_token_pool.get_mamba_indices(req_pool_indices) - self.state_indices_list[bs - 1][: len(mamba_indices)].copy_(mamba_indices) # If topk > 1, we need to use retrieve_next_token and retrieve_next_sibling to handle the eagle tree custom attention mask if forward_mode.is_target_verify() and spec_info.topk > 1: @@ -491,6 +515,7 @@ def _capture_metadata( return ForwardMetadata( query_start_loc=self.query_start_loc_list[bs - 1], mamba_cache_indices=self.state_indices_list[bs - 1], + mamba_cache_indices_gdn=self.state_indices_list_gdn[bs - 1], ) def _replay_metadata( @@ -507,7 +532,7 @@ def _replay_metadata( # Make sure forward metadata is correctly handled for padding reqs req_pool_indices[bs - num_padding :] = 0 mamba_indices = self.req_to_token_pool.get_mamba_indices(req_pool_indices) - mamba_indices[bs - num_padding :] = -1 + mamba_indices[bs - num_padding :] = 0 self.state_indices_list[bs - 1][: len(mamba_indices)].copy_(mamba_indices) if forward_mode.is_decode_or_idle(): if num_padding == 0: @@ -522,6 +547,20 @@ def _replay_metadata( bs - num_padding ) elif forward_mode.is_target_verify(): + start_indices = ( + mamba_indices[: bs - num_padding] * spec_info.draft_token_num + ) + offset = torch.arange( + spec_info.draft_token_num, device=start_indices.device + ) + ranges = start_indices.unsqueeze(1) + offset + ssm_state_indices = ranges.flatten().to(torch.int32) + self.state_indices_list_gdn[bs - 1][ + : len(mamba_indices[: bs - num_padding]) * spec_info.draft_token_num + ].copy_(ssm_state_indices) + self.state_indices_list_gdn[bs - 1][ + len(mamba_indices[: bs - num_padding]) * spec_info.draft_token_num : + ] = 0 if num_padding == 0: self.query_start_loc_list[bs - 1].copy_( self.cached_cuda_graph_verify_query_start_loc[: bs + 1] @@ -556,10 +595,11 @@ def _replay_metadata( return ForwardMetadata( query_start_loc=self.query_start_loc_list[bs - 1], mamba_cache_indices=self.state_indices_list[bs - 1], + mamba_cache_indices_gdn=self.state_indices_list_gdn[bs - 1], ) def get_cuda_graph_seq_len_fill_value(self): - return 1 # Mamba attn does not use seq lens to index kv cache + return 0 # Mamba attn does not use seq lens to index kv cache def get_cpu_graph_seq_len_fill_value(self): return 1 @@ -960,6 +1000,23 @@ def update_mamba_state_after_mtp_verify( ssm_states = mamba_caches.temporal intermediate_state_cache = mamba_caches.intermediate_ssm intermediate_conv_window_cache = mamba_caches.intermediate_conv_window[0] + if is_npu(): + valid_state_indices = state_indices_tensor.to(torch.int64) # [N] + last_steps = accepted_steps.to(torch.int64) # [N] + + move_intermediate_cache_dynamic_h_block( + ssm_states, intermediate_state_cache, valid_state_indices, last_steps + ) + + draft_token_num = intermediate_state_cache.shape[2] + if valid_state_indices.numel() > 0: + conv_state_rollback( + conv_states, + valid_state_indices, + last_steps, + draft_token_num, + ) + return # Use fully fused kernel that handles masking internally # This avoids separate nonzero() and index_select() calls @@ -992,3 +1049,8 @@ def update_mamba_state_after_mtp_verify( mamba_track_indices, mamba_steps_to_track, ) + + def update_verify_buffers_to_fill_after_draft( + self, spec_info: SpecInput, cuda_graph_bs: Optional[int] + ): + pass diff --git a/python/sglang/srt/layers/attention/mamba/mamba2_metadata.py b/python/sglang/srt/layers/attention/mamba/mamba2_metadata.py index 5eeb2b65e307..5d34e1e111d3 100644 --- a/python/sglang/srt/layers/attention/mamba/mamba2_metadata.py +++ b/python/sglang/srt/layers/attention/mamba/mamba2_metadata.py @@ -27,6 +27,7 @@ class ForwardMetadata: query_start_loc: torch.Tensor mamba_cache_indices: torch.Tensor + mamba_cache_indices_gdn: Optional[torch.Tensor] = None # For topk > 1 eagle retrieve_next_token: Optional[torch.Tensor] = None retrieve_next_sibling: Optional[torch.Tensor] = None diff --git a/python/sglang/srt/layers/layernorm.py b/python/sglang/srt/layers/layernorm.py index e4960bdb42d6..431bc5cf8e5a 100644 --- a/python/sglang/srt/layers/layernorm.py +++ b/python/sglang/srt/layers/layernorm.py @@ -84,6 +84,7 @@ if _is_npu: import torch_npu + from sgl_kernel_npu.norm.add_rmsnorm_bias import add_gemma_rms_norm def _forward_with_allreduce_fusion( @@ -567,11 +568,13 @@ def forward_npu( if residual is not None: if post_residual_addition is not None: residual = residual + post_residual_addition - x = x + residual - residual = x + norm_out, residual = add_gemma_rms_norm( + x, self.weight, residual, self.variance_epsilon + ) + return norm_out, residual x, _ = torch_npu.npu_gemma_rms_norm(x, self.weight, self.variance_epsilon) - return x if residual is None else (x, residual) + return x def forward_xpu( self, diff --git a/python/sglang/srt/mem_cache/memory_pool.py b/python/sglang/srt/mem_cache/memory_pool.py index 881f3cad77a2..72cb64a1257c 100644 --- a/python/sglang/srt/mem_cache/memory_pool.py +++ b/python/sglang/srt/mem_cache/memory_pool.py @@ -255,6 +255,15 @@ def __init__( for conv_shape in conv_state_shape ] + if _is_npu: + from sglang.srt.hardware_backend.npu.memory_pool_npu import ( + _init_npu_conv_state, + ) + + conv_state = _init_npu_conv_state( + conv_state[0], conv_state_shape, speculative_num_draft_tokens + ) + if _is_cpu and _cpu_has_amx_support: from sglang.srt.layers.amx_utils import _init_amx_conv_state diff --git a/python/sglang/srt/models/qwen3_5_mtp.py b/python/sglang/srt/models/qwen3_5_mtp.py index 3fa89fcda0a9..037081431e95 100644 --- a/python/sglang/srt/models/qwen3_5_mtp.py +++ b/python/sglang/srt/models/qwen3_5_mtp.py @@ -15,6 +15,7 @@ """Inference-only Qwen3_5 MTP model.""" import logging +import os from typing import Iterable, Optional, Tuple import torch @@ -31,7 +32,8 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.models.qwen3_5 import Qwen3_5ForCausalLM -from sglang.srt.utils import add_prefix +from sglang.srt.server_args import get_global_server_args +from sglang.srt.utils import add_prefix, is_npu logger = logging.getLogger(__name__) @@ -53,6 +55,9 @@ def __init__( # The MTP model is unquantized in the nvfp4 checkpoint. if quant_config and quant_config.get_name() == "modelopt_fp4": quant_config = None + if get_global_server_args().speculative_draft_model_quantization is None: + quant_config = None + self.quant_config = quant_config self.config = config self.tp_size = get_tensor_model_parallel_world_size() @@ -118,6 +123,10 @@ def forward( input_embeds: Optional[torch.Tensor] = None, **kwargs, ): + if is_npu() and self.quant_config is None: + # ascend mtp unquant + os.environ["SGLANG_DEEPEP_BF16_DISPATCH"] = "1" + os.environ["DEEP_NORMAL_MODE_USE_INT8_QUANT"] = "0" assert input_embeds is None input_embeds = forward_batch.mm_input_embeds if ( @@ -149,6 +158,10 @@ def forward( forward_batch, hidden_states, ) + if is_npu() and self.quant_config is None: + # ascend mtp unquant + os.environ["SGLANG_DEEPEP_BF16_DISPATCH"] = "0" + os.environ["DEEP_NORMAL_MODE_USE_INT8_QUANT"] = "1" return self.logits_processor( input_ids, hidden_states, self.lm_head, forward_batch diff --git a/python/sglang/srt/models/qwen3_next_mtp.py b/python/sglang/srt/models/qwen3_next_mtp.py index b2bdbbbe8705..9270cacd6796 100644 --- a/python/sglang/srt/models/qwen3_next_mtp.py +++ b/python/sglang/srt/models/qwen3_next_mtp.py @@ -15,6 +15,7 @@ """Inference-only Qwen3Next MTP Speculative Decoding.""" import logging +import os from typing import Iterable, Optional, Tuple import torch @@ -23,6 +24,7 @@ from sglang.srt.distributed import get_pp_group, get_tensor_model_parallel_world_size from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder +from sglang.srt.hardware_backend.npu.graph_runner.npu_graph_runner import is_npu from sglang.srt.layers.layernorm import GemmaRMSNorm from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.quantization.base_config import QuantizationConfig @@ -46,6 +48,8 @@ def __init__( nn.Module.__init__(self) self.config = config self.tp_size = get_tensor_model_parallel_world_size() + if get_global_server_args().speculative_draft_model_quantization is None: + quant_config = None self.quant_config = quant_config # if not set, model load will be broken in Qwen3NextForCausalLM load_weights() self.pp_group = get_pp_group() @@ -86,6 +90,10 @@ def forward( input_embeds: Optional[torch.Tensor] = None, **kwargs, ): + if is_npu() and self.quant_config is None: + # ascend mtp unquant + os.environ["SGLANG_DEEPEP_BF16_DISPATCH"] = "1" + os.environ["DEEP_NORMAL_MODE_USE_INT8_QUANT"] = "0" if input_embeds is None: input_embeds = self.model.embed_tokens(input_ids) @@ -103,6 +111,10 @@ def forward( forward_batch, hidden_states, ) + if is_npu() and self.quant_config is None: + # ascend mtp unquant + os.environ["SGLANG_DEEPEP_BF16_DISPATCH"] = "0" + os.environ["DEEP_NORMAL_MODE_USE_INT8_QUANT"] = "1" return self.logits_processor( input_ids, hidden_states, self.lm_head, forward_batch From 230a528916c7feb1424b48ca109d9f90b978b32d Mon Sep 17 00:00:00 2001 From: longxin9715 <59550463+longxin9715@users.noreply.github.com> Date: Mon, 30 Mar 2026 21:13:50 +0800 Subject: [PATCH 13/42] fix bug (#181) Co-authored-by: j30065060 --- python/sglang/test/runners.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/test/runners.py b/python/sglang/test/runners.py index 61781fea21de..fa737338d7f1 100644 --- a/python/sglang/test/runners.py +++ b/python/sglang/test/runners.py @@ -234,7 +234,7 @@ def _forward_gme_qwen2_vl( **kwargs, ) -> torch.Tensor: if inputs_embeds is None: - inputs_embeds = self.model.model.embed_tokens(input_ids) + inputs_embeds = self.model.model.get_input_embeddings()(input_ids) if pixel_values is not None: pixel_values = pixel_values.type(self.model.visual.get_dtype()) image_embeds = self.model.visual( From 8de1b25bd5cf9fea6da0ee005c8a65beaefd6ad5 Mon Sep 17 00:00:00 2001 From: McZyWu Date: Tue, 31 Mar 2026 10:00:44 +0800 Subject: [PATCH 14/42] revert pr 19321 for accuracy temporarily (#178) * revert pr 19321 for accuracy temporarily --- python/sglang/srt/layers/linear.py | 38 +--- python/sglang/srt/models/qwen3_next.py | 237 ++++++++++++++++--------- 2 files changed, 158 insertions(+), 117 deletions(-) diff --git a/python/sglang/srt/layers/linear.py b/python/sglang/srt/layers/linear.py index abd7568707fc..924ef64fd33b 100644 --- a/python/sglang/srt/layers/linear.py +++ b/python/sglang/srt/layers/linear.py @@ -10,7 +10,6 @@ from torch import nn from torch.nn.parameter import Parameter, UninitializedParameter -from sglang.kernel_api_logging import wrap_method_with_debug_kernel_once from sglang.srt.distributed import ( divide, get_tensor_model_parallel_rank, @@ -177,13 +176,6 @@ def __init__( else: self.quant_method = quant_config.get_quant_method(self, prefix=prefix) - if self.quant_method is not None: - wrap_method_with_debug_kernel_once( - self.quant_method, - "apply", - op_name=f"sglang.quant_method.{self.quant_method.__class__.__name__}.apply", - ) - def forward(self, x: torch.Tensor) -> torch.Tensor: raise NotImplementedError @@ -539,15 +531,8 @@ def weight_loader( self, param: Parameter, loaded_weight: torch.Tensor, - loaded_shard_id: tuple[int, ...] | int | None = None, + loaded_shard_id: Optional[int] = None, ): - if isinstance(loaded_shard_id, tuple): - if hasattr(param, "load_merged_column_weight"): - return self.weight_loader_v2(param, loaded_weight, loaded_shard_id) - raise NotImplementedError( - "Shard id with multiple indices is not supported in weight_loader, " - "please use weight_loader_v2 instead." - ) # Special case for GGUF # initialize GGUF param after we know the quantize type @@ -714,10 +699,7 @@ def weight_loader( param_data.copy_(loaded_weight) def _load_fused_module_from_checkpoint( - self, - param: BasevLLMParameter, - loaded_weight: torch.Tensor, - output_sizes: list[int] | None = None, + self, param: BasevLLMParameter, loaded_weight: torch.Tensor ): """ Handle special case for models where MLP layers are already @@ -731,8 +713,7 @@ def _load_fused_module_from_checkpoint( current_shard_offset = 0 shard_offsets: List[Tuple[int, int, int]] = [] - output_sizes = output_sizes or self.output_sizes - for i, output_size in enumerate(output_sizes): + for i, output_size in enumerate(self.output_sizes): shard_offsets.append((i, current_shard_offset, output_size)) current_shard_offset += output_size @@ -802,9 +783,9 @@ def weight_loader_v2( self, param: BasevLLMParameter, loaded_weight: torch.Tensor, - loaded_shard_id: tuple[int, ...] | int | None = None, + loaded_shard_id: Optional[int] = None, ): - if loaded_shard_id is None or isinstance(loaded_shard_id, tuple): + if loaded_shard_id is None: if isinstance(param, PerTensorScaleParameter): param.load_merged_column_weight( loaded_weight=loaded_weight, @@ -823,15 +804,8 @@ def weight_loader_v2( tp_size=self.tp_size, ) return - output_sizes = ( - [self.output_sizes[idx] for idx in loaded_shard_id] - if loaded_shard_id - else None - ) # TODO: @dsikka - move to parameter.py - self._load_fused_module_from_checkpoint( - param, loaded_weight, output_sizes=output_sizes - ) + self._load_fused_module_from_checkpoint(param, loaded_weight) return assert loaded_shard_id < len(self.output_sizes) diff --git a/python/sglang/srt/models/qwen3_next.py b/python/sglang/srt/models/qwen3_next.py index c0bf8026179b..1b6ce088b8b0 100644 --- a/python/sglang/srt/models/qwen3_next.py +++ b/python/sglang/srt/models/qwen3_next.py @@ -3,7 +3,6 @@ from typing import Any, Iterable, Optional, Set, Tuple import torch -import triton from torch import nn from sglang.srt.configs.qwen3_next import Qwen3NextConfig @@ -21,7 +20,6 @@ from sglang.srt.layers.layernorm import GemmaRMSNorm from sglang.srt.layers.linear import ( ColumnParallelLinear, - MergedColumnParallelLinear, QKVParallelLinear, RowParallelLinear, ) @@ -56,7 +54,6 @@ logger = logging.getLogger(__name__) -from sglang.jit_kernel.triton.gdn_fused_proj import fused_qkvzba_split_reshape_cat from sglang.srt.layers.attention.fla.fused_norm_gate import FusedRMSNormGated _is_cuda = is_cuda() @@ -65,6 +62,147 @@ _is_amx_available = cpu_has_amx_support() +import triton +import triton.language as tl + + +@triton.jit +def fused_qkvzba_split_reshape_cat_kernel( + mixed_qkv, + z, + b, + a, + mixed_qkvz, + mixed_ba, + NUM_HEADS_QK: tl.constexpr, + NUM_HEADS_V: tl.constexpr, + HEAD_QK: tl.constexpr, + HEAD_V: tl.constexpr, +): + i_bs, i_qk = tl.program_id(0), tl.program_id(1) + QKVZ_DIM_T: tl.constexpr = HEAD_QK * 2 + NUM_HEADS_V // NUM_HEADS_QK * HEAD_V * 2 + BA_DIM_T: tl.constexpr = NUM_HEADS_V // NUM_HEADS_QK * 2 + QKV_DIM_T: tl.constexpr = HEAD_QK * 2 + NUM_HEADS_V // NUM_HEADS_QK * HEAD_V + q_end: tl.constexpr = HEAD_QK + blk_q_ptr = ( + mixed_qkvz + + i_bs * NUM_HEADS_QK * QKVZ_DIM_T + + i_qk * QKVZ_DIM_T + + tl.arange(0, q_end) + ) + k_end: tl.constexpr = q_end + HEAD_QK + blk_k_ptr = ( + mixed_qkvz + + i_bs * NUM_HEADS_QK * QKVZ_DIM_T + + i_qk * QKVZ_DIM_T + + tl.arange(q_end, k_end) + ) + v_end: tl.constexpr = k_end + NUM_HEADS_V // NUM_HEADS_QK * HEAD_V + blk_v_ptr = ( + mixed_qkvz + + i_bs * NUM_HEADS_QK * QKVZ_DIM_T + + i_qk * QKVZ_DIM_T + + tl.arange(k_end, v_end) + ) + z_end: tl.constexpr = v_end + NUM_HEADS_V // NUM_HEADS_QK * HEAD_V + blk_z_ptr = ( + mixed_qkvz + + i_bs * NUM_HEADS_QK * QKVZ_DIM_T + + i_qk * QKVZ_DIM_T + + tl.arange(v_end, z_end) + ) + blk_q_st_ptr = ( + mixed_qkv + + i_bs * NUM_HEADS_QK * QKV_DIM_T + + i_qk * HEAD_QK + + tl.arange(0, HEAD_QK) + ) + blk_k_st_ptr = ( + mixed_qkv + + i_bs * NUM_HEADS_QK * QKV_DIM_T + + NUM_HEADS_QK * HEAD_QK + + i_qk * HEAD_QK + + tl.arange(0, HEAD_QK) + ) + blk_v_st_ptr = ( + mixed_qkv + + i_bs * NUM_HEADS_QK * QKV_DIM_T + + NUM_HEADS_QK * HEAD_QK * 2 + + i_qk * HEAD_V * NUM_HEADS_V // NUM_HEADS_QK + + tl.arange(0, HEAD_V * NUM_HEADS_V // NUM_HEADS_QK) + ) + blk_z_st_ptr = ( + z + + i_bs * NUM_HEADS_V * HEAD_V + + i_qk * HEAD_V * NUM_HEADS_V // NUM_HEADS_QK + + tl.arange(0, HEAD_V * NUM_HEADS_V // NUM_HEADS_QK) + ) + tl.store(blk_q_st_ptr, tl.load(blk_q_ptr)) + tl.store(blk_k_st_ptr, tl.load(blk_k_ptr)) + tl.store(blk_v_st_ptr, tl.load(blk_v_ptr)) + tl.store(blk_z_st_ptr, tl.load(blk_z_ptr)) + b_end: tl.constexpr = NUM_HEADS_V // NUM_HEADS_QK + a_end: tl.constexpr = b_end + NUM_HEADS_V // NUM_HEADS_QK + for i in tl.static_range(b_end): + blk_b_ptr = mixed_ba + i_bs * NUM_HEADS_QK * BA_DIM_T + i_qk * BA_DIM_T + i + blk_b_st_ptr = b + i_bs * NUM_HEADS_V + i_qk * NUM_HEADS_V // NUM_HEADS_QK + i + tl.store(blk_b_st_ptr, tl.load(blk_b_ptr)) + for i in tl.static_range(b_end, a_end): + blk_a_ptr = mixed_ba + i_bs * NUM_HEADS_QK * BA_DIM_T + i_qk * BA_DIM_T + i + blk_a_st_ptr = ( + a + i_bs * NUM_HEADS_V + i_qk * NUM_HEADS_V // NUM_HEADS_QK + (i - b_end) + ) + tl.store(blk_a_st_ptr, tl.load(blk_a_ptr)) + + +def fused_qkvzba_split_reshape_cat( + mixed_qkvz, + mixed_ba, + num_heads_qk, + num_heads_v, + head_qk, + head_v, +): + batch, seq_len = mixed_qkvz.shape[0], 1 + qkv_dim_t = num_heads_qk * head_qk * 2 + num_heads_v * head_v + mixed_qkv = torch.empty( + [batch * seq_len, qkv_dim_t], + dtype=mixed_qkvz.dtype, + device=mixed_qkvz.device, + ) + z = torch.empty( + [batch * seq_len, num_heads_v, head_v], + dtype=mixed_qkvz.dtype, + device=mixed_qkvz.device, + ) + b = torch.empty( + [batch * seq_len, num_heads_v], + dtype=mixed_ba.dtype, + device=mixed_ba.device, + ) + a = torch.empty_like(b) + grid = (batch * seq_len, num_heads_qk) + fused_qkvzba_split_reshape_cat_kernel[grid]( + mixed_qkv, + z, + b, + a, + mixed_qkvz, + mixed_ba, + num_heads_qk, + num_heads_v, + head_qk, + head_v, + num_warps=1, + num_stages=3, + ) + return mixed_qkv, z, b, a + + +if _is_npu: + from sgl_kernel_npu.fla.utils import fused_qkvzba_split_reshape_cat as fused_qkvzba_split_reshape_cat_npu + fused_qkvzba_split_reshape_cat = fused_qkvzba_split_reshape_cat_npu + class Qwen3GatedDeltaNet(nn.Module): def __init__( self, @@ -111,38 +249,28 @@ def __init__( prefix=add_prefix("conv1d", prefix), ) self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1) + projection_size_qkvz = self.key_dim * 2 + self.value_dim * 2 + projection_size_ba = self.num_v_heads * 2 - # projection of the input hidden states - self.in_proj_qkvz = self.create_qkvz_proj( - hidden_size=self.hidden_size, - key_dim=self.key_dim, - value_dim=self.value_dim, + self.in_proj_qkvz = ColumnParallelLinear( + input_size=self.hidden_size, + output_size=projection_size_qkvz, + bias=False, quant_config=quant_config, - prefix=add_prefix("in_proj_qkvz", prefix), tp_rank=self.attn_tp_rank, tp_size=self.attn_tp_size, + prefix=add_prefix("in_proj_qkvz", prefix), ) - - self.in_proj_ba = MergedColumnParallelLinear( + self.in_proj_ba = ColumnParallelLinear( input_size=self.hidden_size, - output_sizes=[self.num_v_heads] * 2, + output_size=projection_size_ba, bias=False, quant_config=quant_config, - prefix=add_prefix("in_proj_ba", prefix), tp_rank=self.attn_tp_rank, tp_size=self.attn_tp_size, + prefix=add_prefix("in_proj_ba", prefix), ) - # Override weight_loader for packed checkpoint format. - # Must capture original_loader BEFORE overwriting. - self.in_proj_qkvz.weight.weight_loader = self._make_packed_weight_loader( - self.in_proj_qkvz - ) - self.in_proj_ba.weight.weight_loader = self._make_packed_weight_loader( - self.in_proj_ba - ) - - # Conv1d weight loader setup query_key_settings = (self.key_dim, 0, False) value_settings = (self.value_dim, 0, False) @@ -216,61 +344,7 @@ def __init__( dt_bias=self.dt_bias, ) - @staticmethod - def _make_packed_weight_loader(module): - """Create a weight_loader that does contiguous TP slicing for fused - (packed-format) checkpoint weights (shard_id=None), and delegates - to the standard MergedColumnParallelLinear loader for split checkpoint - weights (shard_id=int/tuple).""" - original_loader = module.weight.weight_loader - - def weight_loader(param, loaded_weight, loaded_shard_id=None): - if loaded_shard_id is None: - # Fused checkpoint: weight is in packed (per-head-group) - # format. Do contiguous TP slice like ColumnParallelLinear. - output_dim = getattr(param, "output_dim", None) - if output_dim is not None and module.tp_size > 1: - shard_size = param.data.shape[output_dim] - start_idx = module.tp_rank * shard_size - loaded_weight = loaded_weight.narrow( - output_dim, start_idx, shard_size - ) - assert param.data.shape == loaded_weight.shape, ( - f"Shape mismatch: param {param.data.shape} vs " - f"loaded {loaded_weight.shape}" - ) - param.data.copy_(loaded_weight) - else: - # Split checkpoint (int or tuple shard_id) → standard path - original_loader(param, loaded_weight, loaded_shard_id) - - return weight_loader - - def create_qkvz_proj( - self, - hidden_size: int, - key_dim: int, - value_dim: int, - quant_config: QuantizationConfig | None, - prefix: str, - tp_rank: Optional[int] = None, - tp_size: Optional[int] = None, - ) -> MergedColumnParallelLinear: - return MergedColumnParallelLinear( - input_size=hidden_size, - output_sizes=[key_dim, key_dim, value_dim, value_dim], - bias=False, - quant_config=quant_config, - prefix=prefix, - tp_rank=tp_rank, - tp_size=tp_size, - ) - - def fix_query_key_value_ordering( - self, - mixed_qkvz: torch.Tensor, - mixed_ba: torch.Tensor, - ): + def fix_query_key_value_ordering(self, mixed_qkvz, mixed_ba): """ Derives `query`, `key` and `value` tensors from `mixed_qkvzba`. """ @@ -962,18 +1036,11 @@ def load_weights( ) -> Set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) - # self attention ("qkv_proj", "q_proj", "q"), ("qkv_proj", "k_proj", "k"), ("qkv_proj", "v_proj", "v"), - # mlp ("gate_up_proj", "gate_proj", 0), ("gate_up_proj", "up_proj", 1), - # GDN - ("in_proj_qkvz.", "in_proj_qkv.", (0, 1, 2)), - ("in_proj_qkvz.", "in_proj_z.", 3), - ("in_proj_ba.", "in_proj_b.", 0), - ("in_proj_ba.", "in_proj_a.", 1), ] # Params for weights, fp8 weight scales, fp8 activation scales From e36779a44628bb004fb16be973534adaa3f8f50c Mon Sep 17 00:00:00 2001 From: McZyWu Date: Tue, 31 Mar 2026 16:46:33 +0800 Subject: [PATCH 15/42] Bug fix for not import is npu (#182) --- python/sglang/srt/models/qwen3_next_mtp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/models/qwen3_next_mtp.py b/python/sglang/srt/models/qwen3_next_mtp.py index 9270cacd6796..cc4f0f4715e9 100644 --- a/python/sglang/srt/models/qwen3_next_mtp.py +++ b/python/sglang/srt/models/qwen3_next_mtp.py @@ -32,7 +32,7 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.models.qwen3_next import Qwen3NextForCausalLM, Qwen3NextModel from sglang.srt.server_args import get_global_server_args -from sglang.srt.utils import add_prefix +from sglang.srt.utils import add_prefix, is_npu logger = logging.getLogger(__name__) From cf463fba8218d4b17aaa73050e3bd05dcb3e2034 Mon Sep 17 00:00:00 2001 From: iridiumine <42236072+iridiumine@users.noreply.github.com> Date: Tue, 31 Mar 2026 22:25:11 +0800 Subject: [PATCH 16/42] =?UTF-8?q?Revert=20"Revert=20"Use=20LazyValue=20for?= =?UTF-8?q?=20routed=5Fexperts=5Fweights=5Fof=5Flayer=20initializat?= =?UTF-8?q?=E2=80=A6"=20(#190)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- python/sglang/srt/models/qwen3_moe.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/python/sglang/srt/models/qwen3_moe.py b/python/sglang/srt/models/qwen3_moe.py index 1373099fdf5a..010a73074759 100644 --- a/python/sglang/srt/models/qwen3_moe.py +++ b/python/sglang/srt/models/qwen3_moe.py @@ -1182,11 +1182,15 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): logger.warning(f"Parameter {name} not found in params_dict") if not hasattr(self, "routed_experts_weights_of_layer"): - self.routed_experts_weights_of_layer = { - layer_id: self.model.layers[layer_id].mlp.get_moe_weights() - for layer_id in range(self.start_layer, self.end_layer) - if isinstance(self.model.layers[layer_id].mlp, Qwen3MoeSparseMoeBlock) - } + self.routed_experts_weights_of_layer = LazyValue( + lambda: { + layer_id: self.model.layers[layer_id].mlp.get_moe_weights() + for layer_id in range(self.start_layer, self.end_layer) + if isinstance( + self.model.layers[layer_id].mlp, Qwen3MoeSparseMoeBlock + ) + } + ) @classmethod def get_model_config_for_expert_location(cls, config): From b28eeff8f74c5f4db1765f6c7aeff44bfe79b7d6 Mon Sep 17 00:00:00 2001 From: iridiumine <42236072+iridiumine@users.noreply.github.com> Date: Wed, 1 Apr 2026 17:45:19 +0800 Subject: [PATCH 17/42] fix: qwen3.5 precision & quant model load error (#191) --- .../quantization/modelslim/modelslim.py | 9 +- python/sglang/srt/models/qwen3_5.py | 361 ++++++++++++++---- 2 files changed, 304 insertions(+), 66 deletions(-) diff --git a/python/sglang/srt/layers/quantization/modelslim/modelslim.py b/python/sglang/srt/layers/quantization/modelslim/modelslim.py index 84acecccc415..2f3ef5a40f69 100644 --- a/python/sglang/srt/layers/quantization/modelslim/modelslim.py +++ b/python/sglang/srt/layers/quantization/modelslim/modelslim.py @@ -153,9 +153,16 @@ def get_quant_method( key = "vision_model" elif "visual" in prefix: key = "visual" - if "vision_tower" in prefix or "mm_projector" in prefix: + if ( + "vision_tower" in prefix + or "mm_projector" in prefix + or "in_proj_qkvz" in prefix + or "in_proj_ba" in prefix + ): prefix = prefix.replace(r"attn.qkv_proj", r"wqkv") prefix = prefix.replace(r"attn.proj", r"wo") + prefix = prefix.replace(r"in_proj_qkvz", r"in_proj_qkv") + prefix = prefix.replace(r"in_proj_ba", r"in_proj_b") packed_modules_mapping_subset = self.packed_modules_mapping.get(key, {}) prefix_in_quant_config = prefix proj_name = prefix.split(".")[-1] diff --git a/python/sglang/srt/models/qwen3_5.py b/python/sglang/srt/models/qwen3_5.py index 107b73378c72..622758c356ef 100644 --- a/python/sglang/srt/models/qwen3_5.py +++ b/python/sglang/srt/models/qwen3_5.py @@ -20,6 +20,11 @@ import torch import torch.nn as nn +import triton + +from sglang.jit_kernel.triton.gdn_fused_proj import ( + fused_qkvzba_split_reshape_cat_contiguous, +) # Configs from sglang.srt.configs.qwen3_5 import ( @@ -54,6 +59,10 @@ RowParallelLinear, ) from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE +from sglang.srt.layers.parameter import ( + BlockQuantScaleParameter, + PerTensorScaleParameter, +) from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_linear_attention import RadixLinearAttention @@ -70,11 +79,14 @@ # Models from sglang.srt.models.qwen3_vl import Qwen3VLForConditionalGeneration +from sglang.srt.server_args import get_global_server_args # Utils from sglang.srt.utils import ( LazyValue, add_prefix, + cpu_has_amx_support, + is_cpu, is_cuda, is_npu, make_layers, @@ -85,6 +97,15 @@ logger = logging.getLogger(__name__) _is_cuda = is_cuda() _is_npu = is_npu() +_is_cpu = is_cpu() +_is_amx_available = cpu_has_amx_support() + +if _is_npu: + from sgl_kernel_npu.fla.utils import ( + fused_qkvzba_split_reshape_cat as fused_qkvzba_split_reshape_cat_npu, + ) + + fused_qkvzba_split_reshape_cat_contiguous = fused_qkvzba_split_reshape_cat_npu cached_get_processor = lru_cache(get_processor) @@ -129,63 +150,47 @@ def __init__( ) self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1) - # Split projection layers (following vLLM's implementation) - # Instead of fused in_proj_qkvz and in_proj_ba, use separate layers - self.in_proj_qkv = MergedColumnParallelLinear( - input_size=self.hidden_size, - output_sizes=[self.key_dim, self.key_dim, self.value_dim], - bias=False, - quant_config=quant_config, - tp_rank=self.attn_tp_rank, - tp_size=self.attn_tp_size, - prefix=add_prefix("in_proj_qkv", prefix), - ) - self.in_proj_z = ColumnParallelLinear( - input_size=self.hidden_size, - output_size=self.value_dim, - bias=False, - quant_config=quant_config, - tp_rank=self.attn_tp_rank, - tp_size=self.attn_tp_size, - prefix=add_prefix("in_proj_z", prefix), - ) - self.in_proj_b = ColumnParallelLinear( - input_size=self.hidden_size, - output_size=self.num_v_heads, - bias=False, + # projection of the input hidden states + self.in_proj_qkvz = self.create_qkvz_proj( + hidden_size=self.hidden_size, + key_dim=self.key_dim, + value_dim=self.value_dim, quant_config=quant_config, + prefix=add_prefix("in_proj_qkvz", prefix), tp_rank=self.attn_tp_rank, tp_size=self.attn_tp_size, - prefix=add_prefix("in_proj_b", prefix), ) - self.in_proj_a = ColumnParallelLinear( - input_size=self.hidden_size, - output_size=self.num_v_heads, - bias=False, + + self.in_proj_ba = self.create_ba_proj( + hidden_size=self.hidden_size, + num_v_heads=self.num_v_heads, quant_config=quant_config, + prefix=add_prefix("in_proj_ba", prefix), tp_rank=self.attn_tp_rank, tp_size=self.attn_tp_size, - prefix=add_prefix("in_proj_a", prefix), ) + # Override weight loaders for packed checkpoint format. + # Important: for FP8, this must cover not only `.weight` but also + # `weight_scale_inv` / `weight_scale` / `input_scale` if present. + self._bind_packed_weight_loaders(self.in_proj_qkvz) + self._bind_packed_weight_loaders(self.in_proj_ba) + # Conv1d weight loader setup query_key_settings = (self.key_dim, 0, False) value_settings = (self.value_dim, 0, False) - delattr(self.conv1d.weight, "weight_loader") - set_weight_attrs( + self._override_weight_loader( self.conv1d.weight, - { - "weight_loader": mamba_v2_sharded_weight_loader( - [ - query_key_settings, - query_key_settings, - value_settings, - ], - self.attn_tp_size, - self.attn_tp_rank, - ) - }, + mamba_v2_sharded_weight_loader( + [ + query_key_settings, + query_key_settings, + value_settings, + ], + self.attn_tp_size, + self.attn_tp_rank, + ), ) # State parameters @@ -202,7 +207,7 @@ def __init__( conv_weights = self.conv1d.weight.view( self.conv1d.weight.size(0), self.conv1d.weight.size(2) ) - # RadixLinearAttention layer + self.attn = RadixLinearAttention( layer_id=layer_id, num_q_heads=self.num_k_heads // self.attn_tp_size, @@ -218,7 +223,6 @@ def __init__( dt_bias=self.dt_bias, ) - # Normalization layer self.norm = RMSNormGated( self.head_v_dim, eps=self.layer_norm_epsilon, @@ -228,7 +232,6 @@ def __init__( dtype=config.torch_dtype, ) - # Output projection self.out_proj = RowParallelLinear( self.value_dim, self.hidden_size, @@ -241,16 +244,190 @@ def __init__( prefix=add_prefix("out_proj", prefix), ) + @staticmethod + def _override_weight_loader(param, loader): + """Robustly override loader for: + 1) BasevLLMParameter subclasses: real storage is `_weight_loader` + 2) regular Parameters that already have mutable `weight_loader` + 3) regular Parameters without `weight_loader` yet + """ + if hasattr(param, "_weight_loader"): + # FP8 / quantized BasevLLMParameter path + param._weight_loader = loader + return + + if hasattr(param, "weight_loader"): + # Regular parameter/tensor that already has a mutable attr. + # Do NOT call set_weight_attrs here, because it asserts when + # overwriting an existing attribute. + param.weight_loader = loader + return + + # Fresh attribute on a normal tensor/Parameter + set_weight_attrs(param, {"weight_loader": loader}) + + def _bind_packed_weight_loaders(self, module): + """Bind packed-checkpoint-aware loaders to all relevant params of a merged module.""" + for attr_name in ("weight", "weight_scale_inv", "weight_scale", "input_scale"): + param = getattr(module, attr_name, None) + if param is None: + continue + original_loader = getattr(param, "weight_loader", None) + if original_loader is None: + continue + wrapped_loader = self._make_packed_weight_loader(module, original_loader) + self._override_weight_loader(param, wrapped_loader) + + @staticmethod + def _get_split_sizes_for_param(module, param, loaded_shard_id): + """Return checkpoint-side split sizes for this param type.""" + if isinstance(param, BlockQuantScaleParameter): + # Split by output blocks, not raw output sizes. + block_n, _ = module.quant_method.quant_config.weight_block_size + block_n = 1 if getattr(param, "format_ue8m0", False) else block_n + return [ + (module.output_sizes[idx] + block_n - 1) // block_n + for idx in loaded_shard_id + ] + + if isinstance(param, PerTensorScaleParameter): + # One logical scale per logical shard. + return [1 for _ in loaded_shard_id] + + # Normal weight / non-block quant tensor + return [module.output_sizes[idx] for idx in loaded_shard_id] + + @classmethod + def _make_packed_weight_loader(cls, module, original_weight_loader): + """Wrap the param's original loader so split checkpoints: + - in_proj_qkv + in_proj_z -> merged in_proj_qkvz + - in_proj_b + in_proj_a -> merged in_proj_ba + can load correctly for both normal and FP8 params. + """ + + def weight_loader(param, loaded_weight, loaded_shard_id=None): + # Only intercept split-checkpoint tuple shards. + # int shard_id and None should preserve original behavior. + if isinstance(loaded_shard_id, tuple): + split_sizes = cls._get_split_sizes_for_param( + module, param, loaded_shard_id + ) + + if len(loaded_weight.shape) == 0: + # Scalar only makes sense for a single logical shard. + assert len(split_sizes) == 1 and split_sizes[0] == 1, ( + f"Unexpected scalar for tuple shard load: " + f"{loaded_shard_id=}, {split_sizes=}" + ) + chunks = [loaded_weight.reshape(1)] + else: + split_dim = getattr(param, "output_dim", 0) + chunks = loaded_weight.split(split_sizes, dim=split_dim) + + assert len(chunks) == len(loaded_shard_id), ( + f"Chunk/shard mismatch: {len(chunks)=}, " + f"{len(loaded_shard_id)=}, {split_sizes=}" + ) + + for idx, chunk in zip(loaded_shard_id, chunks): + # Delegate each chunk to the param's original int-shard loader. + original_weight_loader(param, chunk, idx) + return + + return original_weight_loader(param, loaded_weight, loaded_shard_id) + + return weight_loader + + def create_qkvz_proj( + self, + hidden_size: int, + key_dim: int, + value_dim: int, + quant_config: QuantizationConfig | None, + prefix: str, + tp_rank: Optional[int] = None, + tp_size: Optional[int] = None, + ) -> MergedColumnParallelLinear: + return MergedColumnParallelLinear( + input_size=hidden_size, + output_sizes=[key_dim, key_dim, value_dim, value_dim], + bias=False, + quant_config=quant_config, + prefix=prefix, + tp_rank=tp_rank, + tp_size=tp_size, + ) + + def create_ba_proj( + self, + hidden_size: int, + num_v_heads: int, + quant_config: QuantizationConfig | None, + prefix: str, + tp_rank: Optional[int] = None, + tp_size: Optional[int] = None, + ) -> MergedColumnParallelLinear: + # Qwen3.5 has separate in_proj_b and in_proj_a weights in the + # checkpoint, which are loaded into the fused in_proj_ba parameter + # via stacked_params_mapping with shard_id 0 and 1 respectively. + return MergedColumnParallelLinear( + input_size=hidden_size, + output_sizes=[num_v_heads, num_v_heads], + bias=False, + quant_config=quant_config, + prefix=prefix, + tp_rank=tp_rank, + tp_size=tp_size, + ) + def fix_query_key_value_ordering( self, - mixed_qkv, - z, - b, - a, + mixed_qkvz: torch.Tensor, + mixed_ba: torch.Tensor, ): - raise NotImplementedError( - "Qwen3.5 Series dont need to fix query key value ordering" - ) + """ + Derives `query`, `key` and `value` tensors from `mixed_qkvzba`. + """ + k_tp = self.key_dim // self.attn_tp_size + v_tp = self.value_dim // self.attn_tp_size + nv_tp = self.num_v_heads // self.attn_tp_size + + # Directly split, no head group reshape + query, key, value, z = mixed_qkvz.split([k_tp, k_tp, v_tp, v_tp], dim=-1) + b, a = mixed_ba.split([nv_tp, nv_tp], dim=-1) + + # value / z reshape to (seq, num_v_heads/tp, head_v_dim) + value = value.reshape(value.size(0), -1, self.head_v_dim) + z = z.reshape(z.size(0), -1, self.head_v_dim) + + return query, key, value, z, b, a + + def _forward_input_proj(self, hidden_states: torch.Tensor): + if ( + _is_cpu + or _is_npu + or not get_global_server_args().disable_piecewise_cuda_graph + ): + DUAL_STREAM_TOKEN_THRESHOLD = 0 + else: + DUAL_STREAM_TOKEN_THRESHOLD = 1024 + + seq_len, _ = hidden_states.shape + if ( + self.alt_stream is not None + and get_is_capture_mode() + and seq_len < DUAL_STREAM_TOKEN_THRESHOLD + ): + current_stream = torch.cuda.current_stream() + self.alt_stream.wait_stream(current_stream) + projected_states_qkvz, _ = self.in_proj_qkvz(hidden_states) + with torch.cuda.stream(self.alt_stream): + projected_states_ba, _ = self.in_proj_ba(hidden_states) + current_stream.wait_stream(self.alt_stream) + else: + projected_states_qkvz, _ = self.in_proj_qkvz(hidden_states) + projected_states_ba, _ = self.in_proj_ba(hidden_states) + return projected_states_qkvz, projected_states_ba def forward( self, @@ -263,30 +440,64 @@ def forward( 2. Core attention (custom op) 3. Output projection """ - seq_len, _ = hidden_states.shape - - mixed_qkv, _ = self.in_proj_qkv(hidden_states) - z, _ = self.in_proj_z(hidden_states) - z = z.reshape(z.size(0), -1, self.head_v_dim) - b, _ = self.in_proj_b(hidden_states) - a, _ = self.in_proj_a(hidden_states) - - b = b.contiguous() - a = a.contiguous() + projected_states_qkvz, projected_states_ba = self._forward_input_proj( + hidden_states + ) + if self.num_v_heads // self.num_k_heads in [1, 2, 4] and not _is_cpu: + mixed_qkv, z, b, a = fused_qkvzba_split_reshape_cat_contiguous( + projected_states_qkvz, + projected_states_ba, + triton.cdiv(self.num_k_heads, self.attn_tp_size), + triton.cdiv(self.num_v_heads, self.attn_tp_size), + self.head_k_dim, + self.head_v_dim, + ) + b = b.contiguous() + a = a.contiguous() + elif _is_cpu and _is_amx_available: + mixed_qkv, z, b, a = ( + torch.ops.sgl_kernel.fused_qkvzba_split_reshape_cat_cpu( + projected_states_qkvz, + projected_states_ba, + self.num_k_heads // self.attn_tp_size, + self.num_v_heads // self.attn_tp_size, + self.head_k_dim, + self.head_v_dim, + ) + ) + else: + query, key, value, z, b, a = self.fix_query_key_value_ordering( + projected_states_qkvz, projected_states_ba + ) + b = b.contiguous() + a = a.contiguous() + query, key, value = map( + lambda x: x.reshape(x.shape[0], -1), (query, key, value) + ) + mixed_qkv = torch.cat((query, key, value), dim=-1) core_attn_out = self.attn( - forward_batch=forward_batch, + forward_batch, mixed_qkv=mixed_qkv, a=a, b=b, ) z_shape_og = z.shape + # reshape input data into 2D tensor core_attn_out = core_attn_out.reshape(-1, core_attn_out.shape[-1]) z = z.reshape(-1, z.shape[-1]) + + # Add padding for DP-Attn + if core_attn_out.shape != z.shape: + core_attn_out_pad = torch.zeros_like(z) + core_attn_out_pad[: core_attn_out.shape[0], :] = core_attn_out + core_attn_out = core_attn_out_pad + core_attn_out = self.norm(core_attn_out, z) core_attn_out = core_attn_out.reshape(z_shape_og) - core_attn_out = core_attn_out.flatten(-2) # ... h d -> ... (h d) + core_attn_out = core_attn_out.reshape(*core_attn_out.shape[:-2], -1) + output, _ = self.out_proj(core_attn_out) return output @@ -818,6 +1029,11 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): ("qkv_proj", "v_proj", "v"), ("gate_up_proj", "gate_proj", 0), ("gate_up_proj", "up_proj", 1), + # GDN + ("in_proj_qkvz.", "in_proj_qkv.", (0, 1, 2)), + ("in_proj_qkvz.", "in_proj_z.", 3), + ("in_proj_ba.", "in_proj_b.", 0), + ("in_proj_ba.", "in_proj_a.", 1), ] loaded_params: Set[str] = set() @@ -894,6 +1110,11 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): ("qkv_proj", "v_proj", "v"), ("gate_up_proj", "gate_proj", 0), ("gate_up_proj", "up_proj", 1), + # GDN + ("in_proj_qkvz.", "in_proj_qkv.", (0, 1, 2)), + ("in_proj_qkvz.", "in_proj_z.", 3), + ("in_proj_ba.", "in_proj_b.", 0), + ("in_proj_ba.", "in_proj_a.", 1), ] # Params for weights, fp8 weight scales, fp8 activation scales @@ -1127,6 +1348,11 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): ("qkv_proj", "v_proj", "v"), ("gate_up_proj", "gate_proj", 0), ("gate_up_proj", "up_proj", 1), + # GDN fused projections + ("in_proj_qkvz.", "in_proj_qkv.", (0, 1, 2)), + ("in_proj_qkvz.", "in_proj_z.", 3), + ("in_proj_ba.", "in_proj_b.", 0), + ("in_proj_ba.", "in_proj_a.", 1), ] loaded_params: Set[str] = set() @@ -1223,6 +1449,11 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): ("qkv_proj", "v_proj", "v"), ("gate_up_proj", "gate_proj", 0), ("gate_up_proj", "up_proj", 1), + # GDN fused projections + ("in_proj_qkvz.", "in_proj_qkv.", (0, 1, 2)), + ("in_proj_qkvz.", "in_proj_z.", 3), + ("in_proj_ba.", "in_proj_b.", 0), + ("in_proj_ba.", "in_proj_a.", 1), ] # Params for weights, fp8 weight scales, fp8 activation scales From 2e7fbff6f4db4d6d2121472f625b961c7c8c93a5 Mon Sep 17 00:00:00 2001 From: iridiumine <42236072+iridiumine@users.noreply.github.com> Date: Wed, 1 Apr 2026 20:50:44 +0800 Subject: [PATCH 18/42] [NPU] change fused_qkvzba_split_reshape_cat_npu to fused_qkvzba_split_reshape_cat_contiguous (#198) --- python/sglang/srt/models/qwen3_5.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/python/sglang/srt/models/qwen3_5.py b/python/sglang/srt/models/qwen3_5.py index 622758c356ef..1e1cb0d7b909 100644 --- a/python/sglang/srt/models/qwen3_5.py +++ b/python/sglang/srt/models/qwen3_5.py @@ -100,13 +100,6 @@ _is_cpu = is_cpu() _is_amx_available = cpu_has_amx_support() -if _is_npu: - from sgl_kernel_npu.fla.utils import ( - fused_qkvzba_split_reshape_cat as fused_qkvzba_split_reshape_cat_npu, - ) - - fused_qkvzba_split_reshape_cat_contiguous = fused_qkvzba_split_reshape_cat_npu - cached_get_processor = lru_cache(get_processor) From 61c384adc67c832e1a90de46de0d7f3a03d087b7 Mon Sep 17 00:00:00 2001 From: iridiumine <42236072+iridiumine@users.noreply.github.com> Date: Thu, 2 Apr 2026 22:44:54 +0800 Subject: [PATCH 19/42] [NPU] Use causal_conv1d and fix qwen-next modelslim (#207) * fix: use causal_conv1d from sgl-kernel-npu --- .../npu/attention/ascend_gdn_backend.py | 30 ++++++++++--------- .../attention/hybrid_linear_attn_backend.py | 4 +-- .../quantization/modelslim/modelslim.py | 9 +----- python/sglang/srt/models/qwen3_5.py | 12 +++++++- 4 files changed, 30 insertions(+), 25 deletions(-) diff --git a/python/sglang/srt/hardware_backend/npu/attention/ascend_gdn_backend.py b/python/sglang/srt/hardware_backend/npu/attention/ascend_gdn_backend.py index b5fc8445cfd5..fcc8c6316d7a 100644 --- a/python/sglang/srt/hardware_backend/npu/attention/ascend_gdn_backend.py +++ b/python/sglang/srt/hardware_backend/npu/attention/ascend_gdn_backend.py @@ -265,23 +265,25 @@ def forward_extend( conv_states[conv_dst[mask_indices]] = mixed_qkv_to_track kernel_size = layer.conv_weights.shape[-1] conv_states_for_prefill = conv_states[:, -(kernel_size - 1) :, :] - conv_states_tmp = conv_states_for_prefill.transpose(1, 2).contiguous() + conv_states_tmp = conv_states_for_prefill.contiguous() - mixed_qkv = causal_conv1d_fn( - mixed_qkv, - layer.conv_weights, + x = mixed_qkv.transpose(0, 1).contiguous() + weight = layer.conv_weights.transpose(0, 1).contiguous() + activation_mode = layer.activation == "silu" + + mixed_qkv = torch.ops.npu.causal_conv1d( + x, + weight, + conv_states_tmp, + query_start_loc, + cache_indices, + has_initial_states, layer.bias, - activation=layer.activation, - conv_states=conv_states_tmp, - has_initial_state=has_initial_states, - cache_indices=cache_indices, - query_start_loc=query_start_loc, - seq_lens_cpu=forward_batch.extend_seq_lens_cpu, - ).transpose(0, 1)[:seq_len] - conv_states[:, -(kernel_size - 1) :, :] = conv_states_tmp.transpose( - 1, 2 - ).contiguous() + activation_mode, + self.pad_slot_id, + )[:seq_len] + conv_states[:, -(kernel_size - 1) :, :] = conv_states_tmp if is_target_verify: g, beta = fused_gdn_gating_kernel_without_sigmoid( layer.A_log, a, b, layer.dt_bias diff --git a/python/sglang/srt/layers/attention/hybrid_linear_attn_backend.py b/python/sglang/srt/layers/attention/hybrid_linear_attn_backend.py index 4643147b27c5..75d0fac2b6d4 100644 --- a/python/sglang/srt/layers/attention/hybrid_linear_attn_backend.py +++ b/python/sglang/srt/layers/attention/hybrid_linear_attn_backend.py @@ -32,7 +32,7 @@ if is_npu(): from sgl_kernel_npu.mamba.mamba_state_update_triton import ( conv_state_rollback, - move_intermediate_cache_dynamic_h_block, + move_intermediate_cache, ) logger = logging.getLogger(__name__) @@ -1004,7 +1004,7 @@ def update_mamba_state_after_mtp_verify( valid_state_indices = state_indices_tensor.to(torch.int64) # [N] last_steps = accepted_steps.to(torch.int64) # [N] - move_intermediate_cache_dynamic_h_block( + move_intermediate_cache( ssm_states, intermediate_state_cache, valid_state_indices, last_steps ) diff --git a/python/sglang/srt/layers/quantization/modelslim/modelslim.py b/python/sglang/srt/layers/quantization/modelslim/modelslim.py index 2f3ef5a40f69..84acecccc415 100644 --- a/python/sglang/srt/layers/quantization/modelslim/modelslim.py +++ b/python/sglang/srt/layers/quantization/modelslim/modelslim.py @@ -153,16 +153,9 @@ def get_quant_method( key = "vision_model" elif "visual" in prefix: key = "visual" - if ( - "vision_tower" in prefix - or "mm_projector" in prefix - or "in_proj_qkvz" in prefix - or "in_proj_ba" in prefix - ): + if "vision_tower" in prefix or "mm_projector" in prefix: prefix = prefix.replace(r"attn.qkv_proj", r"wqkv") prefix = prefix.replace(r"attn.proj", r"wo") - prefix = prefix.replace(r"in_proj_qkvz", r"in_proj_qkv") - prefix = prefix.replace(r"in_proj_ba", r"in_proj_b") packed_modules_mapping_subset = self.packed_modules_mapping.get(key, {}) prefix_in_quant_config = prefix proj_name = prefix.split(".")[-1] diff --git a/python/sglang/srt/models/qwen3_5.py b/python/sglang/srt/models/qwen3_5.py index 1e1cb0d7b909..ebdd00e002bb 100644 --- a/python/sglang/srt/models/qwen3_5.py +++ b/python/sglang/srt/models/qwen3_5.py @@ -129,6 +129,12 @@ def __init__( self.layer_id = layer_id self.activation = config.hidden_act self.layer_norm_epsilon = config.rms_norm_eps + packed_modules_mapping = { + "in_proj_qkvz": ["in_proj_qkv", "in_proj_z"], + "in_proj_ba": ["in_proj_b", "in_proj_a"], + } + if quant_config is not None and hasattr(quant_config, "packed_modules_mapping"): + quant_config.packed_modules_mapping["model"].update(packed_modules_mapping) # Conv1d layer self.conv_dim = self.key_dim * 2 + self.value_dim @@ -437,7 +443,11 @@ def forward( hidden_states ) - if self.num_v_heads // self.num_k_heads in [1, 2, 4] and not _is_cpu: + if ( + self.num_v_heads // self.num_k_heads in [1, 2, 4] + and not _is_cpu + and not _is_npu + ): mixed_qkv, z, b, a = fused_qkvzba_split_reshape_cat_contiguous( projected_states_qkvz, projected_states_ba, From 8ec5072ddac125ffd6dfd389260bde6adf325219 Mon Sep 17 00:00:00 2001 From: cen121212 Date: Fri, 3 Apr 2026 17:23:10 +0800 Subject: [PATCH 20/42] merge sgl-project main (#221) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * [AMD] Fix AMD CI monitor GitHub API rate limit exhaustion (#21527) * [CI] Register missing jit_kernel test files (#21547) * [diffusion] fix: return None instead of raising RuntimeError when no model info found (#21319) Co-authored-by: Mick * [rl][sgl] fix tensor mismatch after pause (#21514) * [Hicache & JIT_kernel] Support page first layout & mla jit kernel (#18311) * test: point DSV3 int8 MLA CI models to lmsys Hugging Face org (#21561) * [CI] Relax several thresholds in flaky CIs (#21562) * feat: add gc_threshold arg (#21481) Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> * Fix flaky test_pp_single_node (#21564) * Split workflow for releasing runtime docker (#21563) * fix tp capture in vit cuda graph (#17255) * [1/n] lora support - Auto detect lora target modules (#21439) Co-authored-by: Baizhou Zhang * [fix] qwen3.5 fuse_moe_triton_tune bug (#20232) * Remove sync when enabling return_logprob (#20972) * Scope streaming backlog coalescing to incremental_streaming_output mode (#21037) Signed-off-by: Vladislav Nosivskoy Co-authored-by: Lianmin Zheng * docs: flesh out MAINTAINER.md oncall lists and link GitHub profiles (#21575) * [NVIDIA] Enable automatic NUMA configuration (#19452) * [diffusion] UX: aggregate expected dtype-cast logs during weight loading (#21552) * [diffusion] refactor: Unify `TeaCacheParams` and `WanTeaCacheParams` (#20706) Co-authored-by: Mick * [diffusion] chore: remove redundant identity preprocess_text functions(#20633) Co-authored-by: Fengyuan Yu <15fengyuan@gmail.com> * Update CODEOWNERS for transformers.py and docs (#21555) Co-authored-by: Lianmin Zheng * reduce CPU peak memory in multimodal tensor hashing (#21123) * Fix HFRunner hang when subprocess dies during init (#21582) * Fix Piecewise CUDA Graph crash with `-enable-mixed-chunk` (#20441) Co-authored-by: jianyingzhu * [CI] Replace upload/download-artifact with job outputs in release-docker workflow (#21579) Co-authored-by: Claude Opus 4.6 (1M context) * Patch transformers is_base_mistral in CI to avoid HF 429 rate limiting (#21586) * [CI] Move v32 cp test to deepep running suite (#21585) * [AMD] Add GLM-4.7-FP8 accuracy CI test for MI35x (#21534) Co-authored-by: Claude Opus 4.6 * [Clean] Remove deprecated environs (#21536) * [diffusion] fix: fix Flux2-Klein prompt tokenization length to 512 and add regression coverage (#21407) * [CI] hot-fix ci lint (#21608) * [diffusion] feat: support overlay model materialization (#21600) * [VLM] Optimize ShmPointerMMData for multi-pickle safety and deferred unwrap (#21465) * feat: enable CUDA graph and timestamp for the whisper model(#21190) * [NPU] Update quantization&CI documentation (#21100) Co-authored-by: Tamir Baydasov <41994229+TamirBaydasov@users.noreply.github.com> * Skip ci for .md files (#21482) * Support skip-softmax attention (#19089) * fix: piecewise_cuda_graph get correct qo_indptr (#21452) Co-authored-by: Avery Huang * fix bench_serving sglang backend to support image dataset (#21294) * [AMD] Add peft>=0.18.0 to diffusion_hip deps for transformers 5.x compat for AMD diffusion model (#21442) Co-authored-by: HaiShaw * [GDN] Fuse GDN kkt + solve_tril into one kernel (#21411) Co-authored-by: luoyuan.luo * [Diffusion] Align diffusion benchmark skill presets with nightly comparison cases (#21616) Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> * Clean up detokenizer and remove dead multimodal_gen code (#21588) Co-authored-by: Claude Opus 4.6 (1M context) * [CI] Skip flaky elastic EP test (#21619) * feat(ci): add GB300 nightly benchmark test suites (#21487) Co-authored-by: Claude Opus 4.6 (1M context) * [CI] Lossen test_return_routed_experts threshold (#21270) * Add subprocess liveness monitor to detect scheduler crashes (#18582) Co-authored-by: 继优 Co-authored-by: shuwenn <47200617+alphabetc1@users.noreply.github.com> * fix: scheduler launch hang when non-current rank dies (#20287) * Wrap IPv6 addresses in gRPC, bench_serving, and log messages (#21236) Co-authored-by: hnyls2002 Co-authored-by: Liangsheng Yin * [HiCache] fix: graceful shutdown of pending async tasks in bench_mix.py (#20276) * Clean up _wait_for_scheduler_ready implementation (#21626) * fix cuda graph capturing error in sm120 mxfp8 triton path (#19835) * [sgl] disable piecewise cuda graph when a model doesn't have layers (#21565) * [Feature] Optimizations for JPEG input on NVIDIA GPU (#19749) * [VLM] perf: optimize CUDA IPC for multimodal transfer by caching IPC pool handles (#21418) * [Fix] SGLANG_USE_CUDA_IPC_TRANSPORT=1 and SGLANG_ENABLE_MM_SPLITTING=1 do not work at the same time. (#19915) * [Fix] Remove redundant allreduce fusion block and skip TP=1 (#20621) * Simplify routed experts test and move base64 encoding to tokenizer manager (#21634) Co-authored-by: Claude Opus 4.6 (1M context) * [Cleanup] Remove unused BatchMultimodalOutput and BatchMultimodalDecodeReq (#21640) Co-authored-by: Claude Opus 4.6 (1M context) * Clean up TokenizerManager: remove dead code and improve rid validation (#21639) Co-authored-by: Claude Opus 4.6 (1M context) * README: coding agent sponsorship for long-term contributors (#21642) * Fix circular reference in CustomTestCase.__init_subclass__ (#21650) Co-authored-by: wan4ch * [Fix] Fix Qwen3.5 MoE model loading and Mamba cache sharding in PP mode (#21448) Co-authored-by: zhangxiaolei123456 * [diffusion] CI: fix dashboard chart (nightly) display issues (#21653) Co-authored-by: Claude Opus 4.6 Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> * Update sponsorship details in README.md (#21658) * [Fix] Handle pre-release tags in nightly wheel version parsing (#21656) Co-authored-by: Claude Opus 4.6 (1M context) * [Intel GPU] Enable DeepSeek R1 inference on XPU (#18461) Signed-off-by: P V R K Jyothendra Varma * [Doc] Update tips for developer new-comers (#21659) Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> * [CI] [FlashInfer v0.6.7] Use offline quantized checkpoint for MXFP8 Gemm tests (#21625) * MFU metrics in Prometheus (#19395) * fix topk softmax performance issue (#14702) * [CPU] add kernel apply_rotary_pos_emb_cpu for Qwen3-VL and Qwen3-Omni (#13121) Co-authored-by: Ma Mingfei * [CPU] Implement MXFP4 Gemm kernels for intel AMX to support GPT OSS series. (#14385) * [AMD] Fused rope kv store (#21315) Co-authored-by: wunhuang * [NPU] Update DeepSeek-V3.2 model deployment instructions in documentation (#21468) Co-authored-by: wuxue (C) * [AMD] Support AMD MXFP4 Qwen3.5-397B-A17B model (#21234) * [Fix] Fix weight_loader property assignment for qwen3-next FP8 models (#21662) Co-authored-by: Claude Opus 4.6 (1M context) * fix mamba cache leak when adder fails to add a matched req. (#21404) * fix: Mistral Small 4 fails to start due to config/weight format mismatch (#21620) Co-authored-by: mengxiancheng03 Co-authored-by: Baizhou Zhang Co-authored-by: Claude Opus 4.6 (1M context) * [diffusion] feat: enhance overlay mechanism (#21648) * [diffusion] CI: relax pr-test threshold (#21682) * [NPU][Diffusion] fix sp modulate for qwen-image-edit (#20974) Co-authored-by: 高鑫 * [NPU] fix eagle3 accept rate (#21255) * DeepSeek-R1-0528-w4a8: DeepEP Low Latency Dispatch Adopts FP8 Communication (#14162) Co-authored-by: undefined * [NPU] GLM-5 optimize with fused kernels (#18617) * [NPU][diffusion]: support parallel decoding of qwen-image (#20757) Co-authored-by: 高鑫 * [diffusion] [NPU] support ring attention on NPU with FA (#21383) * [diffusion][doc]: add ring sp performance benchmark page (#20998) Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> * [GLM-V and GLM-4.7] Cast to FP32 before gate projection for GLM model. (#21660) * fix nemotron capture for non attention layers (#21436) * [Bugfix][NPU] Skip FRACTAL_NZ format for MoE weights with unaligned dimensions (#21209) Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: ronnie_zheng * [AMD] Add SGLANG_DISAGGREGATION_NUM_PRE_ALLOCATE_REQS env var for configurable KV transfer overlap (#20410) Co-authored-by: HaiShaw * [AMD][MoRI] bump MoRI to v0.1.0 (#21673) * [AMD] fix performance regression issue when run gpt-oss with "--context-length 13824" (#21691) * Remove flashinfer wheel cache cleanup that deletes other versions (#21711) Co-authored-by: Alison Shao * [misc] multiprocess compilation to speed up test (#21483) * Fix human-eval CI install on 5090 runners (#21714) Co-authored-by: Alison Shao * Revert "DeepSeek-R1-0528-w4a8: DeepEP Low Latency Dispatch Adopts FP8 Communication" (#21719) * [Fix] Update supported custom_mem_pool types for mooncake (#21728) Co-authored-by: 百麒 * [Perf]Remove H2D for Qwen3.5 SpecV2 (#20864) * [AMD] Fix CI multimodal-gen-test-1-gpu-amd for gen model (#21621) * [diffusion] fix: fix Flux.2 with tp(#21664) * Add explicit disable flag for FlashInfer allreduce fusion (#21446) * [NPU] fix conflict between empty_cache and use_mem_pool (#21507) * [AMD] Use tgemm.mm for MoEGate router gemm in deepseek_v2.py (#21657) * [CI]Remove msgm-en and mmlu tests which cause timeout (#21733) * Fix disaggregation hybrid attention ci (#21745) * Rename rerun-ut to rerun-test (#21747) * bugfix(model):fix deepstack index out of range error (#21727) Co-authored-by: xiaoqi.31 * [diffusion] fix: fix typo (#21746) Signed-off-by: Xiaodong Ye * [CI] Fix rerun-test suite detection to skip commented registrations (#21753) * [PD] Refactor Disagg Conn and Fix Hang with total_request/total_tokens Balancing (#21299) Co-authored-by: Weiliangl User * [CI] Fix ring test timeout (#21751) * Enable evict swa with piecewise cuda graph (#21754) * Fix kimi-linear launch server error (#21752) Co-authored-by: luoyuan.luo * [PD] Tiny cleanup after KVReceiver refactor (#21760) Signed-off-by: Shangming Cai * Fix remote weight info nnode>1 and dp>1 (#17389) * [diffusion] UX: replace deprecated ORJSONResponse with orjson_response (#21755) Co-authored-by: Claude Opus 4.6 * [diffusion] fix: fix Wan2.2-I2V-A14B video max size issue(#21390) Signed-off-by: Xiaodong Ye Co-authored-by: Mick * [HiMambaTree]: Optimize mamba host lock mechanism (#21750) * [AMD] Fix Handle missing rope_theta in get_rope_config for Grok-1 (#21518) * [bugfix] Fix rope theta config for MiniMax after transformers v5 update (#21241) * Fix ineffective is_base_mistral CI patch for HF API rate limiting (#21729) * [2/n] lora - Shared outer experts and support qwen3_30b_a3b_instruct (#21466) Co-authored-by: Baizhou Zhang * Fix cuda graph max bs capture upper bound (#21005) * [Fix] Fall back to triton MOE for GPT-OSS on Blackwell with driver >= 595 (#21780) Co-authored-by: Claude Opus 4.6 (1M context) * Cache nvidia wheels locally to skip repeated 830 MB downloads in CI (#21778) * Add Trivy vulnerability scanning to nightly dev Docker builds (#21772) Co-authored-by: Claude Opus 4.6 (1M context) * [CI] Remove more redundant PCG tests (#21554) * [moe] add customized option to moe-a2a-backend (#21786) * Add CompletionSampler for non-chat eval in run_eval (#21785) * Remove redundant test_moe_eval_accuracy_large (#21787) * Increase hicache eval to 200 examples (#21791) * Switch MooncakeSpec to EAGLE3 + Llama-3.1 (#21794) * Reduce redundant speculative decoding CI tests (#21779) * Fix killall.py crash when sglang is not yet installed (#21797) * Remove obsolete sgl-kernel legacy paths (#21528) * [jit_kernel] Optimize fused_qknorm_rope: deduplicate sincosf for interleave RoPE (#21654) * CUTLASS NVFP4 GEMM improvement of SM120 (#21314) * [gRPC] Preserve original ImportError in grpc_server.py (#21801) Signed-off-by: Chang Su * [Misc] Tiny: Add test network timeouts and dynamic max-parallel for 5090/2-gpu runners (#21800) * Fix draft extend cuda graph when spec_step=1 (#21709) * [Diffusion] Add `--uvicorn-access-log-exclude-prefixes` to suppress noisy access logs (#20379) * Add latency and throughput metrics to run_eval (#21793) * [diffusion] CI: improve ci reliability (#21763) * [bugfix]GLM-4V model (#17122) * Fix CVEs in Docker image: pillow, linux-libc-dev, and broken sgl-model-gateway build (#21789) Co-authored-by: Claude Opus 4.6 (1M context) * fix: only showing recent runners from ci failure analysis (#21015) * [MPS] Fix Triton stub sub-module imports on Python 3.12+ (#21551) Co-authored-by: karanb192 Co-authored-by: R0CKSTAR Co-authored-by: R0CKSTAR * [KDA] Fuse scaled_dot_kkt + solve_tril + recompute_w_u for KDA (#21604) Co-authored-by: luoyuan.luo * chore: bump flashinfer version to 0.6.7 (#21422) Co-authored-by: sglang-bot Co-authored-by: Baizhou Zhang * [3/n] lora moe - Support Qwen3-VL-30B-A3B-Instruct (#21469) Co-authored-by: Baizhou Zhang * [Feature Restoration] repetition_penalty is essential for GLM-V models (#21258) Co-authored-by: Xinyuan Tong <115166877+JustinTong0323@users.noreply.github.com> Co-authored-by: Xinyuan Tong Co-authored-by: hnyls2002 Co-authored-by: Liangsheng Yin * VLM: change default mm-attention backend from triton_attn to fa4 (on blackwell) (#21595) * Fix added tokens config with sensible filter (#17905) * [AMD] Optimize Qwen3-VL decode - fuse QK-norm + 3D mRoPE + KV cache write (#21458) Co-authored-by: Bingxu Chen Co-authored-by: HaiShaw * [Bugfix] Fix PP tied embeddings weight loading for qwen3.5 4B dense model (#21347) * [CI] Fix lint that was not applied in #21458 (#21818) * Bug fix for llama eagle3 (#21397) * glm_interleave for GLM-V (#21671) * style refinement for hisparse (#21198) * [Bug][VLM] Fix shared memory race condition in ShmPointerMMData broadcast for multi-GPU VLM serving (#21655) * [Bugfix] Fix effective_mamba_size over-allocation (#20858) Co-authored-by: Shangming Cai * Fix in-place mode in pause generation (#21705) * [diffusion] fix: respect --prompt-path (#21756) * [NPU] update ascend docs (#21807) * [VLM] remove AsyncMMDataProcessor wrapper (#21651) * Use CustomTestCase for TestSessionControl to enable CI retry (#21830) * [NPU]Add a full test pipeline on NPU, resolve issues in the NPU test architecture (#20751) * [diffusion][CI]: Add individual component accuracy CI for diffusion models (#18709) Co-authored-by: Xiaoyu Zhang <35585791+BBuf@users.noreply.github.com> * [Feature] JIT rmsnorm update (with claude) (#21834) * [Diffusion][NPU] add ring sp performance benchmark page in npu (#21811) * fix(MiMo-V2-Flash): add mimo reasoning parser (#21414) * [diffusion] hardware: support FA3 attention backend on MUSA (attn backend, 14/N) (#18648) Signed-off-by: Xiaodong Ye Co-authored-by: Mick * fix: pre-init tokenizer_manager to avoid AttributeError in shutdown (#21824) * [FlashInver v0.6.7] Integrate flashinfer_trtllm mxfp8 gemm (#21576) * [Misc] Add network timeout to eval dataset downloads (#21873) Co-authored-by: Claude Opus 4.6 (1M context) * [refactor] Clean up duplicate flashinfer trtllm moe code (#21233) * [DSA] Support trtllm sparse mla kernel for prefill batches (#21783) * [Disagg] GPU staging buffer with dynamic ring allocator for heterogeneous TP KV transfer (#19890) * Add merge prohibition policy during CI maintenance mode (#21882) * [Misc] Fix comparator e2e tests: add polars dep + fix dp-attention test (#21804) Co-authored-by: Alison Shao * revert: remove TTL-based hard pin from HiRadixCache (#21884) * Unify GSM8K eval path to Chat API for regression CI readiness (#21667) * [HiCache] fix: Clone host indices to avoid memory leak (#21624) Co-authored-by: Zhiqiang Xie * [HiCache & PD]Fixed detailed cache hit breakdown in PD scenarios. (#21764) * [CI] Add Llama 3.1 8B Instruct FP4 CI test on SM120 (#20648) * [CI] Add Per-Tensor, Blockwise FP8 Tests on SM120 (#20717) Co-authored-by: Brayden Zhong * Allow /rerun-test to checkout fork PR branch for trusted users (#21890) * Direct model loading from object storage with Runai Model Streamer (#17948) Signed-off-by: Noa Neria * fix pcg torch dynamo recompile in mxfp8 Triton path (#21888) Co-authored-by: Hanlin Bi * chore: bump mooncake version to 0.3.10.post1 (#21844) * [VLM] Add VLM TP=4 per-commit CI test and improve MMMU eval prompt/parser (#21841) * fix(ci): update est_time for 57 tests based on runtime analysis (#21896) Co-authored-by: Claude Opus 4.6 (1M context) * [CI] Increase multimodal server test timeout from 60 to 90 minutes (#21897) * [CI] Remove crashing Kimi K2.5 EAGLE3/MTP variants, keep TP8 and TP8+DP8 (#21898) * [diffusion] CI: add initial nvfp4 ci test for b200 (#21767) Co-authored-by: Mick * Migrate all callers from /get_server_info to /server_info (#21463) * Support PP key for file backend (#21901) * Enable multi-thread weight loading by default (#20289) * Skip Go stdlib and NVIDIA tool CVEs in Trivy scan (#21905) Co-authored-by: Claude Opus 4.6 (1M context) * [Kernel] Fuse temperature + softmax in sampling for decode speedup (#20501) * Multi tool streaming fix (#20004) * Return HTTP 400 for streaming validation errors (#21900) * [Spec][Ngram] 4/N: Remove `max_match_window_size` and `min_match_window_size`, matching all suffixes of the Trie (#21225) * Fix ngram doc for speculative_num_draft_tokens default (#21910) * [NVIDIA] Enable fp8 flashinfer_trtllm_routed MoE for MiniMax-M2.5 (#20394) * scheduler: add prefill-only update in merge batch (#21840) * [DSA] Set trtllm kernels as nsa default for Blackwell (#21914) * Revert "Rollback flashmla to older version [1/2]" (#21922) * test: add manual init test for mooncake transfer engine (#21842) Co-authored-by: yunzhi * Fix spec v2 + logprob when max_num_token is set (#20799) * Migrate ngram corpus from torch cpp_extension to TVM FFI jit_kernel (#21920) Co-authored-by: DarkSharpness <2040703891@qq.com> * [NPU] Support GLM-4.7-Flash on NPU (#21408) * [CI] Fix gpu deps import in cpu test (#21950) * [Parallel State Refactor 1/n] Remove stream of PyNCCL (#20866) * [diffusion] chore: fix stage profiler for multi-stage denoising (#21955) * [CI] [Tracing] Add ci for tracing and fix bugs (#21740) * Remove logging for subprocess watchdog start (#21968) * [4/n] Support gpt oss 20b lora (#21570) * [MUSA][9/N] Add FA3 attention backend support through MATE (MUSA AI Tensor Engine) (#17985) Co-authored-by: R0CKSTAR * [Feature] Stronger transformers modeling backend with TP, PP, MoE, VLMs, and torch compile (#19163) * [CI] Remove stale Ascend suite entries from test/srt/run_suite.py (#21978) Co-authored-by: Claude Opus 4.6 (1M context) * Skip broken AutoModel mapping entries when resolving Llava submodules (#21892) * [CI] Add timeouts to Slack upload urlopen and WebClient (#21903) Co-authored-by: Claude Opus 4.6 (1M context) * [Diffusion][NPU] Add support for MOVA (#21633) Co-authored-by: zhangshuai (S) * Remove maxItems=1 restriction when tool_choice is specified (#20208) * [Feature] NVFP4 Marlin fallback for non-Blackwell GPUs (SM75+) (#19652) * [PP] qwen3 vl skip layer id for pp (#19135) * [VLM] Enable per-image MM splitting by default and remove MULTI_IMAGES modality (#21899) * [Bugfix] Fix incorrect dp-attention parallel info in bench_one_batch (#21519) * Revert "[MUSA][9/N] Add FA3 attention backend support through MATE (MUSA AI Tensor Engine)" (#22002) * [NPU] Optimized the wording in the npu docs (#21998) * [Parallel State Refactor 2/n] Unify code path of AMD deterministic all reduce (#20871) * [AMD] Resolve the performance degression when launch server with "--enable-aiter-allreduce-fusion" (#21947) Co-authored-by: wunhuang * chore: bump sgl-kernel version to 0.4.1 (#21447) Co-authored-by: sglang-bot * [Workflow] Avoid triggering nightly tests in kernel bump workflow (#22010) * [Workflow] Fix kernel release jobs skipped on push events (#22011) Co-authored-by: Claude Opus 4.6 (1M context) * [PD]: Add support for HiSparse to directly transfer the cache from Prefill to Decode DRAM. (#21591) Co-authored-by: Tingwei Huang Co-authored-by: Shangming Cai Co-authored-by: Zhiqiang Xie * [Misc] Update CI permission (#22014) * [ROCM][RL] Shuffle Weight In-Place to Preserve Parameter Attributes (#21825) * [CI] Fix duplicate job names that bypass branch protection (#22001) * fix: remove duplicate words in comments (#22007) * [PD] Tiny register info field cleanup for mooncake backend (#22016) * [NPU] optimize glm4.7 (#19246) * [AMD] Enable FP8 KV cache and FP8 attention kernel for NSA on MI300/MI355 with TileLang backend (#21511) * [AMD] Add MiniMax-M2.5 nightly perf benchmarks for MI30x and MI35x (#21524) --------- Signed-off-by: Vladislav Nosivskoy Signed-off-by: P V R K Jyothendra Varma Signed-off-by: Xiaodong Ye Signed-off-by: Shangming Cai Signed-off-by: Chang Su Signed-off-by: Noa Neria Co-authored-by: Bingxu Chen Co-authored-by: Xiaoyu Zhang <35585791+BBuf@users.noreply.github.com> Co-authored-by: yang1002378395-cmyk Co-authored-by: Mick Co-authored-by: Bi Xue Co-authored-by: huangtingwei <141888744+huangtingwei9988@users.noreply.github.com> Co-authored-by: Lianmin Zheng Co-authored-by: Baizhou Zhang Co-authored-by: Muqi Li Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: Qiaolin Yu Co-authored-by: narutolhy <582909902@qq.com> Co-authored-by: Ethan (Yusheng) Su Co-authored-by: zhangxiaolei Co-authored-by: Vladislav Nosivskoy Co-authored-by: Trevor Morris Co-authored-by: Eitan Turok <150733043+eitanturok@users.noreply.github.com> Co-authored-by: Fengyuan Yu Co-authored-by: Fengyuan Yu <15fengyuan@gmail.com> Co-authored-by: Adarsh Shirawalmath <114558126+adarshxs@users.noreply.github.com> Co-authored-by: Yuhao Yang <47235274+yhyang201@users.noreply.github.com> Co-authored-by: Liangsheng Yin Co-authored-by: Jianying <53503712+jianyingzhu@users.noreply.github.com> Co-authored-by: jianyingzhu Co-authored-by: Kangyan-Zhou Co-authored-by: Claude Opus 4.6 (1M context) Co-authored-by: Jacob0226 Co-authored-by: Aditya Sharma <89210949+adityavaid@users.noreply.github.com> Co-authored-by: Yuan Luo Co-authored-by: Xinyuan Tong <115166877+JustinTong0323@users.noreply.github.com> Co-authored-by: Артем Давкин <58187114+OrangeRedeng@users.noreply.github.com> Co-authored-by: Tamir Baydasov <41994229+TamirBaydasov@users.noreply.github.com> Co-authored-by: Shu Wang Co-authored-by: eigen <52445717+yyihuang@users.noreply.github.com> Co-authored-by: Avery Huang Co-authored-by: jacky.cheng Co-authored-by: HaiShaw Co-authored-by: luoyuan.luo Co-authored-by: Shangming Cai Co-authored-by: Junrong Lin <33685709+ocss884@users.noreply.github.com> Co-authored-by: Simon (Jiyou) Li Co-authored-by: 继优 Co-authored-by: shuwenn <47200617+alphabetc1@users.noreply.github.com> Co-authored-by: psaab Co-authored-by: hnyls2002 Co-authored-by: Hanlin Bi <52993433+wolfcomos@users.noreply.github.com> Co-authored-by: wili <98001977+wili-65535@users.noreply.github.com> Co-authored-by: saatwiknagpal Co-authored-by: Mohammad Miadh Angkad <176301910+mmangkad@users.noreply.github.com> Co-authored-by: wan4ch Co-authored-by: Feng Su Co-authored-by: Ying Sheng Co-authored-by: Polisetty V R K Jyothendra Varma Co-authored-by: Ziang Li Co-authored-by: Aishwarya Ramasethu <56765596+aramasethu@users.noreply.github.com> Co-authored-by: Ma Mingfei Co-authored-by: blzheng Co-authored-by: kk <43161300+kkHuang-amd@users.noreply.github.com> Co-authored-by: wunhuang Co-authored-by: Michelle Wu Co-authored-by: wuxue (C) Co-authored-by: Hubert Lu <55214931+hubertlu-tw@users.noreply.github.com> Co-authored-by: strgrb Co-authored-by: LiYomi <106872109+LiYomi@users.noreply.github.com> Co-authored-by: mengxiancheng03 Co-authored-by: GXIN <37653830+gxxx-hum@users.noreply.github.com> Co-authored-by: 高鑫 Co-authored-by: heziiop Co-authored-by: xieminghe1 <141820649+xieminghe1@users.noreply.github.com> Co-authored-by: undefined Co-authored-by: Makcum888e <79456407+Makcum888e@users.noreply.github.com> Co-authored-by: yuefeng Wu <33725817+ChefWu551@users.noreply.github.com> Co-authored-by: Yuxuan Zhang <2448370773@qq.com> Co-authored-by: Vedant V Jhaveri Co-authored-by: ronnie_zheng Co-authored-by: Zhai Feiyue <80079571+ZhaiFeiyue@users.noreply.github.com> Co-authored-by: jhchouuu Co-authored-by: Alison Shao <54658187+alisonshao@users.noreply.github.com> Co-authored-by: Alison Shao Co-authored-by: DarkSharpness <76582120+DarkSharpness@users.noreply.github.com> Co-authored-by: Alison Shao Co-authored-by: Lewis <63569348+TTThanos@users.noreply.github.com> Co-authored-by: 百麒 Co-authored-by: Jincong Chen Co-authored-by: xiazhahe <86939755+xiazhahe@users.noreply.github.com> Co-authored-by: Thomas Wang Co-authored-by: Ke Bao Co-authored-by: xiaoqi Co-authored-by: xiaoqi.31 Co-authored-by: R0CKSTAR Co-authored-by: weireweire Co-authored-by: Weiliangl User Co-authored-by: JD Co-authored-by: Zhangheng Co-authored-by: Michael <13900043+michaelzhang-ai@users.noreply.github.com> Co-authored-by: Yilong Zhao <74357408+happierpig@users.noreply.github.com> Co-authored-by: Johnsonms Co-authored-by: Brayden Zhong Co-authored-by: Chang Su Co-authored-by: KnightLTC <56717110+KnightLTC@users.noreply.github.com> Co-authored-by: Douglas Yang Co-authored-by: Karan Bansal Co-authored-by: karanb192 Co-authored-by: R0CKSTAR Co-authored-by: sglang-bot Co-authored-by: sglang-bot Co-authored-by: Xinyuan Tong Co-authored-by: sbeurnier Co-authored-by: YC Yen-Ching Tseng Co-authored-by: Wenyao Gao <105094497+edwingao28@users.noreply.github.com> Co-authored-by: Alex Nails Co-authored-by: khalilzhk Co-authored-by: Zhiqiang Xie Co-authored-by: yudian0504 <138860534+yudian0504@users.noreply.github.com> Co-authored-by: yunkchen Co-authored-by: wduan-hai Co-authored-by: amote-i <49533125+amote-i@users.noreply.github.com> Co-authored-by: Cherry_ming <136634645@qq.com> Co-authored-by: Ratish P <114130421+Ratish1@users.noreply.github.com> Co-authored-by: YAMY <74099316+YAMY1234@users.noreply.github.com> Co-authored-by: Alison Shao Co-authored-by: ishandhanani <82981111+ishandhanani@users.noreply.github.com> Co-authored-by: Derek Yu <81697272+DerekY2@users.noreply.github.com> Co-authored-by: Noa Neria Co-authored-by: Hanlin Bi Co-authored-by: Prozac614 Co-authored-by: David Cheung Co-authored-by: Mook <68294499+Godmook@users.noreply.github.com> Co-authored-by: Khoa Pham Co-authored-by: foraxe <73625538+foraxe@users.noreply.github.com> Co-authored-by: yunzhi Co-authored-by: DarkSharpness <2040703891@qq.com> Co-authored-by: Todobe <43903496+Todobe@users.noreply.github.com> Co-authored-by: ori <39351881+froststeam@users.noreply.github.com> Co-authored-by: Thomas Co-authored-by: zhangshuai (S) Co-authored-by: lviy <142899752+lviy@users.noreply.github.com> Co-authored-by: Tingwei Huang Co-authored-by: Yuzhen Zhou <82826991+zyzshishui@users.noreply.github.com> Co-authored-by: Ricardo-M-L <69202550+Ricardo-M-L@users.noreply.github.com> Co-authored-by: Kelon Co-authored-by: cen121212 --- .claude/skills/ci-workflow-guide/SKILL.md | 2 +- .claude/skills/write-sglang-test/SKILL.md | 22 +- .github/CI_PERMISSIONS.json | 808 ++++---- .github/CODEOWNERS | 5 +- .github/MAINTAINER.md | 104 +- .github/actions/check-maintenance/action.yml | 4 +- .github/audit_permission.py | 411 +++++ .github/pull_request_template.md | 10 +- .github/update_ci_permission.py | 10 +- .github/workflows/amd-ci-job-monitor.yml | 205 +- .github/workflows/auto-tune.yml | 2 +- .../bot-bump-kernel-version-to-sglang.yml | 40 - .github/workflows/full-test-npu.yml | 355 ++++ .../workflows/nightly-test-amd-rocm720.yml | 63 +- .github/workflows/nightly-test-amd.yml | 28 +- .github/workflows/nightly-test-npu.yml | 160 +- .github/workflows/nightly-test-nvidia.yml | 4 +- .github/workflows/pr-test-amd-rocm720.yml | 8 +- .github/workflows/pr-test-amd.yml | 8 +- .github/workflows/pr-test-multimodal-gen.yml | 151 ++ .github/workflows/pr-test-npu.yml | 114 +- .github/workflows/pr-test-xeon.yml | 6 +- .github/workflows/pr-test-xpu.yml | 6 +- .github/workflows/pr-test.yml | 56 +- .../workflows/release-docker-npu-nightly.yml | 2 +- .github/workflows/release-docker-runtime.yml | 309 ++++ .github/workflows/release-docker.yml | 176 +- .github/workflows/release-pypi-nightly.yml | 8 +- .github/workflows/release-whl-kernel.yml | 8 +- .../{rerun-ut.yml => rerun-test.yml} | 8 +- .github/workflows/slash-command-handler.yml | 28 +- .github/workflows/trivy-scan-dev.yml | 85 + .pre-commit-config.yaml | 6 + 3rdparty/amd/wheel/sglang/pyproject.toml | 2 +- README.md | 4 +- .../bench_cutedsl_kda_decode.py | 21 +- benchmark/hicache/bench_mix.py | 9 +- .../bench_fused_temperature_softmax.py | 108 ++ .../kernels/fused_moe_triton/common_utils.py | 11 +- .../tuning_fused_moe_triton.py | 13 +- .../quantization/tuning_block_wise_kernel.py | 42 +- docker/Dockerfile | 19 +- docker/rocm.Dockerfile | 4 +- docs/advanced_features/object_storage.md | 108 ++ docs/advanced_features/quantization.md | 124 +- docs/advanced_features/server_arguments.md | 13 +- docs/advanced_features/sgl_model_gateway.md | 4 +- .../advanced_features/speculative_decoding.md | 11 +- docs/basic_usage/deepseek_v3.md | 2 +- docs/basic_usage/native_api.ipynb | 4 +- docs/developer_guide/bench_serving.md | 2 +- docs/developer_guide/contribution_guide.md | 25 +- docs/diffusion/api/cli.md | 37 + docs/diffusion/installation.md | 2 +- .../performance/attention_backends.md | 4 +- docs/diffusion/performance/index.md | 6 + .../performance/ring_sp_performance.md | 67 + docs/diffusion/quantization.md | 68 +- docs/get_started/install.md | 2 +- docs/index.rst | 5 +- .../{ => ascend}/ascend_contribution_guide.md | 17 +- docs/platforms/{ => ascend}/ascend_npu.md | 2 +- .../{ => ascend}/ascend_npu_best_practice.md | 353 ++-- .../ascend_npu_deepseek_example.md | 1 - .../ascend_npu_environment_variables.md | 0 .../{ => ascend}/ascend_npu_glm5_examples.md | 2 +- .../ascend/ascend_npu_quantization.md | 52 + .../ascend_npu_qwen3_5_examples.md | 2 +- .../{ => ascend}/ascend_npu_qwen3_examples.md | 91 + .../{ => ascend}/ascend_npu_support.rst | 2 + .../ascend_npu_support_features.md | 1 - .../{ => ascend}/ascend_npu_support_models.md | 0 .../{ => ascend}/mindspore_backend.md | 0 docs/platforms/ascend_npu_quantization.md | 27 - .../ascend_npu_ring_sp_performance.md | 55 + docs/references/environment_variables.md | 5 +- docs/references/production_metrics.md | 38 +- .../extending/mindspore_models.md | 2 +- python/pyproject.toml | 8 +- python/pyproject_npu.toml | 2 + python/pyproject_other.toml | 3 +- python/sglang/_triton_stub.py | 30 +- python/sglang/bench_one_batch.py | 4 +- python/sglang/bench_serving.py | 36 +- python/sglang/benchmark/datasets/__init__.py | 2 + python/sglang/benchmark/datasets/image.py | 21 +- .../sglang/benchmark/datasets/longbench_v2.py | 104 ++ python/sglang/check_env.py | 20 +- python/sglang/cli/killall.py | 14 +- python/sglang/cli/utils.py | 27 +- python/sglang/jit_kernel/all_reduce.py | 4 +- .../sglang/jit_kernel/benchmark/bench_cast.py | 28 +- .../benchmark/bench_fused_add_rmsnorm.py | 75 - .../benchmark/bench_fused_qknorm_rope.py | 92 +- .../sglang/jit_kernel/benchmark/bench_norm.py | 83 +- .../benchmark/bench_nvfp4_scaled_mm.py | 6 +- .../jit_kernel/benchmark/bench_renorm.py | 83 - .../jit_kernel/benchmark/bench_rmsnorm.py | 98 - .../diffusion/bench_fused_norm_scale_shift.py | 6 +- python/sglang/jit_kernel/benchmark/utils.py | 19 +- .../csrc/elementwise/fused_qknorm_rope.cuh | 149 +- .../jit_kernel/csrc/elementwise/rmsnorm.cuh | 180 ++ .../csrc/gemm/marlin/marlin_template.h | 25 +- .../csrc/gemm/marlin_moe/marlin_template.h | 53 +- .../csrc/gemm/marlin_moe/moe_wna16_marlin.cuh | 2 + .../gemm/nvfp4/nvfp4_scaled_mm_common.cuh | 66 + .../gemm/nvfp4/nvfp4_scaled_mm_kernels.cuh | 598 +----- .../csrc/gemm/nvfp4/nvfp4_scaled_mm_sm100.cuh | 284 +++ .../csrc/gemm/nvfp4/nvfp4_scaled_mm_sm120.cuh | 228 +++ python/sglang/jit_kernel/csrc/hicache.cuh | 166 +- python/sglang/jit_kernel/csrc/hisparse.cuh | 112 +- .../csrc/ngram_corpus}/ngram.cpp | 34 +- .../csrc/ngram_corpus}/ngram.h | 9 +- .../csrc/ngram_corpus/ngram_corpus_ffi.cpp | 104 ++ .../csrc/ngram_corpus}/param.h | 22 - .../csrc/ngram_corpus}/queue.h | 0 .../csrc/ngram_corpus}/result.cpp | 0 .../csrc/ngram_corpus}/result.h | 0 .../csrc/ngram_corpus}/trie.cpp | 29 +- .../csrc/ngram_corpus}/trie.h | 8 +- .../diffusion/triton/npu_fallback.py | 21 + .../diffusion/triton/scale_shift.py | 42 +- python/sglang/jit_kernel/fused_qknorm_rope.py | 41 +- python/sglang/jit_kernel/hicache.py | 64 + python/sglang/jit_kernel/moe_wna16_marlin.py | 24 +- python/sglang/jit_kernel/ngram_corpus.py | 88 + python/sglang/jit_kernel/norm.py | 21 +- python/sglang/jit_kernel/tests/test_cast.py | 4 + .../tests/test_custom_all_reduce.py | 56 +- .../tests/test_fused_qknorm_rope.py | 8 +- .../sglang/jit_kernel/tests/test_hicache.py | 247 +++ .../sglang/jit_kernel/tests/test_norm_jit.py | 145 -- python/sglang/jit_kernel/tests/test_renorm.py | 39 - .../sglang/jit_kernel/tests/test_rmsnorm.py | 95 +- python/sglang/jit_kernel/utils.py | 66 +- .../sglang/lang/backend/runtime_endpoint.py | 4 +- .../benchmark-and-profile.md | 122 +- .../scripts/bench_diffusion_denoise.py | 181 +- .../benchmarks/bench_serving.py | 3 +- .../configs/models/vaes/qwenimage.py | 2 + .../configs/pipeline_configs/base.py | 8 +- .../configs/pipeline_configs/flux.py | 25 +- .../configs/pipeline_configs/hunyuan.py | 8 +- .../configs/pipeline_configs/ltx_2.py | 5 +- .../configs/pipeline_configs/sana.py | 5 +- .../configs/pipeline_configs/wan.py | 2 +- .../multimodal_gen/configs/sample/hunyuan.py | 1 + .../configs/sample/sampling_params.py | 5 + .../multimodal_gen/configs/sample/teacache.py | 95 +- .../multimodal_gen/configs/sample/wan.py | 123 +- python/sglang/multimodal_gen/registry.py | 28 +- .../multimodal_gen/runtime/cache/teacache.py | 2 +- .../runtime/entrypoints/cli/generate.py | 18 +- .../entrypoints/diffusion_generator.py | 15 +- .../runtime/entrypoints/http_server.py | 6 +- .../runtime/entrypoints/openai/common_api.py | 12 +- .../runtime/entrypoints/openai/protocol.py | 2 + .../runtime/entrypoints/openai/video_api.py | 2 + .../entrypoints/post_training/weights_api.py | 12 +- .../layers/attention/backends/ascend_fa.py | 105 ++ .../multimodal_gen/runtime/layers/usp.py | 2 +- .../component_loaders/text_encoder_loader.py | 43 +- .../runtime/loader/fsdp_load.py | 94 +- .../runtime/loader/weight_utils.py | 7 +- .../runtime/models/dits/flux_2.py | 25 +- .../runtime/models/dits/mova_video_dit.py | 9 +- .../runtime/models/dits/qwen_image.py | 138 +- .../runtime/models/dits/wanvideo.py | 23 +- .../models/vaes/autoencoder_kl_qwenimage.py | 156 +- .../runtime/models/vaes/common.py | 201 +- .../pipelines_core/stages/denoising.py | 1 + .../pipelines_core/stages/denoising_av.py | 1 + .../pipelines_core/stages/denoising_dmd.py | 1 + .../pipelines_core/stages/input_validation.py | 24 +- .../stages/model_specific_stages/mova.py | 39 +- .../pipelines_core/stages/text_encoding.py | 10 +- .../multimodal_gen/runtime/platforms/musa.py | 59 +- .../multimodal_gen/runtime/platforms/npu.py | 4 + .../multimodal_gen/runtime/server_args.py | 10 + .../runtime/utils/hf_diffusers_utils.py | 29 +- .../runtime/utils/logging_utils.py | 57 +- .../runtime/utils/model_overlay.py | 650 +++++++ .../runtime/utils/perf_logger.py | 19 +- .../sglang/multimodal_gen/test/run_suite.py | 11 +- .../test/scripts/gen_diffusion_ci_outputs.py | 5 +- .../test/server/accuracy_config.py | 386 ++++ .../test/server/accuracy_hooks.py | 579 ++++++ .../test/server/accuracy_utils.py | 880 +++++++++ .../server/ascend/perf_baselines_npu.json | 65 + .../server/ascend/testcase_configs_npu.py | 12 + .../test/server/component_accuracy.py | 534 ++++++ .../test/server/perf_baselines.json | 383 ++-- .../test/server/test_accuracy_1_gpu_a.py | 37 + .../test/server/test_accuracy_1_gpu_b.py | 37 + .../test/server/test_accuracy_2_gpu_a.py | 37 + .../test/server/test_accuracy_2_gpu_b.py | 37 + .../test/server/test_server_c.py | 28 + .../test/server/test_server_utils.py | 37 +- .../test/server/testcase_configs.py | 2 +- .../sglang/multimodal_gen/test/slack_utils.py | 4 +- .../test/unit/test_input_validation.py | 164 ++ .../test/unit/test_resolve_prompts.py | 99 + .../test/unit/test_sampling_params.py | 65 +- python/sglang/profiler.py | 2 +- .../compilation/piecewise_context_manager.py | 3 + python/sglang/srt/configs/load_config.py | 1 + python/sglang/srt/configs/model_config.py | 43 +- python/sglang/srt/constants.py | 2 + python/sglang/srt/disaggregation/base/conn.py | 12 +- .../sglang/srt/disaggregation/common/conn.py | 62 +- .../disaggregation/common/staging_buffer.py | 768 ++++++++ .../disaggregation/common/staging_handler.py | 732 ++++++++ python/sglang/srt/disaggregation/decode.py | 269 ++- .../srt/disaggregation/encode_grpc_server.py | 2 +- .../srt/disaggregation/encode_receiver.py | 15 +- .../srt/disaggregation/encode_server.py | 9 +- python/sglang/srt/disaggregation/fake/conn.py | 25 +- .../srt/disaggregation/mooncake/conn.py | 563 +++++- .../srt/disaggregation/mooncake/utils.py | 2 +- python/sglang/srt/disaggregation/mori/conn.py | 27 +- python/sglang/srt/disaggregation/nixl/conn.py | 18 +- python/sglang/srt/disaggregation/prefill.py | 33 + python/sglang/srt/disaggregation/utils.py | 27 + .../device_communicators/custom_all_reduce.py | 118 +- .../device_communicators/pynccl.py | 80 +- .../sglang/srt/distributed/parallel_state.py | 61 +- python/sglang/srt/entrypoints/engine.py | 153 +- .../engine_info_bootstrap_server.py | 105 ++ python/sglang/srt/entrypoints/grpc_server.py | 8 +- python/sglang/srt/entrypoints/http_server.py | 131 +- .../sglang/srt/entrypoints/openai/protocol.py | 22 + .../srt/entrypoints/openai/serving_chat.py | 36 +- .../entrypoints/openai/serving_completions.py | 23 +- .../openai/serving_transcription.py | 121 +- python/sglang/srt/environ.py | 46 +- .../srt/function_call/base_format_detector.py | 23 +- .../srt/function_call/function_call_parser.py | 8 +- python/sglang/srt/function_call/utils.py | 13 +- .../npu/attention/ascend_backend.py | 7 + .../modules/deepseek_v2_attention_mla_npu.py | 90 +- .../sglang/srt/hardware_backend/npu/utils.py | 102 +- .../srt/layers/attention/aiter_backend.py | 5 +- .../sglang/srt/layers/attention/fla/chunk.py | 34 +- .../srt/layers/attention/fla/chunk_fwd.py | 416 +++++ .../srt/layers/attention/fla/chunk_intra.py | 661 +++++++ .../fla/chunk_intra_token_parallel.py | 197 ++ .../layers/attention/fla/fused_recurrent.py | 3 + python/sglang/srt/layers/attention/fla/kda.py | 22 +- .../sglang/srt/layers/attention/fla/utils.py | 14 + .../srt/layers/attention/fla/wy_fast.py | 6 +- .../layers/attention/flashinfer_backend.py | 60 +- .../srt/layers/attention/mamba/mamba.py | 2 +- .../srt/layers/attention/nsa/nsa_indexer.py | 146 +- .../layers/attention/nsa/tilelang_kernel.py | 349 +++- .../srt/layers/attention/nsa_backend.py | 11 + .../layers/attention/trtllm_mha_backend.py | 3 + .../layers/attention/trtllm_mla_backend.py | 4 + python/sglang/srt/layers/attention/utils.py | 735 +++++++- python/sglang/srt/layers/attention/vision.py | 6 +- python/sglang/srt/layers/fused_sampling.py | 371 ++++ python/sglang/srt/layers/linear.py | 2 +- python/sglang/srt/layers/moe/ep_moe/layer.py | 20 - .../moe/fused_moe_triton/fused_marlin_moe.py | 43 +- .../srt/layers/moe/fused_moe_triton/layer.py | 319 ---- .../moe/moe_runner/flashinfer_trtllm.py | 10 +- .../srt/layers/moe/moe_runner/marlin.py | 10 +- .../layers/moe/token_dispatcher/standard.py | 12 +- python/sglang/srt/layers/moe/utils.py | 4 + .../schemes/compressed_tensors_w4a4_nvfp4.py | 38 +- .../compressed_tensors_w4a4_nvfp4_moe.py | 74 +- .../srt/layers/quantization/fp4_utils.py | 28 +- python/sglang/srt/layers/quantization/fp8.py | 54 +- .../srt/layers/quantization/fp8_kernel.py | 4 +- .../srt/layers/quantization/fp8_utils.py | 145 +- .../srt/layers/quantization/marlin_utils.py | 2 +- .../layers/quantization/marlin_utils_fp4.py | 320 ++++ .../srt/layers/quantization/modelopt_quant.py | 116 +- .../quantization/modelslim/modelslim.py | 4 +- .../sglang/srt/layers/quantization/unquant.py | 16 +- python/sglang/srt/layers/rocm_linear_utils.py | 26 +- .../srt/layers/rotary_embedding/base.py | 94 +- .../srt/layers/rotary_embedding/factory.py | 3 + .../srt/layers/rotary_embedding/mrope.py | 35 + .../layers/rotary_embedding/triton_kernels.py | 18 +- .../srt/layers/rotary_embedding/utils.py | 6 +- python/sglang/srt/layers/sampler.py | 37 +- python/sglang/srt/lora/layers.py | 103 +- python/sglang/srt/lora/lora.py | 31 +- python/sglang/srt/lora/lora_config.py | 34 +- python/sglang/srt/lora/lora_manager.py | 97 +- python/sglang/srt/lora/lora_moe_runners.py | 27 +- python/sglang/srt/lora/mem_pool.py | 186 +- .../srt/lora/triton_ops/sgemm_lora_b.py | 7 +- python/sglang/srt/lora/utils.py | 65 +- .../srt/managers/async_mm_data_processor.py | 122 -- .../srt/managers/detokenizer_manager.py | 104 +- .../srt/managers/hisparse_coordinator.py | 65 +- python/sglang/srt/managers/io_struct.py | 93 +- python/sglang/srt/managers/mm_utils.py | 263 +-- .../srt/managers/multi_tokenizer_mixin.py | 12 - .../srt/managers/multimodal_processor.py | 32 +- python/sglang/srt/managers/schedule_batch.py | 65 +- python/sglang/srt/managers/scheduler.py | 234 +-- .../scheduler_output_processor_mixin.py | 69 +- .../scheduler_runtime_checker_mixin.py | 5 +- .../managers/tokenizer_communicator_mixin.py | 22 - .../sglang/srt/managers/tokenizer_manager.py | 272 +-- python/sglang/srt/managers/tp_worker.py | 6 - .../srt/mem_cache/hi_mamba_radix_cache.py | 77 +- .../sglang/srt/mem_cache/hicache_storage.py | 14 +- python/sglang/srt/mem_cache/hiradix_cache.py | 131 +- .../srt/mem_cache/hisparse_memory_pool.py | 27 + .../sglang/srt/mem_cache/mamba_radix_cache.py | 16 +- python/sglang/srt/mem_cache/memory_pool.py | 63 +- .../sglang/srt/mem_cache/memory_pool_host.py | 217 ++- python/sglang/srt/mem_cache/radix_cache.py | 4 - python/sglang/srt/mem_cache/utils.py | 87 + .../srt/model_executor/cuda_graph_runner.py | 7 +- .../srt/model_executor/forward_batch_info.py | 22 +- .../sglang/srt/model_executor/model_runner.py | 96 +- .../model_runner_kv_cache_mixin.py | 36 +- .../piecewise_cuda_graph_runner.py | 27 +- python/sglang/srt/model_loader/loader.py | 246 ++- .../remote_instance_weight_loader_utils.py | 15 - python/sglang/srt/model_loader/utils.py | 156 +- .../sglang/srt/model_loader/weight_utils.py | 38 +- .../attention_forward_methods/forward_mla.py | 97 +- .../deepseek_common/deepseek_weight_loader.py | 3 +- .../srt/models/deepseek_common/utils.py | 2 + python/sglang/srt/models/deepseek_nextn.py | 24 +- python/sglang/srt/models/deepseek_v2.py | 16 +- python/sglang/srt/models/glm4_moe.py | 79 +- python/sglang/srt/models/glm4_moe_nextn.py | 21 +- python/sglang/srt/models/glm4v.py | 2 - python/sglang/srt/models/gpt_oss.py | 8 + python/sglang/srt/models/grok.py | 6 +- python/sglang/srt/models/llava.py | 101 +- python/sglang/srt/models/minicpmv.py | 18 +- python/sglang/srt/models/minimax_m2.py | 6 +- python/sglang/srt/models/qwen2.py | 1 + python/sglang/srt/models/qwen3.py | 106 +- python/sglang/srt/models/qwen3_5.py | 72 +- python/sglang/srt/models/qwen3_moe.py | 2 + python/sglang/srt/models/qwen3_next.py | 79 +- python/sglang/srt/models/qwen3_vl.py | 13 + python/sglang/srt/models/qwen3_vl_moe.py | 15 +- python/sglang/srt/models/transformers.py | 1641 +++++++++++++++-- python/sglang/srt/models/utils.py | 205 +- python/sglang/srt/models/whisper.py | 123 +- .../multimodal/processors/base_processor.py | 50 +- .../srt/multimodal/processors/internvl.py | 37 +- .../srt/multimodal/processors/kimi_k25.py | 1 + .../srt/multimodal/processors/kimi_vl.py | 1 + .../sglang/srt/multimodal/processors/llava.py | 35 +- .../srt/multimodal/processors/minicpm.py | 27 +- .../multimodal/processors/nano_nemotron_vl.py | 3 + .../srt/multimodal/processors/pixtral.py | 1 + .../srt/multimodal/processors/qwen_audio.py | 2 +- .../srt/multimodal/processors/qwen_vl.py | 3 +- .../processors/transformers_auto.py | 215 +++ .../srt/multimodal/processors/whisper.py | 29 +- .../srt/multimodal/vit_cuda_graph_runner.py | 8 +- .../srt/observability/metrics_collector.py | 43 + .../observability/request_metrics_exporter.py | 3 +- .../observability/scheduler_metrics_mixin.py | 187 +- python/sglang/srt/parser/reasoning_parser.py | 1 + python/sglang/srt/ray/engine.py | 20 +- python/sglang/srt/ray/http_server.py | 19 +- .../srt/sampling/penaltylib/__init__.py | 2 + .../srt/sampling/penaltylib/orchestrator.py | 62 +- .../sampling/penaltylib/repetition_penalty.py | 78 + .../srt/sampling/sampling_batch_info.py | 27 +- python/sglang/srt/server_args.py | 290 ++- .../srt/speculative/cpp_ngram/ngram_corpus.py | 48 +- .../cpp_ngram/ngram_corpus_binding.cpp | 43 - python/sglang/srt/speculative/eagle_info.py | 16 +- .../sglang/srt/speculative/eagle_info_v2.py | 8 +- .../sglang/srt/speculative/eagle_worker_v2.py | 8 +- python/sglang/srt/speculative/ngram_info.py | 21 +- python/sglang/srt/speculative/ngram_worker.py | 7 +- python/sglang/srt/utils/common.py | 335 ++-- .../srt/utils/cuda_ipc_transport_utils.py | 225 ++- .../sglang/srt/utils/hf_transformers_utils.py | 60 + python/sglang/srt/utils/numa_utils.py | 172 +- python/sglang/srt/utils/request_logger.py | 14 +- python/sglang/srt/utils/runai_utils.py | 134 ++ python/sglang/srt/utils/watchdog.py | 64 +- python/sglang/test/accuracy_test_runner.py | 297 ++- .../sglang/test/ascend/test_ascend_utils.py | 18 +- .../test/bench_one_batch_server_internal.py | 4 +- python/sglang/test/few_shot_gsm8k.py | 12 + python/sglang/test/few_shot_gsm8k_engine.py | 14 + python/sglang/test/kits/cache_hit_kit.py | 2 +- python/sglang/test/kits/eval_accuracy_kit.py | 30 +- python/sglang/test/kl_test_utils.py | 4 +- python/sglang/test/lora_utils.py | 5 +- python/sglang/test/nightly_utils.py | 2 +- python/sglang/test/run_combined_tests.py | 4 +- python/sglang/test/run_eval.py | 56 +- python/sglang/test/runners.py | 21 +- .../server_fixtures/disaggregation_fixture.py | 1 + python/sglang/test/simple_eval_common.py | 74 +- python/sglang/test/simple_eval_gpqa.py | 5 +- python/sglang/test/simple_eval_math.py | 5 +- python/sglang/test/simple_eval_mgsm.py | 2 +- python/sglang/test/simple_eval_mmlu.py | 5 +- python/sglang/test/simple_eval_mmmu_vlm.py | 22 +- python/sglang/test/test_mm_utils.py | 50 + python/sglang/test/test_utils.py | 44 +- python/sglang/test/vlm_utils.py | 2 +- python/sglang/utils.py | 61 +- scripts/ci/check_workflow_job_names.py | 58 + scripts/ci/cuda/cache_nvidia_wheels.sh | 24 + .../cuda/ci_download_flashinfer_jit_cache.sh | 2 - scripts/ci/cuda/ci_install_dependency.sh | 31 +- .../diffusion/generate_diffusion_dashboard.py | 15 +- scripts/ci/utils/query_job_status.py | 937 +++++++++- scripts/ci/utils/slash_command_handler.py | 51 +- scripts/ci_monitor/ci_failures_analysis.py | 71 +- scripts/playground/bench_speculative.py | 2 +- sgl-kernel/CMakeLists.txt | 2 - .../bench_amd_deterministic_allreduce.py | 13 +- sgl-kernel/benchmark/bench_fp4_gemm.py | 300 ++- .../benchmark/bench_nvfp4_scaled_gemm.py | 192 -- sgl-kernel/cmake/flashmla.cmake | 58 +- sgl-kernel/csrc/attention/cascade.cu | 55 - sgl-kernel/csrc/common_extension.cc | 10 - sgl-kernel/csrc/common_extension_musa.cc | 3 - sgl-kernel/csrc/cpu/common.h | 6 +- sgl-kernel/csrc/cpu/gemm.cpp | 84 +- sgl-kernel/csrc/cpu/gemm.h | 30 + sgl-kernel/csrc/cpu/gemm_fp8.cpp | 469 ++++- sgl-kernel/csrc/cpu/rope.cpp | 270 +++ sgl-kernel/csrc/cpu/topk.cpp | 8 +- sgl-kernel/csrc/cpu/torch_extension_cpu.cpp | 20 + sgl-kernel/csrc/cpu/vec.h | 43 + sgl-kernel/csrc/elementwise/cast.cu | 172 -- sgl-kernel/csrc/gemm/marlin/marlin_template.h | 25 +- sgl-kernel/include/sgl_kernel_ops.h | 16 - sgl-kernel/pyproject.toml | 2 +- sgl-kernel/pyproject_cpu.toml | 2 +- sgl-kernel/pyproject_musa.toml | 4 +- sgl-kernel/pyproject_rocm.toml | 2 +- sgl-kernel/python/sgl_kernel/__init__.py | 6 - .../python/sgl_kernel/_fa4_interface.py | 940 ---------- sgl-kernel/python/sgl_kernel/attention.py | 19 - sgl-kernel/python/sgl_kernel/elementwise.py | 16 - sgl-kernel/python/sgl_kernel/flash_mla.py | 9 + sgl-kernel/python/sgl_kernel/sampling.py | 73 - sgl-kernel/python/sgl_kernel/version.py | 2 +- ...test_amd_deterministic_custom_allreduce.py | 39 +- sgl-kernel/tests/test_hadamard.py | 86 - sgl-kernel/tests/test_merge_state.py | 143 -- sgl-kernel/tests/test_merge_state_v2.py | 34 +- sgl-model-gateway/README.md | 2 +- .../python/src/sglang_router/mini_lb.py | 6 +- .../e2e_test/infra/simple_eval_common.py | 2 +- .../e2e_test/infra/simple_eval_mmlu.py | 5 +- .../src/routers/http/pd_router.rs | 2 +- sgl-model-gateway/src/routers/http/router.rs | 2 +- sgl-model-gateway/src/server.rs | 2 + .../tests/api/api_endpoints_test.rs | 4 +- sgl-model-gateway/tests/common/mock_worker.rs | 2 +- .../tests/common/tls_mock_worker.rs | 2 +- .../tests/routing/test_pd_routing.rs | 2 +- .../ep/test_moe_deepep_eval_accuracy_large.py | 18 +- test/manual/ep/test_mooncake_expert_backup.py | 21 +- test/manual/ep/test_nixl_ep.py | 22 +- test/manual/hicache/test_pp_with_hicache.py | 26 +- .../test_mooncake_transfer_engine_init.py | 430 +++++ test/manual/lora/test_lora_qwen3_vl.py | 233 --- test/manual/models/test_falcon_h1_models.py | 66 +- test/manual/models/test_grok_models.py | 16 +- test/manual/models/test_kimi_k2_models.py | 22 +- test/manual/models/test_llama4_models.py | 17 +- .../models/test_mistral_large3_basic.py | 21 +- test/manual/models/test_mtp_models.py | 18 +- test/manual/models/test_unsloth_models.py | 98 +- ...est_disaggregation_piecewise_cuda_graph.py | 22 +- test/manual/test_async_mm_data_processor.py | 365 ---- .../test_cross_node_scheduler_info_sync.py | 204 ++ test/manual/test_mla_tp.py | 38 +- .../test_torch_flex_attention_backend.py | 18 +- test/manual/test_tracing.py | 314 ---- test/manual/test_whisper_cuda_graph.py | 161 ++ .../test_deepseek_v3_cutedsl_4gpu.py | 35 +- .../4-gpu-models/test_qwen35_models.py | 4 +- .../8-gpu-models/test_deepseek_v32_basic.py | 40 +- .../8-gpu-models/test_deepseek_v32_mtp.py | 86 +- .../8-gpu-models/test_deepseek_v3_basic.py | 21 +- .../8-gpu-models/test_deepseek_v3_mtp.py | 24 +- test/registered/8-gpu-models/test_kimi_k25.py | 27 +- .../8-gpu-models/test_mimo_models.py | 2 +- .../8-gpu-models/test_ring_2_5_1t.py | 2 + .../mi30x/test_deepseek_v32_mtp_eval_amd.py | 2 +- .../mi35x/test_deepseek_v32_mtp_eval_mi35x.py | 2 +- .../mi35x/test_glm47_fp8_eval_mi35x.py | 59 + .../perf/mi30x/test_minimax_m25_perf_amd.py | 140 ++ .../perf/mi35x/test_minimax_m25_perf_mi35x.py | 146 ++ .../amd/test_deepseek_r1_mxfp4_8gpu.py | 2 +- test/registered/amd/test_deepseek_v32_mtp.py | 4 +- test/registered/amd/test_deepseek_v3_mtp.py | 2 +- .../amd/test_deepseek_v3_mtp_kv_fp8.py | 2 +- test/registered/amd/test_moriep_small.py | 6 +- .../amd/test_qwen3_coder_next_8gpu.py | 2 +- .../HiCache/test_npu_hicache_mha.py} | 4 + .../HiCache/test_npu_hicache_mla.py} | 4 + .../backends/test_npu_sampling_backend.py} | 4 + .../dllm/test_npu_llada2_mini.py} | 4 + .../test_npu_compile_graph_tp1_bf16.py} | 4 + .../test_npu_graph_tp1_bf16.py} | 4 + .../test_npu_graph_tp2_bf16.py} | 4 + .../test_npu_piecewise_graph_prefill.py} | 4 + .../expert_parallelism/test_npu_deepep.py} | 4 + .../parameter/test_npu_warmups.py | 2 +- .../quant/test_npu_autoround_dense.py} | 7 +- .../quant/test_npu_autoround_moe.py} | 9 +- .../quant/test_npu_gptq_moe.py} | 9 +- .../quant/test_npu_w4a4_quantization.py} | 4 + .../quant/test_npu_w8a8_quantization.py} | 4 + .../test_npu_mla_fia_w8a8int8.py} | 4 + .../runtime_opts/test_npu_mla_w8a8int8.py} | 4 + .../runtime_opts/test_npu_tp1_bf16.py} | 4 + .../runtime_opts/test_npu_tp2_bf16.py} | 4 + .../runtime_opts/test_npu_tp2_fia_bf16.py} | 4 + .../runtime_opts/test_npu_tp4_bf16.py} | 4 + .../test_npu_bge_large_en_v1_5.py | 2 +- .../ascend/llm_models/test_npu_afm_4_5b.py | 2 +- .../llm_models/test_npu_c4ai_command_r_v01.py | 2 +- .../ascend/llm_models/test_npu_exaone_3.py | 2 +- .../test_npu_granite_3_0_3b_a800m.py | 2 +- .../llm_models/test_npu_granite_3_1_8b.py | 2 +- .../ascend/llm_models/test_npu_grok_2.py | 2 +- .../ascend/llm_models/test_npu_ling_lite.py | 2 +- .../ascend/llm_models/test_npu_mimo_7b_rl.py | 2 +- .../llm_models/test_npu_persimmon_8b_chat.py | 2 +- .../ascend/llm_models/test_npu_smollm_1_7b.py | 2 +- .../llm_models/test_npu_stablelm_2_1_6b.py | 2 +- .../test_npu_gemma_2_27b_v0_2.py | 2 +- .../test_npu_internlm2_7b_reward.py | 4 +- .../test_npu_llama_3_1_8b_v0_2.py | 2 +- .../ascend/vlm_models/test_ascend_glm_4_5v.py | 33 + .../test_npu_qwen3_vl_4b_instruct.py | 2 +- .../attention/test_chunk_gated_delta_rule.py | 2 +- test/registered/attention/test_fa3.py | 22 +- .../attention/test_flash_attention_4.py | 20 +- .../attention/test_hybrid_attn_backend.py | 27 +- test/registered/attention/test_local_attn.py | 19 +- .../test_deepseek_r1_fp8_trtllm_backend.py | 20 +- .../test_deepseek_v3_fp4_cutlass_moe.py | 21 +- ...test_flashinfer_trtllm_gen_attn_backend.py | 18 +- .../test_flashinfer_trtllm_gen_moe_backend.py | 50 +- .../backends/test_qwen3_fp4_trtllm_gen_moe.py | 17 +- .../test_bench_serving_functionality.py | 3 +- test/registered/core/test_srt_endpoint.py | 4 +- .../cp/test_deepseek_v32_cp_single_node.py | 42 +- .../test_engine_dumper_comparator_e2e.py | 32 +- .../test_disaggregation_basic.py | 77 +- .../test_disaggregation_decode_offload.py | 8 +- .../distributed/test_data_parallelism.py | 14 +- .../test_disaggregation_aarch64.py | 20 +- .../test_disaggregation_different_tp.py | 74 +- .../test_disaggregation_dp_attention.py | 63 +- .../test_disaggregation_hybrid_attention.py | 58 +- .../distributed/test_disaggregation_pp.py | 50 +- .../distributed/test_dp_attention.py | 32 +- .../distributed/test_dp_attention_large.py | 23 +- .../test_load_weights_from_remote_instance.py | 4 +- .../distributed/test_pp_single_node.py | 135 +- test/registered/dllm/test_llada2_mini.py | 20 +- test/registered/dllm/test_llada2_mini_amd.py | 18 +- .../embedding/test_embedding_models.py | 2 +- test/registered/ep/test_deepep_large.py | 60 +- test/registered/ep/test_deepep_small.py | 194 +- test/registered/ep/test_mooncake_ep_small.py | 23 +- .../eval/test_eval_accuracy_large.py | 2 +- .../eval/test_moe_eval_accuracy_large.py | 62 - .../function_call/test_kimik2_detector.py | 2 +- test/registered/gb300/test_deepseek_v32.py | 79 + .../gb300/test_deepseek_v32_nvfp4.py | 82 + test/registered/gb300/test_glm5_fp8.py | 68 + test/registered/gb300/test_glm5_nvfp4.py | 71 + test/registered/gb300/test_kimi_k25.py | 58 + test/registered/gb300/test_kimi_k25_nvfp4.py | 61 + test/registered/gb300/test_qwen35_fp8.py | 75 + test/registered/gb300/test_qwen35_nvfp4.py | 79 + .../test_hicache_storage_file_backend.py | 25 +- test/registered/kernels/test_nsa_indexer.py | 2 +- test/registered/language/test_srt_backend.py | 1 + .../lora/test_fused_moe_lora_kernel.py | 2 +- .../test_lora_gpt_oss_20b_logprob_diff.py | 151 ++ .../test_lora_moe_vllm_sgl_logprob_diff.py | 4 +- ...wen3_30b_a3b_instruct_2507_logprob_diff.py | 151 ++ .../lora/test_lora_qwen3_8b_logprob_diff.py | 202 ++ ..._qwen3_vl_30b_a3b_instruct_logprob_diff.py | 151 ++ test/registered/lora/test_lora_tp.py | 5 +- test/registered/mla/test_flashmla.py | 38 +- test/registered/mla/test_mla_deepseek_v3.py | 76 +- test/registered/mla/test_mla_flashinfer.py | 40 +- .../mla/test_mla_int8_deepseek_v3.py | 88 +- .../model_loading/test_runai_model_loader.py | 50 + .../models/test_compressed_tensors_models.py | 20 +- .../models/test_gpt_oss_models_pcg.py | 72 - .../models/test_kimi_linear_models.py | 20 +- .../models/test_kimi_linear_models_pcg.py | 69 - .../models/test_ministral4_models.py | 32 + .../models/test_nvidia_nemotron_3_nano.py | 2 +- .../models/test_nvidia_nemotron_nano_v2.py | 2 +- .../models/test_qwen3_next_models_pcg.py | 29 - test/registered/models/test_qwen_models.py | 34 +- .../models/test_transformers_backend_eval.py | 43 + .../models/test_transformers_models.py | 23 +- test/registered/moe/test_cutedsl_moe.py | 2 +- test/registered/moe/test_glm4_moe_models.py | 18 +- test/registered/moe/test_moe_ep.py | 40 +- test/registered/moe/test_triton_fused_moe.py | 2 +- .../test_metrics.py | 72 +- .../test_priority_metrics.py | 0 test/registered/observability/test_tracing.py | 794 ++++++++ .../test_tracing_disaggregation.py | 237 +++ .../features/test_openai_server_ebnf.py | 2 +- .../function_call/test_tool_choice.py | 10 +- .../test_openai_server_ignore_eos.py | 2 +- .../test_request_length_validation.py | 17 + ...test_piecewise_cuda_graph_support_1_gpu.py | 19 +- .../quant/test_deepseek_v32_fp4_4gpu.py | 40 +- .../quant/test_deepseek_v32_fp4_mtp_4gpu.py | 44 +- .../quant/test_deepseek_v3_fp4_4gpu.py | 59 +- .../quant/test_fp8_blockwise_gemm.py | 48 +- test/registered/quant/test_fp8_gemm_sm120.py | 86 + test/registered/quant/test_fp8_kernel.py | 2 +- test/registered/quant/test_fp8kv_triton.py | 20 +- test/registered/quant/test_int4fp8_moe.py | 17 +- test/registered/quant/test_int8_kernel.py | 2 +- test/registered/quant/test_modelopt_fp8.py | 18 +- test/registered/quant/test_nvfp4_gemm.py | 20 +- .../registered/quant/test_nvfp4_gemm_sm120.py | 71 + .../quant/test_nvfp4_marlin_fallback.py | 788 ++++++++ .../quant/test_quant_config_parsing.py | 2 +- test/registered/quant/test_quantization.py | 2 +- test/registered/quant/test_torchao.py | 2 +- .../registered/quant/test_w4a8_deepseek_v3.py | 78 +- .../quant/test_w8a8_quantization.py | 18 +- ...est_pause_generation_tensor_consistency.py | 212 +++ .../rl/test_return_routed_experts.py | 41 +- .../test_fused_temperature_softmax.py | 268 +++ .../sampling/test_pytorch_sampling_backend.py | 2 +- test/registered/scheduler/test_abort.py | 4 +- .../scheduler/test_chunked_prefill.py | 2 +- .../scheduler/test_priority_scheduling.py | 3 + .../scheduler/test_retract_decode.py | 2 +- .../sessions/test_session_control.py | 4 +- .../eagle/test_deepseek_v3_fp4_mtp_small.py | 23 +- .../spec/eagle/test_eagle3_basic.py | 2 +- .../spec/eagle/test_eagle_dp_attention.py | 26 +- .../spec/eagle/test_eagle_infer_a.py | 266 +-- .../spec/eagle/test_eagle_infer_b.py | 100 +- .../spec/eagle/test_eagle_infer_beta.py | 20 +- .../test_eagle_infer_beta_dp_attention.py | 28 +- ...est_eagle_infer_beta_dp_attention_large.py | 26 +- .../test_standalone_speculative_decoding.py | 43 +- .../spec/utils/test_build_eagle_tree.py | 2 +- .../spec/utils/test_ngram_corpus.py | 65 +- test/registered/unit/README.md | 28 +- .../test_function_call_parser.py | 157 +- .../function_call/test_glm47_moe_detector.py | 2 +- .../test_json_schema_constraint.py | 65 +- .../function_call/test_parallel_tool_calls.py | 2 +- .../function_call/test_unknown_tool_name.py | 2 +- .../unit/managers/test_prefill_adder.py | 2 +- .../managers/test_scheduler_flush_cache.py | 8 +- .../test_scheduler_pause_generation.py | 134 ++ .../unit/mem_cache/test_evict_policy.py | 2 +- .../unit/mem_cache/test_mamba_unittest.py | 2 + .../unit/mem_cache/test_nsa_pool_host_unit.py | 2 +- .../unit/mem_cache/test_radix_cache_unit.py | 2 +- .../unit/model_loader/test_modelopt_loader.py | 6 +- test/registered/unit/models/test_llava.py | 91 + .../test_request_metrics_exporter.py | 3 +- .../unit/sampling/test_sampling_batch_info.py | 24 +- .../unit/server_args/test_server_args.py | 2 +- test/registered/unit/test_runai_utils.py | 57 + .../unit/utils/test_json_response.py | 2 +- .../unit/utils/test_subprocess_watchdog.py | 137 ++ test/registered/utils/test_log_utils.py | 2 +- test/registered/utils/test_network_address.py | 2 +- test/registered/utils/test_numa_utils.py | 311 ++++ test/registered/utils/test_request_logger.py | 3 +- test/registered/utils/test_socket_utils.py | 2 +- test/registered/vlm/test_patch_embed_perf.py | 2 +- test/registered/vlm/test_vlm_tp4.py | 82 + test/run_suite.py | 7 + test/srt/cpu/test_rope.py | 23 + test/srt/run_suite.py | 40 +- 694 files changed, 34665 insertions(+), 11916 deletions(-) create mode 100644 .github/audit_permission.py create mode 100644 .github/workflows/full-test-npu.yml create mode 100644 .github/workflows/release-docker-runtime.yml rename .github/workflows/{rerun-ut.yml => rerun-test.yml} (88%) create mode 100644 .github/workflows/trivy-scan-dev.yml create mode 100644 benchmark/kernels/bench_fused_temperature_softmax.py create mode 100644 docs/advanced_features/object_storage.md create mode 100644 docs/diffusion/performance/ring_sp_performance.md rename docs/platforms/{ => ascend}/ascend_contribution_guide.md (89%) rename docs/platforms/{ => ascend}/ascend_npu.md (99%) rename docs/platforms/{ => ascend}/ascend_npu_best_practice.md (87%) rename docs/platforms/{ => ascend}/ascend_npu_deepseek_example.md (99%) rename docs/platforms/{ => ascend}/ascend_npu_environment_variables.md (100%) rename docs/platforms/{ => ascend}/ascend_npu_glm5_examples.md (98%) create mode 100644 docs/platforms/ascend/ascend_npu_quantization.md rename docs/platforms/{ => ascend}/ascend_npu_qwen3_5_examples.md (98%) rename docs/platforms/{ => ascend}/ascend_npu_qwen3_examples.md (56%) rename docs/platforms/{ => ascend}/ascend_npu_support.rst (86%) rename docs/platforms/{ => ascend}/ascend_npu_support_features.md (99%) rename docs/platforms/{ => ascend}/ascend_npu_support_models.md (100%) rename docs/platforms/{ => ascend}/mindspore_backend.md (100%) delete mode 100644 docs/platforms/ascend_npu_quantization.md create mode 100644 docs/platforms/ascend_npu_ring_sp_performance.md create mode 100644 python/sglang/benchmark/datasets/longbench_v2.py delete mode 100644 python/sglang/jit_kernel/benchmark/bench_fused_add_rmsnorm.py delete mode 100644 python/sglang/jit_kernel/benchmark/bench_rmsnorm.py create mode 100644 python/sglang/jit_kernel/csrc/gemm/nvfp4/nvfp4_scaled_mm_common.cuh create mode 100644 python/sglang/jit_kernel/csrc/gemm/nvfp4/nvfp4_scaled_mm_sm100.cuh create mode 100644 python/sglang/jit_kernel/csrc/gemm/nvfp4/nvfp4_scaled_mm_sm120.cuh rename python/sglang/{srt/speculative/cpp_ngram => jit_kernel/csrc/ngram_corpus}/ngram.cpp (67%) rename python/sglang/{srt/speculative/cpp_ngram => jit_kernel/csrc/ngram_corpus}/ngram.h (99%) create mode 100644 python/sglang/jit_kernel/csrc/ngram_corpus/ngram_corpus_ffi.cpp rename python/sglang/{srt/speculative/cpp_ngram => jit_kernel/csrc/ngram_corpus}/param.h (78%) rename python/sglang/{srt/speculative/cpp_ngram => jit_kernel/csrc/ngram_corpus}/queue.h (100%) rename python/sglang/{srt/speculative/cpp_ngram => jit_kernel/csrc/ngram_corpus}/result.cpp (100%) rename python/sglang/{srt/speculative/cpp_ngram => jit_kernel/csrc/ngram_corpus}/result.h (100%) rename python/sglang/{srt/speculative/cpp_ngram => jit_kernel/csrc/ngram_corpus}/trie.cpp (86%) rename python/sglang/{srt/speculative/cpp_ngram => jit_kernel/csrc/ngram_corpus}/trie.h (92%) create mode 100644 python/sglang/jit_kernel/ngram_corpus.py create mode 100644 python/sglang/jit_kernel/tests/test_hicache.py delete mode 100644 python/sglang/jit_kernel/tests/test_norm_jit.py create mode 100644 python/sglang/multimodal_gen/runtime/layers/attention/backends/ascend_fa.py create mode 100644 python/sglang/multimodal_gen/runtime/utils/model_overlay.py create mode 100644 python/sglang/multimodal_gen/test/server/accuracy_config.py create mode 100644 python/sglang/multimodal_gen/test/server/accuracy_hooks.py create mode 100644 python/sglang/multimodal_gen/test/server/accuracy_utils.py create mode 100644 python/sglang/multimodal_gen/test/server/component_accuracy.py create mode 100644 python/sglang/multimodal_gen/test/server/test_accuracy_1_gpu_a.py create mode 100644 python/sglang/multimodal_gen/test/server/test_accuracy_1_gpu_b.py create mode 100644 python/sglang/multimodal_gen/test/server/test_accuracy_2_gpu_a.py create mode 100644 python/sglang/multimodal_gen/test/server/test_accuracy_2_gpu_b.py create mode 100644 python/sglang/multimodal_gen/test/server/test_server_c.py create mode 100644 python/sglang/multimodal_gen/test/unit/test_input_validation.py create mode 100644 python/sglang/multimodal_gen/test/unit/test_resolve_prompts.py create mode 100644 python/sglang/srt/disaggregation/common/staging_buffer.py create mode 100644 python/sglang/srt/disaggregation/common/staging_handler.py create mode 100644 python/sglang/srt/entrypoints/engine_info_bootstrap_server.py create mode 100644 python/sglang/srt/layers/attention/fla/chunk_fwd.py create mode 100644 python/sglang/srt/layers/attention/fla/chunk_intra.py create mode 100644 python/sglang/srt/layers/attention/fla/chunk_intra_token_parallel.py create mode 100644 python/sglang/srt/layers/fused_sampling.py create mode 100644 python/sglang/srt/layers/quantization/marlin_utils_fp4.py delete mode 100644 python/sglang/srt/managers/async_mm_data_processor.py create mode 100644 python/sglang/srt/multimodal/processors/transformers_auto.py create mode 100644 python/sglang/srt/sampling/penaltylib/repetition_penalty.py delete mode 100644 python/sglang/srt/speculative/cpp_ngram/ngram_corpus_binding.cpp create mode 100644 python/sglang/srt/utils/runai_utils.py create mode 100644 python/sglang/test/test_mm_utils.py create mode 100755 scripts/ci/check_workflow_job_names.py create mode 100755 scripts/ci/cuda/cache_nvidia_wheels.sh delete mode 100644 sgl-kernel/benchmark/bench_nvfp4_scaled_gemm.py delete mode 100644 sgl-kernel/csrc/attention/cascade.cu delete mode 100644 sgl-kernel/csrc/elementwise/cast.cu delete mode 100644 sgl-kernel/python/sgl_kernel/_fa4_interface.py delete mode 100644 sgl-kernel/tests/test_hadamard.py delete mode 100644 sgl-kernel/tests/test_merge_state.py create mode 100755 test/manual/kv_transfer/test_mooncake_transfer_engine_init.py delete mode 100644 test/manual/lora/test_lora_qwen3_vl.py delete mode 100644 test/manual/test_async_mm_data_processor.py create mode 100755 test/manual/test_cross_node_scheduler_info_sync.py delete mode 100644 test/manual/test_tracing.py create mode 100644 test/manual/test_whisper_cuda_graph.py create mode 100644 test/registered/amd/accuracy/mi35x/test_glm47_fp8_eval_mi35x.py create mode 100644 test/registered/amd/perf/mi30x/test_minimax_m25_perf_amd.py create mode 100644 test/registered/amd/perf/mi35x/test_minimax_m25_perf_mi35x.py rename test/{srt/ascend/test_ascend_hicache_mha.py => registered/ascend/basic_function/HiCache/test_npu_hicache_mha.py} (91%) rename test/{srt/ascend/test_ascend_hicache_mla.py => registered/ascend/basic_function/HiCache/test_npu_hicache_mla.py} (91%) rename test/{srt/ascend/test_ascend_sampling_backend.py => registered/ascend/basic_function/backends/test_npu_sampling_backend.py} (92%) rename test/{srt/ascend/test_llada2_mini_ascend.py => registered/ascend/basic_function/dllm/test_npu_llada2_mini.py} (92%) rename test/{srt/ascend/test_ascend_compile_graph_tp1_bf16.py => registered/ascend/basic_function/optimization_debug/test_npu_compile_graph_tp1_bf16.py} (92%) rename test/{srt/ascend/test_ascend_graph_tp1_bf16.py => registered/ascend/basic_function/optimization_debug/test_npu_graph_tp1_bf16.py} (91%) rename test/{srt/ascend/test_ascend_graph_tp2_bf16.py => registered/ascend/basic_function/optimization_debug/test_npu_graph_tp2_bf16.py} (91%) rename test/{srt/ascend/test_ascend_piecewise_graph_prefill.py => registered/ascend/basic_function/optimization_debug/test_npu_piecewise_graph_prefill.py} (92%) rename test/{srt/ascend/test_ascend_deepep.py => registered/ascend/basic_function/parallel_strategy/expert_parallelism/test_npu_deepep.py} (93%) rename test/{srt/ascend/test_ascend_autoround_dense.py => registered/ascend/basic_function/quant/test_npu_autoround_dense.py} (86%) rename test/{srt/ascend/test_ascend_autoround_moe.py => registered/ascend/basic_function/quant/test_npu_autoround_moe.py} (85%) rename test/{srt/ascend/test_ascend_gptq_moe.py => registered/ascend/basic_function/quant/test_npu_gptq_moe.py} (86%) rename test/{srt/ascend/test_ascend_w4a4_quantization.py => registered/ascend/basic_function/quant/test_npu_w4a4_quantization.py} (93%) rename test/{srt/ascend/test_ascend_w8a8_quantization.py => registered/ascend/basic_function/quant/test_npu_w8a8_quantization.py} (93%) rename test/{srt/ascend/test_ascend_mla_fia_w8a8int8.py => registered/ascend/basic_function/runtime_opts/test_npu_mla_fia_w8a8int8.py} (92%) rename test/{srt/ascend/test_ascend_mla_w8a8int8.py => registered/ascend/basic_function/runtime_opts/test_npu_mla_w8a8int8.py} (91%) rename test/{srt/ascend/test_ascend_tp1_bf16.py => registered/ascend/basic_function/runtime_opts/test_npu_tp1_bf16.py} (91%) rename test/{srt/ascend/test_ascend_tp2_bf16.py => registered/ascend/basic_function/runtime_opts/test_npu_tp2_bf16.py} (91%) rename test/{srt/ascend/test_ascend_tp2_fia_bf16.py => registered/ascend/basic_function/runtime_opts/test_npu_tp2_fia_bf16.py} (92%) rename test/{srt/ascend/test_ascend_tp4_bf16.py => registered/ascend/basic_function/runtime_opts/test_npu_tp4_bf16.py} (91%) create mode 100644 test/registered/ascend/vlm_models/test_ascend_glm_4_5v.py delete mode 100644 test/registered/eval/test_moe_eval_accuracy_large.py create mode 100644 test/registered/gb300/test_deepseek_v32.py create mode 100644 test/registered/gb300/test_deepseek_v32_nvfp4.py create mode 100644 test/registered/gb300/test_glm5_fp8.py create mode 100644 test/registered/gb300/test_glm5_nvfp4.py create mode 100644 test/registered/gb300/test_kimi_k25.py create mode 100644 test/registered/gb300/test_kimi_k25_nvfp4.py create mode 100644 test/registered/gb300/test_qwen35_fp8.py create mode 100644 test/registered/gb300/test_qwen35_nvfp4.py create mode 100644 test/registered/lora/test_lora_gpt_oss_20b_logprob_diff.py create mode 100644 test/registered/lora/test_lora_qwen3_30b_a3b_instruct_2507_logprob_diff.py create mode 100644 test/registered/lora/test_lora_qwen3_8b_logprob_diff.py create mode 100644 test/registered/lora/test_lora_qwen3_vl_30b_a3b_instruct_logprob_diff.py create mode 100644 test/registered/model_loading/test_runai_model_loader.py delete mode 100644 test/registered/models/test_gpt_oss_models_pcg.py delete mode 100644 test/registered/models/test_kimi_linear_models_pcg.py create mode 100644 test/registered/models/test_ministral4_models.py delete mode 100644 test/registered/models/test_qwen3_next_models_pcg.py create mode 100644 test/registered/models/test_transformers_backend_eval.py rename test/registered/{metrics => observability}/test_metrics.py (76%) rename test/registered/{metrics => observability}/test_priority_metrics.py (100%) create mode 100644 test/registered/observability/test_tracing.py create mode 100644 test/registered/observability/test_tracing_disaggregation.py create mode 100644 test/registered/quant/test_fp8_gemm_sm120.py create mode 100644 test/registered/quant/test_nvfp4_gemm_sm120.py create mode 100644 test/registered/quant/test_nvfp4_marlin_fallback.py create mode 100644 test/registered/rl/test_pause_generation_tensor_consistency.py create mode 100644 test/registered/sampling/test_fused_temperature_softmax.py create mode 100644 test/registered/unit/managers/test_scheduler_pause_generation.py create mode 100644 test/registered/unit/models/test_llava.py create mode 100644 test/registered/unit/test_runai_utils.py create mode 100644 test/registered/unit/utils/test_subprocess_watchdog.py create mode 100644 test/registered/utils/test_numa_utils.py create mode 100644 test/registered/vlm/test_vlm_tp4.py diff --git a/.claude/skills/ci-workflow-guide/SKILL.md b/.claude/skills/ci-workflow-guide/SKILL.md index 6e2697742e21..430f5a3069e3 100644 --- a/.claude/skills/ci-workflow-guide/SKILL.md +++ b/.claude/skills/ci-workflow-guide/SKILL.md @@ -381,6 +381,6 @@ group: pr-test-{event_name}-{branch}-{pr_sha}-{stage} | `/rerun-failed-ci` | Reruns failed jobs in the latest workflow run | | `/tag-and-rerun-ci` | Adds label + reruns | | `/rerun-stage ` | Dispatches `pr-test.yml` with `target_stage=` | -| `/rerun-ut ` | Reruns a specific test file via `rerun-ut.yml` | +| `/rerun-test ` | Reruns a specific test file via `rerun-test.yml` | Handled by `scripts/ci/utils/slash_command_handler.py` → `.github/workflows/slash-command-handler.yml`. diff --git a/.claude/skills/write-sglang-test/SKILL.md b/.claude/skills/write-sglang-test/SKILL.md index d24524468320..93bf5b78690c 100644 --- a/.claude/skills/write-sglang-test/SKILL.md +++ b/.claude/skills/write-sglang-test/SKILL.md @@ -92,9 +92,22 @@ Defined in `python/sglang/test/test_utils.py`: | `stage-c-test-large-8-gpu-amd` | `linux-mi325-8gpu-sglang` | 8-GPU MI325 scaling and integration | | `stage-c-test-large-8-gpu-amd-mi35x` | `linux-mi35x-gpu-8` | 8-GPU MI35x scaling (2 partitions) | + +### Per-commit (Ascend NPU) + +| Suite | Runner (label) | Description | +| --- | --- | --- | +| `per-commit-1-npu-a2` | `linux-aarch64-a2-1` | 1-NPU LLM CI machine | +| `per-commit-2-npu-a2` | `linux-aarch64-a2-2` | 2-NPU LLM CI machine | +| `per-commit-4-npu-a3` | `linux-aarch64-a3-4` | 4-NPU LLM CI machine | +| `per-commit-16-npu-a3` | `linux-aarch64-a3-16` | 16-NPU LLM CI machine | +| `multimodal-gen-test-1-npu-a3` | `linux-aarch64-a3-2` | 1-NPU multimodal CI machine | +| `multimodal-gen-test-2-npu-a3` | `linux-aarch64-a3-16` | 2-NPU multimodal CI machine | +| `multimodal-gen-test-8-npu-a3` | `linux-aarch64-a3-16` | 8-NPU multimodal CI machine | + #### Nightly -Nightly suites are listed in `NIGHTLY_SUITES` in [`test/run_suite.py`](../../../test/run_suite.py). They run via `nightly-test-nvidia.yml` and `nightly-test-amd.yml`, not `pr-test.yml`. Examples: +Nightly suites are listed in `NIGHTLY_SUITES` in [`test/run_suite.py`](../../../test/run_suite.py). They run via `nightly-test-nvidia.yml`, `nightly-test-amd.yml` amd `nightly-test-npu.yml`, not `pr-test.yml`. Examples: - `nightly-1-gpu` (CUDA) - `nightly-kernel-1-gpu` (CUDA, JIT kernel full grids) @@ -103,6 +116,11 @@ Nightly suites are listed in `NIGHTLY_SUITES` in [`test/run_suite.py`](../../../ - `nightly-eval-vlm-2-gpu` (CUDA) - `nightly-amd` (AMD) - `nightly-amd-8-gpu-mi35x` (AMD) +- `nightly-1-npu-a3` (NPU) +- `nightly-2-npu-a3` (NPU) +- `nightly-4-npu-a3` (NPU) +- `nightly-8-npu-a3` (NPU) +- `nightly-16-npu-a3` (NPU) > **Note**: Multimodal diffusion uses `python/sglang/multimodal_gen/test/run_suite.py`, not `test/run_suite.py`. @@ -154,7 +172,7 @@ if __name__ == "__main__": unittest.main() ``` -Use `unittest.mock.patch` / `MagicMock` to mock dependencies and isolate the logic under test. If the module fails to import on CPU CI (e.g., imports `torch` or CUDA ops at module level), use `sys.modules` stubs to make the import succeed. See existing tests in `test/registered/unit/` for examples. +Use `unittest.mock.patch` / `MagicMock` to mock dependencies and isolate the logic under test. If the module transitively imports GPU-only packages (e.g. `sgl_kernel`), they can be stubbed so the test runs on CPU CI. See `test/registered/unit/README.md` for details and examples. **Quality bar** — test real logic (validation boundaries, state transitions, error paths, branching, etc.). Skip tests that just verify Python itself works (e.g., "does calling an abstract method raise `NotImplementedError`?", "does a dataclass store the field I assigned?"). Consolidate repetitive patterns into parameterized tests. No production code changes in test PRs. diff --git a/.github/CI_PERMISSIONS.json b/.github/CI_PERMISSIONS.json index 27dbcb618efc..22412fab6564 100644 --- a/.github/CI_PERMISSIONS.json +++ b/.github/CI_PERMISSIONS.json @@ -2,1212 +2,1282 @@ "1pikachu": { "can_tag_run_ci_label": true, "can_rerun_failed_ci": true, + "can_rerun_stage": true, "cooldown_interval_minutes": 0, - "reason": "custom override", - "can_rerun_stage": true + "reason": "custom override" }, "Alcanderian": { "can_tag_run_ci_label": true, "can_rerun_failed_ci": true, + "can_rerun_stage": true, "cooldown_interval_minutes": 0, - "reason": "custom override", - "can_rerun_stage": true + "reason": "custom override" }, "AniZpZ": { "can_tag_run_ci_label": true, "can_rerun_failed_ci": true, + "can_rerun_stage": true, "cooldown_interval_minutes": 0, - "reason": "custom override", - "can_rerun_stage": true + "reason": "custom override" }, "BBuf": { "can_tag_run_ci_label": true, "can_rerun_failed_ci": true, + "can_rerun_stage": true, "cooldown_interval_minutes": 0, - "reason": "top contributor", - "can_rerun_stage": true + "reason": "top contributor" }, "BHZ-BER": { "can_tag_run_ci_label": true, "can_rerun_failed_ci": true, + "can_rerun_stage": true, "cooldown_interval_minutes": 60, - "reason": "custom override", - "can_rerun_stage": true + "reason": "custom override" }, "ByronHsu": { "can_tag_run_ci_label": true, "can_rerun_failed_ci": true, + "can_rerun_stage": true, "cooldown_interval_minutes": 0, - "reason": "custom override", - "can_rerun_stage": true + "reason": "custom override" }, "CaoE": { "can_tag_run_ci_label": true, "can_rerun_failed_ci": true, + "can_rerun_stage": true, "cooldown_interval_minutes": 0, - "reason": "custom override", - "can_rerun_stage": true + "reason": "custom override" }, "CatherineSue": { "can_tag_run_ci_label": true, "can_rerun_failed_ci": true, + "can_rerun_stage": true, "cooldown_interval_minutes": 0, - "reason": "top contributor", - "can_rerun_stage": true + "reason": "top contributor" }, "Chen-0210": { "can_tag_run_ci_label": true, "can_rerun_failed_ci": true, + "can_rerun_stage": true, "cooldown_interval_minutes": 60, - "reason": "custom override", - "can_rerun_stage": true + "reason": "custom override" }, "ClawSeven": { "can_tag_run_ci_label": true, "can_rerun_failed_ci": true, + "can_rerun_stage": true, "cooldown_interval_minutes": 0, - "reason": "custom override", - "can_rerun_stage": true + "reason": "custom override" }, "ConnorLi96": { "can_tag_run_ci_label": true, "can_rerun_failed_ci": true, + "can_rerun_stage": true, "cooldown_interval_minutes": 0, - "reason": "custom override", - "can_rerun_stage": true + "reason": "custom override" }, "DarkSharpness": { "can_tag_run_ci_label": true, "can_rerun_failed_ci": true, + "can_rerun_stage": true, "cooldown_interval_minutes": 0, - "reason": "custom override", - "can_rerun_stage": true + "reason": "top contributor" }, "Edwardf0t1": { "can_tag_run_ci_label": true, "can_rerun_failed_ci": true, + "can_rerun_stage": true, "cooldown_interval_minutes": 0, - "reason": "custom override", - "can_rerun_stage": true + "reason": "custom override" }, "FlamingoPg": { "can_tag_run_ci_label": true, "can_rerun_failed_ci": true, - "cooldown_interval_minutes": 0, - "reason": "top contributor", - "can_rerun_stage": true + "can_rerun_stage": true, + "cooldown_interval_minutes": 60, + "reason": "custom override" }, "FrankLeeeee": { "can_tag_run_ci_label": true, "can_rerun_failed_ci": true, + "can_rerun_stage": true, "cooldown_interval_minutes": 0, - "reason": "custom override", - "can_rerun_stage": true + "reason": "custom override" }, "Fridge003": { "can_tag_run_ci_label": true, "can_rerun_failed_ci": true, + "can_rerun_stage": true, "cooldown_interval_minutes": 0, - "reason": "top contributor", - "can_rerun_stage": true + "reason": "top contributor" }, "HaiShaw": { "can_tag_run_ci_label": true, "can_rerun_failed_ci": true, + "can_rerun_stage": true, "cooldown_interval_minutes": 0, - "reason": "custom override", - "can_rerun_stage": true + "reason": "custom override" }, "HanHan009527": { "can_tag_run_ci_label": true, "can_rerun_failed_ci": true, + "can_rerun_stage": true, "cooldown_interval_minutes": 0, - "reason": "custom override", - "can_rerun_stage": true + "reason": "custom override" }, "HandH1998": { "can_tag_run_ci_label": true, "can_rerun_failed_ci": true, + "can_rerun_stage": true, "cooldown_interval_minutes": 0, - "reason": "custom override", - "can_rerun_stage": true + "reason": "custom override" }, "Hanrui-Wang": { "can_tag_run_ci_label": true, "can_rerun_failed_ci": true, + "can_rerun_stage": true, "cooldown_interval_minutes": 60, - "reason": "custom override", - "can_rerun_stage": true + "reason": "custom override" + }, + "Hexq0210": { + "can_tag_run_ci_label": true, + "can_rerun_failed_ci": true, + "can_rerun_stage": true, + "cooldown_interval_minutes": 0, + "reason": "top contributor" }, "HydraQYH": { "can_tag_run_ci_label": true, "can_rerun_failed_ci": true, + "can_rerun_stage": true, "cooldown_interval_minutes": 0, - "reason": "custom override", - "can_rerun_stage": true + "reason": "custom override" }, "JeremieMelo": { "can_tag_run_ci_label": true, "can_rerun_failed_ci": true, + "can_rerun_stage": true, "cooldown_interval_minutes": 0, - "reason": "custom override", - "can_rerun_stage": true + "reason": "custom override" }, "Johnsonms": { "can_tag_run_ci_label": true, "can_rerun_failed_ci": true, - "cooldown_interval_minutes": 0, - "reason": "top contributor", - "can_rerun_stage": true + "can_rerun_stage": true, + "cooldown_interval_minutes": 60, + "reason": "custom override" }, "JustinTong0323": { "can_tag_run_ci_label": true, "can_rerun_failed_ci": true, + "can_rerun_stage": true, "cooldown_interval_minutes": 0, - "reason": "top contributor", - "can_rerun_stage": true + "reason": "top contributor" }, "Kangyan-Zhou": { "can_tag_run_ci_label": true, "can_rerun_failed_ci": true, + "can_rerun_stage": true, "cooldown_interval_minutes": 0, - "reason": "top contributor", - "can_rerun_stage": true + "reason": "top contributor" }, "LorrinWWW": { "can_tag_run_ci_label": true, "can_rerun_failed_ci": true, + "can_rerun_stage": true, "cooldown_interval_minutes": 60, - "reason": "custom override", - "can_rerun_stage": true + "reason": "custom override" + }, + "Makcum888e": { + "can_tag_run_ci_label": true, + "can_rerun_failed_ci": true, + "can_rerun_stage": true, + "cooldown_interval_minutes": 0, + "reason": "top contributor" }, "MingxuZh": { "can_tag_run_ci_label": true, "can_rerun_failed_ci": true, + "can_rerun_stage": true, "cooldown_interval_minutes": 0, - "reason": "custom override", - "can_rerun_stage": true + "reason": "custom override" }, "Oasis-Git": { "can_tag_run_ci_label": true, "can_rerun_failed_ci": true, - "cooldown_interval_minutes": 60, - "reason": "custom override", - "can_rerun_stage": true + "can_rerun_stage": true, + "cooldown_interval_minutes": 0, + "reason": "top contributor" }, "Prozac614": { "can_tag_run_ci_label": true, "can_rerun_failed_ci": true, - "cooldown_interval_minutes": 60, - "reason": "custom override", - "can_rerun_stage": true + "can_rerun_stage": true, + "cooldown_interval_minutes": 0, + "reason": "top contributor" }, "Qiaolin-Yu": { "can_tag_run_ci_label": true, "can_rerun_failed_ci": true, + "can_rerun_stage": true, "cooldown_interval_minutes": 0, - "reason": "top contributor", - "can_rerun_stage": true + "reason": "top contributor" }, "Qihang-Zhang": { "can_tag_run_ci_label": true, "can_rerun_failed_ci": true, + "can_rerun_stage": true, "cooldown_interval_minutes": 60, - "reason": "custom override", - "can_rerun_stage": true + "reason": "custom override" }, "Ratish1": { "can_tag_run_ci_label": true, "can_rerun_failed_ci": true, + "can_rerun_stage": true, "cooldown_interval_minutes": 0, - "reason": "custom override", - "can_rerun_stage": true + "reason": "top contributor" }, "RubiaCx": { "can_tag_run_ci_label": true, "can_rerun_failed_ci": true, + "can_rerun_stage": true, "cooldown_interval_minutes": 60, - "reason": "custom override", - "can_rerun_stage": true + "reason": "custom override" }, "ShangmingCai": { "can_tag_run_ci_label": true, "can_rerun_failed_ci": true, + "can_rerun_stage": true, "cooldown_interval_minutes": 0, - "reason": "top contributor", - "can_rerun_stage": true + "reason": "top contributor" }, "Shunkangz": { "can_tag_run_ci_label": true, "can_rerun_failed_ci": true, + "can_rerun_stage": true, "cooldown_interval_minutes": 60, - "reason": "custom override", - "can_rerun_stage": true + "reason": "custom override" }, "SimonCqk": { "can_tag_run_ci_label": true, "can_rerun_failed_ci": true, + "can_rerun_stage": true, "cooldown_interval_minutes": 60, - "reason": "custom override", - "can_rerun_stage": true + "reason": "custom override" }, "TianQiLin666666": { "can_tag_run_ci_label": true, "can_rerun_failed_ci": true, + "can_rerun_stage": true, "cooldown_interval_minutes": 60, - "reason": "custom override", - "can_rerun_stage": true + "reason": "custom override" }, "Ubospica": { "can_tag_run_ci_label": true, "can_rerun_failed_ci": true, + "can_rerun_stage": true, "cooldown_interval_minutes": 60, - "reason": "custom override", - "can_rerun_stage": true + "reason": "custom override" }, "Valentine233": { "can_tag_run_ci_label": true, "can_rerun_failed_ci": true, + "can_rerun_stage": true, "cooldown_interval_minutes": 0, - "reason": "custom override", - "can_rerun_stage": true + "reason": "custom override" }, "Xia-Weiwen": { "can_tag_run_ci_label": true, "can_rerun_failed_ci": true, + "can_rerun_stage": true, "cooldown_interval_minutes": 0, - "reason": "custom override", - "can_rerun_stage": true + "reason": "custom override" }, "XiaotongJiang": { "can_tag_run_ci_label": true, "can_rerun_failed_ci": true, + "can_rerun_stage": true, "cooldown_interval_minutes": 60, - "reason": "custom override", - "can_rerun_stage": true + "reason": "custom override" }, "XucSh": { "can_tag_run_ci_label": true, "can_rerun_failed_ci": true, + "can_rerun_stage": true, "cooldown_interval_minutes": 0, - "reason": "custom override", - "can_rerun_stage": true + "reason": "top contributor" }, "YAMY1234": { "can_tag_run_ci_label": true, "can_rerun_failed_ci": true, + "can_rerun_stage": true, "cooldown_interval_minutes": 0, - "reason": "custom override", - "can_rerun_stage": true + "reason": "top contributor" }, "Ying1123": { "can_tag_run_ci_label": true, "can_rerun_failed_ci": true, + "can_rerun_stage": true, "cooldown_interval_minutes": 0, - "reason": "custom override", - "can_rerun_stage": true + "reason": "custom override" }, "ZailiWang": { "can_tag_run_ci_label": true, "can_rerun_failed_ci": true, + "can_rerun_stage": true, "cooldown_interval_minutes": 0, - "reason": "custom override", - "can_rerun_stage": true + "reason": "custom override" }, "ZhengWG": { "can_tag_run_ci_label": true, "can_rerun_failed_ci": true, + "can_rerun_stage": true, "cooldown_interval_minutes": 60, - "reason": "custom override", - "can_rerun_stage": true + "reason": "custom override" }, "ZhengdQin": { "can_tag_run_ci_label": true, "can_rerun_failed_ci": true, + "can_rerun_stage": true, "cooldown_interval_minutes": 0, - "reason": "custom override", - "can_rerun_stage": true + "reason": "custom override" }, "acelyc111": { "can_tag_run_ci_label": true, "can_rerun_failed_ci": true, - "cooldown_interval_minutes": 0, - "reason": "top contributor", - "can_rerun_stage": true + "can_rerun_stage": true, + "cooldown_interval_minutes": 60, + "reason": "custom override" }, "adarshxs": { "can_tag_run_ci_label": true, "can_rerun_failed_ci": true, + "can_rerun_stage": true, "cooldown_interval_minutes": 0, - "reason": "custom override", - "can_rerun_stage": true + "reason": "custom override" }, "airMeng": { "can_tag_run_ci_label": true, "can_rerun_failed_ci": true, + "can_rerun_stage": true, "cooldown_interval_minutes": 0, - "reason": "custom override", - "can_rerun_stage": true + "reason": "custom override" }, "alisonshao": { "can_tag_run_ci_label": true, "can_rerun_failed_ci": true, + "can_rerun_stage": true, "cooldown_interval_minutes": 0, - "reason": "top contributor", - "can_rerun_stage": true + "reason": "top contributor" }, "alphabetc1": { "can_tag_run_ci_label": true, "can_rerun_failed_ci": true, + "can_rerun_stage": true, "cooldown_interval_minutes": 0, - "reason": "custom override", - "can_rerun_stage": true + "reason": "top contributor" }, "amysaq2023": { "can_tag_run_ci_label": true, "can_rerun_failed_ci": true, + "can_rerun_stage": true, "cooldown_interval_minutes": 60, - "reason": "custom override", - "can_rerun_stage": true + "reason": "custom override" }, "attack204": { "can_tag_run_ci_label": true, "can_rerun_failed_ci": true, + "can_rerun_stage": true, "cooldown_interval_minutes": 0, - "reason": "custom override", - "can_rerun_stage": true + "reason": "custom override" }, "ayrnb": { "can_tag_run_ci_label": true, "can_rerun_failed_ci": true, + "can_rerun_stage": true, "cooldown_interval_minutes": 60, - "reason": "custom override", - "can_rerun_stage": true + "reason": "custom override" }, "azhurkevich": { "can_tag_run_ci_label": true, "can_rerun_failed_ci": true, + "can_rerun_stage": true, "cooldown_interval_minutes": 60, - "reason": "custom override", - "can_rerun_stage": true + "reason": "custom override" }, "b8zhong": { "can_tag_run_ci_label": true, "can_rerun_failed_ci": true, + "can_rerun_stage": true, + "cooldown_interval_minutes": 0, + "reason": "top contributor" + }, + "bingxche": { + "can_tag_run_ci_label": true, + "can_rerun_failed_ci": true, + "can_rerun_stage": true, "cooldown_interval_minutes": 0, - "reason": "top contributor", - "can_rerun_stage": true + "reason": "top contributor" }, "blzheng": { "can_tag_run_ci_label": true, "can_rerun_failed_ci": true, + "can_rerun_stage": true, "cooldown_interval_minutes": 0, - "reason": "custom override", - "can_rerun_stage": true + "reason": "custom override" }, "byjiang1996": { "can_tag_run_ci_label": true, "can_rerun_failed_ci": true, - "cooldown_interval_minutes": 0, - "reason": "top contributor", - "can_rerun_stage": true + "can_rerun_stage": true, + "cooldown_interval_minutes": 60, + "reason": "custom override" }, "cctry": { "can_tag_run_ci_label": true, "can_rerun_failed_ci": true, + "can_rerun_stage": true, "cooldown_interval_minutes": 0, - "reason": "custom override", - "can_rerun_stage": true + "reason": "custom override" }, "ch-wan": { "can_tag_run_ci_label": true, "can_rerun_failed_ci": true, + "can_rerun_stage": true, + "cooldown_interval_minutes": 0, + "reason": "top contributor" + }, + "chenxu214": { + "can_tag_run_ci_label": true, + "can_rerun_failed_ci": true, + "can_rerun_stage": true, "cooldown_interval_minutes": 0, - "reason": "top contributor", - "can_rerun_stage": true + "reason": "top contributor" }, "chunyuan-w": { "can_tag_run_ci_label": true, "can_rerun_failed_ci": true, + "can_rerun_stage": true, "cooldown_interval_minutes": 0, - "reason": "custom override", - "can_rerun_stage": true + "reason": "custom override" }, "cicirori": { "can_tag_run_ci_label": true, "can_rerun_failed_ci": true, - "cooldown_interval_minutes": 0, - "reason": "top contributor", - "can_rerun_stage": true + "can_rerun_stage": true, + "cooldown_interval_minutes": 60, + "reason": "custom override" }, "cyb70289": { "can_tag_run_ci_label": true, "can_rerun_failed_ci": true, + "can_rerun_stage": true, "cooldown_interval_minutes": 0, - "reason": "custom override", - "can_rerun_stage": true + "reason": "custom override" }, "dongjiyingdjy": { "can_tag_run_ci_label": true, "can_rerun_failed_ci": true, + "can_rerun_stage": true, "cooldown_interval_minutes": 60, - "reason": "custom override", - "can_rerun_stage": true + "reason": "custom override" }, "dougyster": { "can_tag_run_ci_label": true, "can_rerun_failed_ci": true, + "can_rerun_stage": true, "cooldown_interval_minutes": 0, - "reason": "custom override", - "can_rerun_stage": true + "reason": "top contributor" }, "elfiegg": { "can_tag_run_ci_label": true, "can_rerun_failed_ci": true, + "can_rerun_stage": true, "cooldown_interval_minutes": 60, - "reason": "custom override", - "can_rerun_stage": true + "reason": "custom override" }, "fy1214": { "can_tag_run_ci_label": true, "can_rerun_failed_ci": true, + "can_rerun_stage": true, "cooldown_interval_minutes": 60, - "reason": "custom override", - "can_rerun_stage": true + "reason": "custom override" }, "fzyzcjy": { "can_tag_run_ci_label": true, "can_rerun_failed_ci": true, + "can_rerun_stage": true, "cooldown_interval_minutes": 0, - "reason": "top contributor", - "can_rerun_stage": true + "reason": "top contributor" }, "gaopengff": { "can_tag_run_ci_label": true, "can_rerun_failed_ci": true, + "can_rerun_stage": true, "cooldown_interval_minutes": 0, - "reason": "custom override", - "can_rerun_stage": true + "reason": "custom override" }, - "gongwei-130": { + "glenliu21": { "can_tag_run_ci_label": true, "can_rerun_failed_ci": true, + "can_rerun_stage": true, "cooldown_interval_minutes": 0, - "reason": "top contributor", - "can_rerun_stage": true + "reason": "top contributor" + }, + "gongwei-130": { + "can_tag_run_ci_label": true, + "can_rerun_failed_ci": true, + "can_rerun_stage": true, + "cooldown_interval_minutes": 60, + "reason": "custom override" }, "gongy": { "can_tag_run_ci_label": true, "can_rerun_failed_ci": true, + "can_rerun_stage": true, "cooldown_interval_minutes": 60, - "reason": "custom override", - "can_rerun_stage": true + "reason": "custom override" }, "guapisolo": { "can_tag_run_ci_label": true, "can_rerun_failed_ci": true, + "can_rerun_stage": true, "cooldown_interval_minutes": 60, - "reason": "custom override", - "can_rerun_stage": true + "reason": "custom override" }, "guoyuhong": { "can_tag_run_ci_label": true, "can_rerun_failed_ci": true, + "can_rerun_stage": true, "cooldown_interval_minutes": 60, - "reason": "custom override", - "can_rerun_stage": true + "reason": "custom override" }, "hanming-lu": { "can_tag_run_ci_label": true, "can_rerun_failed_ci": true, + "can_rerun_stage": true, "cooldown_interval_minutes": 0, - "reason": "custom override", - "can_rerun_stage": true + "reason": "custom override" }, "harrisonlimh": { "can_tag_run_ci_label": true, "can_rerun_failed_ci": true, + "can_rerun_stage": true, "cooldown_interval_minutes": 0, - "reason": "custom override", - "can_rerun_stage": true + "reason": "custom override" }, "harvenstar": { "can_tag_run_ci_label": true, "can_rerun_failed_ci": true, + "can_rerun_stage": true, "cooldown_interval_minutes": 0, - "reason": "custom override", - "can_rerun_stage": true + "reason": "top contributor" }, "hebiao064": { "can_tag_run_ci_label": true, "can_rerun_failed_ci": true, - "cooldown_interval_minutes": 0, - "reason": "top contributor", - "can_rerun_stage": true + "can_rerun_stage": true, + "cooldown_interval_minutes": 60, + "reason": "custom override" }, "hlu1": { "can_tag_run_ci_label": true, "can_rerun_failed_ci": true, - "cooldown_interval_minutes": 0, - "reason": "top contributor", - "can_rerun_stage": true + "can_rerun_stage": true, + "cooldown_interval_minutes": 60, + "reason": "custom override" }, "hnyls2002": { "can_tag_run_ci_label": true, "can_rerun_failed_ci": true, + "can_rerun_stage": true, "cooldown_interval_minutes": 0, - "reason": "top contributor", - "can_rerun_stage": true + "reason": "top contributor" }, "huaiyuzh": { "can_tag_run_ci_label": true, "can_rerun_failed_ci": true, + "can_rerun_stage": true, "cooldown_interval_minutes": 0, - "reason": "custom override", - "can_rerun_stage": true + "reason": "custom override" }, "huangtingwei9988": { "can_tag_run_ci_label": true, "can_rerun_failed_ci": true, - "cooldown_interval_minutes": 0, - "reason": "top contributor", - "can_rerun_stage": true + "can_rerun_stage": true, + "cooldown_interval_minutes": 60, + "reason": "custom override" }, "hubertlu-tw": { "can_tag_run_ci_label": true, "can_rerun_failed_ci": true, + "can_rerun_stage": true, "cooldown_interval_minutes": 0, - "reason": "top contributor", - "can_rerun_stage": true + "reason": "top contributor" }, "hyhieu": { "can_tag_run_ci_label": true, "can_rerun_failed_ci": true, + "can_rerun_stage": true, "cooldown_interval_minutes": 60, - "reason": "custom override", - "can_rerun_stage": true + "reason": "custom override" }, "hzh0425": { "can_tag_run_ci_label": true, "can_rerun_failed_ci": true, + "can_rerun_stage": true, "cooldown_interval_minutes": 0, - "reason": "top contributor", - "can_rerun_stage": true + "reason": "top contributor" }, "iforgetmyname": { "can_tag_run_ci_label": true, "can_rerun_failed_ci": true, + "can_rerun_stage": true, "cooldown_interval_minutes": 0, - "reason": "top contributor", - "can_rerun_stage": true + "reason": "top contributor" }, "ishandhanani": { "can_tag_run_ci_label": true, "can_rerun_failed_ci": true, + "can_rerun_stage": true, "cooldown_interval_minutes": 0, - "reason": "top contributor", - "can_rerun_stage": true + "reason": "top contributor" }, "ispobock": { "can_tag_run_ci_label": true, "can_rerun_failed_ci": true, + "can_rerun_stage": true, "cooldown_interval_minutes": 0, - "reason": "top contributor", - "can_rerun_stage": true + "reason": "top contributor" }, "jason-fxz": { "can_tag_run_ci_label": true, "can_rerun_failed_ci": true, + "can_rerun_stage": true, "cooldown_interval_minutes": 60, - "reason": "custom override", - "can_rerun_stage": true + "reason": "custom override" }, "jasperjiaguo": { "can_tag_run_ci_label": true, "can_rerun_failed_ci": true, + "can_rerun_stage": true, "cooldown_interval_minutes": 0, - "reason": "custom override", - "can_rerun_stage": true + "reason": "custom override" }, "jhinpan": { "can_tag_run_ci_label": true, "can_rerun_failed_ci": true, + "can_rerun_stage": true, "cooldown_interval_minutes": 60, - "reason": "custom override", - "can_rerun_stage": true + "reason": "custom override" }, "jianan-gu": { "can_tag_run_ci_label": true, "can_rerun_failed_ci": true, + "can_rerun_stage": true, "cooldown_interval_minutes": 0, - "reason": "custom override", - "can_rerun_stage": true + "reason": "custom override" }, "jinleic": { "can_tag_run_ci_label": true, "can_rerun_failed_ci": true, + "can_rerun_stage": true, "cooldown_interval_minutes": 60, - "reason": "custom override", - "can_rerun_stage": true + "reason": "custom override" }, "jinmingyi1998": { "can_tag_run_ci_label": true, "can_rerun_failed_ci": true, - "cooldown_interval_minutes": 0, - "reason": "top contributor", - "can_rerun_stage": true + "can_rerun_stage": true, + "cooldown_interval_minutes": 60, + "reason": "custom override" }, "kaixih": { "can_tag_run_ci_label": true, "can_rerun_failed_ci": true, - "cooldown_interval_minutes": 0, - "reason": "top contributor", - "can_rerun_stage": true + "can_rerun_stage": true, + "cooldown_interval_minutes": 60, + "reason": "custom override" }, "kevin85421": { "can_tag_run_ci_label": true, "can_rerun_failed_ci": true, + "can_rerun_stage": true, "cooldown_interval_minutes": 0, - "reason": "custom override", - "can_rerun_stage": true + "reason": "custom override" }, "key4ng": { "can_tag_run_ci_label": true, "can_rerun_failed_ci": true, - "cooldown_interval_minutes": 0, - "reason": "top contributor", - "can_rerun_stage": true + "can_rerun_stage": true, + "cooldown_interval_minutes": 60, + "reason": "custom override" }, "kkHuang-amd": { "can_tag_run_ci_label": true, "can_rerun_failed_ci": true, + "can_rerun_stage": true, "cooldown_interval_minutes": 0, - "reason": "top contributor", - "can_rerun_stage": true + "reason": "top contributor" }, "kpham-sgl": { "can_tag_run_ci_label": true, "can_rerun_failed_ci": true, + "can_rerun_stage": true, "cooldown_interval_minutes": 0, - "reason": "custom override", - "can_rerun_stage": true + "reason": "custom override" }, "kssteven418": { "can_tag_run_ci_label": true, "can_rerun_failed_ci": true, + "can_rerun_stage": true, "cooldown_interval_minutes": 0, - "reason": "custom override", - "can_rerun_stage": true + "reason": "custom override" }, "kushanam": { "can_tag_run_ci_label": true, "can_rerun_failed_ci": true, + "can_rerun_stage": true, "cooldown_interval_minutes": 60, - "reason": "custom override", - "can_rerun_stage": true + "reason": "custom override" }, "lanking520": { "can_tag_run_ci_label": true, "can_rerun_failed_ci": true, + "can_rerun_stage": true, "cooldown_interval_minutes": 60, - "reason": "custom override", - "can_rerun_stage": true + "reason": "custom override" }, "lifuhuang": { "can_tag_run_ci_label": true, "can_rerun_failed_ci": true, + "can_rerun_stage": true, + "cooldown_interval_minutes": 60, + "reason": "custom override" + }, + "liupeng374": { + "can_tag_run_ci_label": true, + "can_rerun_failed_ci": true, + "can_rerun_stage": true, "cooldown_interval_minutes": 0, - "reason": "top contributor", - "can_rerun_stage": true + "reason": "top contributor" }, "liusy58": { "can_tag_run_ci_label": true, "can_rerun_failed_ci": true, + "can_rerun_stage": true, "cooldown_interval_minutes": 0, - "reason": "custom override", - "can_rerun_stage": true + "reason": "top contributor" }, "liz-badada": { "can_tag_run_ci_label": true, "can_rerun_failed_ci": true, + "can_rerun_stage": true, "cooldown_interval_minutes": 60, - "reason": "custom override", - "can_rerun_stage": true + "reason": "custom override" }, "merrymercy": { "can_tag_run_ci_label": true, "can_rerun_failed_ci": true, + "can_rerun_stage": true, + "cooldown_interval_minutes": 0, + "reason": "top contributor" + }, + "michaelzhang-ai": { + "can_tag_run_ci_label": true, + "can_rerun_failed_ci": true, + "can_rerun_stage": true, "cooldown_interval_minutes": 0, - "reason": "top contributor", - "can_rerun_stage": true + "reason": "top contributor" }, "mickqian": { "can_tag_run_ci_label": true, "can_rerun_failed_ci": true, + "can_rerun_stage": true, "cooldown_interval_minutes": 0, - "reason": "top contributor", - "can_rerun_stage": true + "reason": "top contributor" }, "mingfeima": { "can_tag_run_ci_label": true, "can_rerun_failed_ci": true, + "can_rerun_stage": true, "cooldown_interval_minutes": 0, - "reason": "custom override", - "can_rerun_stage": true + "reason": "custom override" }, "minleminzui": { "can_tag_run_ci_label": true, "can_rerun_failed_ci": true, + "can_rerun_stage": true, "cooldown_interval_minutes": 60, - "reason": "custom override", - "can_rerun_stage": true + "reason": "custom override" }, "mmangkad": { "can_tag_run_ci_label": true, "can_rerun_failed_ci": true, + "can_rerun_stage": true, "cooldown_interval_minutes": 0, - "reason": "custom override", - "can_rerun_stage": true + "reason": "top contributor" }, "narutolhy": { "can_tag_run_ci_label": true, "can_rerun_failed_ci": true, + "can_rerun_stage": true, "cooldown_interval_minutes": 0, - "reason": "custom override", - "can_rerun_stage": true + "reason": "custom override" }, "netanel-haber": { "can_tag_run_ci_label": true, "can_rerun_failed_ci": true, + "can_rerun_stage": true, "cooldown_interval_minutes": 0, - "reason": "custom override", - "can_rerun_stage": true + "reason": "custom override" }, "nvcastet": { "can_tag_run_ci_label": true, "can_rerun_failed_ci": true, + "can_rerun_stage": true, "cooldown_interval_minutes": 0, - "reason": "custom override", - "can_rerun_stage": true + "reason": "custom override" }, "ocss884": { "can_tag_run_ci_label": true, "can_rerun_failed_ci": true, + "can_rerun_stage": true, "cooldown_interval_minutes": 60, - "reason": "custom override", - "can_rerun_stage": true + "reason": "custom override" }, "pansicheng": { "can_tag_run_ci_label": true, "can_rerun_failed_ci": true, - "cooldown_interval_minutes": 0, - "reason": "top contributor", - "can_rerun_stage": true + "can_rerun_stage": true, + "cooldown_interval_minutes": 60, + "reason": "custom override" }, "pavanimajety": { "can_tag_run_ci_label": true, "can_rerun_failed_ci": true, + "can_rerun_stage": true, "cooldown_interval_minutes": 60, - "reason": "custom override", - "can_rerun_stage": true + "reason": "custom override" }, "pdasgup": { "can_tag_run_ci_label": true, "can_rerun_failed_ci": true, + "can_rerun_stage": true, "cooldown_interval_minutes": 0, - "reason": "custom override", - "can_rerun_stage": true + "reason": "custom override" }, "ping1jing2": { "can_tag_run_ci_label": true, "can_rerun_failed_ci": true, + "can_rerun_stage": true, "cooldown_interval_minutes": 0, - "reason": "custom override", - "can_rerun_stage": true + "reason": "custom override" + }, + "ppraneth": { + "can_tag_run_ci_label": true, + "can_rerun_failed_ci": true, + "can_rerun_stage": true, + "cooldown_interval_minutes": 0, + "reason": "top contributor" }, "pranavm-nvidia": { "can_tag_run_ci_label": true, "can_rerun_failed_ci": true, + "can_rerun_stage": true, "cooldown_interval_minutes": 60, - "reason": "custom override", - "can_rerun_stage": true + "reason": "custom override" }, "pyc96": { "can_tag_run_ci_label": true, "can_rerun_failed_ci": true, + "can_rerun_stage": true, "cooldown_interval_minutes": 0, - "reason": "custom override", - "can_rerun_stage": true + "reason": "custom override" }, "qingquansong": { "can_tag_run_ci_label": true, "can_rerun_failed_ci": true, + "can_rerun_stage": true, "cooldown_interval_minutes": 60, - "reason": "custom override", - "can_rerun_stage": true + "reason": "custom override" }, "qywu": { "can_tag_run_ci_label": true, "can_rerun_failed_ci": true, + "can_rerun_stage": true, "cooldown_interval_minutes": 60, - "reason": "custom override", - "can_rerun_stage": true + "reason": "custom override" }, "rainj-me": { "can_tag_run_ci_label": true, "can_rerun_failed_ci": true, + "can_rerun_stage": true, "cooldown_interval_minutes": 0, - "reason": "top contributor", - "can_rerun_stage": true + "reason": "top contributor" }, "ravi03071991": { "can_tag_run_ci_label": true, "can_rerun_failed_ci": true, + "can_rerun_stage": true, "cooldown_interval_minutes": 60, - "reason": "custom override", - "can_rerun_stage": true + "reason": "custom override" }, "rkooo567": { "can_tag_run_ci_label": true, "can_rerun_failed_ci": true, + "can_rerun_stage": true, "cooldown_interval_minutes": 60, - "reason": "custom override", - "can_rerun_stage": true + "reason": "custom override" }, - "saienduri": { + "roikoren755": { "can_tag_run_ci_label": true, "can_rerun_failed_ci": true, + "can_rerun_stage": true, "cooldown_interval_minutes": 0, - "reason": "top contributor", - "can_rerun_stage": true + "reason": "top contributor" + }, + "saienduri": { + "can_tag_run_ci_label": true, + "can_rerun_failed_ci": true, + "can_rerun_stage": true, + "cooldown_interval_minutes": 60, + "reason": "custom override" }, "samuellees": { "can_tag_run_ci_label": true, "can_rerun_failed_ci": true, + "can_rerun_stage": true, "cooldown_interval_minutes": 60, - "reason": "custom override", - "can_rerun_stage": true + "reason": "custom override" }, "scottjlee": { "can_tag_run_ci_label": true, "can_rerun_failed_ci": true, - "cooldown_interval_minutes": 0, - "reason": "top contributor", - "can_rerun_stage": true + "can_rerun_stage": true, + "cooldown_interval_minutes": 60, + "reason": "custom override" }, "sglang-bot": { "can_tag_run_ci_label": true, "can_rerun_failed_ci": true, + "can_rerun_stage": true, "cooldown_interval_minutes": 0, - "reason": "top contributor", - "can_rerun_stage": true + "reason": "top contributor" }, "sglang-npu-bot": { "can_tag_run_ci_label": true, "can_rerun_failed_ci": true, + "can_rerun_stage": true, "cooldown_interval_minutes": 0, - "reason": "custom override", - "can_rerun_stage": true + "reason": "custom override" }, "shaharmor98": { "can_tag_run_ci_label": true, "can_rerun_failed_ci": true, + "can_rerun_stage": true, "cooldown_interval_minutes": 60, - "reason": "custom override", - "can_rerun_stage": true + "reason": "custom override" }, "shanyu-sys": { "can_tag_run_ci_label": true, "can_rerun_failed_ci": true, + "can_rerun_stage": true, "cooldown_interval_minutes": 60, - "reason": "custom override", - "can_rerun_stage": true + "reason": "custom override" }, "shuaills": { "can_tag_run_ci_label": true, "can_rerun_failed_ci": true, + "can_rerun_stage": true, "cooldown_interval_minutes": 60, - "reason": "custom override", - "can_rerun_stage": true + "reason": "custom override" }, "sleepcoo": { "can_tag_run_ci_label": true, "can_rerun_failed_ci": true, + "can_rerun_stage": true, "cooldown_interval_minutes": 0, - "reason": "custom override", - "can_rerun_stage": true + "reason": "custom override" }, "slin1237": { "can_tag_run_ci_label": true, "can_rerun_failed_ci": true, + "can_rerun_stage": true, "cooldown_interval_minutes": 0, - "reason": "top contributor", - "can_rerun_stage": true + "reason": "top contributor" }, "stmatengss": { "can_tag_run_ci_label": true, "can_rerun_failed_ci": true, - "cooldown_interval_minutes": 0, - "reason": "top contributor", - "can_rerun_stage": true + "can_rerun_stage": true, + "cooldown_interval_minutes": 60, + "reason": "custom override" }, "strgrb": { "can_tag_run_ci_label": true, "can_rerun_failed_ci": true, + "can_rerun_stage": true, "cooldown_interval_minutes": 0, - "reason": "custom override", - "can_rerun_stage": true + "reason": "custom override" }, "sufeng-buaa": { "can_tag_run_ci_label": true, "can_rerun_failed_ci": true, + "can_rerun_stage": true, "cooldown_interval_minutes": 0, - "reason": "custom override", - "can_rerun_stage": true + "reason": "custom override" }, "sundar24295s": { "can_tag_run_ci_label": true, "can_rerun_failed_ci": true, + "can_rerun_stage": true, "cooldown_interval_minutes": 0, - "reason": "custom override", - "can_rerun_stage": true + "reason": "custom override" }, "sunjiweiswift": { "can_tag_run_ci_label": true, "can_rerun_failed_ci": true, + "can_rerun_stage": true, "cooldown_interval_minutes": 0, - "reason": "custom override", - "can_rerun_stage": true + "reason": "custom override" }, "sunxxuns": { "can_tag_run_ci_label": true, "can_rerun_failed_ci": true, + "can_rerun_stage": true, "cooldown_interval_minutes": 0, - "reason": "custom override", - "can_rerun_stage": true + "reason": "custom override" }, "thecodingwizard": { "can_tag_run_ci_label": true, "can_rerun_failed_ci": true, + "can_rerun_stage": true, "cooldown_interval_minutes": 60, - "reason": "custom override", - "can_rerun_stage": true + "reason": "custom override" }, "timmy-feng": { "can_tag_run_ci_label": true, "can_rerun_failed_ci": true, + "can_rerun_stage": true, "cooldown_interval_minutes": 60, - "reason": "custom override", - "can_rerun_stage": true + "reason": "custom override" }, "trevor-m": { "can_tag_run_ci_label": true, "can_rerun_failed_ci": true, - "cooldown_interval_minutes": 0, - "reason": "top contributor", - "can_rerun_stage": true + "can_rerun_stage": true, + "cooldown_interval_minutes": 60, + "reason": "custom override" }, "vincentzed": { "can_tag_run_ci_label": true, "can_rerun_failed_ci": true, + "can_rerun_stage": true, "cooldown_interval_minutes": 0, - "reason": "top contributor", - "can_rerun_stage": true + "reason": "top contributor" }, "wenscarl": { "can_tag_run_ci_label": true, "can_rerun_failed_ci": true, + "can_rerun_stage": true, "cooldown_interval_minutes": 60, - "reason": "custom override", - "can_rerun_stage": true + "reason": "custom override" }, "whybeyoung": { "can_tag_run_ci_label": true, "can_rerun_failed_ci": true, + "can_rerun_stage": true, "cooldown_interval_minutes": 0, - "reason": "top contributor", - "can_rerun_stage": true + "reason": "top contributor" }, "wisclmy0611": { "can_tag_run_ci_label": true, "can_rerun_failed_ci": true, + "can_rerun_stage": true, "cooldown_interval_minutes": 0, - "reason": "custom override", - "can_rerun_stage": true + "reason": "custom override" }, "xiezhq-hermann": { "can_tag_run_ci_label": true, "can_rerun_failed_ci": true, - "cooldown_interval_minutes": 0, - "reason": "top contributor", - "can_rerun_stage": true + "can_rerun_stage": true, + "cooldown_interval_minutes": 60, + "reason": "custom override" }, "xutizhou": { "can_tag_run_ci_label": true, "can_rerun_failed_ci": true, + "can_rerun_stage": true, "cooldown_interval_minutes": 0, - "reason": "custom override", - "can_rerun_stage": true + "reason": "custom override" }, "xyjixyjixyji": { "can_tag_run_ci_label": true, "can_rerun_failed_ci": true, + "can_rerun_stage": true, "cooldown_interval_minutes": 0, - "reason": "custom override", - "can_rerun_stage": true + "reason": "custom override" }, "yanbing-j": { "can_tag_run_ci_label": true, "can_rerun_failed_ci": true, + "can_rerun_stage": true, "cooldown_interval_minutes": 0, - "reason": "custom override", - "can_rerun_stage": true + "reason": "custom override" }, "yangsijia-serena": { "can_tag_run_ci_label": true, "can_rerun_failed_ci": true, + "can_rerun_stage": true, "cooldown_interval_minutes": 60, - "reason": "custom override", - "can_rerun_stage": true + "reason": "custom override" + }, + "yctseng0211": { + "can_tag_run_ci_label": true, + "can_rerun_failed_ci": true, + "can_rerun_stage": true, + "cooldown_interval_minutes": 0, + "reason": "top contributor" }, "yeahdongcn": { "can_tag_run_ci_label": true, "can_rerun_failed_ci": true, + "can_rerun_stage": true, "cooldown_interval_minutes": 0, - "reason": "custom override", - "can_rerun_stage": true + "reason": "top contributor" }, "yhyang201": { "can_tag_run_ci_label": true, "can_rerun_failed_ci": true, + "can_rerun_stage": true, "cooldown_interval_minutes": 0, - "reason": "top contributor", - "can_rerun_stage": true + "reason": "top contributor" }, "yilian49": { "can_tag_run_ci_label": true, "can_rerun_failed_ci": true, + "can_rerun_stage": true, "cooldown_interval_minutes": 60, - "reason": "custom override", - "can_rerun_stage": true + "reason": "custom override" }, "yinghai": { "can_tag_run_ci_label": true, "can_rerun_failed_ci": true, + "can_rerun_stage": true, "cooldown_interval_minutes": 0, - "reason": "custom override", - "can_rerun_stage": true + "reason": "custom override" }, "yingluosanqian": { "can_tag_run_ci_label": true, "can_rerun_failed_ci": true, + "can_rerun_stage": true, "cooldown_interval_minutes": 0, - "reason": "custom override", - "can_rerun_stage": true + "reason": "custom override" }, "yizhang2077": { "can_tag_run_ci_label": true, "can_rerun_failed_ci": true, - "cooldown_interval_minutes": 0, - "reason": "top contributor", - "can_rerun_stage": true + "can_rerun_stage": true, + "cooldown_interval_minutes": 60, + "reason": "custom override" }, "ykcombat": { "can_tag_run_ci_label": true, "can_rerun_failed_ci": true, + "can_rerun_stage": true, "cooldown_interval_minutes": 60, - "reason": "custom override", - "can_rerun_stage": true + "reason": "custom override" }, "ynwang007": { "can_tag_run_ci_label": true, "can_rerun_failed_ci": true, + "can_rerun_stage": true, "cooldown_interval_minutes": 60, - "reason": "custom override", - "can_rerun_stage": true + "reason": "custom override" }, "yuan-luo": { "can_tag_run_ci_label": true, "can_rerun_failed_ci": true, + "can_rerun_stage": true, "cooldown_interval_minutes": 0, - "reason": "top contributor", - "can_rerun_stage": true + "reason": "top contributor" }, "yundai424": { "can_tag_run_ci_label": true, "can_rerun_failed_ci": true, + "can_rerun_stage": true, "cooldown_interval_minutes": 60, - "reason": "custom override", - "can_rerun_stage": true + "reason": "custom override" }, "yushengsu-thu": { "can_tag_run_ci_label": true, "can_rerun_failed_ci": true, - "cooldown_interval_minutes": 0, - "reason": "top contributor", - "can_rerun_stage": true + "can_rerun_stage": true, + "cooldown_interval_minutes": 60, + "reason": "custom override" }, "yyihuang": { "can_tag_run_ci_label": true, "can_rerun_failed_ci": true, + "can_rerun_stage": true, "cooldown_interval_minutes": 60, - "reason": "custom override", - "can_rerun_stage": true + "reason": "custom override" }, "yzh119": { "can_tag_run_ci_label": true, "can_rerun_failed_ci": true, + "can_rerun_stage": true, "cooldown_interval_minutes": 60, - "reason": "custom override", - "can_rerun_stage": true + "reason": "custom override" }, "zhaochenyang20": { "can_tag_run_ci_label": true, "can_rerun_failed_ci": true, + "can_rerun_stage": true, "cooldown_interval_minutes": 0, - "reason": "custom override", - "can_rerun_stage": true + "reason": "custom override" }, "zhijian-liu": { "can_tag_run_ci_label": true, "can_rerun_failed_ci": true, + "can_rerun_stage": true, "cooldown_interval_minutes": 60, - "reason": "custom override", - "can_rerun_stage": true + "reason": "custom override" }, "zhuzilin": { "can_tag_run_ci_label": true, "can_rerun_failed_ci": true, - "cooldown_interval_minutes": 0, - "reason": "top contributor", - "can_rerun_stage": true + "can_rerun_stage": true, + "cooldown_interval_minutes": 60, + "reason": "custom override" }, "zhyncs": { "can_tag_run_ci_label": true, "can_rerun_failed_ci": true, + "can_rerun_stage": true, "cooldown_interval_minutes": 0, - "reason": "top contributor", - "can_rerun_stage": true + "reason": "top contributor" }, "zminglei": { "can_tag_run_ci_label": true, "can_rerun_failed_ci": true, + "can_rerun_stage": true, "cooldown_interval_minutes": 0, - "reason": "top contributor", - "can_rerun_stage": true + "reason": "top contributor" }, "zyksir": { "can_tag_run_ci_label": true, "can_rerun_failed_ci": true, + "can_rerun_stage": true, "cooldown_interval_minutes": 60, - "reason": "custom override", - "can_rerun_stage": true + "reason": "custom override" }, "zyzshishui": { "can_tag_run_ci_label": true, "can_rerun_failed_ci": true, + "can_rerun_stage": true, "cooldown_interval_minutes": 0, - "reason": "custom override", - "can_rerun_stage": true + "reason": "custom override" } } diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index bd279bf31eb3..b1cd1617d53a 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -9,8 +9,8 @@ /python/sglang/multimodal_gen/runtime/layers @mickqian @yhyang201 @BBuf @yingluosanqian @ping1jing2 /python/sglang/multimodal_gen/runtime/models/dits @mickqian @yhyang201 @BBuf @yingluosanqian @ping1jing2 /python/sglang/srt/batch_invariant_ops @Fridge003 @hebiao064 +/python/sglang/srt/compilation @hebiao064 @Oasis-Git /python/sglang/srt/constrained @hnyls2002 @DarkSharpness -/python/sglang/srt/compilation @hebiao064 /python/sglang/srt/disaggregation @ByronHsu @hnyls2002 @ShangmingCai /python/sglang/srt/disaggregation/ascend @ping1jing2 @iforgetmyname /python/sglang/srt/distributed @yizhang2077 @merrymercy @ch-wan @@ -40,6 +40,7 @@ /python/sglang/srt/model_executor/piecewise_cuda_graph_runner.py @hebiao064 /python/sglang/srt/models/deepseek_common @Fridge003 @ispobock @fzyzcjy @ch-wan /python/sglang/srt/models/deepseek_v2.py @fzyzcjy @zhyncs @ispobock @ch-wan @merrymercy @Fridge003 +/python/sglang/srt/models/transformers.py @adarshxs /python/sglang/srt/multimodal @mickqian @JustinTong0323 @yhyang201 @yuan-luo /python/sglang/srt/observability @merrymercy @fzyzcjy @sufeng-buaa /python/sglang/srt/ray @Qiaolin-Yu @xyuzh @@ -49,6 +50,7 @@ /sgl-model-gateway/benches @slin1237 /sgl-model-gateway/bindings/python @CatherineSue @key4ng @slin1237 /sgl-model-gateway/e2e_test @CatherineSue @key4ng +/sgl-model-gateway/examples/wasm @slin1237 /sgl-model-gateway/src/config @slin1237 /sgl-model-gateway/src/core @slin1237 /sgl-model-gateway/src/data_connector @key4ng @@ -62,6 +64,5 @@ /sgl-model-gateway/src/tokenizer @slin1237 @CatherineSue /sgl-model-gateway/src/tool_parser @slin1237 @CatherineSue /sgl-model-gateway/src/wasm @slin1237 -/sgl-model-gateway/examples/wasm @slin1237 /test/srt/ascend @ping1jing2 @iforgetmyname /test/srt/test_modelopt* @Edwardf0t1 diff --git a/.github/MAINTAINER.md b/.github/MAINTAINER.md index 6b39d364b9e8..58b71196c948 100644 --- a/.github/MAINTAINER.md +++ b/.github/MAINTAINER.md @@ -37,34 +37,118 @@ __Note__: The permissions to trigger CI tests are defined separately according t - **Ideal case:** For each modified file, one Codeowner has approved the PR. The PR has also passed the required CI tests. Then, anyone with write permission can merge the PR. - **Exception:** In cases where it is difficult to meet all requirements (due to flaky CI or slow responses), a Merge Oncall can bypass branch protection to merge the PR. -If you meet any issues during the merge, you can discuss in [slack channels](https://slack.sglang.io/): #dev, #pull-request, and #ci-cd-build-release. +If you meet any issues during the merge, you can discuss in [slack channels](https://slack.sglang.io/): #pull-request, #ci-cd-build-release, #dev. ## The List of Merge Oncalls and Reviewers +This section lists the oncalls for each module or feature. The format is @github-username (Slack username). -TODO: fill in the list. +### Scheduler +[@merrymercy](https://github.com/merrymercy) (Lianmin Zheng), [@hnyls2002](https://github.com/hnyls2002) (Liangsheng Yin), [@cctry](https://github.com/cctry) (Shiyang Chen) + +related files +- python/sglang/srt/managers +- python/sglang/srt/model_executor + +### Diffusion +[@mickqian](https://github.com/mickqian) (Mick), [@BBuf](https://github.com/BBuf) (BBuf) + +related files +- python/sglang/multimodal_gen + +### PD disaggregation +[@ByronHsu](https://github.com/ByronHsu) (Byron Hsu), [@cctry](https://github.com/cctry) (Shiyang Chen), [@ShangmingCai](https://github.com/ShangmingCai) (Shangming Cai) + +related files +- python/sglang/srt/disaggregation + +### KV Cache +[@ispobock](https://github.com/ispobock) (Ke Bao), [@xiezhq-hermann](https://github.com/xiezhq-hermann) (Zhiqiang Xie) + +related files +- python/sglang/srt/mem_cache + +### Parallelism +[@ch-wan](https://github.com/ch-wan) (Cheng Wan), [@fzyzcjy](https://github.com/fzyzcjy) (Tom) + +related files +- python/sglang/srt/eplb +- python/sglang/srt/distributed +- python/sglang/srt/layers/dp_attention.py + +### Kernel +[@BBuf](https://github.com/BBuf) (BBuf) + +related files +- python/sglang/jit_kernel +- sgl-kernel + +### Speculative decoding +[@hnyls2002](https://github.com/hnyls2002) (Liangsheng Yin), [@Qiaolin-Yu](https://github.com/Qiaolin-Yu) (Qiaolin Yu) + +related files +- python/sglang/srt/speculative + +### NV and model-specific optimizations +[@Fridge003](https://github.com/Fridge003) (Baizhou Zhang), [@ishandhanani](https://github.com/ishandhanani) (Ishan Dhanani), [@Qiaolin-Yu](https://github.com/Qiaolin-Yu) (Qiaolin Yu) + +related files +- python/sglang/srt/models +- python/sglang/srt/layers/attention + +### AMD optimizations +[@HaiShaw](https://github.com/HaiShaw) (Henry HAI) + +### NPU optimizations +[@iforgetmyname](https://github.com/iforgetmyname) (Even Zhou) + +related files +- python/sglang/srt/hardware_backend/npu + +### CI, Release, Package +[@Kangyan-Zhou](https://github.com/Kangyan-Zhou) (Kangyan Zhou), [@Fridge003](https://github.com/Fridge003) (Baizhou Zhang) + +related files +- .github/workflows + +### Router, API +[@slin1237](https://github.com/slin1237) (Simo Lin) + +related files +- sgl-model-gateway +- python/sglang/srt/grpc +- python/sglang/srt/entrypoints + +### Other Notes Now we have many Merge Oncalls mainly because the CI is flaky and the CODEOWNERS is too coarse-grained. In the future, we hope the CI can be improved and we only need bypass rarely. After that, most Merge Oncalls can be converted back to Write and CODEOWNERS. -This list is based on the current situation. If you or someone you know would like to take on more responsibility and are qualified, please ping @Lianmin Zheng and @Ying Sheng in the Slack channel. They will start a nomination and internal review process. +This list is based on the current situation. If you or someone you know would like to take on more responsibility and are qualified, please ping [Lianmin Zheng](https://github.com/merrymercy) and [Ying Sheng](https://github.com/Ying1123) in the Slack channel. They will start a nomination and internal review process. ## The List of CI Oncalls -The format is @github-username (Slack username). +This section lists the oncalls for each hardware platform. The format is @github-username (Slack username). ### NVIDIA GPUs -@merrymercy (Lianmin Zheng), @Kangyan-Zhou (Kangyan Zhou), @ch-wan (Cheng Wan), @HanHan009527 (hanhan), @ishandhanani (Ishan Dhanani), @key4ng (Keyang Ru), @slin1237 (Simo Lin), @ShangmingCai (Shangming Cai) +[@Kangyan-Zhou](https://github.com/Kangyan-Zhou) (Kangyan Zhou), [@ch-wan](https://github.com/ch-wan) (Cheng Wan), [@HanHan009527](https://github.com/HanHan009527) (hanhan), [@ishandhanani](https://github.com/ishandhanani) (Ishan Dhanani), [@ShangmingCai](https://github.com/ShangmingCai) (Shangming Cai), [@alisonshao](https://github.com/alisonshao) (Alison Shao). ### AMD GPUs -@saienduri (Sai Enduri), @HaiShaw (Henry HAI) +[@saienduri](https://github.com/saienduri) (Sai Enduri), [@HaiShaw](https://github.com/HaiShaw) (Henry HAI) ### Intel CPU and XPU -@mingfeima (Mingfei Ma), @DiweiSun (Diwei Sun) +[@mingfeima](https://github.com/mingfeima) (Mingfei Ma), [@DiweiSun](https://github.com/DiweiSun) (Diwei Sun) ### Ascend NPUs -@iforgetmyname (Even Zhou) +[@iforgetmyname](https://github.com/iforgetmyname) (Even Zhou) + +This list is based on the current situation. If you or someone you know would like to donate machines for CI, they can serve as the CI oncalls for their machines. Please ping [Lianmin Zheng](https://github.com/merrymercy) and [Ying Sheng](https://github.com/Ying1123) in the Slack channel. They will start a nomination and internal review process. + +## CI Maintenance Mode +When the CI is unhealthy (e.g., the scheduled pr-test on `main` is broken for consecutive runs), the project enters **CI Maintenance Mode** by opening [issue #21065](https://github.com/sgl-project/sglang/issues/21065). While active: +- All PR CI runs are paused. Resources are allocated to PRs that fix the CI. +- **Merging non-CI-fix PRs is prohibited.** Only PRs that fix the CI may be merged. In severe cases, merge permissions may be revoked. -This list is based on the current situation. If you or someone you know would like to donate machines for CI, they can serve as the CI oncalls for their machines. Please ping @Lianmin Zheng and @Ying Sheng in the Slack channel. They will start a nomination and internal review process. +Maintenance mode ends when `pr-test.yml` is all green on `main` and the issue is closed. ## Suspending Permissions -If the merge oncall bypasses checks to merge a PR that breaks the `main` branch, or if they repeatedly break the CI due to various reasons, their privileges will be suspended for at least three days, depending on the severity of the incident. +If a Merge Oncall bypasses checks to merge a PR that breaks the `main` branch, merges a non-CI-fix PR during CI Maintenance Mode, or repeatedly breaks the CI due to various reasons, their privileges will be suspended for at least two days, depending on the severity of the incident. diff --git a/.github/actions/check-maintenance/action.yml b/.github/actions/check-maintenance/action.yml index 94a0b20d5606..595283dcdfae 100644 --- a/.github/actions/check-maintenance/action.yml +++ b/.github/actions/check-maintenance/action.yml @@ -1,5 +1,5 @@ name: Check Maintenance Mode -description: Blocks CI when maintenance mode is active (issue #21065 is open), unless the PR has the bypass-maintenance label, or env SGLANG_PR_TEST_BYPASS_MAINTENANCE_ON_MAIN=true (PR Test workflow on main only). +description: Blocks CI when maintenance mode is active (issue #21065 is open), unless the PR has the bypass-maintenance label, or env SGLANG_PR_TEST_BYPASS_MAINTENANCE_ON_MAIN=true (PR Test workflow on main only). Merging non-CI-fix PRs is prohibited during maintenance mode; in severe cases, merge permissions may be revoked. inputs: github-token: @@ -46,10 +46,12 @@ runs: "## āš ļø CI Maintenance Mode is Active" \ "The CI infrastructure is currently under maintenance." \ "All PR CI runs are paused until maintenance is complete." \ + "**Merging non-CI-fix PRs is prohibited during maintenance mode.** In severe cases, merge permissions may be revoked." \ "You might also experience unexpected failures during this period." \ "The team is working on the issue and will update the status as soon as possible." \ "" \ "What should you do?" \ + "- **Do NOT merge non-CI-fix PRs** until maintenance mode is lifted" \ "- Check back later (~12 hours)" \ "- Follow CI Maintenance Mode issue: https://github.com/$REPO/issues/$MAINTENANCE_ISSUE for status updates") diff --git a/.github/audit_permission.py b/.github/audit_permission.py new file mode 100644 index 000000000000..35c19f9b56a1 --- /dev/null +++ b/.github/audit_permission.py @@ -0,0 +1,411 @@ +""" +Audit GitHub repository collaborators with elevated access. + +This script will: +1. Fetch all collaborators with write permission to this repo. +2. Show their github username, Nickname and the role (e.g., admin, maintain, + custom org role, write, triage). +3. Show their last activity related to this repo (last commit, last issue, + last pull request). Put the data in YYYY-MM-DD format. Add a column "last activity date" to the CSV, before the above three breakdown columns. +4. Show activity on other repos: repos touched via public events in the last 90 days (Push, PR, Issues, etc.). Sort the repos by the number of activities. +5. Write results to a CSV sorted by the roles (admin, maintain, custom org role, write, triage) and the last activity date (most recent first). + +Usage: + export GH_TOKEN="your_github_token" + python3 audit_permission.py [--output path] [--repo owner/name] + +Requires: requests, and a token with permission to list collaborators (push+ +access to the repo). +""" + +from __future__ import annotations + +import argparse +import csv +import os +import sys +import time +from collections import Counter +from datetime import datetime, timedelta, timezone +from typing import Any + +try: + import requests +except ImportError: + requests = None # type: ignore + +DEFAULT_OWNER = "sgl-project" +DEFAULT_NAME = "sglang" + +HEADERS: dict[str, str] = {} + + +def _request( + method: str, + url: str, + *, + params: dict[str, Any] | None = None, + max_retries: int = 3, +) -> requests.Response: + if requests is None: + raise RuntimeError("Install the requests package: pip install requests") + for attempt in range(max_retries): + r = requests.request(method, url, headers=HEADERS, params=params, timeout=60) + if r.status_code == 403 and "rate limit" in (r.text or "").lower(): + reset = r.headers.get("X-RateLimit-Reset") + wait = 60 + if reset: + try: + wait = max(1, int(reset) - int(time.time()) + 2) + except ValueError: + pass + print(f"Rate limited; sleeping {wait}s...", file=sys.stderr) + time.sleep(min(wait, 3600)) + continue + return r + return r + + +def paginate_list(url: str, params: dict[str, Any] | None = None) -> list[Any]: + out: list[Any] = [] + next_url: str | None = url + next_params = params + while next_url: + r = _request("GET", next_url, params=next_params) + next_params = None + if r.status_code != 200: + print( + f"Error {r.status_code} GET {next_url}: {r.text[:500]}", + file=sys.stderr, + ) + break + data = r.json() + if isinstance(data, list): + out.extend(data) + else: + break + next_url = None + link = r.headers.get("Link", "") + for part in link.split(", "): + if 'rel="next"' in part: + start = part.find("<") + 1 + end = part.find(">") + if start > 0 and end > start: + next_url = part[start:end] + break + return out + + +def collaborator_role(collab: dict[str, Any]) -> str: + role_name = collab.get("role_name") + if isinstance(role_name, str) and role_name.strip(): + return role_name.strip() + perms = collab.get("permissions") or {} + if perms.get("admin"): + return "admin" + if perms.get("maintain"): + return "maintain" + if perms.get("push"): + return "write" + if perms.get("triage"): + return "triage" + return "read" + + +def has_write_plus(collab: dict[str, Any]) -> bool: + perms = collab.get("permissions") or {} + return bool( + perms.get("admin") + or perms.get("maintain") + or perms.get("push") + or perms.get("triage") + ) + + +def role_sort_tier(collab: dict[str, Any]) -> int: + """Sort order: admin (0), maintain (1), custom org role (2), write (3), triage (4).""" + rn = collab.get("role_name") + if isinstance(rn, str) and rn.strip(): + k = rn.strip().lower() + if k == "admin": + return 0 + if k == "maintain": + return 1 + if k == "write": + return 3 + if k == "triage": + return 4 + if k == "read": + return 5 + return 2 + perms = collab.get("permissions") or {} + if perms.get("admin"): + return 0 + if perms.get("maintain"): + return 1 + if perms.get("push"): + return 3 + if perms.get("triage"): + return 4 + return 5 + + +def fetch_display_name(login: str) -> str: + url = f"https://api.github.com/users/{login}" + r = _request("GET", url) + if r.status_code != 200: + return "" + data = r.json() + if not isinstance(data, dict): + return "" + n = data.get("name") + return n.strip() if isinstance(n, str) else "" + + +def parse_github_ts(s: str) -> datetime | None: + if not s: + return None + s = s.replace("Z", "+00:00") + try: + return datetime.fromisoformat(s) + except ValueError: + return None + + +def iso_timestamp_to_ymd(iso: str | None) -> str: + if not iso: + return "" + p = parse_github_ts(iso) + if not p: + return "" + return p.date().isoformat() + + +def max_date_ymd(*iso_dates: str | None) -> str: + best: datetime | None = None + for d in iso_dates: + p = parse_github_ts(d or "") + if p and (best is None or p > best): + best = p + return best.date().isoformat() if best else "" + + +def parse_ymd(s: str) -> datetime | None: + if not s: + return None + try: + return datetime.strptime(s, "%Y-%m-%d").replace(tzinfo=timezone.utc) + except ValueError: + return None + + +def last_commit_date(owner: str, repo: str, login: str) -> str | None: + url = f"https://api.github.com/repos/{owner}/{repo}/commits" + r = _request("GET", url, params={"author": login, "per_page": 1}) + if r.status_code != 200: + return None + data = r.json() + if not isinstance(data, list) or not data: + return None + commit = data[0].get("commit") or {} + c = commit.get("committer") or commit.get("author") or {} + d = c.get("date") + return d if isinstance(d, str) else None + + +def search_repo_item( + owner: str, repo: str, login: str, kind: str +) -> dict[str, Any] | None: + q = f"repo:{owner}/{repo} is:{kind} author:{login}" + url = "https://api.github.com/search/issues" + r = _request( + "GET", + url, + params={"q": q, "sort": "updated", "order": "desc", "per_page": 1}, + ) + if r.status_code != 200: + return None + payload = r.json() + items = payload.get("items") + if not items: + return None + return items[0] if isinstance(items[0], dict) else None + + +def last_issue_pr_dates( + owner: str, repo: str, login: str +) -> tuple[str | None, str | None]: + issue = search_repo_item(owner, repo, login, "issue") + pr = search_repo_item(owner, repo, login, "pr") + issue_dt = None + pr_dt = None + if issue: + issue_dt = issue.get("updated_at") or issue.get("created_at") + if not isinstance(issue_dt, str): + issue_dt = None + if pr: + pr_dt = pr.get("updated_at") or pr.get("created_at") + if not isinstance(pr_dt, str): + pr_dt = None + return issue_dt, pr_dt + + +def other_repos_activity_column( + login: str, owner: str, repo: str, days: int = 90 +) -> str: + """Repos other than this one touched in the window, sorted by event count (desc).""" + cutoff = datetime.now(timezone.utc) - timedelta(days=days) + full = f"{owner}/{repo}" + counts: Counter[str] = Counter() + url: str | None = f"https://api.github.com/users/{login}/events/public" + params: dict[str, Any] = {"per_page": 100} + + while url: + r = _request("GET", url, params=params) + params = {} + if r.status_code != 200: + break + events = r.json() + if not isinstance(events, list): + break + oldest_in_page: datetime | None = None + for ev in events: + if not isinstance(ev, dict): + continue + created = parse_github_ts(ev.get("created_at") or "") + if created: + if oldest_in_page is None or created < oldest_in_page: + oldest_in_page = created + if created and created < cutoff: + continue + rinfo = ev.get("repo") + name = None + if isinstance(rinfo, dict): + name = rinfo.get("name") + if isinstance(name, str) and name and name != full: + counts[name] += 1 + next_url = None + link = r.headers.get("Link", "") + for part in link.split(", "): + if 'rel="next"' in part: + s, e = part.find("<") + 1, part.find(">") + if s > 0 and e > s: + next_url = part[s:e] + break + if oldest_in_page and oldest_in_page < cutoff: + break + url = next_url + if not events: + break + + ordered = sorted(counts.items(), key=lambda x: (-x[1], x[0])) + return ";".join(f"{n}:{c}" for n, c in ordered) + + +def main() -> None: + parser = argparse.ArgumentParser(description="Audit repo collaborator permissions.") + parser.add_argument( + "--repo", + default=f"{DEFAULT_OWNER}/{DEFAULT_NAME}", + help=f"owner/name (default: {DEFAULT_OWNER}/{DEFAULT_NAME})", + ) + parser.add_argument( + "--output", + "-o", + default=os.path.join(os.path.dirname(__file__), "permission_audit.csv"), + help="Output CSV path", + ) + parser.add_argument( + "--events-days", + type=int, + default=90, + help="Window for other-repo activity via public events", + ) + args = parser.parse_args() + + if "/" not in args.repo: + print("Error: --repo must be owner/name", file=sys.stderr) + sys.exit(1) + owner, name = args.repo.split("/", 1) + + gh_token = os.getenv("GH_TOKEN") + if not gh_token: + print("Error: GH_TOKEN environment variable is not set.", file=sys.stderr) + sys.exit(1) + + global HEADERS + HEADERS = { + "Authorization": f"Bearer {gh_token}", + "Accept": "application/vnd.github+json", + "X-GitHub-Api-Version": "2022-11-28", + } + + collab_url = f"https://api.github.com/repos/{owner}/{name}/collaborators" + print(f"Fetching collaborators for {owner}/{name}...", file=sys.stderr) + collaborators = paginate_list( + collab_url, params={"per_page": 100, "affiliation": "all"} + ) + + rows: list[dict[str, Any]] = [] + elevated = [c for c in collaborators if isinstance(c, dict) and has_write_plus(c)] + print( + f"Found {len(elevated)} collaborators with admin/maintain/write/triage.", + file=sys.stderr, + ) + + for i, col in enumerate(elevated, start=1): + login = col.get("login") + if not isinstance(login, str): + continue + print(f" [{i}/{len(elevated)}] {login}", file=sys.stderr) + + role = collaborator_role(col) + nickname = fetch_display_name(login) + cd = last_commit_date(owner, name, login) + issue_dt, pr_dt = last_issue_pr_dates(owner, name, login) + last_act_ymd = max_date_ymd(cd, issue_dt, pr_dt) + others = other_repos_activity_column(login, owner, name, days=args.events_days) + rows.append( + { + "_role_tier": role_sort_tier(col), + "github_username": login, + "nickname": nickname, + "role": role, + "last_activity_date": last_act_ymd, + "last_commit_date": iso_timestamp_to_ymd(cd), + "last_issue_date": iso_timestamp_to_ymd(issue_dt), + "last_pr_date": iso_timestamp_to_ymd(pr_dt), + "other_repos_90d": others, + } + ) + + def sort_key(r: dict[str, Any]) -> tuple[int, float]: + tier = r["_role_tier"] + act = parse_ymd(r.get("last_activity_date") or "") + ts = act.timestamp() if act else 0.0 + return (tier, -ts) + + rows.sort(key=sort_key) + + fieldnames = [ + "github_username", + "nickname", + "role", + "last_activity_date", + "last_commit_date", + "last_issue_date", + "last_pr_date", + "other_repos_90d", + ] + for r in rows: + del r["_role_tier"] + with open(args.output, "w", newline="", encoding="utf-8") as f: + w = csv.DictWriter(f, fieldnames=fieldnames) + w.writeheader() + w.writerows(rows) + + print(f"Wrote {len(rows)} rows to {args.output}", file=sys.stderr) + + +if __name__ == "__main__": + main() diff --git a/.github/pull_request_template.md b/.github/pull_request_template.md index 45db320d57df..a2338baf30d9 100644 --- a/.github/pull_request_template.md +++ b/.github/pull_request_template.md @@ -12,7 +12,7 @@ -## Benchmarking and Profiling +## Speed Tests and Profiling @@ -24,10 +24,10 @@ - [ ] Provide accuracy and speed benchmark results according to [Test the accuracy](https://docs.sglang.io/developer_guide/contribution_guide.html#test-the-accuracy) and [Benchmark the speed](https://docs.sglang.io/developer_guide/contribution_guide.html#benchmark-the-speed). - [ ] Follow the SGLang code style [guidance](https://docs.sglang.io/developer_guide/contribution_guide.html#code-style-guidance). -## Review Process +## Review and Merge Process -1. Ping Merge Oncalls to start the PR flow. See the [PR Merge Process](https://github.com/sgl-project/sglang/blob/main/.github/MAINTAINER.md#pull-request-merge-process). +1. Ping Merge Oncalls to start the process. See the [PR Merge Process](https://github.com/sgl-project/sglang/blob/main/.github/MAINTAINER.md#pull-request-merge-process). 2. Get approvals from [CODEOWNERS](https://github.com/sgl-project/sglang/blob/main/.github/CODEOWNERS) and other reviewers. 3. Trigger CI tests with [comments](https://docs.sglang.io/developer_guide/contribution_guide.html#how-to-trigger-ci-tests) or contact authorized users to do so. - - `/tag-run-ci-label`, `/rerun-failed-ci`, `/tag-and-rerun-ci` -4. After green CI and required approvals, ask Merge Oncalls to merge. + - Common commands include `/tag-and-rerun-ci`, `/tag-run-ci-label`, `/rerun-failed-ci` +4. After green CI and required approvals, ask Merge Oncalls or people with Write permission to merge the PR. diff --git a/.github/update_ci_permission.py b/.github/update_ci_permission.py index bbf695149022..106532ede44d 100644 --- a/.github/update_ci_permission.py +++ b/.github/update_ci_permission.py @@ -22,7 +22,7 @@ Permissions are assigned according to the following rules: -1. Add the top 50 contributors from the last 90 days with full permissions, no cooldown, and the reason "top contributor". +1. Add the top 50 contributors from the last 120 days with full permissions, no cooldown, and the reason "top contributor". 2. Load all users from the existing `CI_PERMISSIONS.json` file and update their entries as follows: - If a user is already covered by rule 1, skip that user. - If the old reason of a user is "top contributor" but they are not in the current top contributors list, change their configuration to: @@ -117,7 +117,7 @@ def get_write_access_users(): return writers -def get_top_contributors(days=90, limit=50): +def get_top_contributors(days, limit): """Fetches top contributors based on commit count in the last N days.""" print(f"Fetching commits from the last {days} days...") since_date = (datetime.now(timezone.utc) - timedelta(days=days)).isoformat() @@ -132,7 +132,7 @@ def get_top_contributors(days=90, limit=50): author_counts[commit["author"]["login"]] += 1 top_users = [user for user, _ in author_counts.most_common(limit)] - print(f"Found {len(top_users)} active contributors in the last {days} days.") + print(f"Found {len(top_users)} top contributors in the last {days} days.") return set(top_users) @@ -193,7 +193,7 @@ def main(): print(f"Warning: Could not fetch collaborators (check token scope). Error: {e}") write_access_users = set() - top_contributors = get_top_contributors(days=90, limit=50) + top_contributors = get_top_contributors(days=120, limit=50) old_permissions = load_existing_permissions() new_permissions = {} @@ -203,6 +203,7 @@ def main(): new_permissions[user] = { "can_tag_run_ci_label": True, "can_rerun_failed_ci": True, + "can_rerun_stage": True, "cooldown_interval_minutes": 0, "reason": "top contributor", } @@ -220,6 +221,7 @@ def main(): new_permissions[user] = { "can_tag_run_ci_label": True, "can_rerun_failed_ci": True, + "can_rerun_stage": True, "cooldown_interval_minutes": 60, "reason": "custom override", } diff --git a/.github/workflows/amd-ci-job-monitor.yml b/.github/workflows/amd-ci-job-monitor.yml index 87a1954705ce..cbb8798b110a 100644 --- a/.github/workflows/amd-ci-job-monitor.yml +++ b/.github/workflows/amd-ci-job-monitor.yml @@ -20,13 +20,54 @@ on: type: string jobs: + fetch-actions-data: + name: Fetch Actions Snapshot + runs-on: ubuntu-latest + env: + GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: '3.10' + + - name: Install dependencies + run: pip install tabulate + + - name: Select workflows for snapshot + id: select-workflows + run: | + if [[ -n "${{ inputs.job_filter }}" ]]; then + echo "workflows=pr-test-amd.yml" >> "$GITHUB_OUTPUT" + else + echo "workflows=pr-test-amd.yml,nightly-test-amd.yml,pr-test-amd-rocm720.yml,nightly-test-amd-rocm720.yml" >> "$GITHUB_OUTPUT" + fi + + - name: Fetch Actions data snapshot + timeout-minutes: 30 + run: | + python scripts/ci/utils/query_job_status.py \ + --repo ${{ github.repository }} \ + --workflow "${{ steps.select-workflows.outputs.workflows }}" \ + --hours ${{ inputs.hours || '24' }} \ + --dump-data-file actions-job-snapshot.json + + - name: Upload Actions data snapshot + uses: actions/upload-artifact@v4 + with: + name: actions-job-snapshot + path: actions-job-snapshot.json + if-no-files-found: error + # Single job filter mode custom-report: name: Custom Job Report if: ${{ inputs.job_filter }} + needs: fetch-actions-data runs-on: ubuntu-latest - env: - GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} steps: - name: Checkout code uses: actions/checkout@v4 @@ -39,6 +80,12 @@ jobs: - name: Install dependencies run: pip install tabulate + - name: Download Actions data snapshot + uses: actions/download-artifact@v4 + with: + name: actions-job-snapshot + path: ci-data + - name: Generate Custom Job Report timeout-minutes: 30 run: | @@ -47,6 +94,7 @@ jobs: --job "${{ inputs.job_filter }}" \ --workflow "pr-test-amd.yml" \ --hours ${{ inputs.hours || '24' }} \ + --input-data-file ci-data/actions-job-snapshot.json \ --summary # Parse workflow files to get job names dynamically @@ -57,6 +105,8 @@ jobs: outputs: pr_jobs: ${{ steps.parse.outputs.pr_jobs }} nightly_jobs: ${{ steps.parse.outputs.nightly_jobs }} + pr_rocm720_jobs: ${{ steps.parse.outputs.pr_rocm720_jobs }} + nightly_rocm720_jobs: ${{ steps.parse.outputs.nightly_rocm720_jobs }} steps: - name: Checkout code uses: actions/checkout@v4 @@ -80,18 +130,32 @@ jobs: echo "nightly_jobs=$nightly_jobs" >> $GITHUB_OUTPUT echo "Nightly jobs: $nightly_jobs" + # Parse pr-test-amd-rocm720.yml (exclude utility jobs) + # Excluded: call-gate, check-changes, pr-test-amd-finish, cancel, check-all-jobs + pr_rocm720_jobs=$(yq -r '.jobs | keys | .[]' .github/workflows/pr-test-amd-rocm720.yml | \ + grep -v -E '^(call-gate|check-changes|pr-test-amd-finish|cancel|check-all-jobs)$' | \ + jq -R -s -c 'split("\n") | map(select(length > 0))') + echo "pr_rocm720_jobs=$pr_rocm720_jobs" >> $GITHUB_OUTPUT + echo "PR ROCm 7.2 jobs: $pr_rocm720_jobs" + + # Parse nightly-test-amd-rocm720.yml (exclude utility jobs) + # Excluded: check-all-jobs + nightly_rocm720_jobs=$(yq -r '.jobs | keys | .[]' .github/workflows/nightly-test-amd-rocm720.yml | \ + grep -v -E '^(check-all-jobs)$' | \ + jq -R -s -c 'split("\n") | map(select(length > 0))') + echo "nightly_rocm720_jobs=$nightly_rocm720_jobs" >> $GITHUB_OUTPUT + echo "Nightly ROCm 7.2 jobs: $nightly_rocm720_jobs" + # PR CI reports using dynamic matrix pr-ci-reports: name: PR - ${{ matrix.job_name }} - needs: parse-workflows + needs: [parse-workflows, fetch-actions-data] if: ${{ !inputs.job_filter }} runs-on: ubuntu-latest strategy: fail-fast: false matrix: job_name: ${{ fromJson(needs.parse-workflows.outputs.pr_jobs) }} - env: - GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} steps: - name: Checkout code uses: actions/checkout@v4 @@ -104,6 +168,12 @@ jobs: - name: Install dependencies run: pip install tabulate + - name: Download Actions data snapshot + uses: actions/download-artifact@v4 + with: + name: actions-job-snapshot + path: ci-data + - name: Generate Report timeout-minutes: 15 run: | @@ -112,20 +182,19 @@ jobs: --job "${{ matrix.job_name }}" \ --workflow "pr-test-amd.yml" \ --hours ${{ inputs.hours || '24' }} \ + --input-data-file ci-data/actions-job-snapshot.json \ --summary # Nightly AMD test reports using dynamic matrix nightly-reports: name: Nightly - ${{ matrix.job_name }} - needs: parse-workflows + needs: [parse-workflows, fetch-actions-data] if: ${{ !inputs.job_filter }} runs-on: ubuntu-latest strategy: fail-fast: false matrix: job_name: ${{ fromJson(needs.parse-workflows.outputs.nightly_jobs) }} - env: - GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} steps: - name: Checkout code uses: actions/checkout@v4 @@ -138,6 +207,12 @@ jobs: - name: Install dependencies run: pip install tabulate + - name: Download Actions data snapshot + uses: actions/download-artifact@v4 + with: + name: actions-job-snapshot + path: ci-data + - name: Generate Nightly Report timeout-minutes: 15 run: | @@ -146,4 +221,118 @@ jobs: --job "${{ matrix.job_name }}" \ --workflow "nightly-test-amd.yml" \ --hours ${{ inputs.hours || '24' }} \ + --input-data-file ci-data/actions-job-snapshot.json \ + --summary + + # PR ROCm 7.2 CI reports using dynamic matrix + pr-rocm720-ci-reports: + name: PR ROCm720 - ${{ matrix.job_name }} + needs: [parse-workflows, fetch-actions-data] + if: ${{ !inputs.job_filter }} + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + job_name: ${{ fromJson(needs.parse-workflows.outputs.pr_rocm720_jobs) }} + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: '3.10' + + - name: Install dependencies + run: pip install tabulate + + - name: Download Actions data snapshot + uses: actions/download-artifact@v4 + with: + name: actions-job-snapshot + path: ci-data + + - name: Generate PR ROCm 7.2 Report + timeout-minutes: 15 + run: | + python scripts/ci/utils/query_job_status.py \ + --repo ${{ github.repository }} \ + --job "${{ matrix.job_name }}" \ + --workflow "pr-test-amd-rocm720.yml" \ + --hours ${{ inputs.hours || '24' }} \ + --input-data-file ci-data/actions-job-snapshot.json \ + --summary + + # Nightly ROCm 7.2 reports using dynamic matrix + nightly-rocm720-reports: + name: Nightly ROCm720 - ${{ matrix.job_name }} + needs: [parse-workflows, fetch-actions-data] + if: ${{ !inputs.job_filter }} + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + job_name: ${{ fromJson(needs.parse-workflows.outputs.nightly_rocm720_jobs) }} + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: '3.10' + + - name: Install dependencies + run: pip install tabulate + + - name: Download Actions data snapshot + uses: actions/download-artifact@v4 + with: + name: actions-job-snapshot + path: ci-data + + - name: Generate Nightly ROCm 7.2 Report + timeout-minutes: 15 + run: | + python scripts/ci/utils/query_job_status.py \ + --repo ${{ github.repository }} \ + --job "${{ matrix.job_name }}" \ + --workflow "nightly-test-amd-rocm720.yml" \ + --hours ${{ inputs.hours || '24' }} \ + --input-data-file ci-data/actions-job-snapshot.json \ + --summary + + # Runner fleet report - cross-workflow runner analytics in a single pass + runner-fleet-report: + name: Runner Fleet Report + if: ${{ !inputs.job_filter }} + needs: fetch-actions-data + runs-on: ubuntu-latest + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: '3.10' + + - name: Install dependencies + run: pip install tabulate + + - name: Download Actions data snapshot + uses: actions/download-artifact@v4 + with: + name: actions-job-snapshot + path: ci-data + + - name: Generate Runner Fleet Report + timeout-minutes: 30 + run: | + python scripts/ci/utils/query_job_status.py \ + --repo ${{ github.repository }} \ + --runner-report \ + --workflow "pr-test-amd.yml,nightly-test-amd.yml,pr-test-amd-rocm720.yml,nightly-test-amd-rocm720.yml" \ + --hours ${{ inputs.hours || '24' }} \ + --input-data-file ci-data/actions-job-snapshot.json \ --summary diff --git a/.github/workflows/auto-tune.yml b/.github/workflows/auto-tune.yml index 0afc79bb7c8c..16ad5d23b177 100644 --- a/.github/workflows/auto-tune.yml +++ b/.github/workflows/auto-tune.yml @@ -4,7 +4,7 @@ on: workflow_dispatch: jobs: - lint: + auto-tune-lint: runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 diff --git a/.github/workflows/bot-bump-kernel-version-to-sglang.yml b/.github/workflows/bot-bump-kernel-version-to-sglang.yml index 817889846a8d..b26192aba1ac 100644 --- a/.github/workflows/bot-bump-kernel-version-to-sglang.yml +++ b/.github/workflows/bot-bump-kernel-version-to-sglang.yml @@ -58,43 +58,3 @@ jobs: GH_TOKEN: ${{ secrets.GH_PAT_FOR_PULL_REQUEST }} run: | bash scripts/release/commit_and_pr_kernel_to_sglang.sh "$KERNEL_VERSION" "$BRANCH_NAME" - - run-nightly-tests-nvidia: - needs: bump-kernel-version-to-sglang - if: needs.bump-kernel-version-to-sglang.outputs.needs_sync == 'true' - uses: ./.github/workflows/nightly-test-nvidia.yml - with: - ref: ${{ needs.bump-kernel-version-to-sglang.outputs.branch_name }} - secrets: inherit - - run-nightly-tests-amd: - needs: bump-kernel-version-to-sglang - if: needs.bump-kernel-version-to-sglang.outputs.needs_sync == 'true' - uses: ./.github/workflows/nightly-test-amd.yml - with: - ref: ${{ needs.bump-kernel-version-to-sglang.outputs.branch_name }} - secrets: inherit - - run-nightly-tests-npu: - needs: bump-kernel-version-to-sglang - if: needs.bump-kernel-version-to-sglang.outputs.needs_sync == 'true' - uses: ./.github/workflows/nightly-test-npu.yml - with: - ref: ${{ needs.bump-kernel-version-to-sglang.outputs.branch_name }} - secrets: inherit - - run-pr-tests-xeon: - needs: bump-kernel-version-to-sglang - if: needs.bump-kernel-version-to-sglang.outputs.needs_sync == 'true' - uses: ./.github/workflows/pr-test-xeon.yml - with: - ref: ${{ needs.bump-kernel-version-to-sglang.outputs.branch_name }} - secrets: inherit - - run-pr-tests-xpu: - needs: bump-kernel-version-to-sglang - if: needs.bump-kernel-version-to-sglang.outputs.needs_sync == 'true' - uses: ./.github/workflows/pr-test-xpu.yml - with: - ref: ${{ needs.bump-kernel-version-to-sglang.outputs.branch_name }} - secrets: inherit diff --git a/.github/workflows/full-test-npu.yml b/.github/workflows/full-test-npu.yml new file mode 100644 index 000000000000..47355f2c2233 --- /dev/null +++ b/.github/workflows/full-test-npu.yml @@ -0,0 +1,355 @@ +name: Full Test (NPU) + +on: +# pull_request: +# branches: +# - main +# paths: +# - ".github/workflows/full-test-npu.yml" + workflow_dispatch: + inputs: + ref: + description: 'Git ref (branch, tag, or SHA) to test. If not provided, uses the default branch.' + required: false + type: string + default: '' + job_filter: + description: 'Select which job to run (leave empty or "all" to run all jobs)' + required: false + type: string + default: 'all' + image_a3: + description: 'The a3 running docker image of the test task.' + required: false + type: string + default: 'swr.cn-southwest-2.myhuaweicloud.com/base_image/ascend-ci/cann:8.5.0-a3-ubuntu22.04-py3.11' + skip_install_flag: + description: 'Indicates whether to skip the installation of sglang, defaulting to false.' + required: false + type: string + default: 'false' + +concurrency: + group: full-test-npu-${{ inputs.ref || github.ref }} + cancel-in-progress: ${{ github.event_name != 'workflow_call' }} + +jobs: + set-image-config: + runs-on: ubuntu-latest + outputs: + ref: ${{ steps.set-vars.outputs.ref }} + job_filter: ${{ steps.set-vars.outputs.job_filter }} + image_a3: ${{ steps.set-vars.outputs.image_a3 }} + skip_install_flag: ${{ steps.set-vars.outputs.skip_install_flag }} + steps: + # When triggered by PR, no inputs parameters are used. The latest community code is tested by default. + - name: Set image config + id: set-vars + run: | + if [ -z "${{ inputs.ref }}" ]; then + echo "ref=" >> $GITHUB_OUTPUT + else + echo "ref=${{ inputs.ref }}" >> $GITHUB_OUTPUT + fi + + if [ -z "${{ inputs.job_filter }}" ]; then + echo "job_filter=all" >> $GITHUB_OUTPUT + else + echo "job_filter=${{ inputs.job_filter }}" >> $GITHUB_OUTPUT + fi + + if [ -z "${{ inputs.image_a3 }}" ]; then + echo "image_a3=swr.cn-southwest-2.myhuaweicloud.com/base_image/ascend-ci/cann:8.5.0-a3-ubuntu22.04-py3.11" >> $GITHUB_OUTPUT + else + echo "image_a3=${{ inputs.image_a3 }}" >> $GITHUB_OUTPUT + fi + + if [ -z "${{ inputs.skip_install_flag }}" ]; then + echo "skip_install_flag=false" >> $GITHUB_OUTPUT + else + echo "skip_install_flag=${{ inputs.skip_install_flag }}" >> $GITHUB_OUTPUT + fi + + nighly-test-npu: + needs: [set-image-config] + name: nightly-test-npu + if: ${{ (github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') }} + uses: ./.github/workflows/nightly-test-npu.yml + with: + ref: ${{ needs.set-image-config.outputs.ref }} + job_filter: ${{ needs.set-image-config.outputs.job_filter }} + image_a3: ${{ needs.set-image-config.outputs.image_a3 }} + skip_install_flag: ${{ needs.set-image-config.outputs.skip_install_flag }} + secrets: inherit + + full-1-npu-a3: + needs: [set-image-config] + if: ${{ (github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') }} + runs-on: linux-aarch64-a3-2 + container: + image: ${{ needs.set-image-config.outputs.image_a3 }} + steps: + - name: Checkout code + uses: actions/checkout@v4 + with: + ref: ${{ needs.set-image-config.outputs.ref || github.ref }} + + - name: Install dependencies + env: + TORCH_CACHE_URL: "http://cache-service.nginx-pypi-cache.svc.cluster.local/whl/cpu" + PYPI_CACHE_URL: "http://cache-service.nginx-pypi-cache.svc.cluster.local/pypi/simple" + GITHUB_PROXY_URL: "https://gh-proxy.test.osinfra.cn/" + run: | + # speed up by using infra cache services + CACHING_URL="cache-service.nginx-pypi-cache.svc.cluster.local" + sed -Ei "s@(ports|archive).ubuntu.com@${CACHING_URL}:8081@g" /etc/apt/sources.list + pip config set global.index-url http://${CACHING_URL}/pypi/simple + pip config set global.trusted-host "${CACHING_URL}" + + if [ ${{ needs.set-image-config.outputs.skip_install_flag }} != "true" ];then + bash scripts/ci/npu/npu_ci_install_dependency.sh a3 + fi + + # copy required file from our daily cache + cp ~/.cache/modelscope/hub/datasets/otavia/ShareGPT_Vicuna_unfiltered/ShareGPT_V3_unfiltered_cleaned_split.json /tmp + # copy gsm8k dataset + cp ~/.cache/modelscope/hub/datasets/tmp/test.jsonl /tmp + + - name: Print Log Information + run: | + bash scripts/ci/npu/npu_log_print.sh + + - name: Run test + timeout-minutes: 240 + env: + SGLANG_USE_MODELSCOPE: true + SGLANG_IS_IN_CI: true + HF_ENDPOINT: https://hf-mirror.com + TORCH_EXTENSIONS_DIR: /tmp/torch_extensions + PYTORCH_NPU_ALLOC_CONF: "expandable_segments:True" + STREAMS_PER_DEVICE: 32 + run: | + pip install sglang_router + hf download lmms-lab/MMMU --repo-type dataset + pip install sentence_transformers torchaudio==2.8.0 + pip install protobuf==6.31.1 zss pre-commit wandb>=0.16.0 tenacity==8.3.0 loguru openpyxl latex2sympy2 zstandard transformers-stream-generator tqdm-multiprocess pycocoevalcap + pip install yt-dlp sentencepiece==0.1.99 nltk av ftfy sqlitedict==2.1.0 sacrebleu>=1.5.0 pytablewriter black==24.1.0 isort==5.13.2 peft>=0.2.0 accelerate>=0.29.1 + pip install jsonlines httpx==0.25.0 evaluate>=0.4.0 datasets==2.16.1 numexpr xgrammar==0.1.25 numpy==1.26.4 dotenv + git clone --branch v0.3.3 --depth 1 https://github.com/EvolvingLMMs-Lab/lmms-eval.git + cd ./lmms-eval + nohup pip install . > lmmslog.txt 2>&1 & + sleep 120 + export PYTHONPATH=$PYTHONPATH:$(pwd) + cd ../ + cd test + python3 run_suite.py --hw npu --suite full-1-npu-a3 --nightly --continue-on-error --timeout-per-file 3600 + + full-2-npu-a3: + needs: [set-image-config] + if: ${{ (github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') }} + runs-on: linux-aarch64-a3-2 + container: + image: ${{ needs.set-image-config.outputs.image_a3 }} + steps: + - name: Checkout code + uses: actions/checkout@v4 + with: + ref: ${{ needs.set-image-config.outputs.ref || github.ref }} + + - name: Install dependencies + env: + TORCH_CACHE_URL: "http://cache-service.nginx-pypi-cache.svc.cluster.local/whl/cpu" + PYPI_CACHE_URL: "http://cache-service.nginx-pypi-cache.svc.cluster.local/pypi/simple" + GITHUB_PROXY_URL: "https://gh-proxy.test.osinfra.cn/" + run: | + # speed up by using infra cache services + CACHING_URL="cache-service.nginx-pypi-cache.svc.cluster.local" + sed -Ei "s@(ports|archive).ubuntu.com@${CACHING_URL}:8081@g" /etc/apt/sources.list + pip config set global.index-url http://${CACHING_URL}/pypi/simple + pip config set global.trusted-host "${CACHING_URL}" + + if [ ${{ needs.set-image-config.outputs.skip_install_flag }} != "true" ];then + bash scripts/ci/npu/npu_ci_install_dependency.sh a3 + fi + + # copy required file from our daily cache + cp ~/.cache/modelscope/hub/datasets/otavia/ShareGPT_Vicuna_unfiltered/ShareGPT_V3_unfiltered_cleaned_split.json /tmp + # copy gsm8k dataset + cp ~/.cache/modelscope/hub/datasets/tmp/test.jsonl /tmp + + - name: Print Log Information + run: | + bash scripts/ci/npu/npu_log_print.sh + + - name: Run test + timeout-minutes: 240 + env: + SGLANG_USE_MODELSCOPE: true + SGLANG_IS_IN_CI: true + HF_ENDPOINT: https://hf-mirror.com + TORCH_EXTENSIONS_DIR: /tmp/torch_extensions + PYTORCH_NPU_ALLOC_CONF: "expandable_segments:True" + STREAMS_PER_DEVICE: 32 + run: | + pip install sglang_router + hf download lmms-lab/MMMU --repo-type dataset + pip install sentence_transformers torchaudio==2.8.0 + pip install protobuf==6.31.1 zss pre-commit wandb>=0.16.0 tenacity==8.3.0 loguru openpyxl latex2sympy2 zstandard transformers-stream-generator tqdm-multiprocess pycocoevalcap + pip install yt-dlp sentencepiece==0.1.99 nltk av ftfy sqlitedict==2.1.0 sacrebleu>=1.5.0 pytablewriter black==24.1.0 isort==5.13.2 peft>=0.2.0 accelerate>=0.29.1 + pip install jsonlines httpx==0.25.0 evaluate>=0.4.0 datasets==2.16.1 numexpr xgrammar==0.1.25 numpy==1.26.4 dotenv + git clone --branch v0.3.3 --depth 1 https://github.com/EvolvingLMMs-Lab/lmms-eval.git + cd ./lmms-eval + nohup pip install . > lmmslog.txt 2>&1 & + sleep 120 + export PYTHONPATH=$PYTHONPATH:$(pwd) + cd ../ + cd test + python3 run_suite.py --hw npu --suite full-2-npu-a3 --nightly --continue-on-error --timeout-per-file 3600 + + full-4-npu-a3: + needs: [set-image-config] + if: ${{ (github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') }} + runs-on: linux-aarch64-a3-4 + container: + image: ${{ needs.set-image-config.outputs.image_a3 }} + steps: + - name: Checkout code + uses: actions/checkout@v4 + with: + ref: ${{ needs.set-image-config.outputs.ref || github.ref }} + + - name: Install dependencies + env: + TORCH_CACHE_URL: "http://cache-service.nginx-pypi-cache.svc.cluster.local/whl/cpu" + PYPI_CACHE_URL: "http://cache-service.nginx-pypi-cache.svc.cluster.local/pypi/simple" + GITHUB_PROXY_URL: "https://gh-proxy.test.osinfra.cn/" + run: | + # speed up by using infra cache services + CACHING_URL="cache-service.nginx-pypi-cache.svc.cluster.local" + sed -Ei "s@(ports|archive).ubuntu.com@${CACHING_URL}:8081@g" /etc/apt/sources.list + pip config set global.index-url http://${CACHING_URL}/pypi/simple + pip config set global.trusted-host "${CACHING_URL}" + + if [ ${{ needs.set-image-config.outputs.skip_install_flag }} != "true" ];then + bash scripts/ci/npu/npu_ci_install_dependency.sh a3 + fi + + # copy required file from our daily cache + cp ~/.cache/modelscope/hub/datasets/otavia/ShareGPT_Vicuna_unfiltered/ShareGPT_V3_unfiltered_cleaned_split.json /tmp + # copy gsm8k dataset + cp ~/.cache/modelscope/hub/datasets/tmp/test.jsonl /tmp + + - name: Print Log Information + run: | + bash scripts/ci/npu/npu_log_print.sh + + - name: Run test + timeout-minutes: 240 + env: + SGLANG_USE_MODELSCOPE: true + SGLANG_IS_IN_CI: true + HF_ENDPOINT: https://hf-mirror.com + TORCH_EXTENSIONS_DIR: /tmp/torch_extensions + PYTORCH_NPU_ALLOC_CONF: "expandable_segments:True" + STREAMS_PER_DEVICE: 32 + run: | + pip install sglang_router + hf download lmms-lab/MMMU --repo-type dataset + pip install sentence_transformers torchaudio==2.8.0 + pip install protobuf==6.31.1 zss pre-commit wandb>=0.16.0 tenacity==8.3.0 loguru openpyxl latex2sympy2 zstandard transformers-stream-generator tqdm-multiprocess pycocoevalcap + pip install yt-dlp sentencepiece==0.1.99 nltk av ftfy sqlitedict==2.1.0 sacrebleu>=1.5.0 pytablewriter black==24.1.0 isort==5.13.2 peft>=0.2.0 accelerate>=0.29.1 + pip install jsonlines httpx==0.25.0 evaluate>=0.4.0 datasets==2.16.1 numexpr xgrammar==0.1.25 numpy==1.26.4 dotenv + git clone --branch v0.3.3 --depth 1 https://github.com/EvolvingLMMs-Lab/lmms-eval.git + cd ./lmms-eval + nohup pip install . > lmmslog.txt 2>&1 & + sleep 120 + export PYTHONPATH=$PYTHONPATH:$(pwd) + cd ../ + cd test + python3 run_suite.py --hw npu --suite full-4-npu-a3 --nightly --continue-on-error --timeout-per-file 3600 + + full-16-npu-a3: + needs: [set-image-config] + if: ${{ (github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') }} + runs-on: linux-aarch64-a3-16 + container: + image: ${{ needs.set-image-config.outputs.image_a3 }} + steps: + - name: Checkout code + uses: actions/checkout@v4 + with: + ref: ${{ needs.set-image-config.outputs.ref || github.ref }} + + - name: Install dependencies + env: + TORCH_CACHE_URL: "http://cache-service.nginx-pypi-cache.svc.cluster.local/whl/cpu" + PYPI_CACHE_URL: "http://cache-service.nginx-pypi-cache.svc.cluster.local/pypi/simple" + GITHUB_PROXY_URL: "https://gh-proxy.test.osinfra.cn/" + run: | + # speed up by using infra cache services + CACHING_URL="cache-service.nginx-pypi-cache.svc.cluster.local" + sed -Ei "s@(ports|archive).ubuntu.com@${CACHING_URL}:8081@g" /etc/apt/sources.list + pip config set global.index-url http://${CACHING_URL}/pypi/simple + pip config set global.trusted-host "${CACHING_URL}" + + if [ ${{ needs.set-image-config.outputs.skip_install_flag }} != "true" ];then + bash scripts/ci/npu/npu_ci_install_dependency.sh a3 + fi + + # copy required file from our daily cache + cp ~/.cache/modelscope/hub/datasets/otavia/ShareGPT_Vicuna_unfiltered/ShareGPT_V3_unfiltered_cleaned_split.json /tmp + # copy gsm8k dataset + cp ~/.cache/modelscope/hub/datasets/tmp/test.jsonl /tmp + + - name: Print Log Information + run: | + bash scripts/ci/npu/npu_log_print.sh + + - name: Run test + timeout-minutes: 240 + env: + SGLANG_USE_MODELSCOPE: true + SGLANG_IS_IN_CI: true + HF_ENDPOINT: https://hf-mirror.com + TORCH_EXTENSIONS_DIR: /tmp/torch_extensions + PYTORCH_NPU_ALLOC_CONF: "expandable_segments:True" + STREAMS_PER_DEVICE: 32 + run: | + pip install sglang_router + hf download lmms-lab/MMMU --repo-type dataset + pip install sentence_transformers torchaudio==2.8.0 + pip install protobuf==6.31.1 zss pre-commit wandb>=0.16.0 tenacity==8.3.0 loguru openpyxl latex2sympy2 zstandard transformers-stream-generator tqdm-multiprocess pycocoevalcap + pip install yt-dlp sentencepiece==0.1.99 nltk av ftfy sqlitedict==2.1.0 sacrebleu>=1.5.0 pytablewriter black==24.1.0 isort==5.13.2 peft>=0.2.0 accelerate>=0.29.1 + pip install jsonlines httpx==0.25.0 evaluate>=0.4.0 datasets==2.16.1 numexpr xgrammar==0.1.25 numpy==1.26.4 dotenv + git clone --branch v0.3.3 --depth 1 https://github.com/EvolvingLMMs-Lab/lmms-eval.git + cd ./lmms-eval + nohup pip install . > lmmslog.txt 2>&1 & + sleep 120 + export PYTHONPATH=$PYTHONPATH:$(pwd) + cd ../ + cd test + python3 run_suite.py --hw npu --suite full-16-npu-a3 --nightly --continue-on-error --timeout-per-file 3600 + + check-all-jobs: + if: github.repository == 'sgl-project/sglang' && always() + needs: + - nighly-test-npu + - full-1-npu-a3 + - full-2-npu-a3 + - full-4-npu-a3 + - full-16-npu-a3 + runs-on: ubuntu-latest + container: + image: docker.m.daocloud.io/ubuntu:22.04 + steps: + - name: Check if any job failed + run: | + if [[ "${{ contains(needs.*.result, 'failure') }}" == "true" ]]; then + echo "One or more nightly test jobs failed" + exit 1 + fi + if [[ "${{ contains(needs.*.result, 'cancelled') }}" == "true" ]]; then + echo "One or more nightly test jobs were cancelled" + exit 1 + fi + echo "All nightly test jobs passed" diff --git a/.github/workflows/nightly-test-amd-rocm720.yml b/.github/workflows/nightly-test-amd-rocm720.yml index d38a4f10f7fc..14929952ebd6 100644 --- a/.github/workflows/nightly-test-amd-rocm720.yml +++ b/.github/workflows/nightly-test-amd-rocm720.yml @@ -61,6 +61,7 @@ on: - nightly-8-gpu-mi35x-qwen3-235b-mxfp4-rocm720 - nightly-8-gpu-mi35x-qwen35-rocm720 - nightly-8-gpu-mi35x-glm5-rocm720 + - nightly-8-gpu-mi35x-glm47-fp8-rocm720 - nightly-8-gpu-mi35x-minimax-m25-rocm720 job_filter: description: 'Or type comma-separated job names (overrides dropdown if non-empty)' @@ -684,7 +685,7 @@ jobs: echo "$(> $GITHUB_STEP_SUMMARY || true exit ${TEST_EXIT_CODE:-0} - # 8-GPU MiniMax-M2.5 (Accuracy) ROCm 7.2 + # 8-GPU MiniMax-M2.5 (Accuracy + Performance combined) ROCm 7.2 nightly-8-gpu-minimax-m25-rocm720: if: (github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') && (!(inputs.job_filter || inputs.job_select) || (inputs.job_filter || inputs.job_select) == 'all' || contains(format(',{0},', inputs.job_filter || inputs.job_select), ',nightly-8-gpu-minimax-m25-rocm720,')) runs-on: linux-mi325-8gpu-sglang @@ -715,6 +716,18 @@ jobs: echo "$(> $GITHUB_STEP_SUMMARY || true exit ${TEST_EXIT_CODE:-0} + - name: Performance Test ROCm 7.2 (8-GPU MiniMax-M2.5) + timeout-minutes: 120 + continue-on-error: true # Perf test failure doesn't fail the job if accuracy passed + run: | + > github_summary.md # Clear summary file + bash scripts/ci/amd/amd_ci_exec.sh -w /sglang-checkout/test \ + -e SGLANG_USE_AITER=1 \ + -e GITHUB_STEP_SUMMARY="/sglang-checkout/github_summary.md" \ + python3 run_suite.py --hw amd --suite nightly-perf-8-gpu-minimax-m25 --nightly --timeout-per-file 5400 ${{ inputs.continue_on_error && '--continue-on-error' || '' }} || TEST_EXIT_CODE=$? + echo "$(> $GITHUB_STEP_SUMMARY || true + exit ${TEST_EXIT_CODE:-0} + # ============================================== MI30x ROCm 7.2 Diffusion Tests ============================================== # 1-GPU Z-Image-Turbo (Diffusion T2I) ROCm 7.2 nightly-1-gpu-zimage-turbo-rocm720: @@ -1272,7 +1285,40 @@ jobs: echo "$(> $GITHUB_STEP_SUMMARY || true exit ${TEST_EXIT_CODE:-0} - # MI35x 8-GPU MiniMax-M2.5 (Accuracy) ROCm 7.2 + # MI35x 8-GPU GLM-4.7-FP8 (Accuracy) ROCm 7.2 + nightly-8-gpu-mi35x-glm47-fp8-rocm720: + if: (github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') && (!(inputs.job_filter || inputs.job_select) || (inputs.job_filter || inputs.job_select) == 'all' || contains(format(',{0},', inputs.job_filter || inputs.job_select), ',nightly-8-gpu-mi35x-glm47-fp8-rocm720,')) + runs-on: linux-mi35x-gpu-8 + steps: + - name: Checkout code + uses: actions/checkout@v4 + with: + ref: ${{ inputs.ref || github.ref }} + + - name: Setup docker (ROCm 7.2) + run: | + touch github_summary.md + bash scripts/ci/amd/amd_ci_start_container.sh --rocm-version rocm720 + env: + GITHUB_WORKSPACE: ${{ github.workspace }} + + - name: Install dependencies + run: | + bash scripts/ci/amd/amd_ci_install_dependency.sh + # Install tabulate for run_suite.py (missing in MI35x container) + bash scripts/ci/amd/amd_ci_exec.sh pip install tabulate + + - name: Accuracy Test MI35x ROCm 7.2 (8-GPU GLM-4.7-FP8) + timeout-minutes: 120 + run: | + > github_summary.md # Clear summary file + bash scripts/ci/amd/amd_ci_exec.sh -w /sglang-checkout/test \ + -e GITHUB_STEP_SUMMARY="/sglang-checkout/github_summary.md" \ + python3 run_suite.py --hw amd --suite nightly-amd-8-gpu-mi35x-glm47-fp8 --nightly --timeout-per-file 3600 ${{ inputs.continue_on_error && '--continue-on-error' || '' }} || TEST_EXIT_CODE=$? + echo "$(> $GITHUB_STEP_SUMMARY || true + exit ${TEST_EXIT_CODE:-0} + + # MI35x 8-GPU MiniMax-M2.5 (Accuracy + Performance combined) ROCm 7.2 nightly-8-gpu-mi35x-minimax-m25-rocm720: if: (github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') && (!(inputs.job_filter || inputs.job_select) || (inputs.job_filter || inputs.job_select) == 'all' || contains(format(',{0},', inputs.job_filter || inputs.job_select), ',nightly-8-gpu-mi35x-minimax-m25-rocm720,')) runs-on: linux-mi35x-gpu-8 @@ -1305,6 +1351,18 @@ jobs: echo "$(> $GITHUB_STEP_SUMMARY || true exit ${TEST_EXIT_CODE:-0} + - name: Performance Test MI35x ROCm 7.2 (8-GPU MiniMax-M2.5) + timeout-minutes: 120 + continue-on-error: true # Perf test failure doesn't fail the job if accuracy passed + run: | + > github_summary.md # Clear summary file + bash scripts/ci/amd/amd_ci_exec.sh -w /sglang-checkout/test \ + -e SGLANG_USE_AITER=1 \ + -e GITHUB_STEP_SUMMARY="/sglang-checkout/github_summary.md" \ + python3 run_suite.py --hw amd --suite nightly-perf-8-gpu-mi35x-minimax-m25 --nightly --timeout-per-file 5400 ${{ inputs.continue_on_error && '--continue-on-error' || '' }} || TEST_EXIT_CODE=$? + echo "$(> $GITHUB_STEP_SUMMARY || true + exit ${TEST_EXIT_CODE:-0} + # MI35x 8-GPU DeepSeek-V3.2 Performance Test (MTP) ROCm 7.2 nightly-perf-8-gpu-mi35x-deepseek-v32-mtp-rocm720: if: (github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') && (!(inputs.job_filter || inputs.job_select) || (inputs.job_filter || inputs.job_select) == 'all' || contains(format(',{0},', inputs.job_filter || inputs.job_select), ',nightly-perf-8-gpu-mi35x-deepseek-v32-mtp-rocm720,')) @@ -1382,6 +1440,7 @@ jobs: - nightly-8-gpu-mi35x-qwen3-235b-mxfp4-rocm720 - nightly-8-gpu-mi35x-qwen35-rocm720 - nightly-8-gpu-mi35x-glm5-rocm720 + - nightly-8-gpu-mi35x-glm47-fp8-rocm720 - nightly-8-gpu-mi35x-minimax-m25-rocm720 runs-on: ubuntu-latest steps: diff --git a/.github/workflows/nightly-test-amd.yml b/.github/workflows/nightly-test-amd.yml index 5443df894df5..64cca74d7e0f 100644 --- a/.github/workflows/nightly-test-amd.yml +++ b/.github/workflows/nightly-test-amd.yml @@ -687,7 +687,7 @@ jobs: echo "$(> $GITHUB_STEP_SUMMARY || true exit ${TEST_EXIT_CODE:-0} - # 8-GPU MiniMax-M2.5 (Accuracy) + # 8-GPU MiniMax-M2.5 (Accuracy + Performance combined) nightly-8-gpu-minimax-m25: if: (github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') && (!(inputs.job_filter || inputs.job_select) || (inputs.job_filter || inputs.job_select) == 'all' || contains(format(',{0},', inputs.job_filter || inputs.job_select), ',nightly-8-gpu-minimax-m25,')) runs-on: linux-mi325-8gpu-sglang @@ -718,6 +718,18 @@ jobs: echo "$(> $GITHUB_STEP_SUMMARY || true exit ${TEST_EXIT_CODE:-0} + - name: Performance Test (8-GPU MiniMax-M2.5) + timeout-minutes: 120 + continue-on-error: true # Perf test failure doesn't fail the job if accuracy passed + run: | + > github_summary.md # Clear summary file + bash scripts/ci/amd/amd_ci_exec.sh -w /sglang-checkout/test \ + -e SGLANG_USE_AITER=1 \ + -e GITHUB_STEP_SUMMARY="/sglang-checkout/github_summary.md" \ + python3 run_suite.py --hw amd --suite nightly-perf-8-gpu-minimax-m25 --nightly --timeout-per-file 5400 ${{ inputs.continue_on_error && '--continue-on-error' || '' }} || TEST_EXIT_CODE=$? + echo "$(> $GITHUB_STEP_SUMMARY || true + exit ${TEST_EXIT_CODE:-0} + # ============================================== MI30x Diffusion Tests ============================================== # 1-GPU Z-Image-Turbo (Diffusion T2I) nightly-1-gpu-zimage-turbo: @@ -1278,7 +1290,7 @@ jobs: echo "$(> $GITHUB_STEP_SUMMARY || true exit ${TEST_EXIT_CODE:-0} - # MI35x 8-GPU MiniMax-M2.5 (Accuracy) + # MI35x 8-GPU MiniMax-M2.5 (Accuracy + Performance combined) nightly-8-gpu-mi35x-minimax-m25: if: (github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') && (!(inputs.job_filter || inputs.job_select) || (inputs.job_filter || inputs.job_select) == 'all' || contains(format(',{0},', inputs.job_filter || inputs.job_select), ',nightly-8-gpu-mi35x-minimax-m25,')) runs-on: linux-mi35x-gpu-8 @@ -1311,6 +1323,18 @@ jobs: echo "$(> $GITHUB_STEP_SUMMARY || true exit ${TEST_EXIT_CODE:-0} + - name: Performance Test MI35x (8-GPU MiniMax-M2.5) + timeout-minutes: 120 + continue-on-error: true # Perf test failure doesn't fail the job if accuracy passed + run: | + > github_summary.md # Clear summary file + bash scripts/ci/amd/amd_ci_exec.sh -w /sglang-checkout/test \ + -e SGLANG_USE_AITER=1 \ + -e GITHUB_STEP_SUMMARY="/sglang-checkout/github_summary.md" \ + python3 run_suite.py --hw amd --suite nightly-perf-8-gpu-mi35x-minimax-m25 --nightly --timeout-per-file 5400 ${{ inputs.continue_on_error && '--continue-on-error' || '' }} || TEST_EXIT_CODE=$? + echo "$(> $GITHUB_STEP_SUMMARY || true + exit ${TEST_EXIT_CODE:-0} + # MI35x 8-GPU DeepSeek-V3.2 Performance Test (MTP) nightly-perf-8-gpu-mi35x-deepseek-v32-mtp: if: (github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') && (!(inputs.job_filter || inputs.job_select) || (inputs.job_filter || inputs.job_select) == 'all' || contains(format(',{0},', inputs.job_filter || inputs.job_select), ',nightly-perf-8-gpu-mi35x-deepseek-v32-mtp,')) diff --git a/.github/workflows/nightly-test-npu.yml b/.github/workflows/nightly-test-npu.yml index fa19ab1a42b6..7503d9a05099 100644 --- a/.github/workflows/nightly-test-npu.yml +++ b/.github/workflows/nightly-test-npu.yml @@ -2,7 +2,7 @@ name: Nightly Test (NPU) on: schedule: - - cron: '0 17 * * *' # Execute at 1:00 a.m. Beijing Time every day + - cron: '0 18 * * *' # Execute at 2:00 a.m. Beijing Time every day pull_request: branches: - main @@ -21,13 +21,61 @@ on: required: false type: string default: 'all' + image_a3: + description: 'The a3 running docker image of the test task.' + required: false + type: string + default: 'swr.cn-southwest-2.myhuaweicloud.com/base_image/ascend-ci/cann:8.5.0-a3-ubuntu22.04-py3.11' + skip_install_flag: + description: 'Indicates whether to skip the installation of sglang, defaulting to false.' + required: false + type: string + default: 'false' + concurrency: group: nightly-test-npu-${{ inputs.ref || github.ref }} cancel-in-progress: ${{ github.event_name != 'workflow_call' }} jobs: + set-image-config: + runs-on: ubuntu-latest + outputs: + ref: ${{ steps.set-vars.outputs.ref }} + job_filter: ${{ steps.set-vars.outputs.job_filter }} + image_a3: ${{ steps.set-vars.outputs.image_a3 }} + skip_install_flag: ${{ steps.set-vars.outputs.skip_install_flag }} + steps: + # When triggered by PR, no inputs parameters are used. The latest community code is tested by default. + - name: Set image config + id: set-vars + run: | + if [ -z "${{ inputs.ref }}" ]; then + echo "ref=" >> $GITHUB_OUTPUT + else + echo "ref=${{ inputs.ref }}" >> $GITHUB_OUTPUT + fi + + if [ -z "${{ inputs.job_filter }}" ]; then + echo "job_filter=all" >> $GITHUB_OUTPUT + else + echo "job_filter=${{ inputs.job_filter }}" >> $GITHUB_OUTPUT + fi + + if [ -z "${{ inputs.image_a3 }}" ]; then + echo "image_a3=swr.cn-southwest-2.myhuaweicloud.com/base_image/ascend-ci/cann:8.5.0-a3-ubuntu22.04-py3.11" >> $GITHUB_OUTPUT + else + echo "image_a3=${{ inputs.image_a3 }}" >> $GITHUB_OUTPUT + fi + + if [ -z "${{ inputs.skip_install_flag }}" ]; then + echo "skip_install_flag=false" >> $GITHUB_OUTPUT + else + echo "skip_install_flag=${{ inputs.skip_install_flag }}" >> $GITHUB_OUTPUT + fi + nightly-1-npu-a3: + needs: [set-image-config] if: ${{ (github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') }} runs-on: linux-aarch64-a3-2 strategy: @@ -35,26 +83,33 @@ jobs: matrix: part: [0, 1] container: - image: swr.cn-southwest-2.myhuaweicloud.com/base_image/ascend-ci/cann:8.5.0-a3-ubuntu22.04-py3.11 + image: ${{ needs.set-image-config.outputs.image_a3 }} steps: - name: Checkout code uses: actions/checkout@v4 with: - ref: ${{ inputs.ref || github.ref }} + ref: ${{ needs.set-image-config.outputs.ref || github.ref }} - name: Install dependencies + env: + TORCH_CACHE_URL: "http://cache-service.nginx-pypi-cache.svc.cluster.local/whl/cpu" + PYPI_CACHE_URL: "http://cache-service.nginx-pypi-cache.svc.cluster.local/pypi/simple" + GITHUB_PROXY_URL: "https://gh-proxy.test.osinfra.cn/" run: | # speed up by using infra cache services CACHING_URL="cache-service.nginx-pypi-cache.svc.cluster.local" sed -Ei "s@(ports|archive).ubuntu.com@${CACHING_URL}:8081@g" /etc/apt/sources.list pip config set global.index-url http://${CACHING_URL}/pypi/simple - pip config set global.extra-index-url "https://pypi.tuna.tsinghua.edu.cn/simple" - pip config set global.trusted-host "${CACHING_URL} pypi.tuna.tsinghua.edu.cn" - bash scripts/ci/npu/npu_ci_install_dependency.sh a3 + pip config set global.trusted-host "${CACHING_URL}" + + if [ ${{ needs.set-image-config.outputs.skip_install_flag }} != "true" ];then + bash scripts/ci/npu/npu_ci_install_dependency.sh a3 + fi + # copy required file from our daily cache cp ~/.cache/modelscope/hub/datasets/otavia/ShareGPT_Vicuna_unfiltered/ShareGPT_V3_unfiltered_cleaned_split.json /tmp - # copy download through proxy - curl -o /tmp/test.jsonl -L https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl + # copy gsm8k dataset + cp ~/.cache/modelscope/hub/datasets/tmp/test.jsonl /tmp - name: Print Log Information run: | @@ -86,6 +141,7 @@ jobs: python3 run_suite.py --hw npu --suite nightly-1-npu-a3 --nightly --continue-on-error --timeout-per-file 3600 --auto-partition-id ${{ matrix.part }} --auto-partition-size 2 nightly-2-npu-a3: + needs: [set-image-config] if: ${{ (github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') }} runs-on: linux-aarch64-a3-2 strategy: @@ -93,26 +149,33 @@ jobs: matrix: part: [0] container: - image: swr.cn-southwest-2.myhuaweicloud.com/base_image/ascend-ci/cann:8.5.0-a3-ubuntu22.04-py3.11 + image: ${{ needs.set-image-config.outputs.image_a3 }} steps: - name: Checkout code uses: actions/checkout@v4 with: - ref: ${{ inputs.ref || github.ref }} + ref: ${{ needs.set-image-config.outputs.ref || github.ref }} - name: Install dependencies + env: + TORCH_CACHE_URL: "http://cache-service.nginx-pypi-cache.svc.cluster.local/whl/cpu" + PYPI_CACHE_URL: "http://cache-service.nginx-pypi-cache.svc.cluster.local/pypi/simple" + GITHUB_PROXY_URL: "https://gh-proxy.test.osinfra.cn/" run: | # speed up by using infra cache services CACHING_URL="cache-service.nginx-pypi-cache.svc.cluster.local" sed -Ei "s@(ports|archive).ubuntu.com@${CACHING_URL}:8081@g" /etc/apt/sources.list pip config set global.index-url http://${CACHING_URL}/pypi/simple - pip config set global.extra-index-url "https://pypi.tuna.tsinghua.edu.cn/simple" - pip config set global.trusted-host "${CACHING_URL} pypi.tuna.tsinghua.edu.cn" - bash scripts/ci/npu/npu_ci_install_dependency.sh a3 + pip config set global.trusted-host "${CACHING_URL}" + + if [ ${{ needs.set-image-config.outputs.skip_install_flag }} != "true" ];then + bash scripts/ci/npu/npu_ci_install_dependency.sh a3 + fi + # copy required file from our daily cache cp ~/.cache/modelscope/hub/datasets/otavia/ShareGPT_Vicuna_unfiltered/ShareGPT_V3_unfiltered_cleaned_split.json /tmp - # copy download through proxy - curl -o /tmp/test.jsonl -L https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl + # copy gsm8k dataset + cp ~/.cache/modelscope/hub/datasets/tmp/test.jsonl /tmp - name: Print Log Information run: | @@ -143,6 +206,7 @@ jobs: python3 run_suite.py --hw npu --suite nightly-2-npu-a3 --nightly --continue-on-error --timeout-per-file 3600 --auto-partition-id ${{ matrix.part }} --auto-partition-size 1 nightly-4-npu-a3: + needs: [set-image-config] if: ${{ (github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') }} runs-on: linux-aarch64-a3-4 strategy: @@ -150,25 +214,33 @@ jobs: matrix: part: [0] container: - image: swr.cn-southwest-2.myhuaweicloud.com/base_image/ascend-ci/cann:8.5.0-a3-ubuntu22.04-py3.11 + image: ${{ needs.set-image-config.outputs.image_a3 }} steps: - name: Checkout code uses: actions/checkout@v4 with: - ref: ${{ inputs.ref || github.ref }} + ref: ${{ needs.set-image-config.outputs.ref|| github.ref }} - name: Install dependencies + env: + TORCH_CACHE_URL: "http://cache-service.nginx-pypi-cache.svc.cluster.local/whl/cpu" + PYPI_CACHE_URL: "http://cache-service.nginx-pypi-cache.svc.cluster.local/pypi/simple" + GITHUB_PROXY_URL: "https://gh-proxy.test.osinfra.cn/" run: | # speed up by using infra cache services CACHING_URL="cache-service.nginx-pypi-cache.svc.cluster.local" sed -Ei "s@(ports|archive).ubuntu.com@${CACHING_URL}:8081@g" /etc/apt/sources.list - pip config set global.extra-index-url "https://pypi.tuna.tsinghua.edu.cn/simple" - pip config set global.trusted-host "${CACHING_URL} pypi.tuna.tsinghua.edu.cn" - bash scripts/ci/npu/npu_ci_install_dependency.sh a3 + pip config set global.index-url http://${CACHING_URL}/pypi/simple + pip config set global.trusted-host "${CACHING_URL}" + + if [ ${{ needs.set-image-config.outputs.skip_install_flag }} != "true" ];then + bash scripts/ci/npu/npu_ci_install_dependency.sh a3 + fi + # copy required file from our daily cache cp ~/.cache/modelscope/hub/datasets/otavia/ShareGPT_Vicuna_unfiltered/ShareGPT_V3_unfiltered_cleaned_split.json /tmp - # copy download through proxy - curl -o /tmp/test.jsonl -L https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl + # copy gsm8k dataset + cp ~/.cache/modelscope/hub/datasets/tmp/test.jsonl /tmp - name: Print Log Information run: | @@ -200,6 +272,7 @@ jobs: python3 run_suite.py --hw npu --suite nightly-4-npu-a3 --nightly --continue-on-error --timeout-per-file 3600 --auto-partition-id ${{ matrix.part }} --auto-partition-size 1 nightly-8-npu-a3: + needs: [set-image-config] if: ${{ (github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') }} runs-on: linux-aarch64-a3-8 strategy: @@ -207,26 +280,33 @@ jobs: matrix: part: [0] container: - image: swr.cn-southwest-2.myhuaweicloud.com/base_image/ascend-ci/cann:8.5.0-a3-ubuntu22.04-py3.11 + image: ${{ needs.set-image-config.outputs.image_a3 }} steps: - name: Checkout code uses: actions/checkout@v4 with: - ref: ${{ inputs.ref || github.ref }} + ref: ${{ needs.set-image-config.outputs.ref || github.ref }} - name: Install dependencies + env: + TORCH_CACHE_URL: "http://cache-service.nginx-pypi-cache.svc.cluster.local/whl/cpu" + PYPI_CACHE_URL: "http://cache-service.nginx-pypi-cache.svc.cluster.local/pypi/simple" + GITHUB_PROXY_URL: "https://gh-proxy.test.osinfra.cn/" run: | # speed up by using infra cache services CACHING_URL="cache-service.nginx-pypi-cache.svc.cluster.local" sed -Ei "s@(ports|archive).ubuntu.com@${CACHING_URL}:8081@g" /etc/apt/sources.list pip config set global.index-url http://${CACHING_URL}/pypi/simple - pip config set global.extra-index-url "https://pypi.tuna.tsinghua.edu.cn/simple" - pip config set global.trusted-host "${CACHING_URL} pypi.tuna.tsinghua.edu.cn" - bash scripts/ci/npu/npu_ci_install_dependency.sh a3 + pip config set global.trusted-host "${CACHING_URL}" + + if [ ${{ needs.set-image-config.outputs.skip_install_flag }} != "true" ];then + bash scripts/ci/npu/npu_ci_install_dependency.sh a3 + fi + # copy required file from our daily cache cp ~/.cache/modelscope/hub/datasets/otavia/ShareGPT_Vicuna_unfiltered/ShareGPT_V3_unfiltered_cleaned_split.json /tmp - # copy download through proxy - curl -o /tmp/test.jsonl -L https://gh-proxy.test.osinfra.cn/https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl + # copy gsm8k dataset + cp ~/.cache/modelscope/hub/datasets/tmp/test.jsonl /tmp - name: Print Log Information run: | @@ -258,6 +338,7 @@ jobs: python3 run_suite.py --hw npu --suite nightly-8-npu-a3 --nightly --continue-on-error --timeout-per-file 3600 --auto-partition-id ${{ matrix.part }} --auto-partition-size 1 nightly-16-npu-a3: + needs: [set-image-config] if: ${{ (github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') }} runs-on: linux-aarch64-a3-16 strategy: @@ -265,26 +346,33 @@ jobs: matrix: part: [0, 1] container: - image: swr.cn-southwest-2.myhuaweicloud.com/base_image/ascend-ci/cann:8.5.0-a3-ubuntu22.04-py3.11 + image: ${{ needs.set-image-config.outputs.image_a3 }} steps: - name: Checkout code uses: actions/checkout@v4 with: - ref: ${{ inputs.ref || github.ref }} + ref: ${{ needs.set-image-config.outputs.ref || github.ref }} - name: Install dependencies + env: + TORCH_CACHE_URL: "http://cache-service.nginx-pypi-cache.svc.cluster.local/whl/cpu" + PYPI_CACHE_URL: "http://cache-service.nginx-pypi-cache.svc.cluster.local/pypi/simple" + GITHUB_PROXY_URL: "https://gh-proxy.test.osinfra.cn/" run: | # speed up by using infra cache services CACHING_URL="cache-service.nginx-pypi-cache.svc.cluster.local" sed -Ei "s@(ports|archive).ubuntu.com@${CACHING_URL}:8081@g" /etc/apt/sources.list pip config set global.index-url http://${CACHING_URL}/pypi/simple - pip config set global.extra-index-url "https://pypi.tuna.tsinghua.edu.cn/simple" - pip config set global.trusted-host "${CACHING_URL} pypi.tuna.tsinghua.edu.cn" - bash scripts/ci/npu/npu_ci_install_dependency.sh a3 + pip config set global.trusted-host "${CACHING_URL}" + + if [ ${{ needs.set-image-config.outputs.skip_install_flag }} != "true" ];then + bash scripts/ci/npu/npu_ci_install_dependency.sh a3 + fi + # copy required file from our daily cache cp ~/.cache/modelscope/hub/datasets/otavia/ShareGPT_Vicuna_unfiltered/ShareGPT_V3_unfiltered_cleaned_split.json /tmp - # copy download through proxy - curl -o /tmp/test.jsonl -L https://gh-proxy.test.osinfra.cn/https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl + # copy gsm8k dataset + cp ~/.cache/modelscope/hub/datasets/tmp/test.jsonl /tmp - name: Print Log Information run: | diff --git a/.github/workflows/nightly-test-nvidia.yml b/.github/workflows/nightly-test-nvidia.yml index 45afc714c966..e99d873c2753 100644 --- a/.github/workflows/nightly-test-nvidia.yml +++ b/.github/workflows/nightly-test-nvidia.yml @@ -510,7 +510,7 @@ jobs: GITHUB_RUN_ID: ${{ github.run_id }} GPU_CONFIG: "1-gpu-h100" - timeout-minutes: 60 + timeout-minutes: 90 run: | cd python python3 sglang/multimodal_gen/test/run_suite.py \ @@ -568,7 +568,7 @@ jobs: GITHUB_RUN_ID: ${{ github.run_id }} GPU_CONFIG: "2-gpu-h100" - timeout-minutes: 60 + timeout-minutes: 90 run: | cd python python3 sglang/multimodal_gen/test/run_suite.py \ diff --git a/.github/workflows/pr-test-amd-rocm720.yml b/.github/workflows/pr-test-amd-rocm720.yml index 9842758100a2..24fb80ed3aa7 100644 --- a/.github/workflows/pr-test-amd-rocm720.yml +++ b/.github/workflows/pr-test-amd-rocm720.yml @@ -139,21 +139,21 @@ jobs: with: filters: | main_package: - - "python/sglang/!(multimodal_gen)/**" + - "python/sglang/!(multimodal_gen)/**/!(*.md)" - "python/pyproject_rocm.toml" - "python/pyproject_other.toml" - "scripts/ci/amd/*" - "scripts/ci/utils/*" - - "test/**" + - "test/**/!(*.md)" - ".github/workflows/pr-test-amd-rocm720.yml" sgl_kernel: - - "sgl-kernel/**" + - "sgl-kernel/**/*.!(md|txt)" - ".github/workflows/pr-test-amd-rocm720.yml" jit_kernel: - "python/sglang/jit_kernel/**" - ".github/workflows/pr-test-amd-rocm720.yml" multimodal_gen: - - "python/sglang/multimodal_gen/**" + - "python/sglang/multimodal_gen/**/*.!(md|ipynb)" - "python/sglang/cli/**" - "python/sglang/jit_kernel/diffusion/**" - "python/sglang/jit_kernel/tests/diffusion/**" diff --git a/.github/workflows/pr-test-amd.yml b/.github/workflows/pr-test-amd.yml index eba5078d18e3..2afa3cd3716a 100644 --- a/.github/workflows/pr-test-amd.yml +++ b/.github/workflows/pr-test-amd.yml @@ -137,21 +137,21 @@ jobs: with: filters: | main_package: - - "python/sglang/!(multimodal_gen)/**" + - "python/sglang/!(multimodal_gen)/**/!(*.md)" - "python/pyproject_rocm.toml" - "python/pyproject_other.toml" - "scripts/ci/amd/*" - "scripts/ci/utils/*" - - "test/**" + - "test/**/!(*.md)" - ".github/workflows/pr-test-amd.yml" sgl_kernel: - - "sgl-kernel/**" + - "sgl-kernel/**/*.!(md|txt)" - ".github/workflows/pr-test-amd.yml" jit_kernel: - "python/sglang/jit_kernel/**" - ".github/workflows/pr-test-amd.yml" multimodal_gen: - - "python/sglang/multimodal_gen/**" + - "python/sglang/multimodal_gen/**/*.!(md|ipynb)" - "python/sglang/cli/**" - "python/sglang/jit_kernel/diffusion/**" - "python/sglang/jit_kernel/tests/diffusion/**" diff --git a/.github/workflows/pr-test-multimodal-gen.yml b/.github/workflows/pr-test-multimodal-gen.yml index 5cdc72275e1f..a8705d24e74b 100644 --- a/.github/workflows/pr-test-multimodal-gen.yml +++ b/.github/workflows/pr-test-multimodal-gen.yml @@ -9,6 +9,9 @@ on: sgl_kernel: required: true type: string + b200_runner: + required: true + type: string continue_on_error: required: false type: string @@ -156,6 +159,154 @@ jobs: with: artifact-suffix: ${{ matrix.part }} + multimodal-gen-component-accuracy-1-gpu: + if: | + (inputs.target_stage == 'multimodal-gen-component-accuracy-1-gpu') || + ( + !inputs.target_stage && + ((github.event_name == 'schedule' || inputs.test_parallel_dispatch == 'true') || (inputs.caller_needs_failure != 'true' && !cancelled())) && + inputs.multimodal_gen == 'true' + ) + runs-on: 1-gpu-h100 + timeout-minutes: 240 + steps: + - name: Checkout code + uses: actions/checkout@v4 + with: + ref: ${{ inputs.pr_head_sha || inputs.git_ref || github.sha }} + + - uses: ./.github/actions/check-stage-health + + - uses: ./.github/actions/check-maintenance + + - name: Download artifacts + if: inputs.sgl_kernel == 'true' + uses: actions/download-artifact@v4 + with: + path: sgl-kernel/dist/ + merge-multiple: true + pattern: wheel-python3.10-cuda12.9 + + - name: Install dependencies + timeout-minutes: 20 + run: | + CUSTOM_BUILD_SGL_KERNEL=${{inputs.sgl_kernel}} bash scripts/ci/cuda/ci_install_dependency.sh diffusion + + - name: Run diffusion component accuracy tests (1-GPU) + timeout-minutes: 240 + env: + RUNAI_STREAMER_MEMORY_LIMIT: 0 + run: | + cd python + if [ "${{ inputs.continue_on_error }}" = "true" ]; then + exit_code=0 + python3 -m pytest -s -v sglang/multimodal_gen/test/server/test_accuracy_1_gpu_a.py || exit_code=$? + python3 -m pytest -s -v sglang/multimodal_gen/test/server/test_accuracy_1_gpu_b.py || exit_code=$? + exit $exit_code + fi + python3 -m pytest -s -v sglang/multimodal_gen/test/server/test_accuracy_1_gpu_a.py + python3 -m pytest -s -v sglang/multimodal_gen/test/server/test_accuracy_1_gpu_b.py + + - uses: ./.github/actions/upload-cuda-coredumps + if: always() + + multimodal-gen-component-accuracy-2-gpu: + if: | + (inputs.target_stage == 'multimodal-gen-component-accuracy-2-gpu') || + ( + !inputs.target_stage && + ((github.event_name == 'schedule' || inputs.test_parallel_dispatch == 'true') || (inputs.caller_needs_failure != 'true' && !cancelled())) && + inputs.multimodal_gen == 'true' + ) + runs-on: 2-gpu-h100 + timeout-minutes: 240 + steps: + - name: Checkout code + uses: actions/checkout@v4 + with: + ref: ${{ inputs.pr_head_sha || inputs.git_ref || github.sha }} + + - uses: ./.github/actions/check-stage-health + + - uses: ./.github/actions/check-maintenance + + - name: Download artifacts + if: inputs.sgl_kernel == 'true' + uses: actions/download-artifact@v4 + with: + path: sgl-kernel/dist/ + merge-multiple: true + pattern: wheel-python3.10-cuda12.9 + + - name: Install dependencies + timeout-minutes: 20 + run: | + CUSTOM_BUILD_SGL_KERNEL=${{inputs.sgl_kernel}} bash scripts/ci/cuda/ci_install_dependency.sh diffusion + + - name: Run diffusion component accuracy tests (2-GPU) + timeout-minutes: 240 + env: + RUNAI_STREAMER_MEMORY_LIMIT: 0 + run: | + cd python + if [ "${{ inputs.continue_on_error }}" = "true" ]; then + exit_code=0 + torchrun --nproc_per_node=2 -m pytest -s -v sglang/multimodal_gen/test/server/test_accuracy_2_gpu_a.py || exit_code=$? + torchrun --nproc_per_node=2 -m pytest -s -v sglang/multimodal_gen/test/server/test_accuracy_2_gpu_b.py || exit_code=$? + exit $exit_code + fi + torchrun --nproc_per_node=2 -m pytest -s -v sglang/multimodal_gen/test/server/test_accuracy_2_gpu_a.py + torchrun --nproc_per_node=2 -m pytest -s -v sglang/multimodal_gen/test/server/test_accuracy_2_gpu_b.py + + - uses: ./.github/actions/upload-cuda-coredumps + if: always() + + multimodal-gen-test-1-b200: + if: | + (inputs.target_stage == 'multimodal-gen-test-1-b200') || + ( + !inputs.target_stage && + ((github.event_name == 'schedule' || inputs.test_parallel_dispatch == 'true') || (inputs.caller_needs_failure != 'true' && !cancelled())) && + inputs.multimodal_gen == 'true' + ) + runs-on: ${{ inputs.b200_runner }} + timeout-minutes: 240 + steps: + - name: Checkout code + uses: actions/checkout@v4 + with: + ref: ${{ inputs.pr_head_sha || inputs.git_ref || github.sha }} + + + - uses: ./.github/actions/check-maintenance + + - name: Download artifacts + if: inputs.sgl_kernel == 'true' + uses: actions/download-artifact@v4 + with: + path: sgl-kernel/dist/ + merge-multiple: true + pattern: wheel-python3.10-cuda12.9 + + - name: Install dependencies + timeout-minutes: 20 + run: | + CUSTOM_BUILD_SGL_KERNEL=${{inputs.sgl_kernel}} bash scripts/ci/cuda/ci_install_dependency.sh diffusion + + - name: Run diffusion server tests + timeout-minutes: 240 + env: + RUNAI_STREAMER_MEMORY_LIMIT: 0 + CONTINUE_ON_ERROR_FLAG: ${{ inputs.continue_on_error == 'true' && '--continue-on-error' || '' }} + run: | + cd python + python3 sglang/multimodal_gen/test/run_suite.py \ + --suite 1-gpu-b200 \ + $CONTINUE_ON_ERROR_FLAG + + - uses: ./.github/actions/upload-cuda-coredumps + if: always() + multimodal-gen-unit-test: if: | (inputs.target_stage == 'multimodal-gen-unit-test') || diff --git a/.github/workflows/pr-test-npu.yml b/.github/workflows/pr-test-npu.yml index 38b49f234713..c1e6ac5ab057 100644 --- a/.github/workflows/pr-test-npu.yml +++ b/.github/workflows/pr-test-npu.yml @@ -57,13 +57,13 @@ jobs: with: filters: | main_package: - - "python/sglang/!(multimodal_gen)/**" + - "python/sglang/!(multimodal_gen)/**/!(*.md)" - "python/pyproject_npu.toml" - "scripts/ci/npu/npu_ci_install_dependency.sh" - "test/srt/ascend/**" - ".github/workflows/pr-test-npu.yml" multimodal_gen: - - "python/sglang/multimodal_gen/**" + - "python/sglang/multimodal_gen/**/*.!(md|ipynb)" - "python/sglang/srt/**" - "python/pyproject_npu.toml" - "scripts/ci/npu/npu_ci_install_dependency.sh" @@ -76,7 +76,7 @@ jobs: uses: ./.github/workflows/pr-gate.yml secrets: inherit - per-commit-1-npu-a2: + stage-b-test-1-npu-a2: needs: [check-changes, pr-gate] if: needs.check-changes.outputs.main_package == 'true' runs-on: linux-aarch64-a2-1 @@ -111,21 +111,8 @@ jobs: bash scripts/ci/npu/npu_ci_install_dependency.sh 910b # copy required file from our daily cache cp ~/.cache/modelscope/hub/datasets/otavia/ShareGPT_Vicuna_unfiltered/ShareGPT_V3_unfiltered_cleaned_split.json /tmp - # copy download through proxy - curl -o /tmp/test.jsonl -L https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl - - - name: Run registered test - timeout-minutes: 240 - env: - SGLANG_USE_MODELSCOPE: true - SGLANG_IS_IN_CI: true - HF_ENDPOINT: https://hf-mirror.com - TORCH_EXTENSIONS_DIR: /tmp/torch_extensions - PYTORCH_NPU_ALLOC_CONF: "expandable_segments:True" - STREAMS_PER_DEVICE: 32 - run: | - cd test - python3 run_suite.py --hw npu --suite per-commit-1-npu-a2 --continue-on-error --timeout-per-file 3600 --auto-partition-id ${{ matrix.part }} --auto-partition-size 2 + # copy gsm8k dataset + cp ~/.cache/modelscope/hub/datasets/tmp/test.jsonl /tmp - name: Run test timeout-minutes: 60 @@ -137,10 +124,10 @@ jobs: PYTORCH_NPU_ALLOC_CONF: "expandable_segments:True" STREAMS_PER_DEVICE: 32 run: | - cd test/srt - python3 run_suite.py --suite per-commit-1-npu-a2 --auto-partition-id ${{ matrix.part }} --auto-partition-size 2 + cd test + python3 run_suite.py --hw npu --suite stage-b-test-1-npu-a2 --auto-partition-id ${{ matrix.part }} --auto-partition-size 2 - per-commit-2-npu-a2: + stage-b-test-2-npu-a2: needs: [check-changes, pr-gate] if: needs.check-changes.outputs.main_package == 'true' runs-on: linux-aarch64-a2-2 @@ -175,8 +162,8 @@ jobs: bash scripts/ci/npu/npu_ci_install_dependency.sh 910b # copy required file from our daily cache cp ~/.cache/modelscope/hub/datasets/otavia/ShareGPT_Vicuna_unfiltered/ShareGPT_V3_unfiltered_cleaned_split.json /tmp - # copy download through proxy - curl -o /tmp/test.jsonl -L https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl + # copy gsm8k dataset + cp ~/.cache/modelscope/hub/datasets/tmp/test.jsonl /tmp - name: Run test timeout-minutes: 60 @@ -188,10 +175,10 @@ jobs: PYTORCH_NPU_ALLOC_CONF: "expandable_segments:True" STREAMS_PER_DEVICE: 32 run: | - cd test/srt - python3 run_suite.py --suite per-commit-2-npu-a2 --auto-partition-id ${{ matrix.part }} --auto-partition-size 2 + cd test + python3 run_suite.py --hw npu --suite stage-b-test-2-npu-a2 --auto-partition-id ${{ matrix.part }} --auto-partition-size 2 - per-commit-4-npu-a3: + stage-b-test-4-npu-a3: needs: [check-changes, pr-gate] if: needs.check-changes.outputs.main_package == 'true' runs-on: linux-aarch64-a3-4 @@ -222,8 +209,8 @@ jobs: bash scripts/ci/npu/npu_ci_install_dependency.sh a3 # copy required file from our daily cache cp ~/.cache/modelscope/hub/datasets/otavia/ShareGPT_Vicuna_unfiltered/ShareGPT_V3_unfiltered_cleaned_split.json /tmp - # copy download through proxy - curl -o /tmp/test.jsonl -L https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl + # copy gsm8k dataset + cp ~/.cache/modelscope/hub/datasets/tmp/test.jsonl /tmp - name: Run test timeout-minutes: 60 @@ -235,10 +222,11 @@ jobs: PYTORCH_NPU_ALLOC_CONF: "expandable_segments:True" STREAMS_PER_DEVICE: 32 run: | - cd test/srt - python3 run_suite.py --suite per-commit-4-npu-a3 --timeout-per-file 3600 + cd test + python3 run_suite.py --hw npu --suite stage-b-test-4-npu-a3 --timeout-per-file 3600 - per-commit-16-npu-a3: + + stage-b-test-16-npu-a3: needs: [check-changes, pr-gate] if: needs.check-changes.outputs.main_package == 'true' runs-on: linux-aarch64-a3-16 @@ -269,8 +257,8 @@ jobs: bash scripts/ci/npu/npu_ci_install_dependency.sh a3 # copy required file from our daily cache cp ~/.cache/modelscope/hub/datasets/otavia/ShareGPT_Vicuna_unfiltered/ShareGPT_V3_unfiltered_cleaned_split.json /tmp - # copy download through proxy - curl -o /tmp/test.jsonl -L https://gh-proxy.test.osinfra.cn/https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl + # copy gsm8k dataset + cp ~/.cache/modelscope/hub/datasets/tmp/test.jsonl /tmp - name: Run test timeout-minutes: 60 @@ -282,8 +270,8 @@ jobs: PYTORCH_NPU_ALLOC_CONF: "expandable_segments:True" STREAMS_PER_DEVICE: 32 run: | - cd test/srt - python3 run_suite.py --suite per-commit-16-npu-a3 --timeout-per-file 3600 + cd test + python3 run_suite.py --hw npu --suite stage-b-test-16-npu-a3 --timeout-per-file 3600 multimodal-gen-test-1-npu-a3: needs: [check-changes, pr-gate] @@ -314,8 +302,8 @@ jobs: bash scripts/ci/npu/npu_ci_install_dependency.sh a3 diffusion # copy required file from our daily cache cp ~/.cache/modelscope/hub/datasets/otavia/ShareGPT_Vicuna_unfiltered/ShareGPT_V3_unfiltered_cleaned_split.json /tmp - # copy download through proxy - curl -o /tmp/test.jsonl -L https://gh-proxy.test.osinfra.cn/https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl + # copy gsm8k dataset + cp ~/.cache/modelscope/hub/datasets/tmp/test.jsonl /tmp - name: Run test timeout-minutes: 60 @@ -360,8 +348,8 @@ jobs: bash scripts/ci/npu/npu_ci_install_dependency.sh a3 diffusion # copy required file from our daily cache cp ~/.cache/modelscope/hub/datasets/otavia/ShareGPT_Vicuna_unfiltered/ShareGPT_V3_unfiltered_cleaned_split.json /tmp - # copy download through proxy - curl -o /tmp/test.jsonl -L https://gh-proxy.test.osinfra.cn/https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl + # copy gsm8k dataset + cp ~/.cache/modelscope/hub/datasets/tmp/test.jsonl /tmp - name: Run test timeout-minutes: 60 @@ -380,7 +368,7 @@ jobs: multimodal-gen-test-8-npu-a3: needs: [check-changes, pr-gate] if: needs.check-changes.outputs.multimodal_gen == 'true' - runs-on: linux-aarch64-a3-16 + runs-on: linux-aarch64-a3-8 container: image: swr.cn-southwest-2.myhuaweicloud.com/base_image/ascend-ci/cann:8.5.0-a3-ubuntu22.04-py3.11 steps: @@ -406,8 +394,8 @@ jobs: bash scripts/ci/npu/npu_ci_install_dependency.sh a3 diffusion # copy required file from our daily cache cp ~/.cache/modelscope/hub/datasets/otavia/ShareGPT_Vicuna_unfiltered/ShareGPT_V3_unfiltered_cleaned_split.json /tmp - # copy download through proxy - curl -o /tmp/test.jsonl -L https://gh-proxy.test.osinfra.cn/https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl + # copy gsm8k dataset + cp ~/.cache/modelscope/hub/datasets/tmp/test.jsonl /tmp - name: Run test timeout-minutes: 60 @@ -421,3 +409,45 @@ jobs: run: | cd python python3 sglang/multimodal_gen/test/run_suite.py --suite 8-npu + + pr-test-npu-finish: + needs: + [ + check-changes, + + stage-b-test-1-npu-a2, + stage-b-test-2-npu-a2, + stage-b-test-4-npu-a3, + stage-b-test-16-npu-a3, + + multimodal-gen-test-1-npu-a3, + multimodal-gen-test-2-npu-a3, + multimodal-gen-test-8-npu-a3, + ] + if: always() + runs-on: ubuntu-latest + steps: + - name: Check all dependent job statuses + run: | + # Convert the 'needs' context to a JSON string + json_needs='${{ toJson(needs) }}' + + # Get a list of all job names from the JSON keys + job_names=$(echo "$json_needs" | jq -r 'keys_unsorted[]') + + for job in $job_names; do + # For each job, extract its result + result=$(echo "$json_needs" | jq -r --arg j "$job" '.[$j].result') + + # Print the job name and its result + echo "$job: $result" + + # Check for failure or cancellation and exit if found + if [[ "$result" == "failure" || "$result" == "cancelled" ]]; then + echo "The above jobs failed." + exit 1 + fi + done + # If the loop completes, all jobs were successful + echo "All jobs completed successfully" + exit 0 diff --git a/.github/workflows/pr-test-xeon.yml b/.github/workflows/pr-test-xeon.yml index 021a1308593c..109c0d4b06ef 100644 --- a/.github/workflows/pr-test-xeon.yml +++ b/.github/workflows/pr-test-xeon.yml @@ -55,10 +55,10 @@ jobs: with: filters: | main_package: - - "python/sglang/!(multimodal_gen)/**" + - "python/sglang/!(multimodal_gen)/**/!(*.md)" - "python/pyproject_cpu.toml" - - "test/**" - - "sgl-kernel/**" + - "test/**/!(*.md)" + - "sgl-kernel/**/*.!(md|txt)" - ".github/workflows/pr-test-xeon.yml" - "docker/xeon.Dockerfile" diff --git a/.github/workflows/pr-test-xpu.yml b/.github/workflows/pr-test-xpu.yml index e1a4a5766253..abf9760fd33f 100644 --- a/.github/workflows/pr-test-xpu.yml +++ b/.github/workflows/pr-test-xpu.yml @@ -54,10 +54,10 @@ jobs: with: filters: | main_package: - - "python/sglang/!(multimodal_gen)/**" + - "python/sglang/!(multimodal_gen)/**/!(*.md)" - "python/pyproject_xpu.toml" - - "test/**" - - "sgl-kernel/**" + - "test/**/!(*.md)" + - "sgl-kernel/**/*.!(md|txt)" - ".github/workflows/pr-test-xpu.yml" - "docker/xpu.Dockerfile" diff --git a/.github/workflows/pr-test.yml b/.github/workflows/pr-test.yml index 30808b01ce8f..1b4302954820 100644 --- a/.github/workflows/pr-test.yml +++ b/.github/workflows/pr-test.yml @@ -88,6 +88,8 @@ jobs: jit_kernel: ${{ steps.filter-api.outputs.jit_kernel || steps.filter.outputs.jit_kernel || steps.run-mode.outputs.run_all_tests }} multimodal_gen: ${{ steps.filter-api.outputs.multimodal_gen || steps.filter.outputs.multimodal_gen || steps.run-mode.outputs.run_all_tests }} max_parallel: ${{ steps.set-parallel.outputs.max_parallel }} + max_parallel_small: ${{ steps.set-parallel.outputs.max_parallel_small }} + max_parallel_2gpu: ${{ steps.set-parallel.outputs.max_parallel_2gpu }} b200_runner: ${{ steps.set-runner.outputs.b200_runner }} enable_retry: ${{ steps.set-retry.outputs.enable_retry }} continue_on_error: ${{ steps.set-continue-on-error.outputs.continue_on_error }} @@ -125,10 +127,10 @@ jobs: - ".github/workflows/pr-gate.yml" - ".github/actions/**" - "python/pyproject.toml" - - "python/sglang/!(multimodal_gen)/**" + - "python/sglang/!(multimodal_gen)/**/!(*.md)" - "scripts/ci/cuda/*" - "scripts/ci/utils/*" - - "test/**" + - "test/**/!(*.md)" multimodal_gen: - ".github/workflows/pr-test.yml" - ".github/workflows/pr-test-multimodal-gen.yml" @@ -145,7 +147,7 @@ jobs: - "python/sglang/jit_kernel/**" sgl_kernel: - ".github/workflows/pr-test-sgl-kernel.yml" - - "sgl-kernel/**" + - "sgl-kernel/**/*.!(md|txt)" # For /rerun-stage (workflow_dispatch with target_stage), dorny/paths-filter doesn't work # correctly because it falls back to "last commit" detection which breaks for merge commits. @@ -216,13 +218,14 @@ jobs: env: GH_TOKEN: ${{ github.token }} run: | - # Scheduled runs and high-priority PRs get full parallelism + # Determine if this run gets full parallelism (scheduled / high priority) + FULL=false if [[ "${{ github.event_name }}" == "schedule" ]]; then - echo "max_parallel=14" >> $GITHUB_OUTPUT - echo "Scheduled run detected, setting max_parallel to 14" + FULL=true + echo "Scheduled run detected, using full parallelism" elif [[ "${{ github.event_name }}" == "pull_request" && "${{ contains(github.event.pull_request.labels.*.name, 'high priority') }}" == "true" ]]; then - echo "max_parallel=14" >> $GITHUB_OUTPUT - echo "High priority PR detected, setting max_parallel to 14" + FULL=true + echo "High priority PR detected, using full parallelism" elif [[ -n "${{ inputs.target_stage }}" ]]; then # /rerun-stage (workflow_dispatch): query PR labels via GitHub API # Try SHA lookup first (fork PRs), fallback to branch name (non-fork PRs) @@ -238,16 +241,26 @@ jobs: fi echo "PR labels: ${LABELS:-"(none)"}" if echo "$LABELS" | grep -Fxq "high priority"; then - echo "max_parallel=14" >> $GITHUB_OUTPUT - echo "High priority PR detected via API (/rerun-stage), setting max_parallel to 14" - else - echo "max_parallel=3" >> $GITHUB_OUTPUT - echo "Using default max_parallel of 3 (/rerun-stage, no high priority label)" + FULL=true + echo "High priority PR detected via API (/rerun-stage), using full parallelism" fi + fi + + # Set max-parallel for each runner type + # 1-gpu-h100: 14 partitions, 1-gpu-5090: 8 partitions, 2-gpu-h100: 4 partitions + if [[ "$FULL" == "true" ]]; then + LEVEL=full + echo "max_parallel=14" >> $GITHUB_OUTPUT + echo "max_parallel_small=8" >> $GITHUB_OUTPUT + echo "max_parallel_2gpu=4" >> $GITHUB_OUTPUT else + LEVEL=low echo "max_parallel=3" >> $GITHUB_OUTPUT - echo "Using default max_parallel of 3" + echo "max_parallel_small=3" >> $GITHUB_OUTPUT + echo "max_parallel_2gpu=2" >> $GITHUB_OUTPUT fi + echo "parallel_level=$LEVEL" >> $GITHUB_OUTPUT + echo "Parallelism level: $LEVEL" - name: Set B200 runner tag id: set-runner @@ -314,7 +327,7 @@ jobs: echo "| multimodal_gen | ${{ steps.filter-api.outputs.multimodal_gen || steps.filter.outputs.multimodal_gen || steps.run-mode.outputs.run_all_tests }} |" echo "| target_stage | ${{ inputs.target_stage || '(none)' }} |" echo "| detection_method | ${{ inputs.target_stage && 'GitHub API' || 'dorny/paths-filter' }} |" - echo "| max_parallel | ${{ steps.set-parallel.outputs.max_parallel }} |" + echo "| max_parallel | ${{ steps.set-parallel.outputs.parallel_level }} (h100=${{ steps.set-parallel.outputs.max_parallel }}, 5090=${{ steps.set-parallel.outputs.max_parallel_small }}, 2gpu=${{ steps.set-parallel.outputs.max_parallel_2gpu }}) |" echo "| b200_runner | ${{ steps.set-runner.outputs.b200_runner }} |" echo "| enable_retry | ${{ steps.set-retry.outputs.enable_retry }} |" echo "| continue_on_error | ${{ steps.set-continue-on-error.outputs.continue_on_error }} |" @@ -661,7 +674,7 @@ jobs: timeout-minutes: 240 strategy: fail-fast: false - max-parallel: 8 + max-parallel: ${{ fromJson(needs.check-changes.outputs.max_parallel_small) }} matrix: partition: [0, 1, 2, 3, 4, 5, 6, 7] steps: @@ -687,9 +700,6 @@ jobs: run: | source /etc/profile.d/sglang-ci.sh CUSTOM_BUILD_SGL_KERNEL=${{needs.check-changes.outputs.sgl_kernel}} bash scripts/ci/cuda/ci_install_dependency.sh - git clone https://github.com/merrymercy/human-eval.git - cd human-eval - pip install -e . --no-build-isolation - name: Run test timeout-minutes: 30 @@ -777,6 +787,7 @@ jobs: timeout-minutes: 240 strategy: fail-fast: false + max-parallel: ${{ fromJson(needs.check-changes.outputs.max_parallel_2gpu) }} matrix: partition: [0, 1, 2, 3] steps: @@ -801,9 +812,6 @@ jobs: timeout-minutes: 20 run: | CUSTOM_BUILD_SGL_KERNEL=${{needs.check-changes.outputs.sgl_kernel}} bash scripts/ci/cuda/ci_install_dependency.sh - git clone https://github.com/merrymercy/human-eval.git - cd human-eval - pip install -e . --no-build-isolation - name: Run test timeout-minutes: 30 @@ -882,6 +890,9 @@ jobs: ( inputs.target_stage == 'multimodal-gen-test-1-gpu' || inputs.target_stage == 'multimodal-gen-test-2-gpu' || + inputs.target_stage == 'multimodal-gen-component-accuracy-1-gpu' || + inputs.target_stage == 'multimodal-gen-component-accuracy-2-gpu' || + inputs.target_stage == 'multimodal-gen-test-1-b200' || inputs.target_stage == 'multimodal-gen-unit-test' || ( !inputs.target_stage && @@ -893,6 +904,7 @@ jobs: with: multimodal_gen: ${{ needs.check-changes.outputs.multimodal_gen }} sgl_kernel: ${{ needs.check-changes.outputs.sgl_kernel }} + b200_runner: ${{ needs.check-changes.outputs.b200_runner }} continue_on_error: ${{ needs.check-changes.outputs.continue_on_error }} pr_head_sha: ${{ inputs.pr_head_sha || '' }} git_ref: ${{ inputs.git_ref || '' }} diff --git a/.github/workflows/release-docker-npu-nightly.yml b/.github/workflows/release-docker-npu-nightly.yml index 1dc729cfdd47..8866ae2a2776 100644 --- a/.github/workflows/release-docker-npu-nightly.yml +++ b/.github/workflows/release-docker-npu-nightly.yml @@ -8,7 +8,7 @@ on: - 'docker/npu.Dockerfile' workflow_dispatch: schedule: - - cron: "0 0 * * *" + - cron: "0 16 * * *" # Execute at 0:00 a.m. Beijing Time every day concurrency: group: ${{ github.workflow }}-${{ github.sha }} diff --git a/.github/workflows/release-docker-runtime.yml b/.github/workflows/release-docker-runtime.yml new file mode 100644 index 000000000000..9232c094f0f2 --- /dev/null +++ b/.github/workflows/release-docker-runtime.yml @@ -0,0 +1,309 @@ +name: Release Docker Runtime Images +# +# This workflow builds and publishes runtime Docker images (production-optimized, ~50% smaller): +# - lmsysorg/sglang:v{version}-runtime, lmsysorg/sglang:latest-runtime +# - lmsysorg/sglang:v{version}-cu130-runtime, lmsysorg/sglang:latest-cu130-runtime +# +on: + push: + tags: + - "v[0-9]+.*" + workflow_dispatch: + inputs: + version: + description: "Version to build (without v prefix, e.g., 0.5.7)" + required: true + +jobs: + publish-x86: + if: github.repository == 'sgl-project/sglang' + environment: "prod" + strategy: + matrix: + variant: + - cuda_version: "12.9.1" + build_type: "all" + grace_blackwell: 0 + runs-on: x64-docker-build-node + steps: + - name: Delete huge unnecessary tools folder + run: rm -rf /opt/hostedtoolcache + + - name: Checkout repository + uses: actions/checkout@v4 + + - name: Free disk space + uses: jlumbroso/free-disk-space@main + with: + tool-cache: false + docker-images: false + android: true + dotnet: true + haskell: true + large-packages: true + swap-storage: false + + - name: Set up Docker Buildx + uses: docker/setup-buildx-action@v3 + + - name: Login to Docker Hub + uses: docker/login-action@v2 + with: + username: ${{ secrets.DOCKERHUB_USERNAME }} + password: ${{ secrets.DOCKERHUB_TOKEN }} + + - name: Get version from tag + id: version + run: | + if [ "${{ github.event_name }}" = "workflow_dispatch" ]; then + VERSION="${{ github.event.inputs.version }}" + else + # Extract version from tag (e.g., v0.5.7 -> 0.5.7) + VERSION="${GITHUB_REF_NAME#v}" + fi + + # Validate version format + if [ -z "$VERSION" ]; then + echo "::error::Version is empty" + exit 1 + fi + if ! echo "$VERSION" | grep -qE '^[0-9]+\.[0-9]+\.[0-9]+'; then + echo "::error::Invalid version format: $VERSION (expected: X.Y.Z)" + exit 1 + fi + + echo "version=${VERSION}" >> $GITHUB_OUTPUT + + - name: Build and Push AMD64 Runtime + run: | + version=${{ steps.version.outputs.version }} + + docker buildx build \ + --target runtime \ + --platform linux/amd64 \ + --output type=image,name=lmsysorg/sglang,push-by-digest=true,name-canonical=true,push=true \ + -f docker/Dockerfile \ + --build-arg CUDA_VERSION=${{ matrix.variant.cuda_version }} \ + --build-arg BUILD_TYPE=${{ matrix.variant.build_type }} \ + --build-arg GRACE_BLACKWELL=${{ matrix.variant.grace_blackwell }} \ + --build-arg INSTALL_FLASHINFER_JIT_CACHE=1 \ + --build-arg SGL_VERSION=${version} \ + --metadata-file /tmp/metadata-cu129-runtime.json \ + --no-cache \ + . + + DIGEST=$(python3 -c "import json; print(json.load(open('/tmp/metadata-cu129-runtime.json'))['containerimage.digest'])") + echo "Pushed digest: ${DIGEST}" + echo "${DIGEST}" > /tmp/digest-cu129-amd64-runtime.txt + + - name: Build and Push AMD64 Runtime (CUDA 13) + run: | + version=${{ steps.version.outputs.version }} + + docker buildx build \ + --target runtime \ + --platform linux/amd64 \ + --output type=image,name=lmsysorg/sglang,push-by-digest=true,name-canonical=true,push=true \ + -f docker/Dockerfile \ + --build-arg CUDA_VERSION=13.0.1 \ + --build-arg BUILD_TYPE=${{ matrix.variant.build_type }} \ + --build-arg INSTALL_FLASHINFER_JIT_CACHE=1 \ + --build-arg GRACE_BLACKWELL=0 \ + --build-arg SGL_VERSION=${version} \ + --metadata-file /tmp/metadata-cu130-runtime.json \ + --no-cache \ + . + + DIGEST=$(python3 -c "import json; print(json.load(open('/tmp/metadata-cu130-runtime.json'))['containerimage.digest'])") + echo "Pushed digest: ${DIGEST}" + echo "${DIGEST}" > /tmp/digest-cu130-amd64-runtime.txt + + - name: Upload digests + uses: actions/upload-artifact@v4 + with: + name: digests-amd64 + path: /tmp/digest-*.txt + retention-days: 1 + + publish-arm64: + if: github.repository == 'sgl-project/sglang' + environment: "prod" + strategy: + matrix: + variant: + - cuda_version: "12.9.1" + build_type: "all" + grace_blackwell: 1 + runs-on: arm-docker-build-node + steps: + - name: Delete huge unnecessary tools folder + run: rm -rf /opt/hostedtoolcache + + - name: Checkout repository + uses: actions/checkout@v4 + + - name: Set up Docker Buildx + uses: docker/setup-buildx-action@v3 + + - name: Login to Docker Hub + uses: docker/login-action@v2 + with: + username: ${{ secrets.DOCKERHUB_USERNAME }} + password: ${{ secrets.DOCKERHUB_TOKEN }} + + - name: Get version from tag + id: version + run: | + if [ "${{ github.event_name }}" = "workflow_dispatch" ]; then + VERSION="${{ github.event.inputs.version }}" + else + # Extract version from tag (e.g., v0.5.7 -> 0.5.7) + VERSION="${GITHUB_REF_NAME#v}" + fi + + # Validate version format + if [ -z "$VERSION" ]; then + echo "::error::Version is empty" + exit 1 + fi + if ! echo "$VERSION" | grep -qE '^[0-9]+\.[0-9]+\.[0-9]+'; then + echo "::error::Invalid version format: $VERSION (expected: X.Y.Z)" + exit 1 + fi + + echo "version=${VERSION}" >> $GITHUB_OUTPUT + + - name: Build and Push ARM64 Runtime + run: | + version=${{ steps.version.outputs.version }} + + docker buildx build \ + --target runtime \ + --platform linux/arm64 \ + --output type=image,name=lmsysorg/sglang,push-by-digest=true,name-canonical=true,push=true \ + -f docker/Dockerfile \ + --build-arg CUDA_VERSION=${{ matrix.variant.cuda_version }} \ + --build-arg BUILD_TYPE=${{ matrix.variant.build_type }} \ + --build-arg GRACE_BLACKWELL=${{ matrix.variant.grace_blackwell }} \ + --build-arg INSTALL_FLASHINFER_JIT_CACHE=1 \ + --build-arg SGL_VERSION=${version} \ + --metadata-file /tmp/metadata-cu129-runtime.json \ + --no-cache \ + . + + DIGEST=$(python3 -c "import json; print(json.load(open('/tmp/metadata-cu129-runtime.json'))['containerimage.digest'])") + echo "Pushed digest: ${DIGEST}" + echo "${DIGEST}" > /tmp/digest-cu129-arm64-runtime.txt + + - name: Build and Push ARM64 Runtime (CUDA 13) + run: | + version=${{ steps.version.outputs.version }} + + docker buildx build \ + --target runtime \ + --platform linux/arm64 \ + --output type=image,name=lmsysorg/sglang,push-by-digest=true,name-canonical=true,push=true \ + -f docker/Dockerfile \ + --build-arg CUDA_VERSION=13.0.1 \ + --build-arg BUILD_TYPE=${{ matrix.variant.build_type }} \ + --build-arg GRACE_BLACKWELL=1 \ + --build-arg SGL_VERSION=${version} \ + --metadata-file /tmp/metadata-cu130-runtime.json \ + --no-cache \ + . + + DIGEST=$(python3 -c "import json; print(json.load(open('/tmp/metadata-cu130-runtime.json'))['containerimage.digest'])") + echo "Pushed digest: ${DIGEST}" + echo "${DIGEST}" > /tmp/digest-cu130-arm64-runtime.txt + + - name: Upload digests + uses: actions/upload-artifact@v4 + with: + name: digests-arm64 + path: /tmp/digest-*.txt + retention-days: 1 + + create-manifests: + runs-on: ubuntu-22.04 + needs: [publish-x86, publish-arm64] + if: github.repository == 'sgl-project/sglang' + environment: "prod" + steps: + - name: Checkout repository + uses: actions/checkout@v4 + + - name: Set up Docker Buildx + uses: docker/setup-buildx-action@v3 + + - name: Login to Docker Hub + uses: docker/login-action@v2 + with: + username: ${{ secrets.DOCKERHUB_USERNAME }} + password: ${{ secrets.DOCKERHUB_TOKEN }} + + - name: Get version from tag + id: version + run: | + if [ "${{ github.event_name }}" = "workflow_dispatch" ]; then + VERSION="${{ github.event.inputs.version }}" + else + # Extract version from tag (e.g., v0.5.7 -> 0.5.7) + VERSION="${GITHUB_REF_NAME#v}" + fi + + # Validate version format + if [ -z "$VERSION" ]; then + echo "::error::Version is empty" + exit 1 + fi + if ! echo "$VERSION" | grep -qE '^[0-9]+\.[0-9]+\.[0-9]+'; then + echo "::error::Invalid version format: $VERSION (expected: X.Y.Z)" + exit 1 + fi + + echo "version=${VERSION}" >> $GITHUB_OUTPUT + + - name: Download amd64 digests + uses: actions/download-artifact@v4 + with: + name: digests-amd64 + path: /tmp/digests/amd64 + + - name: Download arm64 digests + uses: actions/download-artifact@v4 + with: + name: digests-arm64 + path: /tmp/digests/arm64 + + - name: Create multi-arch manifests + run: | + version=${{ steps.version.outputs.version }} + + CU129_AMD64_RT=$(cat /tmp/digests/amd64/digest-cu129-amd64-runtime.txt) + CU130_AMD64_RT=$(cat /tmp/digests/amd64/digest-cu130-amd64-runtime.txt) + CU129_ARM64_RT=$(cat /tmp/digests/arm64/digest-cu129-arm64-runtime.txt) + CU130_ARM64_RT=$(cat /tmp/digests/arm64/digest-cu130-arm64-runtime.txt) + + # Create versioned runtime manifest + docker buildx imagetools create \ + -t lmsysorg/sglang:v${version}-runtime \ + lmsysorg/sglang@${CU129_AMD64_RT} \ + lmsysorg/sglang@${CU129_ARM64_RT} + + # Create latest runtime manifest + docker buildx imagetools create \ + -t lmsysorg/sglang:latest-runtime \ + lmsysorg/sglang@${CU129_AMD64_RT} \ + lmsysorg/sglang@${CU129_ARM64_RT} + + # Create versioned CUDA 13 runtime manifest + docker buildx imagetools create \ + -t lmsysorg/sglang:v${version}-cu130-runtime \ + lmsysorg/sglang@${CU130_AMD64_RT} \ + lmsysorg/sglang@${CU130_ARM64_RT} + + # Create latest CUDA 13 runtime manifest + docker buildx imagetools create \ + -t lmsysorg/sglang:latest-cu130-runtime \ + lmsysorg/sglang@${CU130_AMD64_RT} \ + lmsysorg/sglang@${CU130_ARM64_RT} diff --git a/.github/workflows/release-docker.yml b/.github/workflows/release-docker.yml index 2f503d22ea41..cebf9cde4215 100644 --- a/.github/workflows/release-docker.yml +++ b/.github/workflows/release-docker.yml @@ -1,15 +1,9 @@ name: Release Docker Images # -# This workflow builds and publishes both framework and runtime Docker images: -# -# Framework images (full development environment): +# This workflow builds and publishes framework Docker images (full development environment): # - lmsysorg/sglang:v{version}, lmsysorg/sglang:latest # - lmsysorg/sglang:v{version}-cu130, lmsysorg/sglang:latest-cu130 # -# Runtime images (production-optimized, ~50% smaller): -# - lmsysorg/sglang:v{version}-runtime, lmsysorg/sglang:latest-runtime -# - lmsysorg/sglang:v{version}-cu130-runtime, lmsysorg/sglang:latest-cu130-runtime -# on: push: tags: @@ -24,6 +18,9 @@ jobs: publish-x86: if: github.repository == 'sgl-project/sglang' environment: "prod" + outputs: + digest-cu129: ${{ steps.build-cu129.outputs.digest }} + digest-cu130: ${{ steps.build-cu130.outputs.digest }} strategy: matrix: variant: @@ -81,6 +78,7 @@ jobs: echo "version=${VERSION}" >> $GITHUB_OUTPUT - name: Build AMD64 Framework + id: build-cu129 run: | version=${{ steps.version.outputs.version }} @@ -100,31 +98,10 @@ jobs: DIGEST=$(python3 -c "import json; print(json.load(open('/tmp/metadata-cu129-framework.json'))['containerimage.digest'])") echo "Pushed digest: ${DIGEST}" - echo "${DIGEST}" > /tmp/digest-cu129-amd64-framework.txt - - - name: Build and Push AMD64 Runtime - run: | - version=${{ steps.version.outputs.version }} - - docker buildx build \ - --target runtime \ - --platform linux/amd64 \ - --output type=image,name=lmsysorg/sglang,push-by-digest=true,name-canonical=true,push=true \ - -f docker/Dockerfile \ - --build-arg CUDA_VERSION=${{ matrix.variant.cuda_version }} \ - --build-arg BUILD_TYPE=${{ matrix.variant.build_type }} \ - --build-arg GRACE_BLACKWELL=${{ matrix.variant.grace_blackwell }} \ - --build-arg INSTALL_FLASHINFER_JIT_CACHE=1 \ - --build-arg SGL_VERSION=${version} \ - --metadata-file /tmp/metadata-cu129-runtime.json \ - --no-cache \ - . - - DIGEST=$(python3 -c "import json; print(json.load(open('/tmp/metadata-cu129-runtime.json'))['containerimage.digest'])") - echo "Pushed digest: ${DIGEST}" - echo "${DIGEST}" > /tmp/digest-cu129-amd64-runtime.txt + echo "digest=${DIGEST}" >> $GITHUB_OUTPUT - name: Build and Push AMD64 Framework (CUDA 13) + id: build-cu130 run: | version=${{ steps.version.outputs.version }} @@ -144,40 +121,14 @@ jobs: DIGEST=$(python3 -c "import json; print(json.load(open('/tmp/metadata-cu130-framework.json'))['containerimage.digest'])") echo "Pushed digest: ${DIGEST}" - echo "${DIGEST}" > /tmp/digest-cu130-amd64-framework.txt - - - name: Build and Push AMD64 Runtime (CUDA 13) - run: | - version=${{ steps.version.outputs.version }} - - docker buildx build \ - --target runtime \ - --platform linux/amd64 \ - --output type=image,name=lmsysorg/sglang,push-by-digest=true,name-canonical=true,push=true \ - -f docker/Dockerfile \ - --build-arg CUDA_VERSION=13.0.1 \ - --build-arg BUILD_TYPE=${{ matrix.variant.build_type }} \ - --build-arg INSTALL_FLASHINFER_JIT_CACHE=1 \ - --build-arg GRACE_BLACKWELL=0 \ - --build-arg SGL_VERSION=${version} \ - --metadata-file /tmp/metadata-cu130-runtime.json \ - --no-cache \ - . - - DIGEST=$(python3 -c "import json; print(json.load(open('/tmp/metadata-cu130-runtime.json'))['containerimage.digest'])") - echo "Pushed digest: ${DIGEST}" - echo "${DIGEST}" > /tmp/digest-cu130-amd64-runtime.txt - - - name: Upload digests - uses: actions/upload-artifact@v4 - with: - name: digests-amd64 - path: /tmp/digest-*.txt - retention-days: 1 + echo "digest=${DIGEST}" >> $GITHUB_OUTPUT publish-arm64: if: github.repository == 'sgl-project/sglang' environment: "prod" + outputs: + digest-cu129: ${{ steps.build-cu129.outputs.digest }} + digest-cu130: ${{ steps.build-cu130.outputs.digest }} strategy: matrix: variant: @@ -224,6 +175,7 @@ jobs: echo "version=${VERSION}" >> $GITHUB_OUTPUT - name: Build ARM64 Framework + id: build-cu129 run: | version=${{ steps.version.outputs.version }} @@ -243,31 +195,10 @@ jobs: DIGEST=$(python3 -c "import json; print(json.load(open('/tmp/metadata-cu129-framework.json'))['containerimage.digest'])") echo "Pushed digest: ${DIGEST}" - echo "${DIGEST}" > /tmp/digest-cu129-arm64-framework.txt - - - name: Build and Push ARM64 Runtime - run: | - version=${{ steps.version.outputs.version }} - - docker buildx build \ - --target runtime \ - --platform linux/arm64 \ - --output type=image,name=lmsysorg/sglang,push-by-digest=true,name-canonical=true,push=true \ - -f docker/Dockerfile \ - --build-arg CUDA_VERSION=${{ matrix.variant.cuda_version }} \ - --build-arg BUILD_TYPE=${{ matrix.variant.build_type }} \ - --build-arg GRACE_BLACKWELL=${{ matrix.variant.grace_blackwell }} \ - --build-arg INSTALL_FLASHINFER_JIT_CACHE=1 \ - --build-arg SGL_VERSION=${version} \ - --metadata-file /tmp/metadata-cu129-runtime.json \ - --no-cache \ - . - - DIGEST=$(python3 -c "import json; print(json.load(open('/tmp/metadata-cu129-runtime.json'))['containerimage.digest'])") - echo "Pushed digest: ${DIGEST}" - echo "${DIGEST}" > /tmp/digest-cu129-arm64-runtime.txt + echo "digest=${DIGEST}" >> $GITHUB_OUTPUT - name: Build and Push ARM64 Framework (CUDA 13) + id: build-cu130 run: | version=${{ steps.version.outputs.version }} @@ -287,35 +218,7 @@ jobs: DIGEST=$(python3 -c "import json; print(json.load(open('/tmp/metadata-cu130-framework.json'))['containerimage.digest'])") echo "Pushed digest: ${DIGEST}" - echo "${DIGEST}" > /tmp/digest-cu130-arm64-framework.txt - - - name: Build and Push ARM64 Runtime (CUDA 13) - run: | - version=${{ steps.version.outputs.version }} - - docker buildx build \ - --target runtime \ - --platform linux/arm64 \ - --output type=image,name=lmsysorg/sglang,push-by-digest=true,name-canonical=true,push=true \ - -f docker/Dockerfile \ - --build-arg CUDA_VERSION=13.0.1 \ - --build-arg BUILD_TYPE=${{ matrix.variant.build_type }} \ - --build-arg GRACE_BLACKWELL=1 \ - --build-arg SGL_VERSION=${version} \ - --metadata-file /tmp/metadata-cu130-runtime.json \ - --no-cache \ - . - - DIGEST=$(python3 -c "import json; print(json.load(open('/tmp/metadata-cu130-runtime.json'))['containerimage.digest'])") - echo "Pushed digest: ${DIGEST}" - echo "${DIGEST}" > /tmp/digest-cu130-arm64-runtime.txt - - - name: Upload digests - uses: actions/upload-artifact@v4 - with: - name: digests-arm64 - path: /tmp/digest-*.txt - retention-days: 1 + echo "digest=${DIGEST}" >> $GITHUB_OUTPUT create-manifests: runs-on: ubuntu-22.04 @@ -357,31 +260,14 @@ jobs: echo "version=${VERSION}" >> $GITHUB_OUTPUT - - name: Download amd64 digests - uses: actions/download-artifact@v4 - with: - name: digests-amd64 - path: /tmp/digests/amd64 - - - name: Download arm64 digests - uses: actions/download-artifact@v4 - with: - name: digests-arm64 - path: /tmp/digests/arm64 - - name: Create multi-arch manifests run: | version=${{ steps.version.outputs.version }} - # Load all digests - CU129_AMD64_FW=$(cat /tmp/digests/amd64/digest-cu129-amd64-framework.txt) - CU129_AMD64_RT=$(cat /tmp/digests/amd64/digest-cu129-amd64-runtime.txt) - CU130_AMD64_FW=$(cat /tmp/digests/amd64/digest-cu130-amd64-framework.txt) - CU130_AMD64_RT=$(cat /tmp/digests/amd64/digest-cu130-amd64-runtime.txt) - CU129_ARM64_FW=$(cat /tmp/digests/arm64/digest-cu129-arm64-framework.txt) - CU129_ARM64_RT=$(cat /tmp/digests/arm64/digest-cu129-arm64-runtime.txt) - CU130_ARM64_FW=$(cat /tmp/digests/arm64/digest-cu130-arm64-framework.txt) - CU130_ARM64_RT=$(cat /tmp/digests/arm64/digest-cu130-arm64-runtime.txt) + CU129_AMD64_FW=${{ needs.publish-x86.outputs.digest-cu129 }} + CU130_AMD64_FW=${{ needs.publish-x86.outputs.digest-cu130 }} + CU129_ARM64_FW=${{ needs.publish-arm64.outputs.digest-cu129 }} + CU130_ARM64_FW=${{ needs.publish-arm64.outputs.digest-cu130 }} # Create versioned framework manifest (default) docker buildx imagetools create \ @@ -395,18 +281,6 @@ jobs: lmsysorg/sglang@${CU129_AMD64_FW} \ lmsysorg/sglang@${CU129_ARM64_FW} - # Create versioned runtime manifest - docker buildx imagetools create \ - -t lmsysorg/sglang:v${version}-runtime \ - lmsysorg/sglang@${CU129_AMD64_RT} \ - lmsysorg/sglang@${CU129_ARM64_RT} - - # Create latest runtime manifest - docker buildx imagetools create \ - -t lmsysorg/sglang:latest-runtime \ - lmsysorg/sglang@${CU129_AMD64_RT} \ - lmsysorg/sglang@${CU129_ARM64_RT} - # Create versioned CUDA 13 framework manifest docker buildx imagetools create \ -t lmsysorg/sglang:v${version}-cu130 \ @@ -418,15 +292,3 @@ jobs: -t lmsysorg/sglang:latest-cu130 \ lmsysorg/sglang@${CU130_AMD64_FW} \ lmsysorg/sglang@${CU130_ARM64_FW} - - # Create versioned CUDA 13 runtime manifest - docker buildx imagetools create \ - -t lmsysorg/sglang:v${version}-cu130-runtime \ - lmsysorg/sglang@${CU130_AMD64_RT} \ - lmsysorg/sglang@${CU130_ARM64_RT} - - # Create latest CUDA 13 runtime manifest - docker buildx imagetools create \ - -t lmsysorg/sglang:latest-cu130-runtime \ - lmsysorg/sglang@${CU130_AMD64_RT} \ - lmsysorg/sglang@${CU130_ARM64_RT} diff --git a/.github/workflows/release-pypi-nightly.yml b/.github/workflows/release-pypi-nightly.yml index edc058bed1af..9a588c14cc63 100644 --- a/.github/workflows/release-pypi-nightly.yml +++ b/.github/workflows/release-pypi-nightly.yml @@ -61,11 +61,15 @@ jobs: HASH="g$(git rev-parse --short HEAD)" BUILD_DATE=$(date -u +%Y%m%d) - # Increment patch version for nightlies (e.g., v0.5.8 -> 0.5.9) + # Increment patch version for nightlies (e.g., v0.5.9 -> 0.5.10) + # Must always increment so nightly > latest tag per PEP 440 ordering: + # X.Y.Z.devN < X.Y.Z.rcN < X.Y.Z < X.Y.(Z+1).devN VERSION=${TAG#v} # Remove 'v' prefix MAJOR=$(echo "$VERSION" | cut -d. -f1) MINOR=$(echo "$VERSION" | cut -d. -f2) - PATCH=$(echo "$VERSION" | cut -d. -f3) + PATCH_RAW=$(echo "$VERSION" | cut -d. -f3) + # Strip pre-release suffixes (rc0, post1, etc.) to get numeric patch + PATCH=$(echo "$PATCH_RAW" | sed 's/[^0-9].*//') NEXT_PATCH=$((PATCH + 1)) NEXT_VERSION="${MAJOR}.${MINOR}.${NEXT_PATCH}" diff --git a/.github/workflows/release-whl-kernel.yml b/.github/workflows/release-whl-kernel.yml index 03f4abf8ff8e..bdf0b0360ee2 100644 --- a/.github/workflows/release-whl-kernel.yml +++ b/.github/workflows/release-whl-kernel.yml @@ -36,7 +36,7 @@ jobs: build-cu129-matrix: if: | github.repository == 'sgl-project/sglang' && - (github.event.inputs.target == 'all' || github.event.inputs.target == 'cu129') + (github.event_name == 'push' || github.event.inputs.target == 'all' || github.event.inputs.target == 'cu129') strategy: matrix: python-version: ["3.10"] @@ -135,7 +135,7 @@ jobs: build-cu130-matrix: if: | github.repository == 'sgl-project/sglang' && - (github.event.inputs.target == 'all' || github.event.inputs.target == 'cu130') + (github.event_name == 'push' || github.event.inputs.target == 'all' || github.event.inputs.target == 'cu130') strategy: matrix: python-version: ["3.10"] @@ -227,7 +227,7 @@ jobs: build-rocm-matrix: if: | github.repository == 'sgl-project/sglang' && - (github.event.inputs.target == 'all' || github.event.inputs.target == 'rocm700' || github.event.inputs.target == 'rocm720') + (github.event_name == 'push' || github.event.inputs.target == 'all' || github.event.inputs.target == 'rocm700' || github.event.inputs.target == 'rocm720') runs-on: amd-docker-scale strategy: matrix: @@ -362,7 +362,7 @@ jobs: build-musa43: if: | github.repository == 'sgl-project/sglang' && - (github.event.inputs.target == 'all' || github.event.inputs.target == 'musa43') + (github.event_name == 'push' || github.event.inputs.target == 'all' || github.event.inputs.target == 'musa43') runs-on: kernel-build-node-musa strategy: matrix: diff --git a/.github/workflows/rerun-ut.yml b/.github/workflows/rerun-test.yml similarity index 88% rename from .github/workflows/rerun-ut.yml rename to .github/workflows/rerun-test.yml index 8cb6fc1a10dc..41531ed360c2 100644 --- a/.github/workflows/rerun-ut.yml +++ b/.github/workflows/rerun-test.yml @@ -1,5 +1,5 @@ -name: Rerun UT -run-name: ${{ inputs.pr_head_sha && format('[rerun-ut] {0} {1}', inputs.test_command, inputs.pr_head_sha) || format('[rerun-ut] {0}', inputs.test_command) }} +name: Rerun Test +run-name: ${{ inputs.pr_head_sha && format('[rerun-test] {0} {1}', inputs.test_command, inputs.pr_head_sha) || format('[rerun-test] {0}', inputs.test_command) }} on: workflow_dispatch: @@ -23,7 +23,7 @@ on: - 8-gpu-h20 - 8-gpu-b200 pr_head_sha: - description: "PR head SHA to checkout (for /rerun-ut on fork PRs)" + description: "PR head SHA to checkout (for /rerun-test on fork PRs)" required: false type: string default: "" @@ -44,7 +44,7 @@ permissions: issues: read jobs: - rerun-ut-cuda: + rerun-test-cuda: runs-on: ${{ inputs.runner_label }} timeout-minutes: 120 env: diff --git a/.github/workflows/slash-command-handler.yml b/.github/workflows/slash-command-handler.yml index 9411e0798e72..53a552a46081 100644 --- a/.github/workflows/slash-command-handler.yml +++ b/.github/workflows/slash-command-handler.yml @@ -20,7 +20,7 @@ jobs: contains(github.event.comment.body, '/rerun-failed-ci') || contains(github.event.comment.body, '/tag-and-rerun-ci') || contains(github.event.comment.body, '/rerun-stage') || - contains(github.event.comment.body, '/rerun-ut')) + contains(github.event.comment.body, '/rerun-test')) runs-on: ubuntu-latest steps: @@ -49,14 +49,34 @@ jobs: fi echo "is_fork=$IS_FORK" >> $GITHUB_OUTPUT echo "ref=$(echo "$PR_DATA" | jq -r '.headRefName')" >> $GITHUB_OUTPUT + echo "pr_ref=refs/pull/${{ github.event.issue.number }}/head" >> $GITHUB_OUTPUT echo "PR owner: $HEAD_OWNER, Repo owner: $REPO_OWNER, Is fork: $IS_FORK" + - name: Check commenter permission for fork PRs + id: perm + if: steps.pr.outputs.is_fork == 'true' + shell: bash + env: + GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} + run: | + PERM=$(gh api repos/${{ github.repository }}/collaborators/${{ github.event.comment.user.login }}/permission --jq '.permission') || { + PERM="none" + echo "::warning::Failed to check commenter permission, defaulting to none" + } + if [[ "$PERM" == "admin" || "$PERM" == "maintain" || "$PERM" == "write" ]]; then + echo "safe_to_checkout_pr=true" >> $GITHUB_OUTPUT + else + echo "safe_to_checkout_pr=false" >> $GITHUB_OUTPUT + fi + echo "Commenter ${{ github.event.comment.user.login }} permission: $PERM" + - name: Checkout code uses: actions/checkout@v4 with: - # For non-fork PRs, checkout PR branch to allow testing handler changes - # For fork PRs, stay on main for security (don't run untrusted code with elevated permissions) - ref: ${{ steps.pr.outputs.is_fork == 'false' && steps.pr.outputs.ref || '' }} + # For non-fork PRs: checkout PR branch by name + # For fork PRs with trusted commenter: checkout via refs/pull/N/head + # For fork PRs with untrusted commenter: stay on main for security + ref: ${{ steps.pr.outputs.is_fork == 'false' && steps.pr.outputs.ref || (steps.perm.outputs.safe_to_checkout_pr == 'true' && steps.pr.outputs.pr_ref || '') }} - name: Set up Python uses: actions/setup-python@v5 diff --git a/.github/workflows/trivy-scan-dev.yml b/.github/workflows/trivy-scan-dev.yml new file mode 100644 index 000000000000..1f73a65c4216 --- /dev/null +++ b/.github/workflows/trivy-scan-dev.yml @@ -0,0 +1,85 @@ +name: Trivy Scan Dev Docker Images + +on: + # Run daily after nightly dev builds (which run at midnight UTC) + schedule: + - cron: "0 6 * * *" + workflow_dispatch: + inputs: + tag: + description: "Image tag to scan (e.g., dev, dev-cu13, latest)" + required: false + default: "" + +jobs: + scan: + if: github.repository == 'sgl-project/sglang' + runs-on: x64-docker-build-node + timeout-minutes: 45 + permissions: + contents: read + security-events: write + strategy: + fail-fast: false + matrix: + tag: ${{ inputs.tag && fromJSON(format('["{0}"]', inputs.tag)) || fromJSON('["dev", "dev-cu13"]') }} + steps: + - name: Checkout repository + uses: actions/checkout@v4 + + - name: Run Trivy vulnerability scanner + uses: aquasecurity/trivy-action@v0.35.0 + with: + image-ref: 'docker.io/lmsysorg/sglang:${{ matrix.tag }}' + scanners: 'vuln' + format: 'sarif' + output: 'trivy-results-${{ matrix.tag }}.sarif' + severity: 'CRITICAL,HIGH' + ignore-unfixed: true + skip-dirs: 'usr/local/go,opt/nvidia' + + - name: Upload Trivy scan results to GitHub Security + uses: github/codeql-action/upload-sarif@v4 + if: always() && hashFiles(format('trivy-results-{0}.sarif', matrix.tag)) != '' + with: + sarif_file: 'trivy-results-${{ matrix.tag }}.sarif' + category: 'trivy-${{ matrix.tag }}' + + - name: Run Trivy (table output for logs) + if: success() + uses: aquasecurity/trivy-action@v0.35.0 + with: + image-ref: 'docker.io/lmsysorg/sglang:${{ matrix.tag }}' + scanners: 'vuln' + format: 'table' + severity: 'CRITICAL,HIGH' + ignore-unfixed: true + skip-dirs: 'usr/local/go,opt/nvidia' + + - name: Scan summary + if: always() + run: | + IMAGE="docker.io/lmsysorg/sglang:${{ matrix.tag }}" + SARIF="trivy-results-${{ matrix.tag }}.sarif" + + echo "## Trivy Scan: \`${{ matrix.tag }}\`" >> "$GITHUB_STEP_SUMMARY" + + if [ ! -f "${SARIF}" ]; then + echo "**Status:** Scan failed — no SARIF output produced" >> "$GITHUB_STEP_SUMMARY" + exit 0 + fi + + VULN_COUNT=$(python3 -c " + import json + data = json.load(open('${SARIF}')) + print(sum(len(run.get('results', [])) for run in data.get('runs', []))) + ") + + echo "- **Image**: \`${IMAGE}\`" >> "$GITHUB_STEP_SUMMARY" + echo "- **Findings**: ${VULN_COUNT}" >> "$GITHUB_STEP_SUMMARY" + + if [ "${VULN_COUNT}" = "0" ]; then + echo "- **Result**: No CRITICAL/HIGH unfixed vulnerabilities found" >> "$GITHUB_STEP_SUMMARY" + else + echo "- **Result**: Found ${VULN_COUNT} finding(s) — check the Security tab for details" >> "$GITHUB_STEP_SUMMARY" + fi diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index a14997b9cffd..850ecd409b5c 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -81,6 +81,12 @@ repos: language: system files: ^\.github/CI_PERMISSIONS\.json$ pass_filenames: false + - id: check-workflow-job-names + name: check for duplicate workflow job names + entry: python3 scripts/ci/check_workflow_job_names.py + language: system + files: ^\.github/workflows/.*\.yml$ + pass_filenames: false - repo: https://github.com/lycheeverse/lychee.git rev: lychee-v0.22.0 hooks: diff --git a/3rdparty/amd/wheel/sglang/pyproject.toml b/3rdparty/amd/wheel/sglang/pyproject.toml index a96f7049df9c..bac322ab38a5 100644 --- a/3rdparty/amd/wheel/sglang/pyproject.toml +++ b/3rdparty/amd/wheel/sglang/pyproject.toml @@ -123,7 +123,7 @@ srt_musa = [ "sglang[runtime_common]", "torch", "torch_musa", - "torchada>=0.1.25", + "torchada>=0.1.45", "mthreads-ml-py", "numpy<2.0", ] diff --git a/README.md b/README.md index 523383d94eca..bdb9a5e047dc 100644 --- a/README.md +++ b/README.md @@ -83,7 +83,9 @@ SGLang is currently hosted under the non-profit open-source organization [LMSYS] logo ## Contact Us -For enterprises interested in adopting or deploying SGLang at scale, including technical consulting, sponsorship opportunities, or partnership inquiries, please contact us at sglang@lmsys.org +For enterprises interested in adopting or deploying SGLang at scale, including technical consulting, sponsorship opportunities, or partnership inquiries, please contact us at [sglang@lmsys.org](mailto:sglang@lmsys.org). + +Long-term active SGLang contributors are eligible for coding agent sponsorship, such as Cursor, Claude Code, or OpenAI Codex. Email [sglang@lmsys.org](mailto:sglang@lmsys.org) with your most important commits or pull requests. ## Acknowledgment We learned the design and reused code from the following projects: [Guidance](https://github.com/guidance-ai/guidance), [vLLM](https://github.com/vllm-project/vllm), [LightLLM](https://github.com/ModelTC/lightllm), [FlashInfer](https://github.com/flashinfer-ai/flashinfer), [Outlines](https://github.com/outlines-dev/outlines), and [LMQL](https://github.com/eth-sri/lmql). diff --git a/benchmark/bench_linear_attention/bench_cutedsl_kda_decode.py b/benchmark/bench_linear_attention/bench_cutedsl_kda_decode.py index 8e7d845f238a..ea124c487bdd 100644 --- a/benchmark/bench_linear_attention/bench_cutedsl_kda_decode.py +++ b/benchmark/bench_linear_attention/bench_cutedsl_kda_decode.py @@ -51,8 +51,11 @@ def make_inputs( b = torch.randn(B, HV, device=device, dtype=dtype) # prefill params for chunk_kda must keep batch dim = 1 - prefill_g = torch.randn(1, B, HV, K, device=device, dtype=dtype) - prefill_beta = torch.sigmoid(torch.randn(1, B, HV, device=device, dtype=dtype)) + # chunk_kda requires g, beta, v to have the same head count as k (H), + # matching the real KimiLinear model where num_heads == num_kv_heads. + prefill_v = torch.randn(1, B, H, V, device=device, dtype=dtype) + prefill_g = torch.randn(1, B, H, K, device=device, dtype=dtype) + prefill_beta = torch.sigmoid(torch.randn(1, B, H, device=device, dtype=dtype)) cu_seqlens = torch.arange(B + 1, device=device, dtype=torch.int32) @@ -66,8 +69,11 @@ def make_inputs( b = torch.randn(B, 1, HV, device=device, dtype=dtype) # prefill params for chunk_kda dense path - prefill_g = torch.randn(B, 1, HV, K, device=device, dtype=dtype) - prefill_beta = torch.sigmoid(torch.randn(B, 1, HV, device=device, dtype=dtype)) + # chunk_kda requires g, beta, v to have the same head count as k (H), + # matching the real KimiLinear model where num_heads == num_kv_heads. + prefill_v = torch.randn(B, 1, H, V, device=device, dtype=dtype) + prefill_g = torch.randn(B, 1, H, K, device=device, dtype=dtype) + prefill_beta = torch.sigmoid(torch.randn(B, 1, H, device=device, dtype=dtype)) cu_seqlens = torch.arange(B + 1, device=device, dtype=torch.int32) else: @@ -94,6 +100,7 @@ def make_inputs( v=v, a=a, b=b, + prefill_v=prefill_v, prefill_g=prefill_g, prefill_beta=prefill_beta, A_log=A_log, @@ -147,12 +154,13 @@ def run_cutedsl(inp): def run_prefill_then_decode_baseline(inp): ssm_states = inp["ssm_states"].clone() + prefill_v_clone = inp["prefill_v"].clone() v_clone = inp["v"].clone() _ = chunk_kda( q=inp["q"], k=inp["k"], - v=v_clone, + v=prefill_v_clone, g=inp["prefill_g"], beta=inp["prefill_beta"], initial_state=ssm_states, @@ -182,12 +190,13 @@ def run_prefill_then_decode_baseline(inp): def run_prefill_then_decode_cutedsl(inp): ssm_states = inp["ssm_states"].clone() + prefill_v_clone = inp["prefill_v"].clone() v_clone = inp["v"].clone() _ = chunk_kda( q=inp["q"], k=inp["k"], - v=v_clone, + v=prefill_v_clone, g=inp["prefill_g"], beta=inp["prefill_beta"], initial_state=ssm_states, diff --git a/benchmark/hicache/bench_mix.py b/benchmark/hicache/bench_mix.py index 833dbf780add..2a65574ea882 100644 --- a/benchmark/hicache/bench_mix.py +++ b/benchmark/hicache/bench_mix.py @@ -426,11 +426,13 @@ async def handle_request(self, user_data): def request_sender(self): async def request_loop(): + tasks = [] while True: if self.sent_requests - self.completed_requests < self.max_parallel: new_request = self.user_generator.pop() if new_request: - asyncio.create_task(self.handle_request(new_request)) + task = asyncio.create_task(self.handle_request(new_request)) + tasks.append(task) self.sent_requests += 1 else: await asyncio.sleep(0.05) @@ -440,6 +442,11 @@ async def request_loop(): self.done = True break + # Cancel all pending tasks and wait for them to finish + for task in tasks: + task.cancel() + await asyncio.gather(*tasks, return_exceptions=True) + loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) loop.run_until_complete(request_loop()) diff --git a/benchmark/kernels/bench_fused_temperature_softmax.py b/benchmark/kernels/bench_fused_temperature_softmax.py new file mode 100644 index 000000000000..fc624b721ecf --- /dev/null +++ b/benchmark/kernels/bench_fused_temperature_softmax.py @@ -0,0 +1,108 @@ +"""Benchmark: fused_temperature_softmax vs separate div_ + softmax vs flashinfer.sampling.softmax. + +Each path clones logits every iteration so timing is not skewed by in-place reuse. +Uses torch.cuda.Event timing; default 50 warmup, 200 timed iterations. + +Columns tri/base and fi/base are speedup vs PyTorch baseline; tri/fi is t_flashinfer/t_triton +(>1 means Triton is faster). +""" + +import argparse + +import torch + + +def benchmark_fn(fn, warmup=50, iters=200): + """Time a zero-arg callable using CUDA events.""" + for _ in range(warmup): + fn() + torch.cuda.synchronize() + + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + start.record() + for _ in range(iters): + fn() + end.record() + torch.cuda.synchronize() + return start.elapsed_time(end) / iters * 1000 # microseconds + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--warmup", type=int, default=50) + parser.add_argument("--iters", type=int, default=200) + args = parser.parse_args() + + from flashinfer.sampling import softmax as flashinfer_softmax + + from sglang.srt.layers.fused_sampling import ( + fused_temperature_softmax, + fused_temperature_softmax_inplace, + ) + + configs = [ + # (batch_size, vocab_size, dtype) + (1, 32000, torch.bfloat16), + (1, 128256, torch.bfloat16), + (32, 32000, torch.bfloat16), + (32, 128256, torch.bfloat16), + (128, 32000, torch.bfloat16), + (128, 128256, torch.bfloat16), + (512, 32000, torch.bfloat16), + (512, 128256, torch.bfloat16), + ] + + header = ( + f"{'bs':>5} {'vocab':>7} {'dtype':>8} " + f"{'baseline (us)':>14} {'triton (us)':>12} {'inplace (us)':>13} {'flashinfer (us)':>16} " + f"{'tri/base':>9} {'fi/base':>8} {'tri/fi':>7}" + ) + print(header) + print("-" * len(header)) + + for bs, vocab, dtype in configs: + temps = torch.rand(bs, 1, dtype=torch.float32, device="cuda") * 1.5 + 0.1 + temps_1d = temps.view(-1) + logits_src = torch.randn(bs, vocab, dtype=dtype, device="cuda") + + # --- Baseline: div_ + softmax --- + def run_baseline(src=logits_src, t=temps): + l = src.clone() + l.div_(t) + l[:] = torch.softmax(l, dim=-1) + + t_base = benchmark_fn(run_baseline, args.warmup, args.iters) + + # --- Triton fused (out-of-place) --- + def run_triton(src=logits_src, t=temps): + fused_temperature_softmax(src.clone(), t) + + t_triton = benchmark_fn(run_triton, args.warmup, args.iters) + + # --- Triton fused (in-place) --- + def run_inplace(src=logits_src, t=temps): + l = src.clone() + fused_temperature_softmax_inplace(l, t) + + t_ip = benchmark_fn(run_inplace, args.warmup, args.iters) + + # --- FlashInfer (clone each iter, same as other paths) --- + def run_flashinfer(src=logits_src, t=temps_1d): + l = src.clone() + flashinfer_softmax(l, temperature=t) + + t_fi = benchmark_fn(run_flashinfer, args.warmup, args.iters) + + sp_triton = t_base / t_triton + sp_fi = t_base / t_fi + tri_vs_fi = t_fi / t_triton + print( + f"{bs:>5} {vocab:>7} {str(dtype):>8} " + f"{t_base:>14.1f} {t_triton:>12.1f} {t_ip:>13.1f} {t_fi:>16.1f} " + f"{sp_triton:>8.2f}x {sp_fi:>7.2f}x {tri_vs_fi:>6.2f}x" + ) + + +if __name__ == "__main__": + main() diff --git a/benchmark/kernels/fused_moe_triton/common_utils.py b/benchmark/kernels/fused_moe_triton/common_utils.py index b7d5b4cc6448..37a9607b6014 100644 --- a/benchmark/kernels/fused_moe_triton/common_utils.py +++ b/benchmark/kernels/fused_moe_triton/common_utils.py @@ -37,11 +37,7 @@ def get_model_config( topk_ids_dir: str = None, ) -> Dict: config = get_config(model_name, trust_remote_code=True) - - # Replace config with text_config for encoder-decoder models after getting block_shape and architecture - if hasattr(config, "text_config"): - config = config.get_text_config() - + architecture = config.architectures[0] block_shape = None if ( hasattr(config, "quantization_config") @@ -61,8 +57,9 @@ def get_model_config( group_size = weights_config.get("group_size") block_shape = [0, group_size] assert len(block_shape) == 2 - - architecture = config.architectures[0] + # Replace config with text_config for encoder-decoder models after getting block_shape and architecture + if hasattr(config, "text_config"): + config = config.get_text_config() hidden_size = config.hidden_size if architecture == "DbrxForCausalLM": diff --git a/benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py b/benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py index 34aa83b38fd2..4cc397f65ed8 100644 --- a/benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py +++ b/benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py @@ -32,9 +32,10 @@ ServerArgs, set_global_server_args_for_scheduler, ) -from sglang.srt.utils import is_hip +from sglang.srt.utils import get_device, is_hip, is_xpu _is_hip = is_hip() +_is_xpu = is_xpu() def benchmark_config( @@ -236,8 +237,8 @@ def run(): class BenchmarkWorker: def __init__(self, seed: int, server_args: ServerArgs) -> None: - torch.set_default_device("cuda") - torch.cuda.manual_seed_all(0) + torch.set_default_device(get_device()) + torch.get_device_module().manual_seed_all(0) self.seed = seed # Get the device ID to allocate tensors and kernels # on the respective GPU. @@ -330,7 +331,11 @@ def tune( ) -> Dict[str, int]: best_config = None best_time = float("inf") - with torch.cuda.device(self.device_id) if is_hip() else nullcontext(): + with ( + torch.get_device_module().device(self.device_id) + if _is_xpu or _is_hip + else nullcontext() + ): for config in tqdm(search_space): try: kernel_time = benchmark_config( diff --git a/benchmark/kernels/quantization/tuning_block_wise_kernel.py b/benchmark/kernels/quantization/tuning_block_wise_kernel.py index 396b14a75a9e..9e4368043524 100644 --- a/benchmark/kernels/quantization/tuning_block_wise_kernel.py +++ b/benchmark/kernels/quantization/tuning_block_wise_kernel.py @@ -31,7 +31,13 @@ _w8a8_block_fp8_matmul_unrolledx4, ) from sglang.srt.layers.quantization.int8_kernel import _w8a8_block_int8_matmul -from sglang.srt.utils import get_device_core_count, get_device_name, is_hip +from sglang.srt.utils import ( + get_device, + get_device_core_count, + get_device_count, + get_device_name, + is_hip, +) _is_hip = is_hip() @@ -221,18 +227,18 @@ def benchmark_config( def run(): w8a8_block_matmul(A, B, As, Bs, block_size, config, out_dtype) - torch.cuda.synchronize() + torch.get_device_module().synchronize() # JIT complication & warmup for _ in range(5): run() - torch.cuda.synchronize() + torch.get_device_module().synchronize() - start_event = torch.cuda.Event(enable_timing=True) - end_event = torch.cuda.Event(enable_timing=True) + start_event = torch.get_device_module().Event(enable_timing=True) + end_event = torch.get_device_module().Event(enable_timing=True) latencies: List[float] = [] for i in range(num_iters): - torch.cuda.synchronize() + torch.get_device_module().synchronize() start_event.record() run() end_event.record() @@ -244,6 +250,7 @@ def run(): def tune(M, N, K, block_size, out_dtype, search_space, input_type): factor_for_scale = 1e-2 + device = get_device() if input_type == "fp8": fp8_info = torch.finfo( @@ -252,14 +259,14 @@ def tune(M, N, K, block_size, out_dtype, search_space, input_type): fp8_max, fp8_min = fp8_info.max, fp8_info.min A_fp32 = ( - (torch.rand(M, K, dtype=torch.float32, device="cuda") - 0.5) * 2 * fp8_max + (torch.rand(M, K, dtype=torch.float32, device=device) - 0.5) * 2 * fp8_max ) A = A_fp32.clamp(min=fp8_min, max=fp8_max).to( torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn ) B_fp32 = ( - (torch.rand(N, K, dtype=torch.float32, device="cuda") - 0.5) * 2 * fp8_max + (torch.rand(N, K, dtype=torch.float32, device=device) - 0.5) * 2 * fp8_max ) B = B_fp32.clamp(min=fp8_min, max=fp8_max).to( torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn @@ -269,12 +276,12 @@ def tune(M, N, K, block_size, out_dtype, search_space, input_type): int8_max, int8_min = int8_info.max, int8_info.min A_fp32 = ( - (torch.rand(M, K, dtype=torch.float32, device="cuda") - 0.5) * 2 * int8_max + (torch.rand(M, K, dtype=torch.float32, device=device) - 0.5) * 2 * int8_max ) A = A_fp32.clamp(min=int8_min, max=int8_max).to(torch.int8) B_fp32 = ( - (torch.rand(N, K, dtype=torch.float32, device="cuda") - 0.5) * 2 * int8_max + (torch.rand(N, K, dtype=torch.float32, device=device) - 0.5) * 2 * int8_max ) B = B_fp32.clamp(min=int8_min, max=int8_max).to(torch.int8) @@ -282,9 +289,9 @@ def tune(M, N, K, block_size, out_dtype, search_space, input_type): n_tiles = (N + block_n - 1) // block_n k_tiles = (K + block_k - 1) // block_k - As = torch.rand(M, k_tiles, dtype=torch.float32, device="cuda") * factor_for_scale + As = torch.rand(M, k_tiles, dtype=torch.float32, device=device) * factor_for_scale Bs = ( - torch.rand(n_tiles, k_tiles, dtype=torch.float32, device="cuda") + torch.rand(n_tiles, k_tiles, dtype=torch.float32, device=device) * factor_for_scale ) @@ -351,11 +358,6 @@ def save_configs( lock.release() -def get_available_gpu_count(): - """Get the number of available GPUs.""" - return torch.cuda.device_count() - - def tune_on_gpu(args_dict): """Run tuning on a specific GPU.""" gpu_id = args_dict["gpu_id"] @@ -364,7 +366,7 @@ def tune_on_gpu(args_dict): args = args_dict["args"] lock = args_dict["lock"] - torch.cuda.set_device(gpu_id) + torch.get_device_module().set_device(gpu_id) print(f"Starting tuning on GPU {gpu_id} with batch sizes {batch_sizes}") block_n = args.block_n @@ -415,12 +417,12 @@ def distribute_batch_sizes(batch_sizes, num_gpus): def main(args): print(args) - num_gpus = get_available_gpu_count() + num_gpus = get_device_count() if num_gpus == 0: raise RuntimeError("No GPU available for tuning") print(f"Found {num_gpus} GPUs for parallel tuning") - torch.cuda.init() + torch.get_device_module().init() if args.batch_size is None: batch_sizes = [ diff --git a/docker/Dockerfile b/docker/Dockerfile index 2c5ec107ab82..4cc092f1e97c 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -19,7 +19,7 @@ ARG PIP_DEFAULT_INDEX ARG UBUNTU_MIRROR ARG GITHUB_ARTIFACTORY=github.com ARG INSTALL_FLASHINFER_JIT_CACHE=0 -ARG FLASHINFER_VERSION=0.6.6 +ARG FLASHINFER_VERSION=0.6.7 ARG MOONCAKE_VERSION=0.3.9 #if need other arg please add in MOONCAKE_COMPILE_ARG ARG MOONCAKE_COMPILE_ARG="-DUSE_HTTP=ON -DUSE_MNNVL=ON -DUSE_CUDA=ON -DWITH_EP=ON" @@ -114,6 +114,7 @@ RUN --mount=type=cache,target=/var/cache/apt,id=base-apt \ libczmq4 \ libczmq-dev \ libfabric-dev \ + linux-libc-dev \ # Package building tools devscripts \ debhelper \ @@ -306,9 +307,12 @@ RUN --mount=type=cache,target=/root/.cache/pip \ cubloaty \ google-cloud-storage -# Build and install sgl-model-gateway (install Rust, build, then remove to save space) +# Build and install sgl-model-gateway (install Rust, build, then remove Rust toolchain) +# Cleanup runs unconditionally via trap to ensure Rust artifacts don't bloat the layer RUN --mount=type=cache,target=/root/.cache/pip \ - curl --proto '=https' --tlsv1.2 --retry 3 --retry-delay 2 -sSf https://sh.rustup.rs | sh -s -- -y \ + cleanup() { rm -rf /root/.cargo /root/.rustup /sgl-workspace/sglang/sgl-model-gateway/target /sgl-workspace/sglang/sgl-model-gateway/bindings/python/target /sgl-workspace/sglang/sgl-model-gateway/bindings/python/dist; sed -i '/\.cargo\/env/d' /root/.profile /root/.bashrc 2>/dev/null; } \ + && trap cleanup EXIT \ + && curl --proto '=https' --tlsv1.2 --retry 3 --retry-delay 2 -sSf https://sh.rustup.rs | sh -s -- -y \ && export PATH="/root/.cargo/bin:${PATH}" \ && rustc --version && cargo --version \ && python3 -m pip install maturin \ @@ -316,10 +320,8 @@ RUN --mount=type=cache,target=/root/.cache/pip \ && ulimit -n 65536 && maturin build --release --features vendored-openssl --out dist \ && python3 -m pip install --force-reinstall dist/*.whl \ && cd /sgl-workspace/sglang/sgl-model-gateway \ - && cargo build --release --bin sglang-router --features vendored-openssl \ - && cp target/release/sglang-router /usr/local/bin/sglang-router \ - && rm -rf /root/.cargo /root/.rustup target dist ~/.cargo \ - && sed -i '/\.cargo\/env/d' /root/.profile /root/.bashrc 2>/dev/null || true + && cargo build --release --bin sgl-model-gateway --features vendored-openssl \ + && cp target/release/sgl-model-gateway /usr/local/bin/sgl-model-gateway RUN --mount=type=cache,target=/root/.cache/pip \ python3 -m pip install "nvidia-cutlass-dsl>=4.4.1" "nvidia-cutlass-dsl-libs-base>=4.4.1" --force-reinstall --no-deps; @@ -448,7 +450,7 @@ RUN if [ "${CUDA_VERSION%%.*}" = "13" ] && [ -d /usr/local/lib/python3.12/dist-p ln -s /usr/local/cuda/bin/ptxas /usr/local/lib/python3.12/dist-packages/triton/backends/nvidia/bin/ptxas; \ fi -RUN python3 -m pip install --upgrade "urllib3>=2.6.3" +RUN python3 -m pip install --upgrade "urllib3>=2.6.3" "pillow>=12.1.1" # Set workspace directory WORKDIR /sgl-workspace/sglang @@ -532,6 +534,7 @@ RUN --mount=type=cache,target=/var/cache/apt,id=runtime-apt \ libnccl-dev \ # GPG key verification gnupg2 \ + linux-libc-dev \ && update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.12 2 \ && update-alternatives --set python3 /usr/bin/python3.12 \ && ln -sf /usr/bin/python3.12 /usr/bin/python \ diff --git a/docker/rocm.Dockerfile b/docker/rocm.Dockerfile index ce2ae4de9dbb..d53f657f08fe 100644 --- a/docker/rocm.Dockerfile +++ b/docker/rocm.Dockerfile @@ -88,7 +88,7 @@ ARG MOONCAKE_REPO="https://github.com/kvcache-ai/Mooncake.git" ARG MOONCAKE_COMMIT="b6a841dc78c707ec655a563453277d969fb8f38d" ARG TILELANG_REPO="https://github.com/tile-ai/tilelang.git" -ARG TILELANG_COMMIT="ebf4a7cb8881432165ae8760e99d209d905c704a" +ARG TILELANG_COMMIT="a55a82302bf7f3c5af635b5c9146f728185cc900" ARG FHT_REPO="https://github.com/jeffdaily/fast-hadamard-transform.git" ARG FHT_BRANCH="rocm" @@ -98,7 +98,7 @@ ARG ENABLE_MORI=0 ARG NIC_BACKEND=none ARG MORI_REPO="https://github.com/ROCm/mori.git" -ARG MORI_COMMIT="2f88d06aba75400262ca5c1ca5986cf1fdf4cd82" +ARG MORI_COMMIT="v0.1.0" # AMD AINIC apt repo settings ARG AINIC_VERSION=1.117.5 diff --git a/docs/advanced_features/object_storage.md b/docs/advanced_features/object_storage.md new file mode 100644 index 000000000000..957ecdbafe31 --- /dev/null +++ b/docs/advanced_features/object_storage.md @@ -0,0 +1,108 @@ +# Loading Models from Object Storage + +SGLang supports direct loading of models from object storage (S3 and Google Cloud Storage) without requiring a full local download. This feature uses the `runai_streamer` load format to stream model weights directly from cloud storage, significantly reducing startup time and local storage requirements. + +## Overview + +When loading models from object storage, SGLang uses a two-phase approach: + +1. **Metadata Download** (once, before process launch): Configuration files and tokenizer files are downloaded to a local cache +2. **Weight Streaming** (lazy, during model loading): Model weights are streamed directly from object storage as needed + +## Supported Storage Backends + +1. **Amazon S3**: `s3://bucket-name/path/to/model/` +2. **Google Cloud Storage**: `gs://bucket-name/path/to/model/` +3. **Azure Blob**: `az://some-azure-container/path/` +4. **S3 compatible**: `s3://bucket-name/path/to/model/` + +## Quick Start + +### Basic Usage + +Simply provide an object storage URI as the model path: + +```bash +# S3 +python -m sglang.launch_server \ + --model-path s3://my-bucket/models/llama-3-8b/ \ + --load-format runai_streamer + +# Google Cloud Storage +python -m sglang.launch_server \ + --model-path gs://my-bucket/models/llama-3-8b/ \ + --load-format runai_streamer +``` + +**Note**: The `--load-format runai_streamer` is automatically detected when using object storage URIs, so you can omit it: + +```bash +python -m sglang.launch_server \ + --model-path s3://my-bucket/models/llama-3-8b/ +``` + +### With Tensor Parallelism + +```bash +python -m sglang.launch_server \ + --model-path gs://my-bucket/models/llama-70b/ \ + --tp 4 \ + --model-loader-extra-config '{"distributed": true}' +``` + +## Configuration + +### Load Format + +The `runai_streamer` load format is specifically designed for object storage, ssd and shared file systems + +```bash +python -m sglang.launch_server \ + --model-path s3://bucket/model/ \ + --load-format runai_streamer +``` + +### Extended Configuration Parameters + +Use `--model-loader-extra-config` to pass additional configuration as a JSON string: + +```bash +python -m sglang.launch_server \ + --model-path s3://bucket/model/ \ + --model-loader-extra-config '{ + "distributed": true, + "concurrency": 8, + "memory_limit": 2147483648 + }' +``` + +#### Available Parameters + +| Parameter | Type | Description | Default | +|-----------|------|-------------|---------| +| `distributed` | bool | Enable distributed streaming for multi-GPU setups. Automatically set to `true` for object storage paths and cuda alike devices. | Auto-detected | +| `concurrency` | int | Number of concurrent download streams. Higher values can improve throughput for large models. | 4 | +| `memory_limit` | int | Memory limit (in bytes) for the streaming buffer. | System-dependent | + + +## Performance Considerations + +### Distributed Streaming + +For multi-GPU setups, enable distributed streaming to parallelize weight loading between the processes: + +```bash +python -m sglang.launch_server \ + --model-path s3://bucket/model/ \ + --tp 8 \ + --model-loader-extra-config '{"distributed": true}' +``` + +## Limitations + +- **Supported Formats**: Currently only supports `.safetensors` weight format (recommended format) +- **Supported Device**: Distributed streaming is supported on cuda alike devices. Otherwise fallback to non distributed streaming + +## See Also + +- [Runai model streamer documentation](https://github.com/run-ai/runai-model-streamer) diff --git a/docs/advanced_features/quantization.md b/docs/advanced_features/quantization.md index 8a30d5084660..5c816953bc2e 100644 --- a/docs/advanced_features/quantization.md +++ b/docs/advanced_features/quantization.md @@ -19,32 +19,35 @@ to guard against abnormal quantization loss regressions. ## Platform Compatibility -The following table summarizes quantization method support across NVIDIA and AMD GPUs. - -| Method | NVIDIA GPUs | AMD GPUs (MI300X/MI325X/MI350X) | Notes | -|--------|:-----------:|:-------------------------------:|-------| -| `fp8` | Yes | Yes | Aiter or Triton backend on AMD | -| `mxfp4` | Yes | Yes | Requires CDNA3/CDNA4 with MXFP support; uses Aiter | -| `blockwise_int8` | Yes | Yes | Triton-based, works on both platforms | -| `w8a8_int8` | Yes | Yes | | -| `w8a8_fp8` | Yes | Yes | Aiter or Triton FP8 on AMD | -| `awq` | Yes | Yes | Uses Triton dequantize on AMD (vs. optimized CUDA kernels on NVIDIA) | -| `gptq` | Yes | Yes | Uses Triton or vLLM kernels on AMD | -| `compressed-tensors` | Yes | Yes | Aiter paths for FP8/MoE on AMD | -| `quark` | Yes | Yes | AMD Quark quantization; Aiter GEMM paths on AMD | -| `auto-round` | Yes | Yes | Platform-agnostic (Intel auto-round) | -| `quark_int4fp8_moe` | No | Yes | AMD-only; online INT4-to-FP8 MoE quantization (CDNA3/CDNA4) | -| `awq_marlin` | Yes | No | Marlin kernels are CUDA-only | -| `gptq_marlin` | Yes | No | Marlin kernels are CUDA-only | -| `gguf` | Yes | No | CUDA-only kernels in sgl-kernel | -| `modelopt` / `modelopt_fp8` | Yes (Hopper/SM90+) | No | [NVIDIA ModelOpt](https://github.com/NVIDIA/Model-Optimizer); requires NVIDIA hardware | -| `modelopt_fp4` | Yes (Blackwell/SM100+) | No | [NVIDIA ModelOpt](https://github.com/NVIDIA/Model-Optimizer); native FP4 on Blackwell (B200, GB200) | -| `petit_nvfp4` | No | Yes (MI250/MI300X/MI325X) | Enables NVFP4 on ROCm via [Petit](https://github.com/causalflow-ai/petit-kernel); use `modelopt_fp4` on NVIDIA Blackwell. Auto-selected when loading NVFP4 models on AMD. See [LMSYS blog](https://lmsys.org/blog/2025-09-21-petit-amdgpu/) and [AMD ROCm blog](https://rocm.blogs.amd.com/artificial-intelligence/fp4-mixed-precision/README.html). | -| `bitsandbytes` | Yes | Experimental | Depends on bitsandbytes ROCm support | -| `torchao` (`int4wo`, etc.) | Yes | Partial | `int4wo` not supported on AMD; other methods may work | +The following table summarizes quantization method support across NVIDIA and AMD GPUs, Ascend NPUs. + +| Method | NVIDIA GPUs | AMD GPUs (MI300X/MI325X/MI350X) | Ascend NPUs (A2/A3) | Notes | +|--------|:-----------:|:-------------------------------:|:-----------------------:|-------| +| `fp8` | Yes | Yes | WIP | Aiter or Triton backend on AMD | +| `mxfp4` | Yes | Yes | WIP | Requires CDNA3/CDNA4 with MXFP support; uses Aiter | +| `blockwise_int8` | Yes | Yes | No | Triton-based, works on both platforms | +| `w8a8_int8` | Yes | Yes | No | | +| `w8a8_fp8` | Yes | Yes | No | Aiter or Triton FP8 on AMD | +| `awq` | Yes | Yes | Yes | Uses Triton dequantize on AMD (vs. optimized CUDA kernels on NVIDIA). Uses CANN kernels on Ascend| +| `gptq` | Yes | Yes | Yes | Uses Triton or vLLM kernels on AMD. Uses CANN kernels on Ascend| +| `compressed-tensors` | Yes | Yes | Partial | Aiter paths for FP8/MoE on AMD. Uses CANN kernels on Ascend, `FP8` not supported yet| +| `quark` | Yes | Yes | No | AMD Quark quantization; Aiter GEMM paths on AMD | +| `auto-round` | Yes | Yes | Partial | Platform-agnostic (Intel auto-round). Uses CANN kernels on Ascend| +| `quark_int4fp8_moe` | No | Yes | No | AMD-only; online INT4-to-FP8 MoE quantization (CDNA3/CDNA4) | +| `awq_marlin` | Yes | No | No | Marlin kernels are CUDA-only | +| `gptq_marlin` | Yes | No | No | Marlin kernels are CUDA-only | +| `gguf` | Yes | No | WIP | CUDA-only kernels in sgl-kernel | +| `modelopt` / `modelopt_fp8` | Yes (Hopper/SM90+) | No | No | [NVIDIA ModelOpt](https://github.com/NVIDIA/Model-Optimizer); requires NVIDIA hardware | +| `modelopt_fp4` | Yes (Blackwell/SM100+) | No | No | [NVIDIA ModelOpt](https://github.com/NVIDIA/Model-Optimizer); native FP4 on Blackwell (B200, GB200) | +| `petit_nvfp4` | No | Yes (MI250/MI300X/MI325X) | No | Enables NVFP4 on ROCm via [Petit](https://github.com/causalflow-ai/petit-kernel); use `modelopt_fp4` on NVIDIA Blackwell. Auto-selected when loading NVFP4 models on AMD. See [LMSYS blog](https://lmsys.org/blog/2025-09-21-petit-amdgpu/) and [AMD ROCm blog](https://rocm.blogs.amd.com/artificial-intelligence/fp4-mixed-precision/README.html). | +| `bitsandbytes` | Yes | Experimental | No | Depends on bitsandbytes ROCm support | +| `torchao` (`int4wo`, etc.) | Yes | Partial | No | `int4wo` not supported on AMD; other methods may work | +| `modelslim` | No | No | Yes | Ascend quantization; Uses CANN kernels | On AMD, several of these methods use [Aiter](https://github.com/ROCm/aiter) for acceleration -- set `SGLANG_USE_AITER=1` where noted. See [AMD GPU setup](../platforms/amd_gpu.md) for installation and configuration details. +On Ascend, various layers quantization configurations are supported, see [Ascend NPU quantization](../platforms/ascend/ascend_npu_quantization.md) for details. + ## GEMM Backends for FP4/FP8 Quantization :::{note} @@ -71,17 +74,18 @@ Backend selection is supported only for **blockwise FP8** and **NVFP4** GEMM. Wh | Backend | Hardware | Description | |---------|----------|-------------| | `auto` | SM100/120 | Auto-selects: `flashinfer_cudnn` on SM120; `flashinfer_cutlass` on SM100 | +| `cutlass` | SM100/120 | SGLang CUTLASS kernel | | `flashinfer_cutlass` | SM100/120 | FlashInfer CUTLASS backend | | `flashinfer_cudnn` | SM100/120 (CUDA 13+, cuDNN 9.15+) | FlashInfer cuDNN backend; used on SM120 for performance | | `flashinfer_trtllm` | SM100 | FlashInfer TensorRT-LLM backend | -When FlashInfer is unavailable for NVFP4, sgl-kernel CUTLASS is used as an automatic fallback. +When FlashInfer is unavailable for NVFP4, the SGLang CUTLASS kernel is used as an automatic fallback. ## Offline Quantization To load already quantized models, simply load the model weights and config. **Again, if the model has been quantized offline, there's no need to add `--quantization` argument when starting the engine. The quantization method will be parsed from the -downloaded Hugging Face config. For example, DeepSeek V3/R1 models are already in FP8, so do not add redundant parameters.** +downloaded Hugging Face or msModelSlim config. For example, DeepSeek V3/R1 models are already in FP8, so do not add redundant parameters.** ```bash python3 -m sglang.launch_server \ @@ -319,7 +323,6 @@ For detailed usage and supported model architectures, see [NVIDIA Model Optimize SGLang includes a streamlined workflow for quantizing models with ModelOpt and automatically exporting them for deployment. - ##### Installation First, install ModelOpt: @@ -477,6 +480,74 @@ model_loader.load_model(model_config=model_config, device_config=DeviceConfig()) - **Calibration-based**: Uses calibration datasets for optimal quantization quality - **Production Ready**: Enterprise-grade quantization with NVIDIA support +#### Using [ModelSlim](https://gitcode.com/Ascend/msmodelslim) +MindStudio-ModelSlim (msModelSlim) is a model offline quantization compression tool launched by MindStudio and optimized for Ascend hardware. + +- **Installation** + + ```bash + # Clone repo and install msmodelslim: + git clone https://gitcode.com/Ascend/msmodelslim.git + cd msmodelslim + bash install.sh + ``` + +- **LLM quantization** + + Download the original floating-point weights of the large model. Taking Qwen3-32B as an example, you can go to [Qwen3-32B](https://huggingface.co/Qwen/Qwen3-32B) to obtain the original model weights. Then install other dependencies (related to the model, refer to the huggingface model card). + > Note: You can find pre-quantized validated models on [modelscope/Eco-Tech](https://modelscope.cn/models/Eco-Tech). + + _Traditional quantification methods require the preparation of calibration data files (```.jsonl``` formats) for calibration in the quantification process._ + ```bash + Qwen3-32B/ # floating-point model downloaded from official HF (or modelscope) repo + msmodelslim/ # msmodelslim repo + |----- lab_calib # calibration date folder (put your dataset here in ```.jsonl``` format or use pre-prepared ones) + |----- some file (such as laos_calib.jsonl) + |----- lab_practice # best practice folder with configs for quantization + |----- model folder (such as qwen3_5_moe folder) # folder with quantization configs + |----- quant_config (such as qwen3_5_moe_w8a8.yaml) # quantization config + |----- another folders + output_folder/ # generated by below command + |----- quant_model_weights-00001-of-0001.safetensors # quantized weights + |----- quant_model_description.json # file with description of the quantization methods for each layer (```W4A4_DYNAMIC```, etc.) + |----- another files (such as config.json, tokenizer.json, etc.) + ``` + Run quantization using one-click quantization (recommended): + ```bash + msmodelslim quant \ + --model_path ${MODEL_PATH} \ + --save_path ${SAVE_PATH} \ + --device npu:0,1 \ + --model_type Qwen3-32B \ + --quant_type w8a8 \ + --trust_remote_code True + ``` + +- **Usage Example** + ```bash + python3 -m sglang.launch_server \ + --model-path $PWD/Qwen3-32B-w8a8 \ + --port 30000 --host 0.0.0.0 + ``` + +- **Available Quantization Methods**: + - [x] ```W4A4_DYNAMIC``` linear with online quantization of activations + - [x] ```W8A8``` linear with offline quantization of activations + - [x] ```W8A8_DYNAMIC``` linear with online quantization of activations + - [x] ```W4A4_DYNAMIC``` MOE with online quantization of activations + - [x] ```W4A8_DYNAMIC``` MOE with online quantization of activations + - [x] ```W8A8_DYNAMIC``` MOE with online quantization of activations + - [ ] ```W4A8``` linear TBD + - [ ] ```W4A16``` linear TBD + - [ ] ```W48A16``` linear TBD + - [ ] ```W4A16``` MoE in progress + - [ ] ```W8A16``` MoE in progress + - [ ] ```KV Cache``` in progress + - [ ] ```Attention``` in progress + + +For more detailed examples of quantization of models, as well as information about their support, see the [examples](https://gitcode.com/Ascend/msmodelslim/blob/master/example/README.md) section in ModelSLim repo. + ## Online Quantization To enable online quantization, you can simply specify `--quantization` in the command line. For example, you can launch the server with the following command to enable `FP8` quantization for model `meta-llama/Meta-Llama-3.1-8B-Instruct`: @@ -529,3 +600,4 @@ Other layers (e.g. projections in the attention layers) have their weights quant - [Torchao: PyTorch Architecture Optimization](https://github.com/pytorch/ao) - [vLLM Quantization](https://docs.vllm.ai/en/latest/quantization/) - [auto-round](https://github.com/intel/auto-round) +- [ModelSlim](https://gitcode.com/Ascend/msmodelslim) diff --git a/docs/advanced_features/server_arguments.md b/docs/advanced_features/server_arguments.md index 61cfe91e07c6..e36b49a54809 100644 --- a/docs/advanced_features/server_arguments.md +++ b/docs/advanced_features/server_arguments.md @@ -84,7 +84,7 @@ Please consult the documentation below and [server_args.py](https://github.com/s | `--tokenizer-mode` | Tokenizer mode. 'auto' will use the fast tokenizer if available, and 'slow' will always use the slow tokenizer. | `auto` | `auto`, `slow` | | `--tokenizer-worker-num` | The worker num of the tokenizer manager. | `1` | Type: int | | `--skip-tokenizer-init` | If set, skip init tokenizer and pass input_ids in generate request. | `False` | bool flag (set to enable) | -| `--load-format` | The format of the model weights to load. "auto" will try to load the weights in the safetensors format and fall back to the pytorch bin format if safetensors format is not available. "pt" will load the weights in the pytorch bin format. "safetensors" will load the weights in the safetensors format. "npcache" will load the weights in pytorch format and store a numpy cache to speed up the loading. "dummy" will initialize the weights with random values, which is mainly for profiling."gguf" will load the weights in the gguf format. "bitsandbytes" will load the weights using bitsandbytes quantization."layered" loads weights layer by layer so that one can quantize a layer before loading another to make the peak memory envelope smaller. "flash_rl" will load the weights in flash_rl format. "fastsafetensors" and "private" are also supported. | `auto` | `auto`, `pt`, `safetensors`, `npcache`, `dummy`, `sharded_state`, `gguf`, `bitsandbytes`, `layered`, `flash_rl`, `remote`, `remote_instance`, `fastsafetensors`, `private` | +| `--load-format` | The format of the model weights to load. "auto" will try to load the weights in the safetensors format and fall back to the pytorch bin format if safetensors format is not available. "pt" will load the weights in the pytorch bin format. "safetensors" will load the weights in the safetensors format. "npcache" will load the weights in pytorch format and store a numpy cache to speed up the loading. "dummy" will initialize the weights with random values, which is mainly for profiling."gguf" will load the weights in the gguf format. "bitsandbytes" will load the weights using bitsandbytes quantization."layered" loads weights layer by layer so that one can quantize a layer before loading another to make the peak memory envelope smaller. "flash_rl" will load the weights in flash_rl format. "fastsafetensors" and "private" are also supported. "runai_streamer" enables direct model loading from object storage and shared file systems.| `auto` | `auto`, `pt`, `safetensors`, `npcache`, `dummy`, `sharded_state`, `gguf`, `bitsandbytes`, `layered`, `flash_rl`, `remote`, `remote_instance`, `fastsafetensors`, `private`, `runai_streamer` | | `--model-loader-extra-config` | Extra config for model loader. This will be passed to the model loader corresponding to the chosen load_format. | `{}` | Type: str | | `--trust-remote-code` | Whether or not to allow for custom models defined on the Hub in their own modeling files. | `False` | bool flag (set to enable) | | `--context-length` | The model's maximum context length. Defaults to None (will use the value from the model's config.json instead). | `None` | Type: int | @@ -185,6 +185,7 @@ Please consult the documentation below and [server_args.py](https://github.com/s | `--crash-dump-folder` | Folder path to dump requests from the last 5 min before a crash (if any). If not specified, crash dumping is disabled. | `None` | Type: str | | `--show-time-cost` | Show time cost of custom marks. | `False` | bool flag (set to enable) | | `--enable-metrics` | Enable log prometheus metrics. | `False` | bool flag (set to enable) | +| `--enable-mfu-metrics` | Enable estimated MFU-related prometheus metrics. | `False` | bool flag (set to enable) | | `--enable-metrics-for-all-schedulers` | Enable --enable-metrics-for-all-schedulers when you want schedulers on all TP ranks (not just TP 0) to record request metrics separately. This is especially useful when dp_attention is enabled, as otherwise all metrics appear to come from TP 0. | `False` | bool flag (set to enable) | | `--tokenizer-metrics-custom-labels-header` | Specify the HTTP header for passing custom labels for tokenizer metrics. | `x-custom-labels` | Type: str | | `--tokenizer-metrics-allowed-custom-labels` | The custom labels allowed for tokenizer metrics. The labels are specified via a dict in '--tokenizer-metrics-custom-labels-header' field in HTTP requests, e.g., {'label1': 'value1', 'label2': 'value2'} is allowed if '--tokenizer-metrics-allowed-custom-labels label1 label2' is set. | `None` | List[str] | @@ -211,7 +212,7 @@ Please consult the documentation below and [server_args.py](https://github.com/s | Argument | Description | Defaults | Options | | --- | --- | --- | --- | | `--api-key` | Set API key of the server. It is also used in the OpenAI API compatible server. | `None` | Type: str | -| `--admin-api-key` | Set **admin API key** for administrative/control endpoints (e.g., weights update, cache flush, `/get_server_info`). Endpoints marked as admin-only require `Authorization: Bearer ` when this is set. | `None` | Type: str | +| `--admin-api-key` | Set **admin API key** for administrative/control endpoints (e.g., weights update, cache flush, `/server_info`). Endpoints marked as admin-only require `Authorization: Bearer ` when this is set. | `None` | Type: str | | `--served-model-name` | Override the model name returned by the v1/models endpoint in OpenAI API server. | `None` | Type: str | | `--weight-version` | Version identifier for the model weights. Defaults to 'default' if not specified. | `default` | Type: str | | `--chat-template` | The builtin chat template name or the path of the chat template file. This is only used for OpenAI-compatible API server. | `None` | Type: str | @@ -268,8 +269,8 @@ Please consult the documentation below and [server_args.py](https://github.com/s | `--mm-attention-backend` | Set multimodal attention backend. | `None` | `sdpa`, `fa3`, `fa4`, `triton_attn`, `ascend_attn`, `aiter_attn` | | `--nsa-prefill-backend` | Choose the NSA backend for the prefill stage (overrides `--attention-backend` when running DeepSeek NSA-style attention). | `flashmla_sparse` | `flashmla_sparse`, `flashmla_kv`, `flashmla_auto`, `fa3`, `tilelang`, `aiter`, `trtllm` | | `--nsa-decode-backend` | Choose the NSA backend for the decode stage when running DeepSeek NSA-style attention. Overrides `--attention-backend` for decoding. | `fa3` | `flashmla_sparse`, `flashmla_kv`, `fa3`, `tilelang`, `aiter`, `trtllm` | -| `--fp8-gemm-backend` | Choose the runner backend for Blockwise FP8 GEMM operations. Options: 'auto' (default, auto-selects based on hardware), 'deep_gemm' (JIT-compiled; enabled by default on NVIDIA Hopper (SM90) and Blackwell (SM100) when DeepGEMM is installed), 'flashinfer_trtllm' (FlashInfer TRTLLM backend; SM100/SM103 only), 'flashinfer_cutlass' (FlashInfer CUTLASS backend, SM120 only), 'flashinfer_deepgemm' (Hopper SM90 only, uses swapAB optimization for small M dimensions in decoding), 'cutlass' (optimal for Hopper/Blackwell GPUs and high-throughput), 'triton' (fallback, widely compatible), 'aiter' (ROCm only). **NOTE**: This replaces the deprecated environment variables SGLANG_ENABLE_FLASHINFER_FP8_GEMM and SGLANG_SUPPORT_CUTLASS_BLOCK_FP8. | `auto` | `auto`, `deep_gemm`, `flashinfer_trtllm`, `flashinfer_cutlass`, `flashinfer_deepgemm`, `cutlass`, `triton`, `aiter` | -| `--fp4-gemm-backend` | Choose the runner backend for NVFP4 GEMM operations. Options: 'flashinfer_cutlass' (default), 'auto' (auto-selects between flashinfer_cudnn/flashinfer_cutlass based on CUDA/cuDNN version), 'flashinfer_cudnn' (FlashInfer cuDNN backend, optimal on CUDA 13+ with cuDNN 9.15+), 'flashinfer_trtllm' (FlashInfer TensorRT-LLM backend, requires different weight preparation with shuffling). All backends are from FlashInfer; when FlashInfer is unavailable, sgl-kernel CUTLASS is used as an automatic fallback. **NOTE**: This replaces the deprecated environment variable SGLANG_FLASHINFER_FP4_GEMM_BACKEND. | `flashinfer_cutlass` | `auto`, `flashinfer_cudnn`, `flashinfer_cutlass`, `flashinfer_trtllm` | +| `--fp8-gemm-backend` | Choose the runner backend for Blockwise FP8 GEMM operations. Options: 'auto' (default, auto-selects based on hardware), 'deep_gemm' (JIT-compiled; enabled by default on NVIDIA Hopper (SM90) and Blackwell (SM100) when DeepGEMM is installed), 'flashinfer_trtllm' (FlashInfer TRTLLM backend; SM100/SM103 only), 'flashinfer_cutlass' (FlashInfer CUTLASS backend, SM120 only), 'flashinfer_deepgemm' (Hopper SM90 only, uses swapAB optimization for small M dimensions in decoding), 'cutlass' (optimal for Hopper/Blackwell GPUs and high-throughput), 'triton' (fallback, widely compatible), 'aiter' (ROCm only).| `auto` | `auto`, `deep_gemm`, `flashinfer_trtllm`, `flashinfer_cutlass`, `flashinfer_deepgemm`, `cutlass`, `triton`, `aiter` | +| `--fp4-gemm-backend` | Choose the runner backend for NVFP4 GEMM operations. Options: 'flashinfer_cutlass' (default), 'auto' (auto-selects between flashinfer_cudnn/flashinfer_cutlass based on CUDA/cuDNN version), 'flashinfer_cudnn' (FlashInfer cuDNN backend, optimal on CUDA 13+ with cuDNN 9.15+), 'flashinfer_trtllm' (FlashInfer TensorRT-LLM backend, requires different weight preparation with shuffling). All backends are from FlashInfer; when FlashInfer is unavailable, sgl-kernel CUTLASS is used as an automatic fallback.| `flashinfer_cutlass` | `auto`, `flashinfer_cudnn`, `flashinfer_cutlass`, `flashinfer_trtllm` | | `--disable-flashinfer-autotune` | Flashinfer autotune is enabled by default. Set this flag to disable the autotune. | `False` | bool flag (set to enable) | ## Speculative decoding @@ -294,12 +295,10 @@ Please consult the documentation below and [server_args.py](https://github.com/s ## Ngram speculative decoding | Argument | Description | Defaults | Options | | --- | --- | --- | --- | -| `--speculative-ngram-min-match-window-size` | The minimum window size for pattern matching in ngram speculative decoding. | `1` | Type: int | -| `--speculative-ngram-max-match-window-size` | The maximum window size for pattern matching in ngram speculative decoding. | `12` | Type: int | | `--speculative-ngram-min-bfs-breadth` | The minimum breadth for BFS (Breadth-First Search) in ngram speculative decoding. | `1` | Type: int | | `--speculative-ngram-max-bfs-breadth` | The maximum breadth for BFS (Breadth-First Search) in ngram speculative decoding. | `10` | Type: int | | `--speculative-ngram-match-type` | Ngram tree-building mode. `BFS` selects recency-based expansion and `PROB` selects frequency-based expansion. This setting is forwarded to the ngram cache implementation. | `BFS` | `BFS`, `PROB` | -| `--speculative-ngram-max-trie-depth` | The max trie depth for ngram speculative decoding. | `18` | Type: int | +| `--speculative-ngram-max-trie-depth` | Maximum suffix length stored and matched by the ngram trie. | `18` | Type: int | | `--speculative-ngram-capacity` | The cache capacity for ngram speculative decoding. | `10000000` | Type: int | ## Multi-layer Eagle speculative decoding diff --git a/docs/advanced_features/sgl_model_gateway.md b/docs/advanced_features/sgl_model_gateway.md index 753743b0b0bb..a718b8a37359 100644 --- a/docs/advanced_features/sgl_model_gateway.md +++ b/docs/advanced_features/sgl_model_gateway.md @@ -77,7 +77,7 @@ SGLang Model Gateway is a high-performance model-routing gateway for large-scale ### Control Plane -- **Worker Manager** discovers capabilities (`/get_server_info`, `/get_model_info`), tracks load, and registers/removes workers in the shared registry. +- **Worker Manager** discovers capabilities (`/server_info`, `/get_model_info`), tracks load, and registers/removes workers in the shared registry. - **Job Queue** serializes add/remove requests and exposes status (`/workers/{worker_id}`) so clients can track onboarding progress. - **Load Monitor** feeds cache-aware and power-of-two policies with live worker load statistics. - **Health Checker** continuously probes workers and updates readiness, circuit breaker state, and router metrics. @@ -552,7 +552,7 @@ Response: | `GET` | `/engine_metrics` | Engine-level metrics from workers | | `GET` | `/v1/models` | List available models | | `GET` | `/get_model_info` | Get model information | -| `GET` | `/get_server_info` | Get server information | +| `GET` | `/server_info` | Get server information | | `POST` | `/flush_cache` | Clear all caches | | `GET` | `/get_loads` | Get all worker loads | | `POST` | `/wasm` | Upload WASM module | diff --git a/docs/advanced_features/speculative_decoding.md b/docs/advanced_features/speculative_decoding.md index c573af0724a8..9806f244e2a1 100644 --- a/docs/advanced_features/speculative_decoding.md +++ b/docs/advanced_features/speculative_decoding.md @@ -387,13 +387,11 @@ Enable it with: | Parameter | Description | Default | |---|---|---| -| `--speculative-num-draft-tokens` | Number of draft tokens verified per step. If omitted, defaults to `--speculative-ngram-max-match-window-size`. | `12` (with default ngram settings) | -| `--speculative-ngram-min-match-window-size` | Minimum matching window size. | `1` | -| `--speculative-ngram-max-match-window-size` | Maximum matching window size. | `12` | +| `--speculative-num-draft-tokens` | Number of draft tokens verified per step. | `12` | | `--speculative-ngram-min-bfs-breadth` | Minimum BFS breadth. | `1` | | `--speculative-ngram-max-bfs-breadth` | Maximum BFS breadth. | `10` | | `--speculative-ngram-match-type` | Ngram tree-building mode: `"BFS"` for recency-based expansion or `"PROB"` for frequency-based expansion. | `"BFS"` | -| `--speculative-ngram-max-trie-depth` | The max trie depth for ngram speculative decoding. | `18` | +| `--speculative-ngram-max-trie-depth` | Maximum suffix length stored and matched by the ngram trie. | `18` | | `--speculative-ngram-capacity` | Cache capacity (number of entries). | `10,000,000` | Notes: @@ -408,7 +406,6 @@ python3 -m sglang.launch_server \ --model Qwen/Qwen2.5-7B-Instruct \ --speculative-algorithm NGRAM \ --speculative-num-draft-tokens 16 \ - --speculative-ngram-max-match-window-size 12 \ --speculative-ngram-max-bfs-breadth 10 \ --mem-fraction-static 0.7 \ --cuda-graph-max-bs 8 \ @@ -464,12 +461,10 @@ Below is a comprehensive list of all speculative decoding parameters available i | Parameter | Type | Default | Description | |---|---|---|---| -| `--speculative-ngram-min-match-window-size` | `int` | `1` | Minimum ngram matching window | -| `--speculative-ngram-max-match-window-size` | `int` | `12` | Maximum ngram matching window | | `--speculative-ngram-min-bfs-breadth` | `int` | `1` | Minimum BFS breadth | | `--speculative-ngram-max-bfs-breadth` | `int` | `10` | Maximum BFS breadth | | `--speculative-ngram-match-type` | `str` | `"BFS"` | Ngram tree-building mode: `"BFS"` for recency-based expansion or `"PROB"` for frequency-based expansion | -| `--speculative-ngram-max-trie-depth` | `int` | `18` | Max trie depth for ngram speculative decoding | +| `--speculative-ngram-max-trie-depth` | `int` | `18` | Maximum suffix length stored and matched by the ngram trie | | `--speculative-ngram-capacity` | `int` | `10,000,000` | Cache capacity | ### Environment variables diff --git a/docs/basic_usage/deepseek_v3.md b/docs/basic_usage/deepseek_v3.md index b558c2223abf..9770c2882f13 100644 --- a/docs/basic_usage/deepseek_v3.md +++ b/docs/basic_usage/deepseek_v3.md @@ -74,7 +74,7 @@ Detailed commands for reference: - [16 x A100 (INT8)](https://github.com/sgl-project/sglang/tree/main/benchmark/deepseek_v3#example-serving-with-16-a100a800-with-int8-quantization) - [32 x L40S (INT8)](https://github.com/sgl-project/sglang/tree/main/benchmark/deepseek_v3#example-serving-with-32-l40s-with-int8-quantization) - [Xeon 6980P CPU](../platforms/cpu_server.md#example-running-deepseek-r1) -- [4 x Atlas 800I A3 (int8)](../platforms/ascend_npu_deepseek_example.md#running-deepseek-with-pd-disaggregation-on-4-x-atlas-800i-a3) +- [4 x Atlas 800I A3 (int8)](../platforms/ascend/ascend_npu_deepseek_example.md#running-deepseek-with-pd-disaggregation-on-4-x-atlas-800i-a3) ### Download Weights If you encounter errors when starting the server, ensure the weights have finished downloading. It's recommended to download them beforehand or restart multiple times until all weights are downloaded. Please refer to [DeepSeek V3](https://huggingface.co/deepseek-ai/DeepSeek-V3-Base#61-inference-with-deepseek-infer-demo-example-only) official guide to download the weights. diff --git a/docs/basic_usage/native_api.ipynb b/docs/basic_usage/native_api.ipynb index 65ff5809efe3..d3ead5e349d6 100644 --- a/docs/basic_usage/native_api.ipynb +++ b/docs/basic_usage/native_api.ipynb @@ -10,7 +10,7 @@ "\n", "- `/generate` (text generation model)\n", "- `/get_model_info`\n", - "- `/get_server_info`\n", + "- `/server_info`\n", "- `/health`\n", "- `/health_generate`\n", "- `/flush_cache`\n", @@ -140,7 +140,7 @@ "metadata": {}, "outputs": [], "source": [ - "url = f\"http://localhost:{port}/get_server_info\"\n", + "url = f\"http://localhost:{port}/server_info\"\n", "\n", "response = requests.get(url)\n", "print_highlight(response.text)" diff --git a/docs/developer_guide/bench_serving.md b/docs/developer_guide/bench_serving.md index 5a67723c8ab7..fee65a117735 100644 --- a/docs/developer_guide/bench_serving.md +++ b/docs/developer_guide/bench_serving.md @@ -352,4 +352,4 @@ python3 -m sglang.bench_serving \ ### Notes - The script raises the file descriptor soft limit (`RLIMIT_NOFILE`) to help with many concurrent connections. -- For sglang, `/get_server_info` is queried post-run to report speculative decoding accept length when available. +- For sglang, `/server_info` is queried post-run to report speculative decoding accept length when available. diff --git a/docs/developer_guide/contribution_guide.md b/docs/developer_guide/contribution_guide.md index 8218dcc87af8..4f23b0f1a3c6 100644 --- a/docs/developer_guide/contribution_guide.md +++ b/docs/developer_guide/contribution_guide.md @@ -31,28 +31,14 @@ pre-commit run --all-files - Link checking with lychee is **enforced in CI**. By default, it is not blocking local commits. - To run local link checks manually, use: `pre-commit run --hook-stage manual lychee --all-files`. -### Link check guidance (lychee) - -- If your PR changes `docs/` or `README.md`, we recommend running local link checks before pushing. -- Local lychee is optional (CI is the source of truth), but if you want a system installation, see the official project: [lycheeverse/lychee](https://github.com/lycheeverse/lychee). -- Recommended local commands: - -```bash -# Fast local/offline check (pre-commit config) -pre-commit run --hook-stage manual lychee --all-files - -# CI-like online check (external links over network) -lychee --config .github/linters/lychee-ci.toml README.md "docs/**/*.md" "docs/**/*.rst" "docs/**/*.ipynb" -``` - ## Run and add unit tests If you add a new feature or fix a bug, please add corresponding unit tests to ensure coverage and prevent regression. -SGLang uses Python's built-in [unittest](https://docs.python.org/3/library/unittest.html) framework with [pytest](https://docs.pytest.org/) as the test runner. ### Unit tests (no server required) Unit tests live under [`test/registered/unit/`](https://github.com/sgl-project/sglang/tree/main/test/registered/unit), organized to mirror the `python/sglang/srt/` source tree. These tests validate component logic **without** launching a server or loading real model weights. +SGLang uses Python's built-in [unittest](https://docs.python.org/3/library/unittest.html) framework with [pytest](https://docs.pytest.org/) as the test runner. **When to add a unit test:** If you modify a file under `python/sglang/srt/`, check whether a corresponding test exists in `test/registered/unit/` and add coverage for your changes. For example: @@ -140,7 +126,6 @@ If you don’t have permission and you’re not the PR author, please ask mainta ### CI rate limits Due to CI scheduling and limited resources, higher-priority PRs may preempt running jobs. In such cases, you may need to rerun the tests. - We apply CI rate limits to prevent abuse and ensure fair usage of our CI resources. Each CI workflow has a default limit defined in its workflow configuration file. For example, in [pr-gate.yml](https://github.com/sgl-project/sglang/blob/main/.github/workflows/pr-gate.yml), the default cooldown period is 120 minutes, and each workflow can override it via the `cool-down-minutes` input parameter: @@ -154,7 +139,6 @@ cool-down-minutes: Users listed in [CI_PERMISSIONS.json](https://github.com/sgl-project/sglang/blob/main/.github/CI_PERMISSIONS.json) may have a per-user cooldown interval. In practice, we use the minimum of the workflow’s default window and the user-specific interval. - ## Code style guidance - Avoid code duplication. If the same code snippet (more than five lines) appears multiple times, extract it into a shared function. - Minimize device synchronization. Reduce expensive CPU-GPU synchronization operations, such as `tensor.item()` or `tensor.cpu()`, whenever possible. Use vectorized code. @@ -188,7 +172,12 @@ Follow these steps: ## Tips for newcomers -If you want to contribute but don’t have a specific idea in mind, pick issues labeled [ā€œgood first issueā€ or ā€œhelp wantedā€](https://github.com/sgl-project/sglang/issues?q=is%3Aissue+label%3A%22good+first+issue%22%2C%22help+wanted%22). These tasks typically have lower complexity and provide an excellent introduction to the codebase. Also check out this [code walk-through](https://github.com/zhaochenyang20/Awesome-ML-SYS-Tutorial/tree/main/sglang/code-walk-through) for a deeper look into SGLang’s workflow. +If you want to contribute but don’t have a specific idea in mind, pick issues labeled [ā€œgood first issueā€ or ā€œhelp wantedā€](https://github.com/sgl-project/sglang/issues?q=is%3Aissue+label%3A%22good+first+issue%22%2C%22help+wanted%22). These tasks typically have lower complexity and provide an excellent introduction to the codebase. + +Also check out the following materials as startup guide: +- [Mini-SGLang](https://github.com/sgl-project/mini-sglang) for a quick overview on the structure of sglang. +- [Code Walk-through](https://github.com/zhaochenyang20/Awesome-ML-SYS-Tutorial/tree/main/sglang/code-walk-through) for a deeper look into SGLang’s workflow. +- [GTC-2026 Training Lab](https://drive.google.com/file/d/1mwOZEtipNLJzrflCTodj34KhuOZEoEw5/view?usp=drive_link) for hands-on practices of how to do optimization, benchmarking, or profiling on a launched SGLang instance. If you have any questions or want to start a discussion, please feel free to ask in our [Slack channel](https://slack.sglang.io). diff --git a/docs/diffusion/api/cli.md b/docs/diffusion/api/cli.md index 5ab0f00ce6b1..a4caaca7d4c5 100644 --- a/docs/diffusion/api/cli.md +++ b/docs/diffusion/api/cli.md @@ -2,6 +2,43 @@ Use the CLI for one-off generation with `sglang generate` or to start a persistent HTTP server with `sglang serve`. +### Overlay repos for non-diffusers models + +If `--model-path` points to a supported non-diffusers source repo, SGLang can resolve it +through a self-hosted overlay repo. + +SGLang first checks a built-in overlay registry. Concrete built-in mappings can be added over time without changing the CLI surface. + +Override example: + +```bash +export SGLANG_DIFFUSION_MODEL_OVERLAY_REGISTRY='{ + "Wan-AI/Wan2.2-S2V-14B": { + "overlay_repo_id": "your-org/Wan2.2-S2V-14B-overlay", + "overlay_revision": "main" + } +}' + +sglang generate \ + --model-path Wan-AI/Wan2.2-S2V-14B \ + --config configs/wan_s2v.yaml +``` + +The overlay repo should be a complete diffusers-style/componentized repo + +You can also pass the overlay repo itself as `--model-path` if it contains `_overlay/overlay_manifest.json`. + +Notes: +1. `SGLANG_DIFFUSION_MODEL_OVERLAY_REGISTRY` is only an optional override for +development and debugging. It accepts either a JSON object or a path to a JSON +file, and can extend or replace built-in entries for the current process. +2. On the first load, SGLang will: + - download overlay metadata from the overlay repo + - download the required files from the original source repo + - materialize a local standard component repo under `~/.cache/sgl_diffusion/materialized_models/` +3. Later loads reuse the materialized local repo. The materialized repo is what the runtime loads as a normal componentized model directory. + + ## Quick Start ### Generate diff --git a/docs/diffusion/installation.md b/docs/diffusion/installation.md index 9531d70d2c68..8c0fa1d11192 100644 --- a/docs/diffusion/installation.md +++ b/docs/diffusion/installation.md @@ -84,7 +84,7 @@ pip install -e "python[all_musa]" ## Platform-Specific: Ascend NPU -For Ascend NPU, please follow the [NPU installation guide](../platforms/ascend_npu.md). +For Ascend NPU, please follow the [NPU installation guide](../platforms/ascend/ascend_npu.md). Quick test: diff --git a/docs/diffusion/performance/attention_backends.md b/docs/diffusion/performance/attention_backends.md index ebb19e111a06..55d6da27000e 100644 --- a/docs/diffusion/performance/attention_backends.md +++ b/docs/diffusion/performance/attention_backends.md @@ -14,7 +14,7 @@ When using the diffusers backend, `--attention-backend` is passed through to dif - **CUDA**: prefers FlashAttention (FA3/FA4) when supported; otherwise falls back to PyTorch SDPA. - **ROCm**: uses FlashAttention when available; otherwise falls back to PyTorch SDPA. - **MPS**: always uses PyTorch SDPA. -- **NPU**: always uses PyTorch SDPA. +- **NPU**: for ring attention uses FA otherwise uses PyTorch SDPA. ## Backend options @@ -87,7 +87,7 @@ Some backends require additional configuration. You can pass these parameters vi | Backend | CUDA | ROCm | MPS | NPU | Notes | |---|---:|---:|---:|---:|---| -| `fa` | āœ… | āœ… | āŒ | āŒ | CUDA requires SM80+ and fp16/bf16. FlashAttention is only used when the required runtime is installed; otherwise it falls back to `torch_sdpa`. | +| `fa` | āœ… | āœ… | āŒ | āœ… | CUDA requires SM80+ and fp16/bf16. FlashAttention is only used when the required runtime is installed; otherwise it falls back to `torch_sdpa`. No extra installations are required for NPU | | `torch_sdpa` | āœ… | āœ… | āœ… | āœ… | Most compatible option across platforms. | | `sliding_tile_attn` | āœ… | āŒ | āŒ | āŒ | CUDA-only. Requires `st_attn`. Configure via `--attention-backend-config`. | | `sage_attn` | āœ… | āŒ | āŒ | āŒ | CUDA-only (optional dependency). | diff --git a/docs/diffusion/performance/index.md b/docs/diffusion/performance/index.md index 4a3c064408a1..2a2abe54a239 100644 --- a/docs/diffusion/performance/index.md +++ b/docs/diffusion/performance/index.md @@ -30,6 +30,12 @@ cache/index profiling ``` +## Current Baseline Snapshot + +For Ring SP benchmark details, see: + +- [Ring SP Performance](ring_sp_performance.md) + ## References - [Cache-DiT Repository](https://github.com/vipshop/cache-dit) diff --git a/docs/diffusion/performance/ring_sp_performance.md b/docs/diffusion/performance/ring_sp_performance.md new file mode 100644 index 000000000000..138698bfc4f5 --- /dev/null +++ b/docs/diffusion/performance/ring_sp_performance.md @@ -0,0 +1,67 @@ +# Ring SP Benchmark: Wan2.2-TI2V-5B (u1r2 vs Baseline) + +This page reports Ring-SP performance for `Wan2.2-TI2V-5B-Diffusers` using: + +- Parallel config: `sp=2, ulysses=1, ring=2` (short: `u1r2`) +- Baseline config: `sp=1, ulysses=1, ring=1` (short: `u1r1`) + +## Benchmark Setup + +- Model: `Wan2.2-TI2V-5B-Diffusers` +- GPU: `48G RTX40 series * 2` + +## Online Serving + +### Ring SP (`u1r2`) + +```bash +sglang serve \ + --model-type diffusion \ + --model-path /model/HuggingFace/Wan-AI/Wan2.2-TI2V-5B-Diffusers \ + --num-gpus 2 --sp-degree 2 --ulysses-degree 1 --ring-degree 2 \ + --port 8898 +``` + +### Baseline (`u1r1`) + +```bash +sglang serve \ + --model-type diffusion \ + --model-path /model/HuggingFace/Wan-AI/Wan2.2-TI2V-5B-Diffusers \ + --num-gpus 1 --sp-degree 1 --ulysses-degree 1 --ring-degree 1 \ + --port 8898 +``` + +## Benchmarks + +### Benchmark Disclaimer + +These benchmarks are provided for reference under one specific setup and command configuration. Actual performance may vary with model settings, runtime environment, and request patterns. + +### Stage Time Breakdown + +| Stage / Metric | `u1r2` (s) | `u1r1` baseline (s) | Speedup | +|---|---:|---:|---:| +| InputValidation | 0.1060 | 0.1029 | 0.97x | +| TextEncoding | 1.3965 | 2.2261 | 1.59x | +| LatentPreparation | 0.0002 | 0.0002 | 1.00x | +| TimestepPreparation | 0.0003 | 0.0004 | 1.33x | +| Denoising | 52.6358 | 71.6785 | 1.36x | +| Decoding | 7.6708 | 13.4314 | 1.75x | +| **Total** | **63.74** | **90.63** | **1.42x** | + +### Memory Usage + +| Memory Metric | `u1r2` (GB) | `u1r1` baseline (GB) | Delta | +|---|---:|---:|---:| +| Peak GPU Memory | 20.07 | 27.40 | -7.33 | +| Peak Allocated | 13.35 | 20.40 | -7.05 | +| Memory Overhead | 6.72 | 7.00 | -0.28 | +| Overhead Ratio | 33.5% | 25.6% | +7.9pp | + +## Summary + +- End-to-end latency improves from `90.63s` to `63.74s` (`1.42x`). +- Main gains come from `Denoising` (`1.36x`) and `Decoding` (`1.75x`). +- Absolute memory usage drops noticeably on Ring-SP (`Peak GPU Memory -7.33GB`, `Peak Allocated -7.05GB`). +- Overhead ratio rises (`+7.9pp`), so future tuning can focus on reducing communication/runtime overhead while preserving the latency gain. diff --git a/docs/diffusion/quantization.md b/docs/diffusion/quantization.md index ab1b634a8b4d..04fa798b679a 100644 --- a/docs/diffusion/quantization.md +++ b/docs/diffusion/quantization.md @@ -44,7 +44,8 @@ backend. |------------------|--------------------------------------------------------------------------------------------|------------------------------------------------------|--------------------------------------------------------------|---------------------------------------|-----------------------------------------------------------------------------------------------------------------------| | `fp8` | Quantized transformer component folder, or safetensors with `quantization_config` metadata | `--transformer-path` or `--transformer-weights-path` | ALL | None | Component-folder and single-file flows are both supported | | `nvfp4-modelopt` | NVFP4 safetensors file, sharded directory, or repo providing transformer weights | `--transformer-weights-path` | FLUX.2 | `comfy-kitchen` optional on Blackwell | Blackwell can use a best-performance kit when available; otherwise SGLang falls back to the generic ModelOpt FP4 path | -| `nunchaku-svdq` | Pre-quantized Nunchaku transformer weights, usually named `svdq-{int4\|fp4}_r{rank}-...` | `--transformer-weights-path` | Model-specific support such as Qwen-Image, FLUX, and Z-Image | `nunchaku` | SGLang can infer precision and rank from the filename and supports both `int4` and `nvfp4` | +| `nunchaku-svdq` | Pre-quantized Nunchaku transformer weights, usually named `svdq-{int4\|fp4}_r{rank}-...` | `--transformer-weights-path` | Model-specific support such as Qwen-Image, FLUX, and Z-Image | `nunchaku` | SGLang can infer precision and rank from the filename and supports both `int4` and `nvfp4` | +| `msmodelslim` | Pre-quantized msmodelslim transformer weights | `--model-path` | Wan2.2 family | None | Currently only compatible with the Ascend NPU family and supports both `w8a8` and `w4a4` | ## NVFP4 @@ -171,3 +172,68 @@ sglang generate \ as `4` or `8`. - Current runtime validation only allows Nunchaku on NVIDIA CUDA Ampere (SM8x) or SM12x GPUs. Hopper (SM90) is currently rejected. + +## [ModelSlim](https://gitcode.com/Ascend/msmodelslim) +MindStudio-ModelSlim (msModelSlim) is a model offline quantization compression tool launched by MindStudio and optimized for Ascend hardware. + +- **Installation** + + ```bash + # Clone repo and install msmodelslim: + git clone https://gitcode.com/Ascend/msmodelslim.git + cd msmodelslim + bash install.sh + ``` + +- **Multimodal_sd quantization** + + Download the original floating-point weights of the large model. Taking Wan2.2-T2V-A14B as an example, you can go to [Wan2.2-T2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-T2V-A14B) to obtain the original model weights. Then install other dependencies (related to the model, refer to the modelscope model card). + > Note: You can find pre-quantized validated models on [modelscope/Eco-Tech](https://modelscope.cn/models/Eco-Tech). + + Run quantization using one-click quantization (recommended): + + ```bash + msmodelslim quant \ + --model_path /path/to/wan2_2_float_weights \ + --save_path /path/to/wan2_2_quantized_weights \ + --device npu \ + --model_type Wan2_2 \ + --quant_type w8a8 \ + --trust_remote_code True + ``` + + For more detailed examples of quantization of models, as well as information about their support, see the [examples](https://gitcode.com/Ascend/msmodelslim/blob/master/example/multimodal_sd/README.md) section in ModelSLim repo. + + > Note: SGLang does not support quantized embeddings, please disable this option when quantizing using msmodelslim. + +- **Auto-Detection and different formats** + + For msmodelslim checkpoints, it's enough to specify only ```--model-path```, the detection of quantization occurs automatically for each layer using parsing of `quant_model_description.json` config. + + In the case of `Wan2.2` only `Diffusers` weights storage format are supported, whereas modelslim saves the quantized model in the original `Wan2.2` format, + for conversion in use `python/sglang/multimodal_gen/tools/wan_repack.py` script: + + ```bash + python wan_repack.py \ + --input-path {path_to_quantized_model} \ + --output-path {path_to_converted_model} + ``` + + After that, please copy all files from original `Diffusers` checkpoint (instead of `transformer`/`tranfsormer_2` folders) + +- **Usage Example** + + With auto-detected flow: + + ```bash + sglang generate \ + --model-path Eco-Tech/Wan2.2-T2V-A14B-Diffusers-w8a8 \ + --prompt "a beautiful sunset" \ + --save-output + ``` + +- **Available Quantization Methods**: + - [x] ```W4A4_DYNAMIC``` linear with online quantization of activations + - [x] ```W8A8``` linear with offline quantization of activations + - [x] ```W8A8_DYNAMIC``` linear with online quantization of activations + - [ ] ```mxfp8``` linear in progress diff --git a/docs/get_started/install.md b/docs/get_started/install.md index 9306afc95539..a8aab8697996 100644 --- a/docs/get_started/install.md +++ b/docs/get_started/install.md @@ -2,7 +2,7 @@ You can install SGLang using one of the methods below. This page primarily applies to common NVIDIA GPU platforms. -For other or newer platforms, please refer to the dedicated pages for [AMD GPUs](../platforms/amd_gpu.md), [Intel Xeon CPUs](../platforms/cpu_server.md), [TPU](../platforms/tpu.md), [NVIDIA DGX Spark](https://lmsys.org/blog/2025-11-03-gpt-oss-on-nvidia-dgx-spark/), [NVIDIA Jetson](../platforms/nvidia_jetson.md), [Ascend NPUs](../platforms/ascend_npu.md), and [Intel XPU](../platforms/xpu.md). +For other or newer platforms, please refer to the dedicated pages for [AMD GPUs](../platforms/amd_gpu.md), [Intel Xeon CPUs](../platforms/cpu_server.md), [TPU](../platforms/tpu.md), [NVIDIA DGX Spark](https://lmsys.org/blog/2025-11-03-gpt-oss-on-nvidia-dgx-spark/), [NVIDIA Jetson](../platforms/nvidia_jetson.md), [Ascend NPUs](../platforms/ascend/ascend_npu.md), and [Intel XPU](../platforms/xpu.md). ## Method 1: With pip or uv diff --git a/docs/index.rst b/docs/index.rst index e61b40ef45a7..1a4defe9ff4c 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -41,6 +41,7 @@ Its core features include: :caption: Advanced Features advanced_features/server_arguments.md + advanced_features/object_storage.md advanced_features/hyperparameter_tuning.md advanced_features/attention_backend.md advanced_features/speculative_decoding.ipynb @@ -86,6 +87,8 @@ Its core features include: diffusion/compatibility_matrix diffusion/api/cli diffusion/api/openai_api + diffusion/performance/index + diffusion/performance/ring_sp_performance diffusion/performance/attention_backends diffusion/performance/cache/index diffusion/quantization @@ -99,7 +102,7 @@ Its core features include: platforms/cpu_server.md platforms/tpu.md platforms/nvidia_jetson.md - platforms/ascend_npu_support.rst + platforms/ascend/ascend_npu_support.rst platforms/xpu.md .. toctree:: diff --git a/docs/platforms/ascend_contribution_guide.md b/docs/platforms/ascend/ascend_contribution_guide.md similarity index 89% rename from docs/platforms/ascend_contribution_guide.md rename to docs/platforms/ascend/ascend_contribution_guide.md index fa87161ff651..5823bd22a0ba 100644 --- a/docs/platforms/ascend_contribution_guide.md +++ b/docs/platforms/ascend/ascend_contribution_guide.md @@ -6,7 +6,7 @@ Welcome to **SGLang**! We appreciate your interest in contributing. This guide p ### Prepare Environment -Before contributing, please ensure that your environment is set up correctly. Follow the steps in the [Installation Guide](../platforms/ascend_npu.md) to install the necessary dependencies. We recommend [using docker](../platforms/ascend_npu.md#method-2-using-docker-image) to build the environment. +Before contributing, please ensure that your environment is set up correctly. Follow the steps in the [Installation Guide](ascend_npu.md) to install the necessary dependencies. We recommend [using docker](ascend_npu.md#method-2-using-docker-image) to build the environment. ### Fork and clone the repository @@ -38,6 +38,18 @@ If you add a new feature or fix a bug, please add corresponding unit tests to en SGLang uses Python's built-in [unittest](https://docs.python.org/3/library/unittest.html) framework. For detailed instructions on running tests and integrating them into CI, refer to [test/README.md](https://github.com/sgl-project/sglang/tree/main/test/README.md). +If you need to use model which is not in ```python/sglang/test/ascend/test_ascend_utils.py`` list. Follow these steps: +1. Register account and upload your model to [modelscope](https://modelscope.cn/models). +2. Make sure your model is pre-cached on the CI server and is on the way "/data/ascend-ci-share-pkking-sglang/modelscope/hub/models/{your_model_repo}/{your_model}". +If this is not the case, use following command on CI server: + ```bash + modelscope download + --model {your_model_repo}/{your_model} + --local_dir /data/ascend-ci-share-pkking-sglang/modelscope/hub/models/{your_model_repo}/{your_model} + ``` + > Note: If you don’t have access to CI server, please ask maintainers (zl19940307@163.com) to download your model. +4. Add model to ```python/sglang/test/ascend/test_ascend_utils.py``` (use docker ```"/root/.cache/modelscope/hub/models/{your_model_repo}/{your_model}"``` path). + ## Write documentations We recommend new contributors start from writing documentation, which helps you quickly understand SGLang codebase. @@ -64,7 +76,7 @@ You can find additional accuracy eval examples in: - [test_moe_eval_accuracy_large.py](https://github.com/sgl-project/sglang/blob/main/test/registered/eval/test_moe_eval_accuracy_large.py) ## Benchmark the speed -Refer to [Benchmark and Profiling](../developer_guide/benchmark_and_profiling.md). +Refer to [Benchmark and Profiling](../../developer_guide/benchmark_and_profiling.md). ## Requesting a review for merge You can follow the pull request merge process described in [MAINTAINER.md](https://github.com/sgl-project/sglang/blob/main/.github/MAINTAINER.md). @@ -108,7 +120,6 @@ cool-down-minutes: Users listed in [CI_PERMISSIONS.json](https://github.com/sgl-project/sglang/blob/main/.github/CI_PERMISSIONS.json) may have a per-user cooldown interval. In practice, we use the minimum of the workflow’s default window and the user-specific interval. - ## Code style guidance - Avoid code duplication. If the same code snippet (more than five lines) appears multiple times, extract it into a shared function. - Minimize device synchronization. Reduce expensive CPU-GPU synchronization operations, such as `tensor.item()` or `tensor.cpu()`, whenever possible. Use vectorized code. diff --git a/docs/platforms/ascend_npu.md b/docs/platforms/ascend/ascend_npu.md similarity index 99% rename from docs/platforms/ascend_npu.md rename to docs/platforms/ascend/ascend_npu.md index 860eb0a7d76b..6a0eef31db26 100644 --- a/docs/platforms/ascend_npu.md +++ b/docs/platforms/ascend/ascend_npu.md @@ -170,7 +170,7 @@ export SGLANG_SET_CPU_AFFINITY=1 python3 -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct --attention-backend ascend ``` -#### PD Separation Scene +#### PD Disaggregation Scene 1. Launch Prefill Server ```shell # Enabling CPU Affinity diff --git a/docs/platforms/ascend_npu_best_practice.md b/docs/platforms/ascend/ascend_npu_best_practice.md similarity index 87% rename from docs/platforms/ascend_npu_best_practice.md rename to docs/platforms/ascend/ascend_npu_best_practice.md index aba6d2012666..39d49db48a30 100644 --- a/docs/platforms/ascend_npu_best_practice.md +++ b/docs/platforms/ascend/ascend_npu_best_practice.md @@ -7,23 +7,23 @@ you encounter issues or have any questions, please [open an issue](https://githu ### Low Latency -| Model | Hardware | Cards | Deploy Mode | Dataset | TPOT | Quantization | Configuration | -|-------------------|---------------|-------|---------------|-----------|------|--------------|---------------------------------------------------------------------------------------| -| Deepseek-R1 | Atlas 800I A3 | 32 | PD Separation | 6K+1.6K | 20ms | W8A8 INT8 | [Optimal Configuration](#deepseek-r1-6k-1_6k-20ms-on-a3-32-cards-separation-mode) | -| Deepseek-R1 | Atlas 800I A3 | 32 | PD Separation | 3.9K+1K | 20ms | W8A8 INT8 | [Optimal Configuration](#deepseek-r1-3_9k-1k-20ms-on-a3-32-cards-separation-mode) | -| Deepseek-R1 | Atlas 800I A3 | 32 | PD Separation | 3.5K+1.5K | 20ms | W8A8 INT8 | [Optimal Configuration](#deepseek-r1-3_5k-1_5k-20ms-on-a3-32-cards-separation-mode) | -| Deepseek-R1 | Atlas 800I A3 | 32 | PD Separation | 3.5K+1K | 20ms | W8A8 INT8 | [Optimal Configuration](#deepseek-r1-3_5k-1k-20ms-on-a3-32-cards-separation-mode) | -| DeepSeek-V3.2-Exp | Atlas 800I A3 | 32 | PD Separation | 64K+3K | 30ms | W8A8 INT8 | [Optimal Configuration](#deepseek-v32-exp-64k-3k-30ms-on-a3-32-cards-separation-mode) | +| Model | Hardware | Cards | Deploy Mode | Dataset | TPOT | Quantization | Configuration | +|-------------------|---------------|-------|-------------------|-----------|------|--------------|-------------------------------------------------------------------------------------------| +| Deepseek-R1 | Atlas 800I A3 | 32 | PD Disaggregation | 6K+1.6K | 20ms | W8A8 INT8 | [Optimal Configuration](#deepseek-r1-6k-1_6k-20ms-on-a3-32-cards-disaggregation-mode) | +| Deepseek-R1 | Atlas 800I A3 | 32 | PD Disaggregation | 3.9K+1K | 20ms | W8A8 INT8 | [Optimal Configuration](#deepseek-r1-3_9k-1k-20ms-on-a3-32-cards-disaggregation-mode) | +| Deepseek-R1 | Atlas 800I A3 | 32 | PD Disaggregation | 3.5K+1.5K | 20ms | W8A8 INT8 | [Optimal Configuration](#deepseek-r1-3_5k-1_5k-20ms-on-a3-32-cards-disaggregation-mode) | +| Deepseek-R1 | Atlas 800I A3 | 32 | PD Disaggregation | 3.5K+1K | 20ms | W8A8 INT8 | [Optimal Configuration](#deepseek-r1-3_5k-1k-20ms-on-a3-32-cards-disaggregation-mode) | +| DeepSeek-V3.2 | Atlas 800I A3 | 32 | PD Disaggregation | 128K+1K | 20ms | W8A8 INT8 | [Optimal Configuration](#deepseek-v32-128k-1k-20ms-on-a3-32-cards-disaggregation-mode) | ### High Throughput -| Model | Hardware | Cards | Deploy Mode | Dataset | TPOT | Quantization | Configuration | -|-------------|---------------|-------|---------------|-----------|------|--------------|-------------------------------------------------------------------------------------| -| Deepseek-R1 | Atlas 800I A3 | 32 | PD Separation | 3.5K+1.5K | 50ms | W8A8 INT8 | [Optimal Configuration](#deepseek-r1-3_5k-1_5k-50ms-on-a3-32-cards-separation-mode) | -| Deepseek-R1 | Atlas 800I A3 | 8 | PD Mixed | 2K+2K | 50ms | W4A8 INT8 | [Optimal Configuration](#deepseek-r1-2k-2k-50ms-on-a3-8-cards-mixed-mode) | -| Deepseek-R1 | Atlas 800I A3 | 16 | PD Separation | 2K+2K | 50ms | W4A8 INT8 | [Optimal Configuration](#deepseek-r1-2k-2k-50ms-on-a3-16-cards-separation-mode) | -| Deepseek-R1 | Atlas 800I A3 | 8 | PD Mixed | 3.5K+1.5K | 50ms | W4A8 INT8 | [Optimal Configuration](#deepseek-r1-3_5k-1_5k-50ms-on-a3-8-cards-mixed-mode) | -| Deepseek-R1 | Atlas 800I A3 | 16 | PD Separation | 3.5K+1.5K | 50ms | W4A8 INT8 | [Optimal Configuration](#deepseek-r1-3_5k-1_5k-50ms-on-a3-16-cards-separation-mode) | +| Model | Hardware | Cards | Deploy Mode | Dataset | TPOT | Quantization | Configuration | +|-------------|---------------|-------|-------------------|-----------|------|--------------|-----------------------------------------------------------------------------------------| +| Deepseek-R1 | Atlas 800I A3 | 32 | PD Disaggregation | 3.5K+1.5K | 50ms | W8A8 INT8 | [Optimal Configuration](#deepseek-r1-3_5k-1_5k-50ms-on-a3-32-cards-disaggregation-mode) | +| Deepseek-R1 | Atlas 800I A3 | 8 | PD Mixed | 2K+2K | 50ms | W4A8 INT8 | [Optimal Configuration](#deepseek-r1-2k-2k-50ms-on-a3-8-cards-mixed-mode) | +| Deepseek-R1 | Atlas 800I A3 | 16 | PD Disaggregation | 2K+2K | 50ms | W4A8 INT8 | [Optimal Configuration](#deepseek-r1-2k-2k-50ms-on-a3-16-cards-disaggregation-mode) | +| Deepseek-R1 | Atlas 800I A3 | 8 | PD Mixed | 3.5K+1.5K | 50ms | W4A8 INT8 | [Optimal Configuration](#deepseek-r1-3_5k-1_5k-50ms-on-a3-8-cards-mixed-mode) | +| Deepseek-R1 | Atlas 800I A3 | 16 | PD Disaggregation | 3.5K+1.5K | 50ms | W4A8 INT8 | [Optimal Configuration](#deepseek-r1-3_5k-1_5k-50ms-on-a3-16-cards-disaggregation-mode) | ## Qwen Series Models @@ -40,32 +40,32 @@ you encounter issues or have any questions, please [open an issue](https://githu ### High Throughput -| Model | Hardware | Cards | Deploy Mode | Dataset | TPOT | Quantization | Configuration | -|--------------------------------|---------------|-------|---------------|-----------|-------|--------------|--------------------------------------------------------------------------------------------------------| -| Qwen3-235B-A22B | Atlas 800I A3 | 24 | PD Separation | 3.5K+1.5K | 50ms | W8A8 INT8 | [Optimal Configuration](#qwen3-235b-a22b-3_5k-1_5k-50ms-on-a3-24-cards-separation-mode) | -| Qwen3-235B-A22B | Atlas 800I A3 | 8 | PD Mixed | 3.5K+1.5K | 50ms | W8A8 INT8 | [Optimal Configuration](#qwen3-235b-a22b-3_5k-1_5k-50ms-on-a3-8-cards-mixed-mode) | -| Qwen3-235B-A22B | Atlas 800I A3 | 8 | PD Mixed | 2K+2K | 100ms | W8A8 INT8 | [Optimal Configuration](#qwen3-235b-a22b-2k-2k-100ms-on-a3-8-cards-mixed-mode) | -| Qwen3-235B-A22B | Atlas 800I A3 | 8 | PD Mixed | 2K+2K | 50ms | W8A8 INT8 | [Optimal Configuration](#qwen3-235b-a22b-2k-2k-50ms-on-a3-8-cards-mixed-mode) | -| Qwen3-235B-A22B | Atlas 800I A3 | 16 | PD Mixed | 2K+2K | 50ms | W8A8 INT8 | [Optimal Configuration](#qwen3-235b-a22b-2k-2k-50ms-on-a3-16-cards-mixed-mode) | -| Qwen3-32B | Atlas 800I A3 | 2 | PD Mixed | 3.5K+1.5K | 50ms | W8A8 INT8 | [Optimal Configuration](#qwen3-32b-3_5k-1_5k-50ms-on-a3-2-cards-mixed-mode) | -| Qwen3-32B | Atlas 800I A3 | 2 | PD Mixed | 2K+2K | 50ms | W8A8 INT8 | [Optimal Configuration](#qwen3-32b-2k-2k-50ms-on-a3-2-cards-mixed-mode) | -| Qwen3-30B-A3B | Atlas 800I A3 | 1 | PD Mixed | 3.5K+1.5K | 50ms | W8A8 INT8 | [Optimal Configuration](#qwen3-30b-a3b-3_5k-1_5k-50ms-on-a3-1-card-mixed-mode) | -| Qwen3-Coder-480B-A35B-Instruct | Atlas 800I A3 | 24 | PD Separation | 3.5K+1.5K | 50ms | W8A8 INT8 | [Optimal Configuration](#qwen3-coder-480b-a35b-instruct-3_5k-1_5k-50ms-on-a3-24-cards-separation-mode) | -| Qwen3-Coder-480B-A35B-Instruct | Atlas 800I A3 | 16 | PD Mixed | 3.5K+1.5K | 50ms | W8A8 INT8 | [Optimal Configuration](#qwen3-coder-480b-a35b-instruct-3_5k-1_5k-50ms-on-a3-16-cards-mixed-mode) | -| Qwen3-Coder-480B-A35B-Instruct | Atlas 800I A3 | 8 | PD Mixed | 3.5K+1.5K | 50ms | W8A8 INT8 | [Optimal Configuration](#qwen3-coder-480b-a35b-instruct-3_5k-1_5k-50ms-on-a3-8-cards-mixed-mode) | -| Qwen3-Next-80B-A3B-Instruct | Atlas 800I A3 | 2 | PD Mixed | 3.5K+1.5K | 50ms | W8A8 INT8 | [Optimal Configuration](#qwen3-next-80B-a3b-instruct-3_5k-1_5k-50ms-on-a3-2-cards-mixed-mode) | -| Qwen3-32B | Atlas 800I A2 | 8 | PD Mixed | 3.5K+1.5K | 50ms | W8A8 INT8 | [Optimal Configuration](#qwen3-32b-3_5k-1_5k-50ms-on-a2-8-cards-mixed-mode) | -| Qwen3-32B | Atlas 800I A2 | 8 | PD Mixed | 2K+2K | 50ms | W8A8 INT8 | [Optimal Configuration](#qwen3-32b-2k-2k-50ms-on-a2-8-cards-mixed-mode) | +| Model | Hardware | Cards | Deploy Mode | Dataset | TPOT | Quantization | Configuration | +|--------------------------------|---------------|-------|-------------------|-----------|-------|--------------|------------------------------------------------------------------------------------------------------------| +| Qwen3-235B-A22B | Atlas 800I A3 | 24 | PD Disaggregation | 3.5K+1.5K | 50ms | W8A8 INT8 | [Optimal Configuration](#qwen3-235b-a22b-3_5k-1_5k-50ms-on-a3-24-cards-disaggregation-mode) | +| Qwen3-235B-A22B | Atlas 800I A3 | 8 | PD Mixed | 3.5K+1.5K | 50ms | W8A8 INT8 | [Optimal Configuration](#qwen3-235b-a22b-3_5k-1_5k-50ms-on-a3-8-cards-mixed-mode) | +| Qwen3-235B-A22B | Atlas 800I A3 | 8 | PD Mixed | 2K+2K | 100ms | W8A8 INT8 | [Optimal Configuration](#qwen3-235b-a22b-2k-2k-100ms-on-a3-8-cards-mixed-mode) | +| Qwen3-235B-A22B | Atlas 800I A3 | 8 | PD Mixed | 2K+2K | 50ms | W8A8 INT8 | [Optimal Configuration](#qwen3-235b-a22b-2k-2k-50ms-on-a3-8-cards-mixed-mode) | +| Qwen3-235B-A22B | Atlas 800I A3 | 16 | PD Mixed | 2K+2K | 50ms | W8A8 INT8 | [Optimal Configuration](#qwen3-235b-a22b-2k-2k-50ms-on-a3-16-cards-mixed-mode) | +| Qwen3-32B | Atlas 800I A3 | 2 | PD Mixed | 3.5K+1.5K | 50ms | W8A8 INT8 | [Optimal Configuration](#qwen3-32b-3_5k-1_5k-50ms-on-a3-2-cards-mixed-mode) | +| Qwen3-32B | Atlas 800I A3 | 2 | PD Mixed | 2K+2K | 50ms | W8A8 INT8 | [Optimal Configuration](#qwen3-32b-2k-2k-50ms-on-a3-2-cards-mixed-mode) | +| Qwen3-30B-A3B | Atlas 800I A3 | 1 | PD Mixed | 3.5K+1.5K | 50ms | W8A8 INT8 | [Optimal Configuration](#qwen3-30b-a3b-3_5k-1_5k-50ms-on-a3-1-card-mixed-mode) | +| Qwen3-Coder-480B-A35B-Instruct | Atlas 800I A3 | 24 | PD Disaggregation | 3.5K+1.5K | 50ms | W8A8 INT8 | [Optimal Configuration](#qwen3-coder-480b-a35b-instruct-3_5k-1_5k-50ms-on-a3-24-cards-disaggregation-mode) | +| Qwen3-Coder-480B-A35B-Instruct | Atlas 800I A3 | 16 | PD Mixed | 3.5K+1.5K | 50ms | W8A8 INT8 | [Optimal Configuration](#qwen3-coder-480b-a35b-instruct-3_5k-1_5k-50ms-on-a3-16-cards-mixed-mode) | +| Qwen3-Coder-480B-A35B-Instruct | Atlas 800I A3 | 8 | PD Mixed | 3.5K+1.5K | 50ms | W8A8 INT8 | [Optimal Configuration](#qwen3-coder-480b-a35b-instruct-3_5k-1_5k-50ms-on-a3-8-cards-mixed-mode) | +| Qwen3-Next-80B-A3B-Instruct | Atlas 800I A3 | 2 | PD Mixed | 3.5K+1.5K | 50ms | W8A8 INT8 | [Optimal Configuration](#qwen3-next-80B-a3b-instruct-3_5k-1_5k-50ms-on-a3-2-cards-mixed-mode) | +| Qwen3-32B | Atlas 800I A2 | 8 | PD Mixed | 3.5K+1.5K | 50ms | W8A8 INT8 | [Optimal Configuration](#qwen3-32b-3_5k-1_5k-50ms-on-a2-8-cards-mixed-mode) | +| Qwen3-32B | Atlas 800I A2 | 8 | PD Mixed | 2K+2K | 50ms | W8A8 INT8 | [Optimal Configuration](#qwen3-32b-2k-2k-50ms-on-a2-8-cards-mixed-mode) | ## Optimal Configuration -### DeepSeek-R1 3_5K-1_5K 50ms on A3 32 Cards Separation Mode +### DeepSeek-R1 3_5K-1_5K 50ms on A3 32 Cards Disaggregation Mode Model: Deepseek R1 Hardware: Atlas 800I A3 32Card -DeployMode: PD Separation +DeployMode: PD Disaggregation Dataset: random @@ -158,7 +158,6 @@ done ``` ```shell -export SGLANG_DP_ROUND_ROBIN=1 python -m sglang_router.launch_router \ --pd-disaggregation \ --policy cache_aware \ @@ -178,13 +177,13 @@ We tested it based on the `RANDOM` dataset. python -m sglang.bench_serving --dataset-name random --backend sglang --host 127.0.0.1 --port 6688 --max-concurrency 768 --random-input-len 3500 --random-output-len 1500 --num-prompts 3072 --random-range-ratio 1 --request-rate 16 ``` -### DeepSeek-R1 6K-1_6K 20ms on A3 32 Cards Separation Mode +### DeepSeek-R1 6K-1_6K 20ms on A3 32 Cards Disaggregation Mode Model: Deepseek R1 Hardware: Atlas 800I A3 32Card -DeployMode: PD Separation +DeployMode: PD Disaggregation Dataset: random @@ -275,7 +274,6 @@ done ``` ```shell -export SGLANG_DP_ROUND_ROBIN=1 python -m sglang_router.launch_router \ --pd-disaggregation \ --policy cache_aware \ @@ -295,13 +293,13 @@ We tested it based on the `RANDOM` dataset. python -m sglang.bench_serving --dataset-name random --backend sglang --host 127.0.0.1 --port 6688 --max-concurrency 32 --random-input-len 6000 --random-output-len 1600 --num-prompts 32 --random-range-ratio 1 ``` -### DeepSeek-R1 3_9K-1K 20ms on A3 32 Cards Separation Mode +### DeepSeek-R1 3_9K-1K 20ms on A3 32 Cards Disaggregation Mode Model: Deepseek R1 Hardware: Atlas 800I A3 32Card -DeployMode: PD Separation +DeployMode: PD Disaggregation Dataset: random @@ -311,7 +309,7 @@ TPOT: 20ms #### Model Deployment -Please Turn to [DeepSeek-R1 6K-1_6K 20ms on A3 32 Cards Separation Mode](#deepseek-r1-6k-1_6k-20ms-on-a3-32-cards-separation-mode) +Please Turn to [DeepSeek-R1 6K-1_6K 20ms on A3 32 Cards Disaggregation Mode](#deepseek-r1-6k-1_6k-20ms-on-a3-32-cards-disaggregation-mode) #### Benchmark @@ -321,13 +319,13 @@ We tested it based on the `RANDOM` dataset. python -m sglang.bench_serving --dataset-name random --backend sglang --host 127.0.0.1 --port 6688 --max-concurrency 768 --random-input-len 3900 --random-output-len 1000 --num-prompts 768 --random-range-ratio 1 --request-rate 16 ``` -### DeepSeek-R1 3_5K-1_5K 20ms on A3 32 Cards Separation Mode +### DeepSeek-R1 3_5K-1_5K 20ms on A3 32 Cards Disaggregation Mode Model: Deepseek R1 Hardware: Atlas 800I A3 32Card -DeployMode: PD Separation +DeployMode: PD Disaggregation Dataset: random @@ -337,7 +335,7 @@ TPOT: 20ms #### Model Deployment -Please Turn to [DeepSeek-R1 6K-1_6K 20ms on A3 32 Cards Separation Mode](#deepseek-r1-6k-1_6k-20ms-on-a3-32-cards-separation-mode) +Please Turn to [DeepSeek-R1 6K-1_6K 20ms on A3 32 Cards Disaggregation Mode](#deepseek-r1-6k-1_6k-20ms-on-a3-32-cards-disaggregation-mode) #### Benchmark @@ -347,13 +345,13 @@ We tested it based on the `RANDOM` dataset. python -m sglang.bench_serving --dataset-name random --backend sglang --host 127.0.0.1 --port 6688 --max-concurrency 768 --random-input-len 3500 --random-output-len 1500 --num-prompts 768 --random-range-ratio 1 --request-rate 16 ``` -### DeepSeek-R1 3_5K-1K 20ms on A3 32 Cards Separation Mode +### DeepSeek-R1 3_5K-1K 20ms on A3 32 Cards Disaggregation Mode Model: Deepseek R1 Hardware: Atlas 800I A3 32Card -DeployMode: PD Separation +DeployMode: PD Disaggregation Dataset: random @@ -363,7 +361,7 @@ TPOT: 20ms #### Model Deployment -Please Turn to [DeepSeek-R1 6K-1_6K 20ms on A3 32 Cards Separation Mode](#deepseek-r1-6k-1_6k-20ms-on-a3-32-cards-separation-mode) +Please Turn to [DeepSeek-R1 6K-1_6K 20ms on A3 32 Cards Disaggregation Mode](#deepseek-r1-6k-1_6k-20ms-on-a3-32-cards-disaggregation-mode) #### Benchmark @@ -453,13 +451,13 @@ We tested it based on the `RANDOM` dataset. python -m sglang.bench_serving --dataset-name random --backend sglang --host 127.0.0.1 --port 6699 --max-concurrency 256 --random-input-len 2048 --random-output-len 2048 --num-prompts 1024 --random-range-ratio 1 ``` -### DeepSeek-R1 2K-2K 50ms on A3 16 Cards Separation Mode +### DeepSeek-R1 2K-2K 50ms on A3 16 Cards Disaggregation Mode Model: Deepseek R1 Hardware: Atlas 800I A3 16Card -DeployMode: PD Separation +DeployMode: PD Disaggregation Dataset: random @@ -559,7 +557,6 @@ done ``` ```shell -export SGLANG_DP_ROUND_ROBIN=1 python -m sglang_router.launch_router \ --pd-disaggregation \ --policy cache_aware \ @@ -655,13 +652,13 @@ We tested it based on the `RANDOM` dataset. python -m sglang.bench_serving --dataset-name random --backend sglang --host 127.0.0.1 --port 6699 --max-concurrency 144 --random-input-len 3500 --random-output-len 1500 --num-prompts 576 --random-range-ratio 1 ``` -### DeepSeek-R1 3_5K-1_5K 50ms on A3 16 Cards Separation Mode +### DeepSeek-R1 3_5K-1_5K 50ms on A3 16 Cards Disaggregation Mode Model: Deepseek R1 Hardware: Atlas 800I A3 16Card -DeployMode: PD Separation +DeployMode: PD Disaggregation Dataset: random @@ -760,7 +757,6 @@ done ``` ```shell -export SGLANG_DP_ROUND_ROBIN=1 python -m sglang_router.launch_router \ --pd-disaggregation \ --policy cache_aware \ @@ -779,24 +775,22 @@ We tested it based on the `RANDOM` dataset. python -m sglang.bench_serving --dataset-name random --backend sglang --host 127.0.0.1 --port 6688 --max-concurrency 384 --random-input-len 3500 --random-output-len 1500 --num-prompts 1536 --random-range-ratio 1 ``` -### DeepSeek-V3.2-Exp 64K-3K 30ms on A3 32 Cards Separation Mode +### DeepSeek-V3.2 128K-1K 20ms on A3 32 Cards Disaggregation Mode -Model: DeepSeek-V3.2-Exp-W8A8 +Model: DeepSeek-V3.2-W8A8 Hardware: Atlas 800I A3 32Card -DeployMode: PD Separation +DeployMode: PD Disaggregation Dataset: random -Input Output Length: 64K+3K +Input Output Length: 128K+1K -TPOT: 30ms +TPOT: 20ms #### Model Deployment -Deploy Prefill Instance - ```shell echo performance | tee /sys/devices/system/cpu/cpu*/cpufreq/scaling_governor sysctl -w vm.swappiness=0 @@ -815,167 +809,115 @@ source /usr/local/Ascend/nnal/atb/set_env.sh export LD_LIBRARY_PATH=/usr/local/Ascend/ascend-toolkit/latest/opp/vendors/customize/op_api/lib/:${LD_LIBRARY_PATH} export PATH=/usr/local/Ascend/8.5.0/compiler/bishengir/bin:$PATH -export ASCEND_HOME_PATH=/usr/local/Ascend/ascend-toolkit/latest - export PYTORCH_NPU_ALLOC_CONF=expandable_segments:True export STREAMS_PER_DEVICE=32 +export ASCEND_MF_STORE_URL="tcp://your prefill ip1:24670" -export HCCL_BUFFSIZE=1024 -export DEEPEP_NORMAL_LONG_SEQ_ROUND=5 -export DEEPEP_NORMAL_LONG_SEQ_PER_ROUND_TOKENS=512 - +P_IP=('your prefill ip1' 'your prefill ip2') +D_IP=('your decode ip1' 'your decode ip2') MODEL_PATH=xxx -export SGLANG_NPU_USE_MLAPO=1 -export DEEP_NORMAL_MODE_USE_INT8_QUANT=1 -export SGLANG_NPU_USE_MULTI_STREAM=1 -export HCCL_OP_EXPANSION_MODE=AIV - -IPs=('your prefill ip1' 'your prefill ip2') +LOCAL_HOST1=`hostname -I|awk -F " " '{print$1}'` +LOCAL_HOST2=`hostname -I|awk -F " " '{print$2}'` +echo "${LOCAL_HOST1}" +echo "${LOCAL_HOST2}" -# get IP in current node -LOCAL_HOST=`hostname -I|awk -F " " '{print$1}'` -echo "LOCAL_HOST = " ${LOCAL_HOST} -# get node index -for i in "${!IPs[@]}"; +# prefill +for i in "${!P_IP[@]}"; do - echo "LOCAL_HOST=${LOCAL_HOST}, IPs[${i}]=${IPs[$i]}" - if [ "$LOCAL_HOST" == "${IPs[$i]}" ]; then - echo "Node Rank : ${i}" - VC_TASK_INDEX=$i - break - fi -done - -IFNAMES=('xxx' 'xxx') - -export HCCL_SOCKET_IFNAME=${IFNAMES[$VC_TASK_INDEX]} -export GLOO_SOCKET_IFNAME=${HCCL_SOCKET_IFNAME} -echo "HCCL_SOCKET_IFNAME : ${HCCL_SOCKET_IFNAME}" -nnodes=${#IPs[@]} -tp_size=`expr 16 \* ${nnodes}` -export ASCEND_MF_STORE_URL=tcp://${IPs[0]}:24667 - -python3 -m sglang.launch_server --model-path ${MODEL_PATH} \ ---tp $tp_size \ ---trust-remote-code \ ---attention-backend ascend \ ---device npu \ ---watchdog-timeout 9000 \ ---host ${IPs[$VC_TASK_INDEX]} --port 8000 \ ---mem-fraction-static 0.73 \ ---disable-radix-cache --chunked-prefill-size -1 --max-prefill-tokens 68000 \ ---max-running-requests 1 \ ---moe-a2a-backend deepep --deepep-mode normal \ ---quantization modelslim \ ---disaggregation-transfer-backend ascend \ ---disaggregation-mode prefill \ ---disable-cuda-graph \ ---nnodes $nnodes --node-rank $VC_TASK_INDEX \ ---disaggregation-bootstrap-port 8995 \ ---enable-nsa-prefill-context-parallel --moe-dense-tp-size 1 \ ---speculative-algorithm NEXTN --speculative-num-steps 1 --speculative-eagle-topk 1 --speculative-num-draft-tokens 2 \ ---dist-init-addr ${IPs[0]}:10000 -``` - -Deploy Decode Instance - -```shell -echo performance | tee /sys/devices/system/cpu/cpu*/cpufreq/scaling_governor -sysctl -w vm.swappiness=0 -sysctl -w kernel.numa_balancing=0 -sysctl -w kernel.sched_migration_cost_ns=50000 - -export SGLANG_SET_CPU_AFFINITY=1 -unset https_proxy -unset http_proxy -unset HTTPS_PROXY -unset HTTP_PROXY -unset ASCEND_LAUNCH_BLOCKING -source /usr/local/Ascend/ascend-toolkit/set_env.sh -source /usr/local/Ascend/nnal/atb/set_env.sh -export LD_LIBRARY_PATH=/usr/local/Ascend/ascend-toolkit/latest/opp/vendors/customize/op_api/lib/:${LD_LIBRARY_PATH} -export PATH=/usr/local/Ascend/8.5.0/compiler/bishengir/bin:$PATH -export ASCEND_HOME_PATH=/usr/local/Ascend/ascend-toolkit/latest - -export PYTORCH_NPU_ALLOC_CONF=expandable_segments:True -export STREAMS_PER_DEVICE=32 - -MODEL_PATH=xxx + if [[ "$LOCAL_HOST1" == "${P_IP[$i]}" || "$LOCAL_HOST2" == "${P_IP[$i]}" ]]; + then + echo "${P_IP[$i]}" + export HCCL_BUFFSIZE=1200 + export DEEP_NORMAL_MODE_USE_INT8_QUANT=1 + export TASK_QUEUE_ENABLE=2 + export HCCL_SOCKET_IFNAME=xxx + export GLOO_SOCKET_IFNAME=xxx -export SGLANG_NPU_USE_MULTI_STREAM=1 -export SGLANG_NPU_USE_MLAPO=1 -export HCCL_OP_EXPANSION_MODE=AIV -export SGLANG_SCHEDULER_SKIP_ALL_GATHER=1 -export TASK_QUEUE_ENABLE=0 -export SGLANG_ENABLE_OVERLAP_PLAN_STREAM=1 -export SGLANG_ENABLE_SPEC_V2=1 + python3 -m sglang.launch_server --model-path ${MODEL_PATH} \ + --tp 32 \ + --trust-remote-code \ + --attention-backend ascend \ + --device npu \ + --watchdog-timeout 9000 \ + --host ${P_IP[$i]} --port 8000 \ + --mem-fraction-static 0.73 \ + --disable-radix-cache --chunked-prefill-size -1 --max-prefill-tokens 68000 \ + --max-running-requests 1 \ + --moe-a2a-backend deepep --deepep-mode normal \ + --quantization modelslim \ + --disaggregation-transfer-backend ascend \ + --disaggregation-mode prefill \ + --disable-cuda-graph \ + --nnodes 2 --node-rank $i \ + --disaggregation-bootstrap-port 8995 \ + --moe-dense-tp-size 1 \ + --enable-nsa-prefill-context-parallel \ + --nsa-prefill-cp-mode in-seq-split \ + --attn-cp-size 32 \ + --speculative-algorithm NEXTN --speculative-num-steps 1 --speculative-eagle-topk 1 --speculative-num-draft-tokens 2 \ + --dist-init-addr ${P_IP[0]}:10000 + break + fi +done -IPs=('your decode ip1' 'your decode ip2') -export prefill_ip=your prefill ip1 -# get IP in current node -LOCAL_HOST=`hostname -I|awk -F " " '{print$1}'` -echo "LOCAL_HOST = " ${LOCAL_HOST} -# get node index -for i in "${!IPs[@]}"; +# decode +for i in "${!D_IP[@]}"; do - echo "LOCAL_HOST=${LOCAL_HOST}, IPs[${i}]=${IPs[$i]}" - if [ "$LOCAL_HOST" == "${IPs[$i]}" ]; then - echo "Node Rank : ${i}" - VC_TASK_INDEX=$i - break - fi -done - -IFNAMES=('xxx' 'xxx') + if [[ "$LOCAL_HOST1" == "${D_IP[$i]}" || "$LOCAL_HOST2" == "${D_IP[$i]}" ]]; + then + echo "${D_IP[$i]}" + export SGLANG_ENABLE_OVERLAP_PLAN_STREAM=1 + export SGLANG_ENABLE_SPEC_V2=1 -export HCCL_SOCKET_IFNAME=${IFNAMES[$VC_TASK_INDEX]} -export GLOO_SOCKET_IFNAME=${HCCL_SOCKET_IFNAME} -nnodes=${#IPs[@]} -tp_size=`expr 16 \* ${nnodes}` -export ASCEND_MF_STORE_URL=tcp://${prefill_ip}:24667 + export TASK_QUEUE_ENABLE=0 + export SGLANG_SCHEDULER_SKIP_ALL_GATHER=1 -CHUNKED_SIZE=65536 -DP=8 -export HCCL_BUFFSIZE=400 -export SGLANG_DEEPEP_NUM_MAX_DISPATCH_TOKENS_PER_RANK=8 + export HCCL_SOCKET_IFNAME=xxx + export GLOO_SOCKET_IFNAME=xxx -python3 -m sglang.launch_server --model-path ${MODEL_PATH} \ ---tp $tp_size \ ---dp ${DP} \ ---ep $tp_size \ ---moe-dense-tp-size 1 \ ---enable-dp-attention \ ---enable-dp-lm-head \ ---trust-remote-code \ ---attention-backend ascend \ ---device npu \ ---watchdog-timeout 9000 \ ---host ${IPs[$VC_TASK_INDEX]} --port 8001 \ ---mem-fraction-static 0.79 \ ---disable-radix-cache \ ---chunked-prefill-size -1 --max-prefill-tokens 68000 \ ---max-running-requests 32 \ ---cuda-graph-max-bs 4 \ ---moe-a2a-backend deepep \ ---deepep-mode low_latency \ ---quantization modelslim \ ---speculative-algorithm NEXTN --speculative-num-steps 3 --speculative-eagle-topk 1 --speculative-num-draft-tokens 4 \ ---disaggregation-transfer-backend ascend \ ---disaggregation-mode decode \ ---load-balance-method round_robin \ ---nnodes $nnodes --node-rank $VC_TASK_INDEX \ ---dist-init-addr ${IPs[0]}:10000 --load-balance-method decode_round_robin + DP=8 + export HCCL_BUFFSIZE=400 + export SGLANG_DEEPEP_NUM_MAX_DISPATCH_TOKENS_PER_RANK=8 + + python3 -m sglang.launch_server --model-path ${MODEL_PATH} \ + --tp 32 \ + --dp ${DP} \ + --ep 32 \ + --moe-dense-tp-size 1 \ + --enable-dp-attention \ + --enable-dp-lm-head \ + --trust-remote-code \ + --attention-backend ascend \ + --device npu \ + --watchdog-timeout 9000 \ + --host ${D_IP[$i]} --port 8001 \ + --mem-fraction-static 0.79 \ + --disable-radix-cache \ + --chunked-prefill-size -1 --max-prefill-tokens 68000 \ + --max-running-requests 32 \ + --cuda-graph-max-bs 4 \ + --moe-a2a-backend deepep \ + --deepep-mode low_latency \ + --quantization modelslim \ + --speculative-algorithm NEXTN --speculative-num-steps 3 --speculative-eagle-topk 1 --speculative-num-draft-tokens 4 \ + --disaggregation-transfer-backend ascend \ + --disaggregation-mode decode \ + --nnodes 2 --node-rank $i \ + --dist-init-addr ${D_IP[0]}:10000 + break + fi +done ``` + ```shell -export SGLANG_DP_ROUND_ROBIN=1 python -m sglang_router.launch_router \ --pd-disaggregation \ --policy cache_aware \ - --prefill http://PIP1:8000 8995 \ - --decode http://DIP1:8001 \ + --prefill http://P_IP1:8000 8995 \ + --decode http://D_IP1:8001 \ --host 127.0.0.1 \ --port 6688 \ --mini-lb @@ -986,16 +928,16 @@ python -m sglang_router.launch_router \ We tested it based on the `RANDOM` dataset. ```shell -python -m sglang.bench_serving --dataset-name random --backend sglang --host 127.0.0.1 --port 6688 --max-concurrency 32 --random-input-len 64000 --random-output-len 3000 --num-prompts 64 --random-range-ratio 1 +python -m sglang.bench_serving --dataset-name random --backend sglang --host 127.0.0.1 --port 6688 --max-concurrency 8 --random-input-len 131076 --random-output-len 1024 --num-prompts 8 --random-range-ratio 1 ``` -### Qwen3-235B-A22B 3_5K-1_5K 50ms on A3 24 Cards Separation Mode +### Qwen3-235B-A22B 3_5K-1_5K 50ms on A3 24 Cards Disaggregation Mode Model: Qwen3-235B-A22B-W8A8 Hardware: Atlas 800I A3 24Card -DeployMode: PD Separation +DeployMode: PD Disaggregation Dataset: random @@ -1030,7 +972,6 @@ D_IP=('your decode ip1' 'your decode ip2') export SGLANG_DISAGGREGATION_BOOTSTRAP_TIMEOUT=600 export SGLANG_ENABLE_OVERLAP_PLAN_STREAM=1 export SGLANG_ENABLE_SPEC_V2=1 -export SGLANG_DP_ROUND_ROBIN=1 LOCAL_HOST1=`hostname -I|awk -F " " '{print$1}'` LOCAL_HOST2=`hostname -I|awk -F " " '{print$2}'` @@ -1106,7 +1047,6 @@ done ``` ```shell -export SGLANG_DP_ROUND_ROBIN=1 python -m sglang_router.launch_router \ --pd-disaggregation \ --policy cache_aware \ @@ -1920,13 +1860,13 @@ We tested it based on the `RANDOM` dataset. python -m sglang.bench_serving --dataset-name random --backend sglang --host 127.0.0.1 --port 7239 --max-concurrency 156 --random-input-len 3500 --random-output-len 1500 --num-prompts 624 --random-range-ratio 1 ``` -### Qwen3-Coder-480B-A35B-Instruct 3_5K-1_5K 50ms on A3 24 Cards Separation Mode +### Qwen3-Coder-480B-A35B-Instruct 3_5K-1_5K 50ms on A3 24 Cards Disaggregation Mode Model: Qwen3-Coder-480B-A35B-Instruct Hardware: Atlas 800I A3 24Card -DeployMode: PD Separation +DeployMode: PD Disaggregation Dataset: random @@ -2025,7 +1965,6 @@ done ``` ```shell -export SGLANG_DP_ROUND_ROBIN=1 python -m sglang_router.launch_router \ --pd-disaggregation \ --policy cache_aware \ diff --git a/docs/platforms/ascend_npu_deepseek_example.md b/docs/platforms/ascend/ascend_npu_deepseek_example.md similarity index 99% rename from docs/platforms/ascend_npu_deepseek_example.md rename to docs/platforms/ascend/ascend_npu_deepseek_example.md index cdecb544c13a..abda404d5995 100644 --- a/docs/platforms/ascend_npu_deepseek_example.md +++ b/docs/platforms/ascend/ascend_npu_deepseek_example.md @@ -262,7 +262,6 @@ done 2. SGLang Model Gateway (former Router): ```shell -export SGLANG_DP_ROUND_ROBIN=1 python -m sglang_router.launch_router \ --pd-disaggregation \ --policy cache_aware \ diff --git a/docs/platforms/ascend_npu_environment_variables.md b/docs/platforms/ascend/ascend_npu_environment_variables.md similarity index 100% rename from docs/platforms/ascend_npu_environment_variables.md rename to docs/platforms/ascend/ascend_npu_environment_variables.md diff --git a/docs/platforms/ascend_npu_glm5_examples.md b/docs/platforms/ascend/ascend_npu_glm5_examples.md similarity index 98% rename from docs/platforms/ascend_npu_glm5_examples.md rename to docs/platforms/ascend/ascend_npu_glm5_examples.md index f748b6408a10..f613e8956583 100644 --- a/docs/platforms/ascend_npu_glm5_examples.md +++ b/docs/platforms/ascend/ascend_npu_glm5_examples.md @@ -191,4 +191,4 @@ Not test yet. ### Using Benchmark -Refer to [Benchmark and Profiling](../developer_guide/benchmark_and_profiling.md) for details. +Refer to [Benchmark and Profiling](../../developer_guide/benchmark_and_profiling.md) for details. diff --git a/docs/platforms/ascend/ascend_npu_quantization.md b/docs/platforms/ascend/ascend_npu_quantization.md new file mode 100644 index 000000000000..8b2e30ba1e5d --- /dev/null +++ b/docs/platforms/ascend/ascend_npu_quantization.md @@ -0,0 +1,52 @@ +# Quantization on Ascend + +To load already quantized models, simply load the model weights and config. Again, if the model has been quantized offline, there's no need to add `--quantization` argument when starting the engine. The quantization method will be automatically parsed from the downloaded `quant_model_description.json` or `config.json` config. + +SGLang support **mix-bits** quantization (independently defines and loads each layer depending on the type of quantification specified in the `quant_model_description'.json`). [Advanced mix-bits for MoE](https://github.com/sgl-project/sglang/pull/17361) in progress, will add independent quantization determination for the w13 (up-gate) and w2 (down) layers). + +[ModelSlim on Ascend support](https://github.com/sgl-project/sglang/pull/14504) +| Quantization scheme | Layer type | A2 Supported | A3 Supported | A5 Supported | Diffusion models | +|-----------------------------------------------------------|--------------------------|:----------------------------------------:|:----------------------------------------:|:------------------------------------------:|:------------------------------------------:| +| W4A4 dynamic | Linear | **√** | **√** | **TBD** | **√** | +| W8A8 static | Linear | **√** | **√** | **TBD** | **√** | +| W8A8 dynamic | Linear | **√** | **√** | **TBD** | **√** | +| [MXFP8](https://github.com/sgl-project/sglang/pull/20922) | Linear | **x** | **x** | **WIP** | **WIP** | +| W4A4 dynamic | MoE | **√** | **√** | **TBD** | **x** | +| W4A8 dynamic | MoE | **√** | **√** | **TBD** | **x** | +| W8A8 dynamic | MoE | **√** | **√** | **TBD** | **x** | +| [MXFP8](https://github.com/sgl-project/sglang/pull/20922) | MoE | **x** | **x** | **WIP** | **x** | + +[AWQ on Ascend support](https://github.com/sgl-project/sglang/pull/10158): +| Quantization scheme | Layer type | A2 Supported | A3 Supported | A5 Supported | +|--------------------------------|--------------------------|:----------------------------------------:|:----------------------------------------:|:------------------------------------------:| +| W4A16 | Linear | **√** | **√** | **TBD** | +| W8A16 | Linear | **√** | **√** | **TBD** | +| W4A16 | MoE | **√** | **√** | **TBD** | + +GPTQ on Ascend support +| Quantization scheme | Layer type | A2 Supported | A3 Supported | A5 Supported | +|----------------------------------------------------------------------------|--------------------------|:----------------------------------------:|:----------------------------------------:|:-----------------------------------------:| +| [W4A16](https://github.com/sgl-project/sglang/pull/15203) | Linear | **√** | **√** | **TBD** | +| [W8A16](https://github.com/sgl-project/sglang/pull/15203) | Linear | **√** | **√** | **TBD** | +| [W4A16 MOE](https://github.com/sgl-project/sglang/pull/16364) | MoE | **√** | **√** | **TBD** | +| [W8A16 MOE](https://github.com/sgl-project/sglang/pull/16364) | MoE | **√** | **√** | **TBD** | + +[Auto-round on Ascend support](https://github.com/sgl-project/sglang/pull/16699) +| Quantization scheme | Layer type | A2 Supported | A3 Supported | A5 Supported | +|--------------------------------|--------------------------|:----------------------------------------:|:----------------------------------------:|:-----------------------------------------:| +| W4A16 | Linear | **√** | **√** | **TBD** | +| W8A16 | Linear | **√** | **√** | **TBD** | +| W4A16 | MoE | **√** | **√** | **TBD** | +| W8A16 | MoE | **√** | **√** | **TBD** | + +Compressed-tensors (LLM Compressor) on Ascend support: +| Quantization scheme | Layer type | A2 Supported | A3 Supported | A5 Supported | +|-----------------------------------------------------------------------------------------------|--------------------------|:----------------------------------------:|:----------------------------------------:|:-----------------------------------------:| +| [W8A8 dynamic](https://github.com/sgl-project/sglang/pull/14504) | Linear | **√** | **√** | **TBD** | +| [W4A8 dynamic with/without activation clip](https://github.com/sgl-project/sglang/pull/14736) | MoE | **√** | **√** | **TBD** | +| [W4A16 MOE](https://github.com/sgl-project/sglang/pull/12759) | MoE | **√** | **√** | **TBD** | +| [W8A8 dynamic](https://github.com/sgl-project/sglang/pull/14504) | MoE | **√** | **√** | **TBD** | + +[GGUF on Ascend support](https://github.com/sgl-project/sglang/pull/17883) + +in progress diff --git a/docs/platforms/ascend_npu_qwen3_5_examples.md b/docs/platforms/ascend/ascend_npu_qwen3_5_examples.md similarity index 98% rename from docs/platforms/ascend_npu_qwen3_5_examples.md rename to docs/platforms/ascend/ascend_npu_qwen3_5_examples.md index b19f7321e84e..8660f17cc5ea 100644 --- a/docs/platforms/ascend_npu_qwen3_5_examples.md +++ b/docs/platforms/ascend/ascend_npu_qwen3_5_examples.md @@ -228,4 +228,4 @@ Not test yet. ### Using Benchmark -Refer to [Benchmark and Profiling](../developer_guide/benchmark_and_profiling.md) for details. +Refer to [Benchmark and Profiling](../../developer_guide/benchmark_and_profiling.md) for details. diff --git a/docs/platforms/ascend_npu_qwen3_examples.md b/docs/platforms/ascend/ascend_npu_qwen3_examples.md similarity index 56% rename from docs/platforms/ascend_npu_qwen3_examples.md rename to docs/platforms/ascend/ascend_npu_qwen3_examples.md index 5278a22a1001..7ceedd351d1f 100644 --- a/docs/platforms/ascend_npu_qwen3_examples.md +++ b/docs/platforms/ascend/ascend_npu_qwen3_examples.md @@ -94,6 +94,97 @@ python -m sglang.launch_server \ --mem-fraction-static 0.8 ``` +#### Running Qwen3-235B-A22B-Instruct-2507 with 256K long sequence on 2 x Atlas 800I A3 without CP + +This example uses **PD disaggregation** for long-sequence inference and keeps **context parallel disabled**. + +Set the shared environment variables on both nodes first: + +```shell +export ASCEND_USE_FIA=1 +export SGLANG_SET_CPU_AFFINITY=1 +export ASCEND_MF_STORE_URL="tcp://:12345" +export HCCL_SOCKET_IFNAME= +export GLOO_SOCKET_IFNAME= + +MODEL_PATH=/root/.cache/modelscope/hub/models/zcgy26/Qwen3-235B-A22B-Instruct-2507-w8a8 +``` + +**Prefill node:** + +```shell +export ASCEND_LAUNCH_BLOCKING=1 +export DEEP_NORMAL_MODE_USE_INT8_QUANT=1 +export HCCL_BUFFSIZE=1500 +export DEEPEP_NORMAL_LONG_SEQ_PER_ROUND_TOKENS=1024 +export DEEPEP_NORMAL_LONG_SEQ_ROUND=128 +export DEEPEP_NORMAL_COMBINE_ENABLE_LONG_SEQ=1 + +python3 -m sglang.launch_server \ + --model-path ${MODEL_PATH} \ + --disaggregation-mode prefill \ + --disaggregation-transfer-backend ascend \ + --disaggregation-bootstrap-port 8995 \ + --attention-backend ascend \ + --disable-radix-cache \ + --quantization modelslim \ + --chunked-prefill-size -1 \ + --skip-server-warmup \ + --device npu \ + --tp-size 16 \ + --mem-fraction-static 0.45 \ + --max-running-requests 1 \ + --host \ + --port 8000 \ + --dist-init-addr :5000 \ + --nnodes 1 \ + --node-rank 0 \ + --moe-a2a-backend deepep \ + --deepep-mode normal +``` + +**Decode node:** + +```shell +export SGLANG_DEEPEP_BF16_DISPATCH=0 +export HCCL_BUFFSIZE=4000 +export DEEPEP_NORMAL_LONG_SEQ_PER_ROUND_TOKENS=4096 +export DEEPEP_NORMAL_LONG_SEQ_ROUND=16 + +python3 -m sglang.launch_server \ + --model-path ${MODEL_PATH} \ + --disaggregation-mode decode \ + --disaggregation-transfer-backend ascend \ + --attention-backend ascend \ + --mem-fraction-static 0.8 \ + --disable-cuda-graph \ + --device npu \ + --disable-radix-cache \ + --quantization modelslim \ + --chunked-prefill-size 8192 \ + --skip-server-warmup \ + --tp-size 16 \ + --max-running-requests 1 \ + --host \ + --port 8232 \ + --moe-a2a-backend deepep \ + --deepep-mode low_latency \ + --disable-overlap-schedule +``` + +**Router:** + +```shell +python3 -m sglang_router.launch_router \ + --pd-disaggregation \ + --policy cache_aware \ + --prefill http://:8000 8995 \ + --decode http://:8232 \ + --host \ + --port 6689 \ + --prometheus-port 29010 +``` + #### Running Qwen3-VL-8B-Instruct on 1 x Atlas 800I A3. Model weights could be found [here](https://huggingface.co/Qwen/Qwen3-VL-8B-Instruct) diff --git a/docs/platforms/ascend_npu_support.rst b/docs/platforms/ascend/ascend_npu_support.rst similarity index 86% rename from docs/platforms/ascend_npu_support.rst rename to docs/platforms/ascend/ascend_npu_support.rst index cd64c58f6cd5..a579cbfb1169 100644 --- a/docs/platforms/ascend_npu_support.rst +++ b/docs/platforms/ascend/ascend_npu_support.rst @@ -7,11 +7,13 @@ Ascend NPUs ascend_npu.md ascend_npu_support_features.md ascend_npu_support_models.md + ascend_npu_quantization.md ascend_npu_deepseek_example.md ascend_npu_qwen3_examples.md mindspore_backend.md ascend_contribution_guide.md ascend_npu_best_practice.md + ascend_npu_ring_sp_performance.md ascend_npu_qwen3_5_examples.md ascend_npu_glm5_examples.md ascend_npu_environment_variables.md diff --git a/docs/platforms/ascend_npu_support_features.md b/docs/platforms/ascend/ascend_npu_support_features.md similarity index 99% rename from docs/platforms/ascend_npu_support_features.md rename to docs/platforms/ascend/ascend_npu_support_features.md index 80bf6ce890d4..54a9cf81384a 100644 --- a/docs/platforms/ascend_npu_support_features.md +++ b/docs/platforms/ascend/ascend_npu_support_features.md @@ -104,7 +104,6 @@ click [Server Arguments](https://docs.sglang.io/advanced_features/server_argumen | `--base-gpu-id` | `0` | Type: int | A2, A3 | | `--gpu-id-step` | `1` | Type: int | A2, A3 | | `--sleep-on-idle` | `False` | bool flag (set to enable) | A2, A3 | -| `--custom-sigquit-handler` | `None` | Optional[Callable] | A2, A3 | ## Logging diff --git a/docs/platforms/ascend_npu_support_models.md b/docs/platforms/ascend/ascend_npu_support_models.md similarity index 100% rename from docs/platforms/ascend_npu_support_models.md rename to docs/platforms/ascend/ascend_npu_support_models.md diff --git a/docs/platforms/mindspore_backend.md b/docs/platforms/ascend/mindspore_backend.md similarity index 100% rename from docs/platforms/mindspore_backend.md rename to docs/platforms/ascend/mindspore_backend.md diff --git a/docs/platforms/ascend_npu_quantization.md b/docs/platforms/ascend_npu_quantization.md deleted file mode 100644 index fb4adb54fb35..000000000000 --- a/docs/platforms/ascend_npu_quantization.md +++ /dev/null @@ -1,27 +0,0 @@ -Quantization on Ascend. - -To load already quantized models, simply load the model weights and config. Again, if the model has been quantized offline, there's no need to add `--quantization` argument when starting the engine. The quantization method will be automatically parsed from the downloaded `quant_model_description.json` or `config.json` config. - -[ModelSlim on Ascend support](https://github.com/sgl-project/sglang/pull/14504): -- [x] W4A4 dynamic linear -- [x] W8A8 static linear -- [x] W8A8 dynamic linear -- [x] W4A4 dynamic MOE -- [x] W4A8 dynamic MOE -- [x] W8A8 dynamic MOE - -[AWQ on Ascend support](https://github.com/sgl-project/sglang/pull/10158): -- [x] W4A16 linear -- [x] W8A16 linear # Need to test -- [x] W4A16 MOE # Need to test - -Compressed-tensors (LLM Compressor) on Ascend support: -- [x] [W4A8 dynamic MOE with/without activation clip](https://github.com/sgl-project/sglang/pull/14736) # Need to test -- [x] [W4A16 MOE](https://github.com/sgl-project/sglang/pull/12759) -- [x] [W8A8 dynamic linear](https://github.com/sgl-project/sglang/pull/14504) -- [x] [W8A8 dynamic MOE](https://github.com/sgl-project/sglang/pull/14504) - -Diffusion model [modelslim](https://github.com/sgl-project/sglang/pull/17996) quantization on Ascend support: -- [x] W4A4 dynamic linear -- [x] W8A8 static linear -- [x] W8A8 dynamic linear diff --git a/docs/platforms/ascend_npu_ring_sp_performance.md b/docs/platforms/ascend_npu_ring_sp_performance.md new file mode 100644 index 000000000000..014328aefa4f --- /dev/null +++ b/docs/platforms/ascend_npu_ring_sp_performance.md @@ -0,0 +1,55 @@ +# Ascend NPU Ring-SP Performance (Wan2.1-T2V-1.3B) + +This page reports Ring-SP performance on Ascend NPU with `torch_npu==2.10.0`. + +- Baseline config: `ulysses=1, ring=1` (short: `u1r1`) +- Ring-SP config: `ulysses=1, ring=2` (short: `u1r2`) + +## Benchmark Setup + +- Model: `Wan2.1-T2V-1.3B-Diffusers` +- Prompt: `"a cat is playing piano"` +- Framework command: `sglang generate` +- Runtime: `torch_npu==2.10.0` + +## Generate Commands + +### Baseline (`u1r1`) + +```bash +sglang generate --model-path /nas/disk1/Wan2.1-T2V-1.3B-Diffusers \ + --prompt "a cat is playing piano" --num-gpus 1 --ring-degree 1 \ + --save-output +``` + +### Ring-SP (`u1r2`) + +```bash +sglang generate --model-path /nas/disk1/Wan2.1-T2V-1.3B-Diffusers \ + --prompt "a cat is playing piano" --num-gpus 2 --ring-degree 2 \ + --save-output +``` + +## Benchmarks + +Benchmark Disclaimer + +These numbers are from one fixed setup and one prompt case. Actual performance may vary by model settings, environment, and workload. + +### Stage Time Breakdown + +| Stage / Metric | `u1r2` (s) | `u1r1` baseline (s) | Speedup | +|---|---:|---:|---:| +| InputValidation | 0.0003 | 0.0002 | 0.67x | +| TextEncoding | 3.5936 | 3.5820 | 1.00x | +| LatentPreparation | 0.0007 | 0.0055 | 7.86x | +| TimestepPreparation | 0.0008 | 0.0007 | 0.88x | +| Denoising | 121.2788 | 239.2580 | 1.97x | +| Decoding | 13.8685 | 16.4969 | 1.19x | +| **Total (Pixel data generated)** | **141.86** | **266.50** | **1.88x** | + +## Summary + +- With `torch_npu==2.10.0`, Ring-SP (`u1r2`) runs successfully on NPU for this case. +- End-to-end generation time improves from `266.50s` to `141.86s` (`1.88x`). +- The main gain comes from `DenoisingStage` (`1.97x`), while decoding also improves (`1.19x`). diff --git a/docs/references/environment_variables.md b/docs/references/environment_variables.md index 9d95c4f237aa..8a28436d802e 100644 --- a/docs/references/environment_variables.md +++ b/docs/references/environment_variables.md @@ -45,6 +45,8 @@ SGLang supports various environment variables that can be used to configure its | `SGLANG_NCCL_ALL_GATHER_IN_OVERLAP_SCHEDULER_SYNC_BATCH` | Enable NCCL for gathering when preparing mlp sync batch under overlap scheduler (without this flag gloo is used for gathering) | `false` | | `SGLANG_SYMM_MEM_PREALLOC_GB_SIZE` | Size of preallocated GPU buffer (in GB) for NCCL symmetric memory pool to limit memory fragmentation. Only have an effect when server arg `--enable-symm-mem` is set. | `-1` | | `SGLANG_CUSTOM_ALLREDUCE_ALGO` | The algorithm of custom all-reduce. Set to `oneshot` or `1stage` to force use one-shot. Set to `twoshot` or `2stage` to force use two-shot. | `` | +| `SGLANG_SKIP_SOFTMAX_PREFILL_THRESHOLD_SCALE_FACTOR` | Skip-softmax threshold scale factor for TRT-LLM prefill attention in flashinfer. `None` means standard attention. See https://arxiv.org/abs/2512.12087 | `None` | +| `SGLANG_SKIP_SOFTMAX_DECODE_THRESHOLD_SCALE_FACTOR` | Skip-softmax threshold scale factor for TRT-LLM decode attention in flashinfer. `None` means standard attention. See https://arxiv.org/abs/2512.12087 | `None` | ## DeepGEMM Configuration (Advanced Optimization) @@ -119,12 +121,11 @@ SGLang supports various environment variables that can be used to configure its | `SGLANG_INT4_WEIGHT` | Enable INT4 weight quantization | `false` | | `SGLANG_PER_TOKEN_GROUP_QUANT_8BIT_V2` | Apply per token group quantization kernel with fused silu and mul and masked m | `false` | | `SGLANG_FORCE_FP8_MARLIN` | Force using FP8 MARLIN kernels even if other FP8 kernels are available | `false` | +| `SGLANG_FORCE_NVFP4_MARLIN` | Force using NVFP4 Marlin fallback kernels even on Blackwell GPUs with native FP4 support | `false` | | `SGLANG_FLASHINFER_FP4_GEMM_BACKEND` (deprecated) | Select backend for `mm_fp4` on Blackwell GPUs. **DEPRECATED**: Please use `--fp4-gemm-backend` instead. | `` | | `SGLANG_NVFP4_CKPT_FP8_GEMM_IN_ATTN` | Quantize q_b_proj from BF16 to FP8 when launching DeepSeek NVFP4 checkpoint | `false` | | `SGLANG_MOE_NVFP4_DISPATCH` | Use nvfp4 for moe dispatch (on flashinfer_cutlass or flashinfer_cutedsl moe runner backend) | `"false"` | | `SGLANG_NVFP4_CKPT_FP8_NEXTN_MOE` | Quantize moe of nextn layer from BF16 to FP8 when launching DeepSeek NVFP4 checkpoint | `false` | -| `SGLANG_ENABLE_FLASHINFER_FP8_GEMM` (deprecated) | Use flashinfer kernels when running blockwise fp8 GEMM on Blackwell GPUs. **DEPRECATED**: Please use `--fp8-gemm-backend=flashinfer_trtllm` (SM100/SM103) or `--fp8-gemm-backend=flashinfer_cutlass` (SM120/SM121 and newer) instead. | `false` | -| `SGLANG_SUPPORT_CUTLASS_BLOCK_FP8` (deprecated) | Use Cutlass kernels when running blockwise fp8 GEMM on Hopper or Blackwell GPUs. **DEPRECATED**: Please use `--fp8-gemm-backend=cutlass` instead. | `false` | | `SGLANG_QUANT_ALLOW_DOWNCASTING` | Allow weight dtype downcasting during loading (e.g., fp32 → fp16). By default, SGLang rejects this kind of downcasting when using quantization. | `false` | | `SGLANG_FP8_IGNORED_LAYERS` | A comma-separated list of layer names to ignore during FP8 quantization. For example: `model.layers.0,model.layers.1.,qkv_proj`. | `""` | diff --git a/docs/references/production_metrics.md b/docs/references/production_metrics.md index 85a6ff8a64a6..d104584ee4bc 100644 --- a/docs/references/production_metrics.md +++ b/docs/references/production_metrics.md @@ -142,7 +142,8 @@ This section describes how to set up the monitoring stack (Prometheus + Grafana) python -m sglang.launch_server \ --model-path \ --port 30000 \ - --enable-metrics + --enable-metrics \ + --enable-mfu-metrics ``` Replace `` with the actual path to your model (e.g., `meta-llama/Meta-Llama-3.1-8B-Instruct`). Ensure the server is accessible from the monitoring stack (you might need `--host 0.0.0.0` if running in Docker). By default, the metrics endpoint will be available at `http://:30000/metrics`. @@ -229,3 +230,38 @@ python3 -m sglang.bench_serving \ to generate some requests. Then you should be able to see the metrics in the Grafana dashboard. + +## Estimated Performance Metrics (MFU-related) + +SGLang exports the following estimated per-GPU counters that can be used to derive +Model FLOPs Utilization (MFU)-related signals: + +- `sglang:estimated_flops_per_gpu_total`: Estimated floating-point operations. +- `sglang:estimated_read_bytes_per_gpu_total`: Estimated bytes read from memory. +- `sglang:estimated_write_bytes_per_gpu_total`: Estimated bytes written to memory. + +These metrics are available when both `--enable-metrics` and +`--enable-mfu-metrics` are enabled. + +These are cumulative counters. Use Prometheus `rate(...)` to get per-second values. + +### PromQL examples + +Average TFLOPS per GPU: + +```promql +rate(sglang:estimated_flops_per_gpu_total[1m]) / 1e12 +``` + +Average estimated memory bandwidth in GB/s: + +```promql +(rate(sglang:estimated_read_bytes_per_gpu_total[1m]) + + rate(sglang:estimated_write_bytes_per_gpu_total[1m])) / 1e9 +``` + +### Notes + +- These metrics are estimates intended for observability and trend analysis. +- Estimated memory bytes reflect modeled traffic and are not a direct hardware + counter from GPU profilers. diff --git a/docs/supported_models/extending/mindspore_models.md b/docs/supported_models/extending/mindspore_models.md index 1e5583293ca8..caa5ade9c166 100644 --- a/docs/supported_models/extending/mindspore_models.md +++ b/docs/supported_models/extending/mindspore_models.md @@ -19,7 +19,7 @@ Currently, the following models are supported: ## Installation -> **Note**: Currently, MindSpore models are provided by an independent package `sgl-mindspore`. Support for MindSpore is built upon current SGLang support for Ascend NPU platform. Please first [install SGLang for Ascend NPU](../../platforms/ascend_npu.md) and then install `sgl-mindspore`: +> **Note**: Currently, MindSpore models are provided by an independent package `sgl-mindspore`. Support for MindSpore is built upon current SGLang support for Ascend NPU platform. Please first [install SGLang for Ascend NPU](../../platforms/ascend/ascend_npu.md) and then install `sgl-mindspore`: ```shell git clone https://github.com/mindspore-lab/sgl-mindspore.git diff --git a/python/pyproject.toml b/python/pyproject.toml index 483cf432bccd..7510b439185a 100755 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -27,8 +27,8 @@ dependencies = [ "datasets", "einops", "fastapi", - "flashinfer_python==0.6.6", # keep it aligned with jit-cache version in Dockerfile - "flashinfer_cubin==0.6.6", + "flashinfer_python==0.6.7", # keep it aligned with jit-cache version in Dockerfile + "flashinfer_cubin==0.6.7", "gguf", "interegular", "llguidance>=0.7.11,<0.8.0", @@ -97,6 +97,7 @@ torch = [ [project.optional-dependencies] checkpoint-engine = ["checkpoint-engine==0.1.2"] +runai = ["runai-model-streamer[s3,gcs,azure]>=0.15.7"] diffusion = [ "PyYAML==6.0.1", "cloudpickle==3.1.2", @@ -108,7 +109,7 @@ diffusion = [ "remote-pdb==2.1.0", "st_attn==0.0.7 ; platform_machine != 'aarch64' and platform_machine != 'arm64'", "vsa==0.0.4 ; platform_machine != 'aarch64' and platform_machine != 'arm64'", - "runai_model_streamer>=0.15.5", + "runai_model_streamer>=0.15.7", "cache-dit==1.3.0", "addict==2.4.0", "av==16.1.0", @@ -139,6 +140,7 @@ test = [ "pandas", "parameterized", "peft>=0.18.0", + "polars", "pytest", "pytest-cov", "diff-cover", diff --git a/python/pyproject_npu.toml b/python/pyproject_npu.toml index ddb0844b23c2..c7f989467184 100644 --- a/python/pyproject_npu.toml +++ b/python/pyproject_npu.toml @@ -26,6 +26,8 @@ dependencies = [ "einops", "fastapi", "gguf", + "hf_transfer", + "huggingface_hub", "interegular", "llguidance>=0.7.11,<0.8.0", "modelscope", diff --git a/python/pyproject_other.toml b/python/pyproject_other.toml index 395c7450e3e8..ba19cbc13fa3 100755 --- a/python/pyproject_other.toml +++ b/python/pyproject_other.toml @@ -98,6 +98,7 @@ srt_hip = [ diffusion_hip = [ "sglang[diffusion_common]", + "peft>=0.18.0", "st_attn==0.0.7", "vsa==0.0.4", "runai_model_streamer>=0.15.5", @@ -113,7 +114,7 @@ srt_musa = [ "sglang[runtime_common]", "torch", "torch_musa", - "torchada>=0.1.25", + "torchada>=0.1.45", "mthreads-ml-py", "numpy<2.0", ] diff --git a/python/sglang/_triton_stub.py b/python/sglang/_triton_stub.py index d78cdfb01f03..b2e252bf1860 100644 --- a/python/sglang/_triton_stub.py +++ b/python/sglang/_triton_stub.py @@ -125,25 +125,23 @@ class _TritonFinder: ``triton.*`` sub-module that isn't already in ``sys.modules``. """ - def find_module(self, fullname, path=None): + def find_spec(self, fullname, path=None, target=None): + """PEP 451 meta-path finder for ``triton.*`` sub-modules.""" if fullname == "triton" or fullname.startswith("triton."): - return self + if fullname in sys.modules: + return getattr(sys.modules[fullname], "__spec__", None) + # Create and register the mock so the import machinery finds it + mod = _MockModule(fullname) + sys.modules[fullname] = mod + parts = fullname.rsplit(".", 1) + if len(parts) == 2: + parent_name, child_name = parts + parent = sys.modules.get(parent_name) + if parent is not None: + setattr(parent, child_name, mod) + return mod.__spec__ return None - def load_module(self, fullname): - if fullname in sys.modules: - return sys.modules[fullname] - mod = _MockModule(fullname) - sys.modules[fullname] = mod - # Wire up the parent relationship - parts = fullname.rsplit(".", 1) - if len(parts) == 2: - parent_name, child_name = parts - parent = sys.modules.get(parent_name) - if parent is not None: - setattr(parent, child_name, mod) - return mod - def _make_mock(name: str) -> _MockModule: """Create a ``_MockModule`` and register it in ``sys.modules``.""" diff --git a/python/sglang/bench_one_batch.py b/python/sglang/bench_one_batch.py index 31244c8851c4..b462bb7d2c26 100644 --- a/python/sglang/bench_one_batch.py +++ b/python/sglang/bench_one_batch.py @@ -66,6 +66,7 @@ from sglang.srt.configs.model_config import ModelConfig from sglang.srt.distributed.parallel_state import destroy_distributed_environment from sglang.srt.entrypoints.engine import _set_envs_and_config +from sglang.srt.layers.dp_attention import get_attention_tp_size from sglang.srt.layers.moe import initialize_moe_config from sglang.srt.layers.quantization.fp4_utils import initialize_fp4_gemm_config from sglang.srt.layers.quantization.fp8_utils import initialize_fp8_gemm_config @@ -453,7 +454,8 @@ def _maybe_prepare_mlp_sync_batch(batch: ScheduleBatch, model_runner): prepare_mlp_sync_batch_raw( batch, dp_size=model_runner.server_args.dp_size, - attn_tp_size=1, + attn_tp_size=get_attention_tp_size(), + attn_cp_size=model_runner.attn_cp_size, tp_group=model_runner.tp_group, get_idle_batch=None, disable_cuda_graph=model_runner.server_args.disable_cuda_graph, diff --git a/python/sglang/bench_serving.py b/python/sglang/bench_serving.py index 4e79ece7e3bc..161b4c3be1d6 100644 --- a/python/sglang/bench_serving.py +++ b/python/sglang/bench_serving.py @@ -44,6 +44,7 @@ remove_prefix, set_ulimit, ) +from sglang.srt.utils.network import NetworkAddress _ROUTING_KEY_HEADER = "X-SMG-Routing-Key" @@ -1401,7 +1402,7 @@ async def limited_request_func(request_func_input, pbar): if "sglang" in backend: server_info = requests.get( - base_url + "/get_server_info", headers=get_auth_headers() + base_url + "/server_info", headers=get_auth_headers() ) if server_info.status_code == 200: server_info_json = server_info.json() @@ -1537,7 +1538,7 @@ async def limited_request_func(request_func_input, pbar): print("{:<40} {:<10.2f}".format("Max ITL (ms):", metrics.max_itl_ms)) print("=" * 50) - resp = requests.get(base_url + "/get_server_info", headers=get_auth_headers()) + resp = requests.get(base_url + "/server_info", headers=get_auth_headers()) server_info = resp.json() if resp.status_code == 200 else None if ( @@ -1726,10 +1727,16 @@ def run_benchmark(args_: argparse.Namespace): "truss": 8080, }.get(args.backend, 30000) + # Build base URL with proper IPv6 bracket wrapping (only when base_url is not provided) + if not args.base_url: + _na = NetworkAddress(args.host, args.port) + _host_base = _na.to_url() + else: + _na = None + _host_base = None + model_url = ( - f"{args.base_url}/v1/models" - if args.base_url - else f"http://{args.host}:{args.port}/v1/models" + f"{args.base_url}/v1/models" if args.base_url else f"{_host_base}/v1/models" ) if args.backend == "sglang-embedding": @@ -1740,43 +1747,39 @@ def run_benchmark(args_: argparse.Namespace): ) elif args.backend in ["sglang", "sglang-native"]: api_url = ( - f"{args.base_url}/generate" - if args.base_url - else f"http://{args.host}:{args.port}/generate" + f"{args.base_url}/generate" if args.base_url else f"{_host_base}/generate" ) elif args.backend in ["sglang-oai", "vllm", "lmdeploy"]: api_url = ( f"{args.base_url}/v1/completions" if args.base_url - else f"http://{args.host}:{args.port}/v1/completions" + else f"{_host_base}/v1/completions" ) elif args.backend in ["sglang-oai-chat", "vllm-chat", "lmdeploy-chat"]: api_url = ( f"{args.base_url}/v1/chat/completions" if args.base_url - else f"http://{args.host}:{args.port}/v1/chat/completions" + else f"{_host_base}/v1/chat/completions" ) elif args.backend == "trt": api_url = ( f"{args.base_url}/v2/models/ensemble/generate_stream" if args.base_url - else f"http://{args.host}:{args.port}/v2/models/ensemble/generate_stream" + else f"{_host_base}/v2/models/ensemble/generate_stream" ) if args.model is None: print("Please provide a model using `--model` when using `trt` backend.") sys.exit(1) elif args.backend == "gserver": - api_url = args.base_url if args.base_url else f"{args.host}:{args.port}" + api_url = args.base_url if args.base_url else _na.to_host_port_str() args.model = args.model or "default" elif args.backend == "truss": api_url = ( f"{args.base_url}/v1/models/model:predict" if args.base_url - else f"http://{args.host}:{args.port}/v1/models/model:predict" + else f"{_host_base}/v1/models/model:predict" ) - base_url = ( - f"http://{args.host}:{args.port}" if args.base_url is None else args.base_url - ) + base_url = _host_base if args.base_url is None else args.base_url # Wait for server to be ready if args.ready_check_timeout_sec > 0: @@ -1937,6 +1940,7 @@ def __call__(self, parser, namespace, values, option_string=None): "mmmu", "image", "mooncake", + "longbench_v2", ], help="Name of the dataset to benchmark on.", ) diff --git a/python/sglang/benchmark/datasets/__init__.py b/python/sglang/benchmark/datasets/__init__.py index 63612d52e414..615e3a2419e4 100644 --- a/python/sglang/benchmark/datasets/__init__.py +++ b/python/sglang/benchmark/datasets/__init__.py @@ -6,6 +6,7 @@ GeneratedSharedPrefixDataset, ) from sglang.benchmark.datasets.image import ImageDataset +from sglang.benchmark.datasets.longbench_v2 import LongBenchV2Dataset from sglang.benchmark.datasets.mmmu import MMMUDataset from sglang.benchmark.datasets.mooncake import MooncakeDataset from sglang.benchmark.datasets.openai_dataset import OpenAIDataset @@ -24,6 +25,7 @@ "mmmu": MMMUDataset, "image": ImageDataset, "mooncake": MooncakeDataset, + "longbench_v2": LongBenchV2Dataset, } diff --git a/python/sglang/benchmark/datasets/image.py b/python/sglang/benchmark/datasets/image.py index a32576b376a5..160c319901ff 100644 --- a/python/sglang/benchmark/datasets/image.py +++ b/python/sglang/benchmark/datasets/image.py @@ -148,15 +148,18 @@ def create_mm_data_row( # Vision tokens = total tokens - text tokens vision_prompt_len = prompt_len - text_prompt_len - use_raw_prompt = backend in [ - "sglang", - "sglang-oai", - "sglang-oai-chat", - "vllm", - "vllm-chat", - "lmdeploy", - "lmdeploy-chat", - ] + supported_backends = ["sglang", "sglang-native", "sglang-oai-chat"] + if backend not in supported_backends: + raise ValueError( + f"Image dataset only supports backends: {supported_backends}, " + f"got '{backend}'." + ) + + # sglang-oai-chat: server's chat handler applies chat template, so send raw text. + # sglang/sglang-native: /generate does not apply chat template, so send prompt_str + # which contains image placeholder tokens needed by the multimodal processor. + use_raw_prompt = backend == "sglang-oai-chat" + return DatasetRow( prompt=text_prompt if use_raw_prompt else prompt_str, prompt_len=prompt_len, diff --git a/python/sglang/benchmark/datasets/longbench_v2.py b/python/sglang/benchmark/datasets/longbench_v2.py new file mode 100644 index 000000000000..e8a64295798f --- /dev/null +++ b/python/sglang/benchmark/datasets/longbench_v2.py @@ -0,0 +1,104 @@ +import random +from argparse import Namespace +from dataclasses import dataclass +from typing import List, Optional + +from transformers import PreTrainedTokenizerBase + +from sglang.benchmark.datasets.common import BaseDataset, DatasetRow + +LONGBENCH_V2_REPO_ID = "THUDM/LongBench-v2" +LONGBENCH_V2_DEFAULT_OUTPUT_LEN = 10 # answer letter + short explanation + + +def _format_prompt(example: dict) -> str: + return ( + f"{example['context']}\n\n" + f"Question: {example['question']}\n" + f"A. {example['choice_A']}\n" + f"B. {example['choice_B']}\n" + f"C. {example['choice_C']}\n" + f"D. {example['choice_D']}\n" + f"Answer:" + ) + + +@dataclass +class LongBenchV2Dataset(BaseDataset): + dataset_path: str + num_requests: int + fixed_output_len: Optional[int] + context_len: Optional[int] + + @classmethod + def from_args(cls, args: Namespace) -> "LongBenchV2Dataset": + return cls( + dataset_path=args.dataset_path, + num_requests=args.num_prompts, + fixed_output_len=args.sharegpt_output_len, + context_len=args.sharegpt_context_len, + ) + + def load( + self, tokenizer: PreTrainedTokenizerBase, model_id=None + ) -> List[DatasetRow]: + return sample_longbench_v2_requests( + dataset_path=self.dataset_path, + num_requests=self.num_requests, + tokenizer=tokenizer, + fixed_output_len=self.fixed_output_len, + context_len=self.context_len, + ) + + +def sample_longbench_v2_requests( + dataset_path: str, + num_requests: int, + tokenizer: PreTrainedTokenizerBase, + fixed_output_len: Optional[int] = None, + context_len: Optional[int] = None, +) -> List[DatasetRow]: + output_len = ( + fixed_output_len + if fixed_output_len is not None + else LONGBENCH_V2_DEFAULT_OUTPUT_LEN + ) + + # Load dataset + if dataset_path: + # Local file (parquet or JSON lines) + import pandas as pd + + if dataset_path.endswith(".parquet"): + df = pd.read_parquet(dataset_path) + examples = df.to_dict(orient="records") + else: + import json + + with open(dataset_path) as f: + examples = [json.loads(line) for line in f if line.strip()] + else: + from datasets import load_dataset + + ds = load_dataset(LONGBENCH_V2_REPO_ID, split="train") + examples = list(ds) + + random.shuffle(examples) + + rows: List[DatasetRow] = [] + for example in examples: + if len(rows) >= num_requests: + break + + prompt = _format_prompt(example) + prompt_ids = tokenizer(prompt).input_ids + prompt_len = len(prompt_ids) + + if context_len is not None and prompt_len + output_len > context_len: + continue + + rows.append( + DatasetRow(prompt=prompt, prompt_len=prompt_len, output_len=output_len) + ) + + return rows diff --git a/python/sglang/check_env.py b/python/sglang/check_env.py index 516e90237e65..a9a8a62b207a 100644 --- a/python/sglang/check_env.py +++ b/python/sglang/check_env.py @@ -194,22 +194,12 @@ def _get_cuda_driver_version(self): """ Get CUDA driver version. """ - versions = set() - try: - output = subprocess.check_output( - [ - "nvidia-smi", - "--query-gpu=driver_version", - "--format=csv,noheader,nounits", - ] - ) - versions = set(output.decode().strip().split("\n")) - if len(versions) == 1: - return {"CUDA Driver Version": versions.pop()} - else: - return {"CUDA Driver Versions": ", ".join(sorted(versions))} - except subprocess.SubprocessError: + from sglang.srt.utils.common import get_nvidia_driver_version_str + + ver = get_nvidia_driver_version_str() + if ver is None: return {"CUDA Driver Version": "Not Available"} + return {"CUDA Driver Version": ver} def get_topology(self): """ diff --git a/python/sglang/cli/killall.py b/python/sglang/cli/killall.py index 39a10f053aed..629284e5bffc 100755 --- a/python/sglang/cli/killall.py +++ b/python/sglang/cli/killall.py @@ -77,19 +77,25 @@ def _run_smi(query, query_type="gpu"): def _get_smi_version(): - """Return nvidia-smi driver version and CUDA version, or None on failure.""" + """Return nvidia-smi driver version and GPU name, or None on failure.""" + # Inline nvidia-smi query — killall.py runs before pip install, so sglang + # internals may not be importable. try: - out = subprocess.check_output( + result = subprocess.run( [ "nvidia-smi", "--query-gpu=driver_version", "--format=csv,noheader,nounits", ], + capture_output=True, text=True, + check=True, timeout=10, ) - driver = out.strip().splitlines()[0].strip() if out.strip() else "unknown" - except (subprocess.SubprocessError, FileNotFoundError, IndexError): + driver = result.stdout.strip().split("\n")[0].strip() or None + except (subprocess.SubprocessError, FileNotFoundError): + driver = None + if driver is None: return None try: out = subprocess.check_output( diff --git a/python/sglang/cli/utils.py b/python/sglang/cli/utils.py index 1d867eccf08b..60fb10d9220d 100644 --- a/python/sglang/cli/utils.py +++ b/python/sglang/cli/utils.py @@ -5,10 +5,24 @@ from functools import lru_cache from sglang.srt.environ import envs +from sglang.utils import ( + has_diffusion_overlay_registry_match, + is_known_non_diffusers_diffusion_model, + load_diffusion_overlay_registry_from_env, +) logger = logging.getLogger(__name__) +@lru_cache(maxsize=1) +def _load_overlay_registry() -> dict: + return load_diffusion_overlay_registry_from_env() + + +def _is_overlay_diffusion_model(model_path: str) -> bool: + return has_diffusion_overlay_registry_match(model_path, _load_overlay_registry()) + + def _is_diffusers_model_dir(model_dir: str) -> bool: """Check if a local directory contains a valid diffusers model_index.json.""" config_path = os.path.join(model_dir, "model_index.json") @@ -29,19 +43,16 @@ def get_is_diffusion_model(model_path: str) -> bool: Returns False on any failure (network error, 404, offline mode, etc.) so that the caller falls through to the standard LLM server path. """ - try: - from sglang.multimodal_gen.registry import ( - is_known_non_diffusers_multimodal_model, - ) - except ImportError: - is_known_non_diffusers_multimodal_model = lambda _: False + if _is_overlay_diffusion_model(model_path): + # short-circuit, if applicable for the overlay mechanism (diffusion-only) + return True if os.path.isdir(model_path): if _is_diffusers_model_dir(model_path): return True - return is_known_non_diffusers_multimodal_model(model_path) + return is_known_non_diffusers_diffusion_model(model_path) - if is_known_non_diffusers_multimodal_model(model_path): + if is_known_non_diffusers_diffusion_model(model_path): return True try: diff --git a/python/sglang/jit_kernel/all_reduce.py b/python/sglang/jit_kernel/all_reduce.py index dd02100822d3..48f763259beb 100644 --- a/python/sglang/jit_kernel/all_reduce.py +++ b/python/sglang/jit_kernel/all_reduce.py @@ -95,7 +95,7 @@ def config_pull( def _jit_custom_all_reduce_pull_module(dtype: torch.dtype, world_size: int): args = make_cpp_args(dtype, world_size, is_arch_support_pdl()) return load_jit( - "custom_all_reduce", + "custom_all_reduce_pull", *args, extra_ldflags=["-lcuda"], cuda_files=["distributed/custom_all_reduce_pull.cuh"], @@ -107,7 +107,7 @@ def _jit_custom_all_reduce_pull_module(dtype: torch.dtype, world_size: int): def _jit_custom_all_reduce_push_module(dtype: torch.dtype, world_size: int): args = make_cpp_args(dtype, world_size, is_arch_support_pdl()) return load_jit( - "custom_all_reduce", + "custom_all_reduce_push", *args, extra_ldflags=["-lcuda"], cuda_files=["distributed/custom_all_reduce_push.cuh"], diff --git a/python/sglang/jit_kernel/benchmark/bench_cast.py b/python/sglang/jit_kernel/benchmark/bench_cast.py index 18dbbf726f99..97c71bcb01d2 100644 --- a/python/sglang/jit_kernel/benchmark/bench_cast.py +++ b/python/sglang/jit_kernel/benchmark/bench_cast.py @@ -1,7 +1,6 @@ import torch import triton import triton.testing -from sgl_kernel import downcast_fp8 as downcast_fp8_aot from sglang.jit_kernel.benchmark.utils import ( DEFAULT_DEVICE, @@ -9,6 +8,9 @@ run_benchmark, ) from sglang.jit_kernel.cast import downcast_fp8 as downcast_fp8_jit +from sglang.test.ci.ci_register import register_cuda_ci + +register_cuda_ci(est_time=10, suite="stage-b-kernel-benchmark-1-gpu-large") DEVICE = DEFAULT_DEVICE DTYPE = torch.bfloat16 @@ -28,9 +30,9 @@ CONFIGS = [(sl, h, d, sl * 2) for sl in SL_LIST for h, d in HEAD_DIM_LIST] -LINE_VALS = ["aot", "jit"] -LINE_NAMES = ["AOT (sgl-kernel)", "JIT (cast.cuh, 256 threads, 2D grid)"] -STYLES = [("blue", "--"), ("orange", "-")] +LINE_VALS = ["jit"] +LINE_NAMES = ["JIT (cast.cuh, 256 threads, 2D grid)"] +STYLES = [("orange", "-")] # ── Perf report ──────────────────────────────────────────────────────────────── @@ -45,7 +47,7 @@ line_names=LINE_NAMES, styles=STYLES, ylabel="us", - plot_name="downcast-fp8-aot-vs-jit", + plot_name="downcast-fp8-jit", args={}, ) ) @@ -58,10 +60,7 @@ def benchmark(input_sl, head, dim, out_sl, provider): v_scale = torch.tensor([1.0], dtype=torch.float32, device=DEVICE) loc = torch.arange(input_sl, dtype=torch.int64, device=DEVICE) - if provider == "aot": - fn = lambda: downcast_fp8_aot(k, v, k_out, v_out, k_scale, v_scale, loc) - else: - fn = lambda: downcast_fp8_jit(k, v, k_out, v_out, k_scale, v_scale, loc) + fn = lambda: downcast_fp8_jit(k, v, k_out, v_out, k_scale, v_scale, loc) return run_benchmark(fn) @@ -81,26 +80,19 @@ def _report_bandwidth(input_sl, head, dim, dtype): v_scale = torch.tensor([1.0], dtype=torch.float32, device=DEVICE) loc = torch.arange(input_sl, dtype=torch.int64, device=DEVICE) - aot_fn = lambda: downcast_fp8_aot(k, v, k_out, v_out, k_scale, v_scale, loc) jit_fn = lambda: downcast_fp8_jit(k, v, k_out, v_out, k_scale, v_scale, loc) - aot_ms, _, _ = triton.testing.do_bench(aot_fn, quantiles=[0.5, 0.2, 0.8]) jit_ms, _, _ = triton.testing.do_bench(jit_fn, quantiles=[0.5, 0.2, 0.8]) def fmt(ms): return f"{ms*1000:6.2f}us {total_bytes/(ms*1e-3)/1e9:6.0f}GB/s" - print( - f" sl={input_sl:5d} h={head:2d} d={dim:4d}" - f" | aot {fmt(aot_ms)}" - f" | jit {fmt(jit_ms)}" - f" | speedup {aot_ms/jit_ms:.2f}x" - ) + print(f" sl={input_sl:5d} h={head:2d} d={dim:4d}" f" | jit {fmt(jit_ms)}") def report_bandwidth(): print(f"\n{'='*95}") - print(" AOT (sgl-kernel) vs JIT (cast.cuh, 256 threads, 2D grid)") + print(" JIT (cast.cuh, 256 threads, 2D grid)") print(f" dtype={DTYPE}, device={DEVICE}") print(f"{'='*95}") for sl in [64, 256, 1024, 2048]: diff --git a/python/sglang/jit_kernel/benchmark/bench_fused_add_rmsnorm.py b/python/sglang/jit_kernel/benchmark/bench_fused_add_rmsnorm.py deleted file mode 100644 index a842be84b72b..000000000000 --- a/python/sglang/jit_kernel/benchmark/bench_fused_add_rmsnorm.py +++ /dev/null @@ -1,75 +0,0 @@ -import itertools - -import torch -import triton -import triton.testing -from flashinfer import fused_add_rmsnorm as fi_fused_add_rmsnorm - -from sglang.jit_kernel.benchmark.utils import run_benchmark -from sglang.jit_kernel.norm import fused_add_rmsnorm as jit_fused_add_rmsnorm -from sglang.test.ci.ci_register import register_cuda_ci -from sglang.utils import is_in_ci - -register_cuda_ci(est_time=6, suite="stage-b-kernel-benchmark-1-gpu-large") - -IS_CI = is_in_ci() - - -def sglang_jit_fused_add_rmsnorm( - input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, eps: float -) -> None: - jit_fused_add_rmsnorm(input, residual, weight, eps) - - -def flashinfer_fused_add_rmsnorm( - input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, eps: float -) -> None: - fi_fused_add_rmsnorm(input, residual, weight, eps=eps) - - -DTYPE = torch.bfloat16 -DEVICE = "cuda" - -if IS_CI: - BS_LIST = [16] - HIDDEN_SIZE_LIST = [512, 2048] -else: - BS_LIST = [2**n for n in range(0, 14)] - HIDDEN_SIZE_LIST = [1536, 3072, 4096, 5120, 8192] - -LINE_VALS = ["jit", "flashinfer"] -LINE_NAMES = ["SGL JIT Kernel", "FlashInfer"] -STYLES = [("orange", "-"), ("blue", "--"), ("green", "-."), ("red", ":")] - -configs = list(itertools.product(HIDDEN_SIZE_LIST, BS_LIST)) - - -@triton.testing.perf_report( - triton.testing.Benchmark( - x_names=["hidden_size", "batch_size"], - x_vals=configs, - line_arg="provider", - line_vals=LINE_VALS, - line_names=LINE_NAMES, - styles=STYLES, - ylabel="us", - plot_name="fused-add-rmsnorm-performance", - args={}, - ) -) -def benchmark(hidden_size: int, batch_size: int, provider: str): - input = torch.randn((batch_size, hidden_size), dtype=DTYPE, device=DEVICE) - residual = torch.randn((batch_size, hidden_size), dtype=DTYPE, device=DEVICE) - weight = torch.randn(hidden_size, dtype=DTYPE, device=DEVICE) - FN_MAP = { - "jit": sglang_jit_fused_add_rmsnorm, - "flashinfer": flashinfer_fused_add_rmsnorm, - } - fn = lambda: FN_MAP[provider]( - input.clone(), residual.clone(), weight, torch.finfo(torch.bfloat16).eps - ) - return run_benchmark(fn) - - -if __name__ == "__main__": - benchmark.run(print_data=True) diff --git a/python/sglang/jit_kernel/benchmark/bench_fused_qknorm_rope.py b/python/sglang/jit_kernel/benchmark/bench_fused_qknorm_rope.py index e6ef6bc77a36..905d41d5d6b0 100644 --- a/python/sglang/jit_kernel/benchmark/bench_fused_qknorm_rope.py +++ b/python/sglang/jit_kernel/benchmark/bench_fused_qknorm_rope.py @@ -1,8 +1,8 @@ """ Benchmark: fused_qknorm_rope JIT vs AOT (sgl_kernel) -Measures throughput (us) for fused_qk_norm_rope across typical -LLM configurations (head_dim x num_heads x num_tokens). +Measures throughput (µs) for fused_qk_norm_rope across typical +LLM configurations (head_dim Ɨ num_heads Ɨ num_tokens). Run: python python/sglang/jit_kernel/benchmark/bench_fused_qknorm_rope.py @@ -18,6 +18,9 @@ from sglang.jit_kernel.fused_qknorm_rope import ( fused_qk_norm_rope as fused_qk_norm_rope_jit, ) +from sglang.test.ci.ci_register import register_cuda_ci + +register_cuda_ci(est_time=6, suite="stage-b-kernel-benchmark-1-gpu-large") try: from sgl_kernel import fused_qk_norm_rope as fused_qk_norm_rope_aot @@ -36,7 +39,7 @@ ci_range=[64, 512], ) -# (head_dim, num_heads_q, num_heads_k, num_heads_v) - typical MoE/dense configs +# (head_dim, num_heads_q, num_heads_k, num_heads_v) — typical MoE/dense configs MODEL_CONFIGS = get_benchmark_range( full_range=[ (64, 32, 8, 8), # small @@ -46,6 +49,16 @@ ci_range=[(128, 32, 8, 8)], ) +# Real production shapes (self-attention; num_heads_k == num_heads_v == num_heads_q). +# Format: (name, num_tokens, num_heads_q, num_heads_k, num_heads_v, head_dim, rotary_dim) +PRODUCTION_SHAPES = [ + ("flux_1024", 4096, 24, 24, 24, 128, 128), + ("qwen_image_1024", 4096, 32, 32, 32, 128, 128), + ("qwen_image_partial", 4096, 32, 32, 32, 128, 64), + ("zimage_1024", 4096, 30, 30, 30, 128, 128), + ("batch2_medium", 4096, 24, 24, 24, 128, 128), # B=2, T=2048 +] + LINE_VALS = ["jit", "aot"] if AOT_AVAILABLE else ["jit"] LINE_NAMES = ["JIT (new)", "AOT sgl_kernel"] if AOT_AVAILABLE else ["JIT (new)"] STYLES = [("blue", "--"), ("orange", "-")] if AOT_AVAILABLE else [("blue", "--")] @@ -120,6 +133,75 @@ def bench_fused_qknorm_rope( return run_benchmark(fn) +# --------------------------------------------------------------------------- +# Benchmark: fused_qk_norm_rope — real production shapes (with speedup column) +# --------------------------------------------------------------------------- + + +def bench_fused_qknorm_rope_production(): + device = "cuda" + header = f"{'name':<22} {'tokens':>6} {'nq':>4} {'nk':>4} {'nv':>4} {'hd':>4} {'rdim':>5} {'JIT(us)':>9} {'AOT(us)':>9} {'speedup':>8}" + sep = "-" * len(header) + print("\nfused-qknorm-rope-production-shapes:") + print(sep) + print(header) + print(sep) + + for ( + name, + num_tokens, + num_heads_q, + num_heads_k, + num_heads_v, + head_dim, + rotary_dim, + ) in PRODUCTION_SHAPES: + total_heads = num_heads_q + num_heads_k + num_heads_v + qkv = torch.randn( + (num_tokens, total_heads * head_dim), dtype=torch.bfloat16, device=device + ) + q_weight = torch.ones(head_dim, dtype=torch.bfloat16, device=device) + k_weight = torch.ones(head_dim, dtype=torch.bfloat16, device=device) + position_ids = torch.arange(num_tokens, dtype=torch.int32, device=device) + + common_kwargs = dict( + num_heads_q=num_heads_q, + num_heads_k=num_heads_k, + num_heads_v=num_heads_v, + head_dim=head_dim, + eps=1e-5, + q_weight=q_weight, + k_weight=k_weight, + base=10000.0, + is_neox=False, + position_ids=position_ids, + factor=1.0, + low=1.0, + high=32.0, + attention_factor=1.0, + rotary_dim=rotary_dim, + ) + + jit_us, _, _ = run_benchmark( + lambda: fused_qk_norm_rope_jit(qkv.clone(), **common_kwargs) + ) + if AOT_AVAILABLE: + aot_us, _, _ = run_benchmark( + lambda: fused_qk_norm_rope_aot(qkv.clone(), **common_kwargs) + ) + speedup = f"{aot_us / jit_us:.2f}x" + aot_str = f"{aot_us:9.3f}" + else: + aot_str = f"{'N/A':>9}" + speedup = "N/A" + + print( + f"{name:<22} {num_tokens:>6} {num_heads_q:>4} {num_heads_k:>4} {num_heads_v:>4}" + f" {head_dim:>4} {rotary_dim:>5} {jit_us:9.3f} {aot_str} {speedup:>8}" + ) + print(sep) + + # --------------------------------------------------------------------------- # Quick correctness diff # --------------------------------------------------------------------------- @@ -127,7 +209,7 @@ def bench_fused_qknorm_rope( def calculate_diff(): if not AOT_AVAILABLE: - print("sgl_kernel not available - skipping AOT diff check") + print("sgl_kernel not available — skipping AOT diff check") return device = "cuda" @@ -181,3 +263,5 @@ def calculate_diff(): calculate_diff() print() bench_fused_qknorm_rope.run(print_data=True) + print() + bench_fused_qknorm_rope_production() diff --git a/python/sglang/jit_kernel/benchmark/bench_norm.py b/python/sglang/jit_kernel/benchmark/bench_norm.py index d046ecf2d2a8..345388ef7fc3 100644 --- a/python/sglang/jit_kernel/benchmark/bench_norm.py +++ b/python/sglang/jit_kernel/benchmark/bench_norm.py @@ -6,40 +6,39 @@ from flashinfer.norm import fused_add_rmsnorm as fi_fused_add_rmsnorm from flashinfer.norm import rmsnorm as fi_rmsnorm -from sglang.jit_kernel.benchmark.utils import run_benchmark +from sglang.jit_kernel.benchmark.utils import get_benchmark_range, run_benchmark from sglang.jit_kernel.norm import fused_add_rmsnorm as jit_fused_add_rmsnorm from sglang.jit_kernel.norm import rmsnorm as jit_rmsnorm from sglang.test.ci.ci_register import register_cuda_ci -from sglang.utils import is_in_ci -register_cuda_ci(est_time=5, suite="stage-b-kernel-benchmark-1-gpu-large") +register_cuda_ci(est_time=30, suite="stage-b-kernel-benchmark-1-gpu-large") -IS_CI = is_in_ci() DTYPE = torch.bfloat16 DEVICE = "cuda" -# JIT rmsnorm: hidden_size in {64,128,256} or (multiple of 256, <=8192) -# JIT fused_add_rmsnorm: hidden_size % 8 == 0, <=8192 -# Use multiples of 256 <=8192 to satisfy both kernels -if IS_CI: - BS_LIST = [16] - HIDDEN_SIZE_LIST = [512, 2048] -else: - BS_LIST = [2**n for n in range(0, 14)] - HIDDEN_SIZE_LIST = [1536, 3072, 4096, 5120, 8192] - -LINE_VALS = ["jit", "flashinfer"] -LINE_NAMES = ["SGL JIT Kernel", "FlashInfer"] +BS_LIST = get_benchmark_range( + full_range=[2**n for n in range(0, 14)], + ci_range=[16, 32], +) +HIDDEN_SIZE_LIST = get_benchmark_range( + full_range=sorted([1536, *range(1024, 8192 + 1, 1024)]), + ci_range=[512, 2048], +) + +LINE_VALS = ["flashinfer", "jit"] +LINE_NAMES = ["FlashInfer", "SGL JIT Kernel"] STYLES = [("blue", "--"), ("green", "-.")] +NUM_LAYERS = 4 # avoid L2 effect -configs = list(itertools.product(HIDDEN_SIZE_LIST, BS_LIST)) +configs_0 = list(itertools.product(HIDDEN_SIZE_LIST + [16384], BS_LIST)) +configs_1 = list(itertools.product(HIDDEN_SIZE_LIST, BS_LIST)) @triton.testing.perf_report( triton.testing.Benchmark( x_names=["hidden_size", "batch_size"], - x_vals=configs, + x_vals=configs_0, line_arg="provider", line_vals=LINE_VALS, line_names=LINE_NAMES, @@ -50,20 +49,24 @@ ) ) def benchmark_rmsnorm(hidden_size: int, batch_size: int, provider: str): - input = torch.randn((batch_size, hidden_size), dtype=DTYPE, device=DEVICE) - weight = torch.randn(hidden_size, dtype=DTYPE, device=DEVICE) - FN_MAP = { - "jit": lambda: jit_rmsnorm(input.clone(), weight), - "flashinfer": lambda: fi_rmsnorm(input.clone(), weight, out=input.clone()), - } - fn = FN_MAP[provider] - return run_benchmark(fn) + input = torch.randn( + (NUM_LAYERS, batch_size, hidden_size), dtype=DTYPE, device=DEVICE + ) + weight = torch.randn((NUM_LAYERS, hidden_size), dtype=DTYPE, device=DEVICE) + FN_MAP = {"jit": jit_rmsnorm, "flashinfer": fi_rmsnorm} + + def f(): + fn = FN_MAP[provider] + for i in range(NUM_LAYERS): + fn(input[i], weight[i], out=input[i]) + + return run_benchmark(f, scale=NUM_LAYERS) @triton.testing.perf_report( triton.testing.Benchmark( x_names=["hidden_size", "batch_size"], - x_vals=configs, + x_vals=configs_1, line_arg="provider", line_vals=LINE_VALS, line_names=LINE_NAMES, @@ -74,19 +77,19 @@ def benchmark_rmsnorm(hidden_size: int, batch_size: int, provider: str): ) ) def benchmark_fused_add_rmsnorm(hidden_size: int, batch_size: int, provider: str): - input = torch.randn((batch_size, hidden_size), dtype=DTYPE, device=DEVICE) - residual = torch.randn((batch_size, hidden_size), dtype=DTYPE, device=DEVICE) - weight = torch.randn(hidden_size, dtype=DTYPE, device=DEVICE) - FN_MAP = { - "jit": lambda: jit_fused_add_rmsnorm( - input.clone(), residual.clone(), weight, torch.finfo(DTYPE).eps - ), - "flashinfer": lambda: fi_fused_add_rmsnorm( - input.clone(), residual.clone(), weight, eps=torch.finfo(DTYPE).eps - ), - } - fn = FN_MAP[provider] - return run_benchmark(fn) + input = torch.randn( + (NUM_LAYERS, batch_size, hidden_size), dtype=DTYPE, device=DEVICE + ) + residual = torch.randn_like(input) + weight = torch.randn((NUM_LAYERS, hidden_size), dtype=DTYPE, device=DEVICE) + FN_MAP = {"jit": jit_fused_add_rmsnorm, "flashinfer": fi_fused_add_rmsnorm} + + def f(): + fn = FN_MAP[provider] + for i in range(NUM_LAYERS): + fn(input[i], residual[i], weight[i]) + + return run_benchmark(f, scale=NUM_LAYERS) if __name__ == "__main__": diff --git a/python/sglang/jit_kernel/benchmark/bench_nvfp4_scaled_mm.py b/python/sglang/jit_kernel/benchmark/bench_nvfp4_scaled_mm.py index f7af1e7c5100..4278f0348260 100644 --- a/python/sglang/jit_kernel/benchmark/bench_nvfp4_scaled_mm.py +++ b/python/sglang/jit_kernel/benchmark/bench_nvfp4_scaled_mm.py @@ -7,7 +7,7 @@ from sglang.jit_kernel.benchmark.utils import get_benchmark_range, run_benchmark from sglang.jit_kernel.nvfp4 import cutlass_scaled_fp4_mm, scaled_fp4_quant -from sglang.srt.utils import is_sm100_supported +from sglang.srt.utils import is_sm100_supported, is_sm120_supported from sglang.test.ci.ci_register import register_cuda_ci register_cuda_ci(est_time=5, suite="stage-b-kernel-benchmark-1-gpu-large") @@ -15,7 +15,7 @@ FLOAT4_E2M1_MAX = 6.0 FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max BLOCK_SIZE = 16 -_NVFP4_SUPPORTED = is_sm100_supported() +_NVFP4_SUPPORTED = is_sm100_supported() or is_sm120_supported() K_E2M1_TO_FLOAT = [ 0.0, @@ -178,7 +178,7 @@ def benchmark(m, n, k, provider): if __name__ == "__main__": if not _NVFP4_SUPPORTED: - print("[skip] NVFP4 scaled_mm benchmark requires sm100+ with CUDA 12.8+.") + print("[skip] NVFP4 scaled_mm benchmark requires sm100/sm120 with CUDA 12.8+.") sys.exit(0) if not _AOT_SCALED_MM_AVAILABLE: print( diff --git a/python/sglang/jit_kernel/benchmark/bench_renorm.py b/python/sglang/jit_kernel/benchmark/bench_renorm.py index cd4ab36b4326..f65a615ac194 100644 --- a/python/sglang/jit_kernel/benchmark/bench_renorm.py +++ b/python/sglang/jit_kernel/benchmark/bench_renorm.py @@ -82,31 +82,6 @@ def torch_top_p_renorm_probs(probs, top_p, eps=1e-5): return renorm_probs -def torch_top_k_mask_logits(logits, top_k): - """Vectorized PyTorch implementation of top-k logits masking.""" - batch_size, vocab_size = logits.shape - - # Handle scalar or tensor k - if isinstance(top_k, int): - k_val = min(max(top_k, 1), vocab_size) - # Get top-k indices for all batches at once - _, topk_indices = torch.topk(logits, k_val, dim=1, largest=True) - - # Create masked logits: start with -inf everywhere - masked_logits = torch.full_like(logits, float("-inf")) - # Scatter the top-k values back - masked_logits.scatter_(1, topk_indices, logits.gather(1, topk_indices)) - else: - # Variable k per batch - need to handle separately - masked_logits = torch.full_like(logits, float("-inf")) - for i in range(batch_size): - k_val = min(max(top_k[i].item(), 1), vocab_size) - _, topk_indices = torch.topk(logits[i], k_val, largest=True) - masked_logits[i, topk_indices] = logits[i, topk_indices] - - return masked_logits - - def calculate_diff_top_k_renorm(batch_size, vocab_size, k): """Compare Torch reference and SGLang kernel for top-k renorm correctness.""" torch.manual_seed(42) @@ -139,20 +114,6 @@ def calculate_diff_top_p_renorm(batch_size, vocab_size, p): torch.testing.assert_close(torch_output, sglang_output, rtol=1e-3, atol=1e-3) -def calculate_diff_top_k_mask(batch_size, vocab_size, k): - """Compare Torch reference and SGLang kernel for top-k mask correctness.""" - torch.manual_seed(42) - device = torch.device("cuda") - - logits = torch.randn(batch_size, vocab_size, device=device) * 5 - top_k_tensor = torch.full((batch_size,), k, device=device, dtype=torch.int32) - - torch_output = torch_top_k_mask_logits(logits, top_k_tensor) - sglang_output = sgl_kernel.top_k_mask_logits(logits, top_k_tensor) - - torch.testing.assert_close(torch_output, sglang_output, rtol=1e-3, atol=1e-3) - - # Parameter space - simplified for CI if is_in_ci(): batch_size_range = [16] @@ -231,38 +192,6 @@ def benchmark_top_p_renorm(batch_size, vocab_size, p, provider): return run_benchmark_no_cudagraph(fn) -@triton.testing.perf_report( - triton.testing.Benchmark( - x_names=["batch_size", "vocab_size", "k"], - x_vals=configs_k, - line_arg="provider", - line_vals=["torch", "sglang"], - line_names=["Torch Reference", "SGL Kernel"], - styles=[("red", "-"), ("orange", "-")], - ylabel="us", - plot_name="top-k-mask-logits-performance", - args={}, - ) -) -def benchmark_top_k_mask(batch_size, vocab_size, k, provider): - # Skip invalid configurations - if k >= vocab_size: - return float("nan"), float("nan"), float("nan") - - torch.manual_seed(42) - device = torch.device("cuda") - - logits = torch.randn(batch_size, vocab_size, device=device) * 5 - top_k_tensor = torch.full((batch_size,), k, device=device, dtype=torch.int32) - - if provider == "torch": - fn = lambda: torch_top_k_mask_logits(logits.clone(), top_k_tensor) - elif provider == "sglang": - fn = lambda: sgl_kernel.top_k_mask_logits(logits.clone(), top_k_tensor) - - return run_benchmark_no_cudagraph(fn) - - if __name__ == "__main__": print("=" * 60) print("Running correctness checks...") @@ -291,15 +220,6 @@ def benchmark_top_k_mask(batch_size, vocab_size, k, provider): batch_size, vocab_size, p = cfg print(f" āœ“ Passed: batch_size={batch_size}, vocab_size={vocab_size}, p={p}") - print("\n3. Testing top_k_mask_logits...") - for cfg in test_configs_k: - batch_size, vocab_size, k = cfg - if k < vocab_size: # Skip invalid configs - calculate_diff_top_k_mask(batch_size, vocab_size, k) - print( - f" āœ“ Passed: batch_size={batch_size}, vocab_size={vocab_size}, k={k}" - ) - print("\n" + "=" * 60) print("All correctness checks passed!") print("=" * 60) @@ -314,9 +234,6 @@ def benchmark_top_k_mask(batch_size, vocab_size, k, provider): print("\n2. Benchmarking top_p_renorm_probs...") benchmark_top_p_renorm.run(print_data=True) - print("\n3. Benchmarking top_k_mask_logits...") - benchmark_top_k_mask.run(print_data=True) - print("\n" + "=" * 60) print("Benchmarking complete!") print("=" * 60) diff --git a/python/sglang/jit_kernel/benchmark/bench_rmsnorm.py b/python/sglang/jit_kernel/benchmark/bench_rmsnorm.py deleted file mode 100644 index 779b8ad7e207..000000000000 --- a/python/sglang/jit_kernel/benchmark/bench_rmsnorm.py +++ /dev/null @@ -1,98 +0,0 @@ -import itertools - -import torch -import triton -import triton.testing -from flashinfer import rmsnorm as fi_rmsnorm -from sgl_kernel import rmsnorm - -from sglang.jit_kernel.benchmark.utils import ( - DEFAULT_DEVICE, - DEFAULT_DTYPE, - get_benchmark_range, - run_benchmark, -) -from sglang.jit_kernel.norm import rmsnorm as jit_rmsnorm -from sglang.test.ci.ci_register import register_cuda_ci - -register_cuda_ci(est_time=21, suite="stage-b-kernel-benchmark-1-gpu-large") - - -def sglang_aot_rmsnorm( - input: torch.Tensor, - weight: torch.Tensor, -) -> None: - rmsnorm(input, weight, out=input) - - -def sglang_jit_rmsnorm( - input: torch.Tensor, - weight: torch.Tensor, -) -> None: - jit_rmsnorm(input, weight, output=input) - - -def flashinfer_rmsnorm( - input: torch.Tensor, - weight: torch.Tensor, -) -> None: - fi_rmsnorm(input, weight, out=input) - - -@torch.compile() -def torch_impl_rmsnorm( - input: torch.Tensor, - weight: torch.Tensor, - eps: float = 1e-6, -) -> None: - mean = input.float().pow(2).mean(dim=-1, keepdim=True) - norm = (mean + eps).rsqrt() - input.copy_(input.float() * norm * weight.float()) - - -BS_LIST = get_benchmark_range( - full_range=[2**n for n in range(0, 14)], - ci_range=[16], -) -HIDDEN_SIZE_LIST = get_benchmark_range( - full_range=[1536, 3072, 4096, 5120, 8192], - ci_range=[512, 2048], -) - -LINE_VALS = ["aot", "jit", "flashinfer", "torch"] -LINE_NAMES = ["SGL AOT Kernel", "SGL JIT Kernel", "FlashInfer", "PyTorch"] -STYLES = [("orange", "-"), ("blue", "--"), ("green", "-."), ("red", ":")] - -configs = list(itertools.product(HIDDEN_SIZE_LIST, BS_LIST)) - - -@triton.testing.perf_report( - triton.testing.Benchmark( - x_names=["hidden_size", "batch_size"], - x_vals=configs, - line_arg="provider", - line_vals=LINE_VALS, - line_names=LINE_NAMES, - styles=STYLES, - ylabel="us", - plot_name="rmsnorm-performance", - args={}, - ) -) -def benchmark(hidden_size: int, batch_size: int, provider: str): - input = torch.randn( - (batch_size, hidden_size), dtype=DEFAULT_DTYPE, device=DEFAULT_DEVICE - ) - weight = torch.randn(hidden_size, dtype=DEFAULT_DTYPE, device=DEFAULT_DEVICE) - FN_MAP = { - "aot": sglang_aot_rmsnorm, - "jit": sglang_jit_rmsnorm, - "flashinfer": flashinfer_rmsnorm, - "torch": torch_impl_rmsnorm, - } - fn = lambda: FN_MAP[provider](input.clone(), weight) - return run_benchmark(fn) - - -if __name__ == "__main__": - benchmark.run(print_data=True) diff --git a/python/sglang/jit_kernel/benchmark/diffusion/bench_fused_norm_scale_shift.py b/python/sglang/jit_kernel/benchmark/diffusion/bench_fused_norm_scale_shift.py index ae9ce7ff8cbb..759241a6d970 100644 --- a/python/sglang/jit_kernel/benchmark/diffusion/bench_fused_norm_scale_shift.py +++ b/python/sglang/jit_kernel/benchmark/diffusion/bench_fused_norm_scale_shift.py @@ -18,7 +18,11 @@ from sglang.test.ci.ci_register import register_cuda_ci from sglang.utils import is_in_ci -register_cuda_ci(est_time=17, suite="stage-b-kernel-benchmark-1-gpu-large") +register_cuda_ci( + est_time=17, + suite="stage-b-kernel-benchmark-1-gpu-large", + disabled="Temporarily skipped to unblock flashinfer upgrade. Ref: https://github.com/sgl-project/sglang/actions/runs/23735552939/job/69139238979?pr=21422", +) if is_in_ci(): B_RANGE, S_RANGE, D_RANGE = [1], [128], [1024] diff --git a/python/sglang/jit_kernel/benchmark/utils.py b/python/sglang/jit_kernel/benchmark/utils.py index c17ef4f9a0a1..3bd5e793d945 100644 --- a/python/sglang/jit_kernel/benchmark/utils.py +++ b/python/sglang/jit_kernel/benchmark/utils.py @@ -1,6 +1,6 @@ """Common utilities for jit_kernel benchmark files.""" -from typing import Callable, List, Tuple +from typing import Callable, List, Sequence, Tuple import torch import triton.testing @@ -19,25 +19,30 @@ def get_benchmark_range(full_range: List, ci_range: List) -> List: def run_benchmark( - fn: Callable, quantiles: List[float] = None + fn: Callable, + quantiles: Sequence[float] = (), + scale: float = 1.0, ) -> Tuple[float, float, float]: """Execute benchmark using CUDA graph and return times in microseconds. Args: fn: Function to benchmark quantiles: Quantiles for timing measurements [median, min, max] + scale: Scale the result down (usually num_layers). Returns: Tuple of (median_us, max_us, min_us) """ - quantiles = quantiles or DEFAULT_QUANTILES + quantiles = list(quantiles or DEFAULT_QUANTILES) ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(fn, quantiles=quantiles) - return 1000 * ms, 1000 * max_ms, 1000 * min_ms + return 1000 * ms / scale, 1000 * max_ms / scale, 1000 * min_ms / scale def run_benchmark_no_cudagraph( - fn: Callable, quantiles: List[float] = None + fn: Callable, + quantiles: Sequence[float] = (), + scale: float = 1.0, ) -> Tuple[float, float, float]: - quantiles = quantiles or DEFAULT_QUANTILES + quantiles = list(quantiles or DEFAULT_QUANTILES) ms, min_ms, max_ms = triton.testing.do_bench(fn, quantiles=quantiles) - return 1000 * ms, 1000 * max_ms, 1000 * min_ms + return 1000 * ms / scale, 1000 * max_ms / scale, 1000 * min_ms / scale diff --git a/python/sglang/jit_kernel/csrc/elementwise/fused_qknorm_rope.cuh b/python/sglang/jit_kernel/csrc/elementwise/fused_qknorm_rope.cuh index 40401572b3b8..1c1f41dccf47 100644 --- a/python/sglang/jit_kernel/csrc/elementwise/fused_qknorm_rope.cuh +++ b/python/sglang/jit_kernel/csrc/elementwise/fused_qknorm_rope.cuh @@ -39,11 +39,11 @@ namespace { // When factor != 1.0, blends interpolated and extrapolated frequencies. // --------------------------------------------------------------------------- -__device__ inline float -compute_freq_yarn(float base, int rotary_dim, int half_dim, float factor, float low, float high) { +template +__device__ inline float compute_freq(float base, int rotary_dim, int half_dim, float factor, float low, float high) { float freq = powf(base, -2.0f * half_dim / static_cast(rotary_dim)); - if (factor != 1.0f) { + if constexpr (yarn) { float inv_freq_extrapolation = freq; float inv_freq_interpolation = freq / factor; @@ -68,11 +68,14 @@ compute_freq_yarn(float base, int rotary_dim, int half_dim, float factor, float // // Each warp processes one (token, head) pair. // head_dim: compile-time head dimension (64, 128, or 256) -// interleave: true -> interleave / GPT-J style RoPE (!is_neox) -// false -> NeoX style RoPE (is_neox) +// interleave: true → interleave / GPT-J style RoPE (!is_neox) +// false → NeoX style RoPE (is_neox) // --------------------------------------------------------------------------- -template +// interleave (GPT-J) pairs (2k,2k+1) share the same freq/theta, +// so sin/cos is computed once per pair and copied to the odd element, +// halving powf + __sincosf calls vs a naive per-element approach. +template __global__ void fusedQKNormRopeKernel( __nv_bfloat16* qkv, // [num_tokens, (nq+nk+nv)*head_dim], in-place int const num_heads_q, @@ -139,36 +142,65 @@ __global__ void fusedQKNormRopeKernel( // Apply RMSNorm // ------------------------------------------------------------------- float rms_rcp = rsqrtf(sumOfSquares / static_cast(head_dim) + eps); - for (int i = 0; i < numElemsPerThread; i++) { - int dim = laneId * numElemsPerThread + i; - float weight = isQ ? device::cast(q_weight[dim]) : device::cast(k_weight[dim]); - elements[i] *= rms_rcp * weight; + { + vec_T wvec; + wvec.load((isQ ? q_weight : k_weight) + offsetThread - offsetWarp); + for (int i = 0; i < numElemsPerThread; i++) { + elements[i] *= rms_rcp * device::cast(wvec[i]); + } } // ------------------------------------------------------------------- // Apply RoPE to the first rotary_dim elements // ------------------------------------------------------------------- - float elements2[numElemsPerThread]; - float cos_vals[numElemsPerThread]; - float sin_vals[numElemsPerThread]; float pos_id = static_cast(position_ids[tokenIdx]); int const rotary_lanes = rotary_dim / numElemsPerThread; bool const applyRotary = (laneId < rotary_lanes); if (applyRotary) { if constexpr (interleave) { - // Interleave (GPT-J) style: pairs of consecutive elements share a frequency - for (int i = 0; i < numElemsPerThread; i++) { - elements2[i] = (i % 2 == 0) ? -elements[i + 1] : elements[i - 1]; + // Pairs (2k, 2k+1) share the same half_dim → same freq/theta. + // numElemsPerThread is always even (head_dim/32, head_dim in {64,128,256}), + // so we step by 2 and handle both elements of each pair per iteration. + // + // freq follows a geometric series across pairs: freq[k] = freq[0] * ratio^k, + // where ratio = base^(-2/rotary_dim). Pre-compute both outside the loop to + // replace all but the first powf call with a single multiply per iteration. + // + // sin/cos are applied immediately to e0/e1, eliminating the elements2, + // cos_vals, sin_vals intermediate arrays and reducing register pressure. + int const half_dim_start = laneId * numElemsPerThread / 2; + float freq = powf(base, -2.0f * static_cast(half_dim_start) / static_cast(rotary_dim)); + float const freq_ratio = powf(base, -2.0f / static_cast(rotary_dim)); + + for (int i = 0; i < numElemsPerThread; i += 2) { + float e0 = elements[i]; + float e1 = elements[i + 1]; + + float f = freq; + if constexpr (yarn) { + int half_dim = half_dim_start + i / 2; + float inv_freq_interpolation = freq / factor; + float high_adj = (fabsf(low - high) <= 1e-6f) ? high + 0.001f : high; + float linear_func = (static_cast(half_dim) - low) / (high_adj - low); + float ramp_func = fminf(fmaxf(linear_func, 0.0f), 1.0f); + float extrap_factor = 1.0f - ramp_func; + f = inv_freq_interpolation * (1.0f - extrap_factor) + freq * extrap_factor; + } - int dim_idx = laneId * numElemsPerThread + i; - int half_dim = dim_idx / 2; - float freq = compute_freq_yarn(base, rotary_dim, half_dim, factor, low, high); - float theta = pos_id * freq; - __sincosf(theta, &sin_vals[i], &cos_vals[i]); + float s, c; + __sincosf(pos_id * f, &s, &c); + elements[i] = (e0 * c - e1 * s) * attention_factor; + elements[i + 1] = (e1 * c + e0 * s) * attention_factor; + + freq *= freq_ratio; } } else { // NeoX style: first and second halves of the rotary region are paired + float elements2[numElemsPerThread]; + float cos_vals[numElemsPerThread]; + float sin_vals[numElemsPerThread]; + __syncwarp(); int const half_rotary_lanes = rotary_lanes / 2; // Avoid UB from (1u << 32) when rotary_lanes == 32 @@ -183,15 +215,15 @@ __global__ void fusedQKNormRopeKernel( // Remap so that both halves use the same set of frequencies dim_idx = (dim_idx * 2) % rotary_dim; int half_dim = dim_idx / 2; - float freq = compute_freq_yarn(base, rotary_dim, half_dim, factor, low, high); + float freq = compute_freq(base, rotary_dim, half_dim, factor, low, high); float theta = pos_id * freq; __sincosf(theta, &sin_vals[i], &cos_vals[i]); } __syncwarp(); - } - for (int i = 0; i < numElemsPerThread; i++) { - elements[i] = (elements[i] * cos_vals[i] + elements2[i] * sin_vals[i]) * attention_factor; + for (int i = 0; i < numElemsPerThread; i++) { + elements[i] = (elements[i] * cos_vals[i] + elements2[i] * sin_vals[i]) * attention_factor; + } } } @@ -209,14 +241,8 @@ __global__ void fusedQKNormRopeKernel( // --------------------------------------------------------------------------- // Host-side tvm-ffi entry point -// -// HEAD_DIM and INTERLEAVE are compile-time template parameters, passed as -// template arguments from Python via the cuda_wrappers specialisation in -// fused_qknorm_rope.py (e.g. fused_qk_norm_rope<128, false>). This avoids -// both runtime dispatch and macro-based specialisation. // --------------------------------------------------------------------------- -template void fused_qk_norm_rope( tvm::ffi::TensorView qkv, // [num_tokens, (nq+nk+nv)*head_dim] bf16 tvm::ffi::TensorView q_weight, // [head_dim] bf16 @@ -225,8 +251,10 @@ void fused_qk_norm_rope( int num_heads_q, int num_heads_k, int num_heads_v, + int head_dim, float eps, float base, + int is_neox, // 0 = interleave style, 1 = NeoX style float factor, float low, float high, @@ -234,8 +262,6 @@ void fused_qk_norm_rope( int rotary_dim) { using namespace host; - static_assert(HEAD_DIM == 64 || HEAD_DIM == 128 || HEAD_DIM == 256, "HEAD_DIM must be 64, 128, or 256"); - RuntimeCheck(qkv.device().device_type == kDLCUDA, "qkv must be a CUDA tensor"); RuntimeCheck(qkv.is_contiguous(), "qkv must be contiguous"); RuntimeCheck(qkv.dtype().code == kDLBfloat && qkv.dtype().bits == 16, "qkv must be bfloat16"); @@ -244,12 +270,12 @@ void fused_qk_norm_rope( RuntimeCheck(q_weight.is_contiguous(), "q_weight must be contiguous"); RuntimeCheck(q_weight.dtype().code == kDLBfloat && q_weight.dtype().bits == 16, "q_weight must be bfloat16"); RuntimeCheck( - q_weight.ndim() == 1 && static_cast(q_weight.size(0)) == HEAD_DIM, "q_weight must be 1D of size head_dim"); + q_weight.ndim() == 1 && static_cast(q_weight.size(0)) == head_dim, "q_weight must be 1D of size head_dim"); RuntimeCheck(k_weight.is_contiguous(), "k_weight must be contiguous"); RuntimeCheck(k_weight.dtype().code == kDLBfloat && k_weight.dtype().bits == 16, "k_weight must be bfloat16"); RuntimeCheck( - k_weight.ndim() == 1 && static_cast(k_weight.size(0)) == HEAD_DIM, "k_weight must be 1D of size head_dim"); + k_weight.ndim() == 1 && static_cast(k_weight.size(0)) == head_dim, "k_weight must be 1D of size head_dim"); RuntimeCheck(position_ids.device().device_type == kDLCUDA, "position_ids must be a CUDA tensor"); RuntimeCheck(position_ids.is_contiguous(), "position_ids must be contiguous"); @@ -259,13 +285,20 @@ void fused_qk_norm_rope( int num_tokens = static_cast(qkv.size(0)); int total_heads = num_heads_q + num_heads_k + num_heads_v; RuntimeCheck( - static_cast(qkv.size(1)) == total_heads * HEAD_DIM, "qkv.size(1) must equal (nq + nk + nv) * head_dim"); + static_cast(qkv.size(1)) == total_heads * head_dim, "qkv.size(1) must equal (nq + nk + nv) * head_dim"); RuntimeCheck(static_cast(position_ids.size(0)) == num_tokens, "position_ids must have num_tokens elements"); - constexpr int numElemsPerThread = HEAD_DIM / 32; + static_assert( + JIT_HEAD_DIM == 64 || JIT_HEAD_DIM == 128 || JIT_HEAD_DIM == 256, "JIT_HEAD_DIM must be 64, 128, or 256"); + static_assert(JIT_INTERLEAVE == 0 || JIT_INTERLEAVE == 1, "JIT_INTERLEAVE must be 0 or 1"); + static_assert(JIT_YARN == 0 || JIT_YARN == 1, "JIT_YARN must be 0 or 1"); + RuntimeCheck(head_dim == JIT_HEAD_DIM, "head_dim mismatch with JIT-compiled kernel"); + + int numElemsPerThread = head_dim / 32; RuntimeCheck(rotary_dim % numElemsPerThread == 0, "rotary_dim must be divisible by (head_dim / 32)"); - if constexpr (!INTERLEAVE) { + bool neox = static_cast(is_neox); + if (neox) { // NeoX uses __shfl_xor_sync which requires half_rotary_lanes to be a power of 2 int rotary_lanes = rotary_dim / numElemsPerThread; int half_rotary_lanes = rotary_lanes / 2; @@ -273,35 +306,41 @@ void fused_qk_norm_rope( RuntimeCheck(is_pow2, "half_rotary_lanes must be a power of 2 for NeoX style RoPE"); } + bool interleave = !neox; + RuntimeCheck(interleave == static_cast(JIT_INTERLEAVE), "interleave mismatch with JIT-compiled kernel"); + bool use_yarn = (factor != 1.0f); + RuntimeCheck(use_yarn == static_cast(JIT_YARN), "yarn mismatch with JIT-compiled kernel"); + cudaStream_t stream = LaunchKernel::resolve_device(qkv.device()); constexpr int blockSize = 256; int warpsPerBlock = blockSize / 32; int totalQKHeads = num_heads_q + num_heads_k; int totalWarps = num_tokens * totalQKHeads; - int gridSize = host::div_ceil(totalWarps, warpsPerBlock); + int gridSize = div_ceil(totalWarps, warpsPerBlock); auto* qkv_ptr = reinterpret_cast<__nv_bfloat16*>(qkv.data_ptr()); auto const* qw_ptr = reinterpret_cast<__nv_bfloat16 const*>(q_weight.data_ptr()); auto const* kw_ptr = reinterpret_cast<__nv_bfloat16 const*>(k_weight.data_ptr()); auto const* pos_ptr = reinterpret_cast(position_ids.data_ptr()); - fusedQKNormRopeKernel<<>>( - qkv_ptr, - num_heads_q, - num_heads_k, - num_heads_v, - eps, - qw_ptr, - kw_ptr, - base, - pos_ptr, - num_tokens, - factor, - low, - high, - attention_factor, - rotary_dim); + fusedQKNormRopeKernel(JIT_INTERLEAVE), static_cast(JIT_YARN)> + <<>>( + qkv_ptr, + num_heads_q, + num_heads_k, + num_heads_v, + eps, + qw_ptr, + kw_ptr, + base, + pos_ptr, + num_tokens, + factor, + low, + high, + attention_factor, + rotary_dim); } } // namespace diff --git a/python/sglang/jit_kernel/csrc/elementwise/rmsnorm.cuh b/python/sglang/jit_kernel/csrc/elementwise/rmsnorm.cuh index 4f24b09736e1..2e1edd692d87 100644 --- a/python/sglang/jit_kernel/csrc/elementwise/rmsnorm.cuh +++ b/python/sglang/jit_kernel/csrc/elementwise/rmsnorm.cuh @@ -4,6 +4,7 @@ #include #include #include +#include #include @@ -47,6 +48,130 @@ __global__ void rmsnorm_cta(const RMSNormParams __grid_constant__ params) { PDLTriggerSecondary(); // launch secondary kernel } +// Pre-Blackwell: 16B vector, each thread loads/stores twice +template +__global__ __launch_bounds__(kDim / 16) void rmsnorm_cta_double(const RMSNormParams __grid_constant__ params) { + using namespace device; + using Float2 = packed_t; + using Storage = AlignedVector; + + constexpr auto kNumThreads = kDim / 16; + constexpr auto kNumWarps = kNumThreads / kWarpThreads; + + const auto& [input, weight_ptr, output, input_stride, output_stride, num_tokens, eps] = params; + const auto gmem = tile::Memory::cta(kNumThreads); + __shared__ float smem[32]; + + PDLWaitPrimary(); + + const auto input_ptr = pointer::offset(input, blockIdx.x * input_stride); + const auto output_ptr = pointer::offset(output, blockIdx.x * output_stride); + + const auto input_first = gmem.load(input_ptr, 0); + const auto input_second = gmem.load(input_ptr, 1); + const auto weight_first = gmem.load(weight_ptr, 0); + const auto weight_second = gmem.load(weight_ptr, 1); + + float sum_of_squares = 0.0f; +#pragma unroll + for (auto j = 0u; j < 4u; ++j) { + const auto [x, y] = cast(input_first[j]); + sum_of_squares += x * x + y * y; + } +#pragma unroll + for (auto j = 0u; j < 4u; ++j) { + const auto [x, y] = cast(input_second[j]); + sum_of_squares += x * x + y * y; + } + + sum_of_squares = warp::reduce_sum(sum_of_squares); + const auto warp_id = threadIdx.x / kWarpThreads; + smem[warp_id] = sum_of_squares; + __syncthreads(); + if (warp_id == 0) { + const auto tx = threadIdx.x; + const auto local_sum = tx < kNumWarps ? smem[tx] : 0.0f; + sum_of_squares = warp::reduce_sum(local_sum); + smem[tx] = math::rsqrt(sum_of_squares / kDim + eps); + } + __syncthreads(); + const float norm_factor = smem[warp_id]; + + Storage output_first, output_second; +#pragma unroll + for (auto j = 0u; j < 4u; ++j) { + const auto [ix, iy] = cast(input_first[j]); + const auto [wx, wy] = cast(weight_first[j]); + output_first[j] = cast(fp32x2_t{ix * norm_factor * wx, iy * norm_factor * wy}); + } +#pragma unroll + for (auto j = 0u; j < 4u; ++j) { + const auto [ix, iy] = cast(input_second[j]); + const auto [wx, wy] = cast(weight_second[j]); + output_second[j] = cast(fp32x2_t{ix * norm_factor * wx, iy * norm_factor * wy}); + } + + gmem.store(output_ptr, output_first, 0); + gmem.store(output_ptr, output_second, 1); + + PDLTriggerSecondary(); +} + +// Blackwell: 32B vector, each thread loads/stores once +template +__global__ __launch_bounds__(kDim / 16) void rmsnorm_cta_wide(const RMSNormParams __grid_constant__ params) { + using namespace device; + using Float2 = packed_t; + using Storage = AlignedVector; + + constexpr auto kNumThreads = kDim / 16; + constexpr auto kNumWarps = kNumThreads / kWarpThreads; + + const auto& [input, weight_ptr, output, input_stride, output_stride, num_tokens, eps] = params; + const auto gmem = tile::Memory::cta(kNumThreads); + __shared__ float smem[32]; + + PDLWaitPrimary(); + + const auto input_ptr = pointer::offset(input, blockIdx.x * input_stride); + const auto output_ptr = pointer::offset(output, blockIdx.x * output_stride); + + const auto input_vec = gmem.load(input_ptr); + const auto weight_vec = gmem.load(weight_ptr); + + float sum_of_squares = 0.0f; +#pragma unroll + for (auto j = 0u; j < 8u; ++j) { + const auto [x, y] = cast(input_vec[j]); + sum_of_squares += x * x + y * y; + } + + sum_of_squares = warp::reduce_sum(sum_of_squares); + const auto warp_id = threadIdx.x / kWarpThreads; + smem[warp_id] = sum_of_squares; + __syncthreads(); + if (warp_id == 0) { + const auto tx = threadIdx.x; + const auto local_sum = tx < kNumWarps ? smem[tx] : 0.0f; + sum_of_squares = warp::reduce_sum(local_sum); + smem[tx] = math::rsqrt(sum_of_squares / kDim + eps); + } + __syncthreads(); + const float norm_factor = smem[warp_id]; + + Storage output_vec; +#pragma unroll + for (auto j = 0u; j < 8u; ++j) { + const auto [ix, iy] = cast(input_vec[j]); + const auto [wx, wy] = cast(weight_vec[j]); + output_vec[j] = cast(fp32x2_t{ix * norm_factor * wx, iy * norm_factor * wy}); + } + + gmem.store(output_ptr, output_vec); + + PDLTriggerSecondary(); +} + template __global__ void rmsnorm_warp(const RMSNormParams __grid_constant__ params) { using namespace device; @@ -178,4 +303,59 @@ struct RMSNormKernel { } }; +template +struct RMSNormHalfKernel { + static_assert(kDim % 512 == 0 && sizeof(DType) == 2); +#if SGL_ARCH_BLACKWELL_OR_GREATER + static constexpr auto kernel = rmsnorm_cta_wide; +#else + static constexpr auto kernel = rmsnorm_cta_double; +#endif + static constexpr auto kBlockSize = static_cast(kDim / 16); + + static void + run(const tvm::ffi::TensorView input, + const tvm::ffi::TensorView weight, + const tvm::ffi::TensorView output, + float eps) { + using namespace host; + auto N = SymbolicSize{"num_tokens"}; + auto D = SymbolicSize{"hidden_size"}; + auto SI = SymbolicSize{"input_stride"}; + auto SO = SymbolicSize{"output_stride"}; + auto device = SymbolicDevice{}; + D.set_value(kDim); + device.set_options(); + + TensorMatcher({N, D}) // input + .with_strides({SI, 1}) + .with_dtype() + .with_device(device) + .verify(input); + TensorMatcher({D}) // weight + .with_dtype() + .with_device(device) + .verify(weight); + TensorMatcher({N, D}) // output + .with_strides({SO, 1}) + .with_dtype() + .with_device(device) + .verify(output); + + const auto num_tokens = static_cast(N.unwrap()); + const auto params = RMSNormParams{ + .input = input.data_ptr(), + .weight = weight.data_ptr(), + .output = output.data_ptr(), + .input_stride = SI.unwrap(), + .output_stride = SO.unwrap(), + .num_tokens = num_tokens, + .eps = eps, + }; + + LaunchKernel(num_tokens, kBlockSize, device.unwrap()) // + .enable_pdl(kUsePDL)(kernel, params); + } +}; + } // namespace diff --git a/python/sglang/jit_kernel/csrc/gemm/marlin/marlin_template.h b/python/sglang/jit_kernel/csrc/gemm/marlin/marlin_template.h index 6c4112e633fd..651710a963f7 100644 --- a/python/sglang/jit_kernel/csrc/gemm/marlin/marlin_template.h +++ b/python/sglang/jit_kernel/csrc/gemm/marlin/marlin_template.h @@ -484,11 +484,11 @@ __global__ void Marlin( constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta; // Scale sizes/strides without act_order - int s_gl_stride = prob_n / 8; - constexpr int s_sh_stride = 16 * thread_n_blocks / 8; - constexpr int s_tb_groups = !has_act_order && group_blocks != -1 && group_blocks < thread_k_blocks - ? thread_k_blocks / group_blocks / (w_type == host::kFE2M1f ? 2 : 1) - : 1; + // FP4 (kFE2M1f) uses FP8 scales (1 byte/element), others use FP16 (2 bytes) + int s_gl_stride = prob_n / (w_type == host::kFE2M1f ? 16 : 8); + constexpr int s_sh_stride = 16 * thread_n_blocks / (w_type == host::kFE2M1f ? 16 : 8); + constexpr int s_tb_groups = + !has_act_order && group_blocks != -1 && group_blocks < thread_k_blocks ? thread_k_blocks / group_blocks : 1; constexpr int s_sh_stage = s_tb_groups * s_sh_stride; int s_gl_rd_delta = s_gl_stride; @@ -540,8 +540,7 @@ __global__ void Marlin( if constexpr (group_blocks == -1) { s_gl_rd = s_sh_stride * slice_col + threadIdx.x; } else { - s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) / (w_type == host::kFE2M1f ? 2 : 1) + - s_sh_stride * slice_col + threadIdx.x; + s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) + s_sh_stride * slice_col + threadIdx.x; } } auto s_sh_wr = threadIdx.x; @@ -563,15 +562,7 @@ __global__ void Marlin( // we scale a `half2` tile in column-major layout in the former and in // row-major in the latter case. int s_sh_rd; - if constexpr (group_blocks != -1 && w_type == host::kFE2M1f) { - auto warp_id = threadIdx.x / 32; - int n_warps = thread_n_blocks / 4; - int warp_row = warp_id / n_warps; - - s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + (threadIdx.x % 32) / 4; - s_sh_rd = s_sh_rd * 2 + warp_row % 2; - - } else if constexpr (group_blocks != -1) + if constexpr (group_blocks != -1) s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + (threadIdx.x % 32) / 4; else if constexpr (group_blocks == -1 && (m_block_size_8 || (has_zp && !dequant_skip_flop))) s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + (threadIdx.x % 32) / 8; @@ -876,7 +867,7 @@ __global__ void Marlin( cur_k += k_iter_size * (k % b_sh_wr_iters); int k_blocks = cur_k / 16; - int cur_group_id = k_blocks / (group_blocks * (w_type == host::kFE2M1f ? 2 : 1)); + int cur_group_id = k_blocks / group_blocks; int4* sh_s_stage = sh_s + s_sh_stage * pipe; diff --git a/python/sglang/jit_kernel/csrc/gemm/marlin_moe/marlin_template.h b/python/sglang/jit_kernel/csrc/gemm/marlin_moe/marlin_template.h index bf7dcb202301..566fa5f59606 100644 --- a/python/sglang/jit_kernel/csrc/gemm/marlin_moe/marlin_template.h +++ b/python/sglang/jit_kernel/csrc/gemm/marlin_moe/marlin_template.h @@ -626,11 +626,10 @@ __global__ void Marlin( constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta; // Scale sizes/strides without act_order - int s_gl_stride = prob_n / 8; - constexpr int s_sh_stride = 16 * thread_n_blocks / 8; - constexpr int s_tb_groups = !has_act_order && group_blocks != -1 && group_blocks < thread_k_blocks - ? thread_k_blocks / group_blocks / (w_type == host::kFE2M1f ? 2 : 1) - : 1; + int s_gl_stride = prob_n / (w_type == host::kFE2M1f ? 16 : 8); + constexpr int s_sh_stride = 16 * thread_n_blocks / (w_type == host::kFE2M1f ? 16 : 8); + constexpr int s_tb_groups = + !has_act_order && group_blocks != -1 && group_blocks < thread_k_blocks ? thread_k_blocks / group_blocks : 1; constexpr int s_sh_stage = s_tb_groups * s_sh_stride; int s_gl_rd_delta = s_gl_stride; @@ -682,8 +681,7 @@ __global__ void Marlin( if constexpr (group_blocks == -1) { s_gl_rd = s_sh_stride * slice_col + threadIdx.x; } else { - s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) / (w_type == host::kFE2M1f ? 2 : 1) + - s_sh_stride * slice_col + threadIdx.x; + s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) + s_sh_stride * slice_col + threadIdx.x; } } auto s_sh_wr = threadIdx.x; @@ -705,15 +703,7 @@ __global__ void Marlin( // we scale a `half2` tile in column-major layout in the former and in // row-major in the latter case. int s_sh_rd; - if constexpr (group_blocks != -1 && w_type == host::kFE2M1f) { - auto warp_id = threadIdx.x / 32; - int n_warps = thread_n_blocks / 4; - int warp_row = warp_id / n_warps; - - s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + (threadIdx.x % 32) / 4; - s_sh_rd = s_sh_rd * 2 + (warp_row / group_blocks) % 2; - - } else if constexpr (group_blocks != -1) + if constexpr (group_blocks != -1) s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + (threadIdx.x % 32) / 4; else if constexpr (group_blocks == -1 && (m_block_size_8 || (has_zp && !dequant_skip_flop))) s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + (threadIdx.x % 32) / 8; @@ -1038,18 +1028,15 @@ __global__ void Marlin( cur_k += k_iter_size * (k % b_sh_wr_iters); int k_blocks = cur_k / 16; - int cur_group_id = k_blocks / (group_blocks * (w_type == host::kFE2M1f ? 2 : 1)); + int cur_group_id = k_blocks / group_blocks; int4* sh_s_stage = sh_s + s_sh_stage * pipe; if constexpr (w_type_id != host::kFE2M1f.id()) { reinterpret_cast(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd + cur_group_id * s_sh_stride]; - } else if constexpr (group_blocks == 1 || thread_k_blocks > 4) { - reinterpret_cast(&frag_s[k % 2])[0] = - reinterpret_cast(sh_s_stage)[s_sh_rd + cur_group_id * (2 * s_sh_stride)]; } else { reinterpret_cast(&frag_s[k % 2])[0] = - reinterpret_cast(sh_s_stage)[s_sh_rd + cur_group_id * (2 * s_sh_stride) + k % 2]; + reinterpret_cast(sh_s_stage)[s_sh_rd + cur_group_id * (2 * s_sh_stride)]; } } } @@ -1243,17 +1230,19 @@ __global__ void Marlin( } } - // Commented out FP4/FP8 scale dequantization since we don't generate - // kFE2M1f kernels to reduce compilation time - // if constexpr (w_type == host::kFE2M1f) { - // int s_quant_0 = reinterpret_cast(frag_s[k2])[0]; - // int s_quant_1 = reinterpret_cast(frag_s[k2])[1]; - // - // dequant_fp8_scales( - // s_quant_0, reinterpret_cast(&frag_s[k2])); - // dequant_fp8_scales( - // s_quant_1, reinterpret_cast(&frag_s[k2]) + 2); - // } +#ifdef SGL_MOE_MARLIN_FP4 + // Convert FP8 per-group scales to BF16/FP16 before applying them. + // Required for kFE2M1f (NVFP4): frag_s holds raw float8_e4m3fn bytes; + // without this conversion scale would misinterpret them as + // BF16/FP16, producing NaN/Inf multipliers. + if constexpr (w_type == host::kFE2M1f) { + int s_quant_0 = reinterpret_cast(frag_s[k2])[0]; + int s_quant_1 = reinterpret_cast(frag_s[k2])[1]; + + dequant_fp8_scales(s_quant_0, reinterpret_cast(&frag_s[k2])); + dequant_fp8_scales(s_quant_1, reinterpret_cast(&frag_s[k2]) + 2); + } +#endif // We have the m dimension as the inner loop in order to encourage overlapping // dequantization and matmul operations. diff --git a/python/sglang/jit_kernel/csrc/gemm/marlin_moe/moe_wna16_marlin.cuh b/python/sglang/jit_kernel/csrc/gemm/marlin_moe/moe_wna16_marlin.cuh index 81c021dc8ecc..d89954200c88 100644 --- a/python/sglang/jit_kernel/csrc/gemm/marlin_moe/moe_wna16_marlin.cuh +++ b/python/sglang/jit_kernel/csrc/gemm/marlin_moe/moe_wna16_marlin.cuh @@ -453,7 +453,9 @@ MarlinFuncPtr get_marlin_kernel( COMMON_GET_IF(host::kU4B8) COMMON_GET_IF(host::kU8B128) +#ifdef SGL_MOE_MARLIN_FP4 NVFP4_GET_IF(host::kFE2M1f) +#endif BIGGROUP_GET_IF(host::kFE4M3fn) diff --git a/python/sglang/jit_kernel/csrc/gemm/nvfp4/nvfp4_scaled_mm_common.cuh b/python/sglang/jit_kernel/csrc/gemm/nvfp4/nvfp4_scaled_mm_common.cuh new file mode 100644 index 000000000000..f5ebca05b37c --- /dev/null +++ b/python/sglang/jit_kernel/csrc/gemm/nvfp4/nvfp4_scaled_mm_common.cuh @@ -0,0 +1,66 @@ +/* Copyright 2026 SGLang Team. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#pragma once + +#include +#include +#include + +#include +#include + +#include +#include +#include + +using namespace host; + +// clang-format off +#include "cutlass/cutlass.h" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/util/packed_stride.hpp" +// clang-format on + +#define CUTLASS_CHECK(status) \ + { \ + cutlass::Status error = status; \ + RuntimeCheck(error == cutlass::Status::kSuccess, cutlassGetStatusString(error)); \ + } + +using namespace cute; + +inline uint32_t next_pow_2(uint32_t x) noexcept { + if (x <= 1) return 1; + return 1u << (32 - __builtin_clz(x - 1)); +} + +inline auto alloc_workspace_tensor(size_t required_bytes, DLDevice device) -> tvm::ffi::Tensor { + if (required_bytes == 0) return {}; + DLDataType u8 = {kDLUInt, 8, 1}; + int64_t shape[] = {static_cast(required_bytes)}; + return ffi::empty(tvm::ffi::ShapeView(shape, 1), u8, device); +} + +inline int getSMVersion(int device_id) { + int sm_major = 0; + int sm_minor = 0; + RuntimeDeviceCheck(cudaDeviceGetAttribute(&sm_major, cudaDevAttrComputeCapabilityMajor, device_id)); + RuntimeDeviceCheck(cudaDeviceGetAttribute(&sm_minor, cudaDevAttrComputeCapabilityMinor, device_id)); + return sm_major * 10 + sm_minor; +} diff --git a/python/sglang/jit_kernel/csrc/gemm/nvfp4/nvfp4_scaled_mm_kernels.cuh b/python/sglang/jit_kernel/csrc/gemm/nvfp4/nvfp4_scaled_mm_kernels.cuh index 9cc309f14b55..8c5cfefd7956 100644 --- a/python/sglang/jit_kernel/csrc/gemm/nvfp4/nvfp4_scaled_mm_kernels.cuh +++ b/python/sglang/jit_kernel/csrc/gemm/nvfp4/nvfp4_scaled_mm_kernels.cuh @@ -1,4 +1,4 @@ -/* Copyright 2025 SGLang Team. All Rights Reserved. +/* Copyright 2026 SGLang Team. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,593 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include -#include - -#include -#include - -#include -#include -#include -#include - -using namespace host; - -// clang-format off -#include "cutlass/cutlass.h" -#include "cutlass/gemm/collective/collective_builder.hpp" -#include "cutlass/epilogue/collective/collective_builder.hpp" -#include "cutlass/gemm/device/gemm_universal_adapter.h" -#include "cutlass/gemm/kernel/gemm_universal.hpp" -#include "cutlass/util/packed_stride.hpp" -// clang-format on - -/** - * Helper function for checking CUTLASS errors - */ -#define CUTLASS_CHECK(status) \ - { \ - cutlass::Status error = status; \ - RuntimeCheck(error == cutlass::Status::kSuccess, cutlassGetStatusString(error)); \ - } - -using namespace cute; - -// Helper function for next power of 2 -inline uint32_t next_pow_2(uint32_t x) { - if (x == 0) return 1; - x--; - x |= x >> 1; - x |= x >> 2; - x |= x >> 4; - x |= x >> 8; - x |= x >> 16; - return x + 1; -} - -struct WorkspaceKey { - int device_id; - uintptr_t stream; - auto operator==(const WorkspaceKey&) const -> bool = default; -}; - -struct WorkspaceKeyHash { - auto operator()(const WorkspaceKey& key) const -> size_t { - size_t h1 = std::hash{}(key.device_id); - size_t h2 = std::hash{}(key.stream); - return h1 ^ (h2 + 0x9e3779b97f4a7c15ULL + (h1 << 6) + (h1 >> 2)); - } -}; - -struct WorkspaceState { - void* ptr = nullptr; - size_t bytes = 0; -}; - -inline auto get_cached_workspace(size_t required_bytes, int device_id, cudaStream_t stream) -> void* { - if (required_bytes == 0) { - return nullptr; - } - - thread_local std::unordered_map cache; - WorkspaceKey key{device_id, reinterpret_cast(stream)}; - auto& ws = cache[key]; - - if (ws.ptr != nullptr && ws.bytes >= required_bytes) { - return ws.ptr; - } - - RuntimeDeviceCheck(cudaSetDevice(device_id)); - if (ws.ptr != nullptr) { - RuntimeDeviceCheck(cudaFreeAsync(ws.ptr, stream)); - ws.ptr = nullptr; - ws.bytes = 0; - } - RuntimeDeviceCheck(cudaMallocAsync(&ws.ptr, required_bytes, stream)); - ws.bytes = required_bytes; - return ws.ptr; -} - -#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) || \ - defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED) -// Config(half_t/bfloat16_t) for M <= 128 -template -struct KernelConfigM128 { - using OutputType = T; - using MmaTileShape = Shape<_128, _256, _256>; - using ClusterShape = Shape; - using EpilogueTile = Shape<_128, _64>; // Avoid register spilling - using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized1Sm; - using MainloopSchedule = cutlass::gemm::KernelTmaWarpSpecialized1SmNvf4Sm100; - const static dim3 preferred_cluster; - const static dim3 fallback_cluster; -}; -template -const dim3 KernelConfigM128::preferred_cluster(1, 4, 1); -template -const dim3 KernelConfigM128::fallback_cluster(1, 2, 1); - -// Config(half_t/bfloat16_t) for M <= 256 -template -struct KernelConfigM256 { - using OutputType = T; - using MmaTileShape = Shape<_256, _256, _256>; - using ClusterShape = Shape; - using EpilogueTile = Shape<_128, _64>; // Avoid register spilling - using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized2Sm; - using MainloopSchedule = cutlass::gemm::KernelTmaWarpSpecialized2SmNvf4Sm100; - const static dim3 preferred_cluster; - const static dim3 fallback_cluster; -}; -template -const dim3 KernelConfigM256::preferred_cluster(2, 4, 1); -template -const dim3 KernelConfigM256::fallback_cluster(2, 1, 1); - -// Default config(half_t/bfloat16_t) for M > 256 -template -struct KernelConfigDefault { - using OutputType = T; - using MmaTileShape = Shape<_256, _256, _256>; - using ClusterShape = Shape; - using EpilogueTile = Shape<_128, _64>; // Avoid register spilling - using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized2Sm; - using MainloopSchedule = cutlass::gemm::KernelTmaWarpSpecialized2SmNvf4Sm100; - const static dim3 preferred_cluster; - const static dim3 fallback_cluster; -}; -template -const dim3 KernelConfigDefault::preferred_cluster(4, 4, 1); -template -const dim3 KernelConfigDefault::fallback_cluster(2, 1, 1); - -struct KernelConfigFp32 { - using OutputType = float; - using MmaTileShape = Shape<_128, _128, _256>; - using ClusterShape = Shape; - using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto; - using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized1Sm; - using MainloopSchedule = cutlass::gemm::KernelTmaWarpSpecialized1SmNvf4Sm100; - const static dim3 preferred_cluster; - const static dim3 fallback_cluster; -}; -const dim3 KernelConfigFp32::preferred_cluster = dim3(1, 4, 1); -const dim3 KernelConfigFp32::fallback_cluster = dim3(1, 2, 1); - -// SM120 specific configurations -struct sm120_fp4_config_M256 { - using ClusterShape = Shape<_1, _1, _1>; - using MmaTileShape = Shape<_128, _128, _128>; - using PerSmTileShape_MNK = Shape<_128, _128, _128>; -}; - -struct sm120_fp4_config_default { - using ClusterShape = Shape<_1, _1, _1>; - using MmaTileShape = Shape<_256, _128, _128>; - using PerSmTileShape_MNK = Shape<_256, _128, _128>; -}; - -template -struct Fp4GemmSm100 { - using Config = KernelConfig; // For generating args - using OutputType = typename KernelConfig::OutputType; - // A matrix configuration - using ElementA = cutlass::nv_float4_t; - using LayoutATag = cutlass::layout::RowMajor; - static constexpr int AlignmentA = 32; - - // B matrix configuration - using ElementB = cutlass::nv_float4_t; - using LayoutBTag = cutlass::layout::ColumnMajor; - static constexpr int AlignmentB = 32; - - // C/D matrix configuration - using ElementD = OutputType; - using ElementC = OutputType; - using LayoutCTag = cutlass::layout::RowMajor; - using LayoutDTag = cutlass::layout::RowMajor; - static constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value; - static constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; - // Kernel functional config - using ElementAccumulator = float; - using ArchTag = cutlass::arch::Sm100; - using OperatorClass = cutlass::arch::OpClassBlockScaledTensorOp; - - // Kernel Perf config - using MmaTileShape = typename KernelConfig::MmaTileShape; - using ClusterShape = typename KernelConfig::ClusterShape; - using EpilogueTile = typename KernelConfig::EpilogueTile; - using EpilogueSchedule = typename KernelConfig::EpilogueSchedule; - using MainloopSchedule = typename KernelConfig::MainloopSchedule; - - using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< - ArchTag, - OperatorClass, - MmaTileShape, - ClusterShape, - EpilogueTile, - ElementAccumulator, - ElementAccumulator, - void, - LayoutCTag, - AlignmentC, - ElementD, - LayoutDTag, - AlignmentD, - EpilogueSchedule, - cutlass::epilogue::fusion::LinearCombination>::CollectiveOp; - - using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< - ArchTag, - OperatorClass, - ElementA, - LayoutATag, - AlignmentA, - ElementB, - LayoutBTag, - AlignmentB, - ElementAccumulator, - MmaTileShape, - ClusterShape, - cutlass::gemm::collective::StageCountAutoCarveout( - sizeof(typename CollectiveEpilogue::SharedStorage))>, - MainloopSchedule>::CollectiveOp; - - using GemmKernel = - cutlass::gemm::kernel::GemmUniversal, CollectiveMainloop, CollectiveEpilogue, void>; - using Gemm = cutlass::gemm::device::GemmUniversalAdapter; - using StrideA = typename Gemm::GemmKernel::StrideA; - using LayoutA = decltype(cute::make_layout(make_shape(0, 0, 0), StrideA{})); - using LayoutSFA = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFA; - using StrideB = typename Gemm::GemmKernel::StrideB; - using LayoutB = decltype(cute::make_layout(make_shape(0, 0, 0), StrideB{})); - using LayoutSFB = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFB; - using StrideC = typename Gemm::GemmKernel::StrideC; - using LayoutC = decltype(cute::make_layout(make_shape(0, 0, 0), StrideC{})); - using StrideD = typename Gemm::GemmKernel::StrideD; - using LayoutD = decltype(cute::make_layout(make_shape(0, 0, 0), StrideD{})); -}; - -// SM120 specific GEMM template -template -struct Fp4GemmSm120 { - using ElementA = cutlass::nv_float4_t; - using LayoutATag = cutlass::layout::RowMajor; - static constexpr int AlignmentA = 32; - - using ElementB = cutlass::nv_float4_t; - using LayoutBTag = cutlass::layout::ColumnMajor; - static constexpr int AlignmentB = 32; - - using ElementD = OutType; - using ElementC = OutType; - using LayoutCTag = cutlass::layout::RowMajor; - using LayoutDTag = cutlass::layout::RowMajor; - static constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value; - static constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; - - using ElementAccumulator = float; - using ArchTag = cutlass::arch::Sm120; - using OperatorClass = cutlass::arch::OpClassBlockScaledTensorOp; - - using MmaTileShape = typename Config::MmaTileShape; - using ClusterShape = typename Config::ClusterShape; - using PerSmTileShape_MNK = typename Config::PerSmTileShape_MNK; - - using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< - ArchTag, - OperatorClass, - PerSmTileShape_MNK, - ClusterShape, - cutlass::epilogue::collective::EpilogueTileAuto, - ElementAccumulator, - ElementAccumulator, - ElementC, - LayoutCTag, - AlignmentC, - ElementD, - LayoutDTag, - AlignmentD, - cutlass::epilogue::collective::EpilogueScheduleAuto>::CollectiveOp; - - using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< - ArchTag, - OperatorClass, - ElementA, - LayoutATag, - AlignmentA, - ElementB, - LayoutBTag, - AlignmentB, - ElementAccumulator, - MmaTileShape, - ClusterShape, - cutlass::gemm::collective::StageCountAutoCarveout( - sizeof(typename CollectiveEpilogue::SharedStorage))>, - cutlass::gemm::collective::KernelScheduleAuto>::CollectiveOp; - - using GemmKernel = - cutlass::gemm::kernel::GemmUniversal, CollectiveMainloop, CollectiveEpilogue, void>; - - using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -}; - -template -typename T::Gemm::Arguments args_from_options( - tvm::ffi::TensorView D, - tvm::ffi::TensorView A, - tvm::ffi::TensorView B, - tvm::ffi::TensorView A_sf, - tvm::ffi::TensorView B_sf, - tvm::ffi::TensorView alpha, - int64_t M, - int64_t N, - int64_t K) { - using ElementA = typename T::Gemm::ElementA; - using ElementB = typename T::Gemm::ElementB; - using ElementSFA = cutlass::float_ue4m3_t; - using ElementSFB = cutlass::float_ue4m3_t; - using ElementD = typename T::Gemm::ElementD; - using ElementCompute = float; - using StrideA = typename T::StrideA; - using StrideB = typename T::StrideB; - using StrideD = typename T::StrideD; - using Sm1xxBlkScaledConfig = typename T::Gemm::GemmKernel::CollectiveMainloop::Sm1xxBlkScaledConfig; - - int m = static_cast(M); - int n = static_cast(N); - int k = static_cast(K); - auto stride_A = cutlass::make_cute_packed_stride(StrideA{}, {m, k, 1}); - auto stride_B = cutlass::make_cute_packed_stride(StrideB{}, {n, k, 1}); - auto stride_D = cutlass::make_cute_packed_stride(StrideD{}, {m, n, 1}); - - auto layout_SFA = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA(cute::make_shape(m, n, k, 1)); - auto layout_SFB = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFB(cute::make_shape(m, n, k, 1)); - - typename T::Gemm::Arguments arguments{ - cutlass::gemm::GemmUniversalMode::kGemm, - {m, n, k, 1}, - {// Mainloop arguments - static_cast(A.data_ptr()), - stride_A, - static_cast(B.data_ptr()), - stride_B, - static_cast(A_sf.data_ptr()), - layout_SFA, - static_cast(B_sf.data_ptr()), - layout_SFB}, - { // Epilogue arguments - {}, // epilogue.thread - nullptr, - stride_D, - static_cast(D.data_ptr()), - stride_D}}; - auto& fusion_args = arguments.epilogue.thread; - fusion_args.alpha_ptr = static_cast(alpha.data_ptr()); - using KernelConfig = typename T::Config; - arguments.hw_info.cluster_shape = KernelConfig::preferred_cluster; - arguments.hw_info.cluster_shape_fallback = KernelConfig::fallback_cluster; - return arguments; -} - -template -void runGemm( - tvm::ffi::TensorView D, - tvm::ffi::TensorView A, - tvm::ffi::TensorView B, - tvm::ffi::TensorView A_sf, - tvm::ffi::TensorView B_sf, - tvm::ffi::TensorView alpha, - int64_t m, - int64_t n, - int64_t k, - cudaStream_t stream) { - typename T::Gemm gemm; - auto arguments = args_from_options(D, A, B, A_sf, B_sf, alpha, m, n, k); - - size_t workspace_size = T::Gemm::get_workspace_size(arguments); - int device_id = A.device().device_id; - void* workspace = get_cached_workspace(workspace_size, device_id, stream); - - CUTLASS_CHECK(gemm.can_implement(arguments)); - - CUTLASS_CHECK(gemm.initialize(arguments, workspace, stream)); - - CUTLASS_CHECK(gemm.run(arguments, workspace, stream)); -} - -// SM120 specific args_from_options function -template -typename Gemm::Arguments args_from_options_sm120( - tvm::ffi::TensorView D, - tvm::ffi::TensorView A, - tvm::ffi::TensorView B, - tvm::ffi::TensorView A_sf, - tvm::ffi::TensorView B_sf, - tvm::ffi::TensorView alpha, - int M, - int N, - int K) { - using ElementA = typename Gemm::ElementA; - using ElementB = typename Gemm::ElementB; - using ElementD = typename Gemm::ElementD; - using ElementSFA = cutlass::float_ue4m3_t; - using ElementSFB = cutlass::float_ue4m3_t; - using ElementCompute = float; - - using StrideA = typename Gemm::GemmKernel::StrideA; - using StrideB = typename Gemm::GemmKernel::StrideB; - using StrideC = typename Gemm::GemmKernel::StrideC; - using StrideD = typename Gemm::GemmKernel::StrideD; - - using Sm1xxBlkScaledConfig = typename Gemm::GemmKernel::CollectiveMainloop::Sm1xxBlkScaledConfig; - - auto stride_A = cutlass::make_cute_packed_stride(StrideA{}, {M, K, 1}); - auto stride_B = cutlass::make_cute_packed_stride(StrideB{}, {N, K, 1}); - auto stride_D = cutlass::make_cute_packed_stride(StrideD{}, {M, N, 1}); - - auto layout_SFA = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA(cute::make_shape(M, N, K, 1)); - auto layout_SFB = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFB(cute::make_shape(M, N, K, 1)); - - typename Gemm::Arguments arguments{ - cutlass::gemm::GemmUniversalMode::kGemm, - {M, N, K, 1}, - {static_cast(A.data_ptr()), - stride_A, - static_cast(B.data_ptr()), - stride_B, - static_cast(A_sf.data_ptr()), - layout_SFA, - static_cast(B_sf.data_ptr()), - layout_SFB}, - {{}, static_cast(D.data_ptr()), stride_D, static_cast(D.data_ptr()), stride_D}}; - auto& fusion_args = arguments.epilogue.thread; - fusion_args.alpha_ptr = static_cast(alpha.data_ptr()); - - return arguments; -} - -// SM120 specific runGemm function -template -void runGemmSm120( - tvm::ffi::TensorView D, - tvm::ffi::TensorView A, - tvm::ffi::TensorView B, - tvm::ffi::TensorView A_sf, - tvm::ffi::TensorView B_sf, - tvm::ffi::TensorView alpha, - int M, - int N, - int K, - cudaStream_t stream) { - Gemm gemm; - - auto arguments = args_from_options_sm120(D, A, B, A_sf, B_sf, alpha, M, N, K); - - size_t workspace_size = Gemm::get_workspace_size(arguments); - int device_id = A.device().device_id; - void* workspace = get_cached_workspace(workspace_size, device_id, stream); - - CUTLASS_CHECK(gemm.can_implement(arguments)); - - CUTLASS_CHECK(gemm.initialize(arguments, workspace, stream)); - - CUTLASS_CHECK(gemm.run(arguments, workspace, stream)); -} - -// Dispatch function to select appropriate config based on M -template -void cutlassFp4GemmDispatch( - tvm::ffi::TensorView D, - tvm::ffi::TensorView A, - tvm::ffi::TensorView B, - tvm::ffi::TensorView A_sf, - tvm::ffi::TensorView B_sf, - tvm::ffi::TensorView alpha, - int64_t m, - int64_t n, - int64_t k, - cudaStream_t stream) { - if (m <= 128) { - // m in [1, 128] - runGemm>>(D, A, B, A_sf, B_sf, alpha, m, n, k, stream); - } else if (m <= 256) { - // m in (128, 256] - runGemm>>(D, A, B, A_sf, B_sf, alpha, m, n, k, stream); - } else { - // m in (256, inf) - runGemm>>(D, A, B, A_sf, B_sf, alpha, m, n, k, stream); - } -} - -// Dispatch function to select appropriate config based on M -template <> -void cutlassFp4GemmDispatch( - tvm::ffi::TensorView D, - tvm::ffi::TensorView A, - tvm::ffi::TensorView B, - tvm::ffi::TensorView A_sf, - tvm::ffi::TensorView B_sf, - tvm::ffi::TensorView alpha, - int64_t m, - int64_t n, - int64_t k, - cudaStream_t stream) { - runGemm>(D, A, B, A_sf, B_sf, alpha, m, n, k, stream); -} - -// SM120 specific dispatch functions -void cutlass_fp4_bf16_gemm_dispatch_sm120( - tvm::ffi::TensorView D, - tvm::ffi::TensorView A, - tvm::ffi::TensorView B, - tvm::ffi::TensorView A_sf, - tvm::ffi::TensorView B_sf, - tvm::ffi::TensorView alpha, - int m, - int n, - int k, - cudaStream_t stream) { - uint32_t const mp2 = std::max(static_cast(16), next_pow_2(m)); - if (mp2 <= 256) { - runGemmSm120::Gemm>( - D, A, B, A_sf, B_sf, alpha, m, n, k, stream); - } else { - runGemmSm120::Gemm>( - D, A, B, A_sf, B_sf, alpha, m, n, k, stream); - } -} - -void cutlass_fp4_f16_gemm_dispatch_sm120( - tvm::ffi::TensorView D, - tvm::ffi::TensorView A, - tvm::ffi::TensorView B, - tvm::ffi::TensorView A_sf, - tvm::ffi::TensorView B_sf, - tvm::ffi::TensorView alpha, - int m, - int n, - int k, - cudaStream_t stream) { - uint32_t const mp2 = std::max(static_cast(16), next_pow_2(m)); - if (mp2 <= 256) { - runGemmSm120::Gemm>( - D, A, B, A_sf, B_sf, alpha, m, n, k, stream); - } else { - runGemmSm120::Gemm>( - D, A, B, A_sf, B_sf, alpha, m, n, k, stream); - } -} - -#else -template -void cutlassFp4GemmDispatch( - tvm::ffi::TensorView D, - tvm::ffi::TensorView A, - tvm::ffi::TensorView B, - tvm::ffi::TensorView A_sf, - tvm::ffi::TensorView B_sf, - tvm::ffi::TensorView alpha, - int64_t m, - int64_t n, - int64_t k, - cudaStream_t stream) { - RuntimeCheck( - false, - "Unsupported CUTLASS version. Set VLLM_CUTLASS_SRC_DIR to " - "a CUTLASS 3.8 source directory to enable support."); -} -#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) || - // defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED) - -inline int getSMVersion(int device_id) { - int sm_major = 0; - int sm_minor = 0; - RuntimeDeviceCheck(cudaDeviceGetAttribute(&sm_major, cudaDevAttrComputeCapabilityMajor, device_id)); - RuntimeDeviceCheck(cudaDeviceGetAttribute(&sm_minor, cudaDevAttrComputeCapabilityMinor, device_id)); - return sm_major * 10 + sm_minor; -} +#include "nvfp4_scaled_mm_common.cuh" +#include "nvfp4_scaled_mm_sm100.cuh" +#include "nvfp4_scaled_mm_sm120.cuh" void cutlass_scaled_fp4_mm_sm100a_sm120a( tvm::ffi::TensorView D, @@ -718,11 +134,11 @@ void cutlass_scaled_fp4_mm_sm100a_sm120a( } } else { if (host::is_type(D.dtype())) { - cutlassFp4GemmDispatch(D, A, B, A_sf, B_sf, alpha, m, n, k, stream); + cutlassFp4GemmDispatchSm100(D, A, B, A_sf, B_sf, alpha, m, n, k, stream); } else if (host::is_type(D.dtype())) { - cutlassFp4GemmDispatch(D, A, B, A_sf, B_sf, alpha, m, n, k, stream); + cutlassFp4GemmDispatchSm100(D, A, B, A_sf, B_sf, alpha, m, n, k, stream); } else if (host::is_type(D.dtype())) { - cutlassFp4GemmDispatch(D, A, B, A_sf, B_sf, alpha, m, n, k, stream); + cutlassFp4GemmDispatchSm100(D, A, B, A_sf, B_sf, alpha, m, n, k, stream); } else { Panic("Unsupported output data type of nvfp4 mm"); } diff --git a/python/sglang/jit_kernel/csrc/gemm/nvfp4/nvfp4_scaled_mm_sm100.cuh b/python/sglang/jit_kernel/csrc/gemm/nvfp4/nvfp4_scaled_mm_sm100.cuh new file mode 100644 index 000000000000..bd5927a23f09 --- /dev/null +++ b/python/sglang/jit_kernel/csrc/gemm/nvfp4/nvfp4_scaled_mm_sm100.cuh @@ -0,0 +1,284 @@ +/* Copyright 2026 SGLang Team. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#pragma once + +#include "nvfp4_scaled_mm_common.cuh" + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +// Config(half_t/bfloat16_t) for M <= 128 +template +struct KernelConfigM128 { + using OutputType = T; + using MmaTileShape = Shape<_128, _256, _256>; + using ClusterShape = Shape; + using EpilogueTile = Shape<_128, _64>; // Avoid register spilling + using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized1Sm; + using MainloopSchedule = cutlass::gemm::KernelTmaWarpSpecialized1SmNvf4Sm100; + const static dim3 preferred_cluster; + const static dim3 fallback_cluster; +}; +template +const dim3 KernelConfigM128::preferred_cluster(1, 4, 1); +template +const dim3 KernelConfigM128::fallback_cluster(1, 2, 1); + +// Config(half_t/bfloat16_t) for M <= 256 +template +struct KernelConfigM256 { + using OutputType = T; + using MmaTileShape = Shape<_256, _256, _256>; + using ClusterShape = Shape; + using EpilogueTile = Shape<_128, _64>; // Avoid register spilling + using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized2Sm; + using MainloopSchedule = cutlass::gemm::KernelTmaWarpSpecialized2SmNvf4Sm100; + const static dim3 preferred_cluster; + const static dim3 fallback_cluster; +}; +template +const dim3 KernelConfigM256::preferred_cluster(2, 4, 1); +template +const dim3 KernelConfigM256::fallback_cluster(2, 1, 1); + +// Default config(half_t/bfloat16_t) for M > 256 +template +struct KernelConfigDefault { + using OutputType = T; + using MmaTileShape = Shape<_256, _256, _256>; + using ClusterShape = Shape; + using EpilogueTile = Shape<_128, _64>; // Avoid register spilling + using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized2Sm; + using MainloopSchedule = cutlass::gemm::KernelTmaWarpSpecialized2SmNvf4Sm100; + const static dim3 preferred_cluster; + const static dim3 fallback_cluster; +}; +template +const dim3 KernelConfigDefault::preferred_cluster(4, 4, 1); +template +const dim3 KernelConfigDefault::fallback_cluster(2, 1, 1); + +struct KernelConfigFp32 { + using OutputType = float; + using MmaTileShape = Shape<_128, _128, _256>; + using ClusterShape = Shape; + using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto; + using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized1Sm; + using MainloopSchedule = cutlass::gemm::KernelTmaWarpSpecialized1SmNvf4Sm100; + const static dim3 preferred_cluster; + const static dim3 fallback_cluster; +}; +const dim3 KernelConfigFp32::preferred_cluster = dim3(1, 4, 1); +const dim3 KernelConfigFp32::fallback_cluster = dim3(1, 2, 1); + +template +struct Fp4GemmSm100 { + using Config = KernelConfig; + using OutputType = typename KernelConfig::OutputType; + + using ElementA = cutlass::nv_float4_t; + using LayoutATag = cutlass::layout::RowMajor; + static constexpr int AlignmentA = 32; + + using ElementB = cutlass::nv_float4_t; + using LayoutBTag = cutlass::layout::ColumnMajor; + static constexpr int AlignmentB = 32; + + using ElementD = OutputType; + using ElementC = OutputType; + using LayoutCTag = cutlass::layout::RowMajor; + using LayoutDTag = cutlass::layout::RowMajor; + static constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value; + static constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; + + using ElementAccumulator = float; + using ArchTag = cutlass::arch::Sm100; + using OperatorClass = cutlass::arch::OpClassBlockScaledTensorOp; + + using MmaTileShape = typename KernelConfig::MmaTileShape; + using ClusterShape = typename KernelConfig::ClusterShape; + using EpilogueTile = typename KernelConfig::EpilogueTile; + using EpilogueSchedule = typename KernelConfig::EpilogueSchedule; + using MainloopSchedule = typename KernelConfig::MainloopSchedule; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, + OperatorClass, + MmaTileShape, + ClusterShape, + EpilogueTile, + ElementAccumulator, + ElementAccumulator, + void, + LayoutCTag, + AlignmentC, + ElementD, + LayoutDTag, + AlignmentD, + EpilogueSchedule, + cutlass::epilogue::fusion::LinearCombination>::CollectiveOp; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, + OperatorClass, + ElementA, + LayoutATag, + AlignmentA, + ElementB, + LayoutBTag, + AlignmentB, + ElementAccumulator, + MmaTileShape, + ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout( + sizeof(typename CollectiveEpilogue::SharedStorage))>, + MainloopSchedule>::CollectiveOp; + + using GemmKernel = + cutlass::gemm::kernel::GemmUniversal, CollectiveMainloop, CollectiveEpilogue, void>; + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + using StrideA = typename Gemm::GemmKernel::StrideA; + using LayoutA = decltype(cute::make_layout(make_shape(0, 0, 0), StrideA{})); + using LayoutSFA = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFA; + using StrideB = typename Gemm::GemmKernel::StrideB; + using LayoutB = decltype(cute::make_layout(make_shape(0, 0, 0), StrideB{})); + using LayoutSFB = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFB; + using StrideC = typename Gemm::GemmKernel::StrideC; + using LayoutC = decltype(cute::make_layout(make_shape(0, 0, 0), StrideC{})); + using StrideD = typename Gemm::GemmKernel::StrideD; + using LayoutD = decltype(cute::make_layout(make_shape(0, 0, 0), StrideD{})); +}; + +template +typename T::Gemm::Arguments args_from_options( + tvm::ffi::TensorView D, + tvm::ffi::TensorView A, + tvm::ffi::TensorView B, + tvm::ffi::TensorView A_sf, + tvm::ffi::TensorView B_sf, + tvm::ffi::TensorView alpha, + int64_t M, + int64_t N, + int64_t K) { + using ElementA = typename T::Gemm::ElementA; + using ElementB = typename T::Gemm::ElementB; + using ElementSFA = cutlass::float_ue4m3_t; + using ElementSFB = cutlass::float_ue4m3_t; + using ElementD = typename T::Gemm::ElementD; + using ElementCompute = float; + using StrideA = typename T::StrideA; + using StrideB = typename T::StrideB; + using StrideD = typename T::StrideD; + using Sm1xxBlkScaledConfig = typename T::Gemm::GemmKernel::CollectiveMainloop::Sm1xxBlkScaledConfig; + + int m = static_cast(M); + int n = static_cast(N); + int k = static_cast(K); + auto stride_A = cutlass::make_cute_packed_stride(StrideA{}, {m, k, 1}); + auto stride_B = cutlass::make_cute_packed_stride(StrideB{}, {n, k, 1}); + auto stride_D = cutlass::make_cute_packed_stride(StrideD{}, {m, n, 1}); + + auto layout_SFA = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA(cute::make_shape(m, n, k, 1)); + auto layout_SFB = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFB(cute::make_shape(m, n, k, 1)); + + typename T::Gemm::Arguments arguments{ + cutlass::gemm::GemmUniversalMode::kGemm, + {m, n, k, 1}, + {// Mainloop arguments + static_cast(A.data_ptr()), + stride_A, + static_cast(B.data_ptr()), + stride_B, + static_cast(A_sf.data_ptr()), + layout_SFA, + static_cast(B_sf.data_ptr()), + layout_SFB}, + { // Epilogue arguments + {}, // epilogue.thread + nullptr, + stride_D, + static_cast(D.data_ptr()), + stride_D}}; + auto& fusion_args = arguments.epilogue.thread; + fusion_args.alpha_ptr = static_cast(alpha.data_ptr()); + using KernelConfig = typename T::Config; + arguments.hw_info.cluster_shape = KernelConfig::preferred_cluster; + arguments.hw_info.cluster_shape_fallback = KernelConfig::fallback_cluster; + return arguments; +} + +template +void runGemm( + tvm::ffi::TensorView D, + tvm::ffi::TensorView A, + tvm::ffi::TensorView B, + tvm::ffi::TensorView A_sf, + tvm::ffi::TensorView B_sf, + tvm::ffi::TensorView alpha, + int64_t m, + int64_t n, + int64_t k, + cudaStream_t stream) { + typename T::Gemm gemm; + auto arguments = args_from_options(D, A, B, A_sf, B_sf, alpha, m, n, k); + + size_t workspace_size = T::Gemm::get_workspace_size(arguments); + auto workspace_tensor = alloc_workspace_tensor(workspace_size, A.device()); + void* workspace = (workspace_size == 0) ? nullptr : workspace_tensor.data_ptr(); + + CUTLASS_CHECK(gemm.can_implement(arguments)); + + CUTLASS_CHECK(gemm.initialize(arguments, workspace, stream)); + + CUTLASS_CHECK(gemm.run(arguments, workspace, stream)); +} + +template +void cutlassFp4GemmDispatchSm100( + tvm::ffi::TensorView D, + tvm::ffi::TensorView A, + tvm::ffi::TensorView B, + tvm::ffi::TensorView A_sf, + tvm::ffi::TensorView B_sf, + tvm::ffi::TensorView alpha, + int64_t m, + int64_t n, + int64_t k, + cudaStream_t stream) { + if (m <= 128) { + runGemm>>(D, A, B, A_sf, B_sf, alpha, m, n, k, stream); + } else if (m <= 256) { + runGemm>>(D, A, B, A_sf, B_sf, alpha, m, n, k, stream); + } else { + runGemm>>(D, A, B, A_sf, B_sf, alpha, m, n, k, stream); + } +} + +template <> +void cutlassFp4GemmDispatchSm100( + tvm::ffi::TensorView D, + tvm::ffi::TensorView A, + tvm::ffi::TensorView B, + tvm::ffi::TensorView A_sf, + tvm::ffi::TensorView B_sf, + tvm::ffi::TensorView alpha, + int64_t m, + int64_t n, + int64_t k, + cudaStream_t stream) { + runGemm>(D, A, B, A_sf, B_sf, alpha, m, n, k, stream); +} + +#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) diff --git a/python/sglang/jit_kernel/csrc/gemm/nvfp4/nvfp4_scaled_mm_sm120.cuh b/python/sglang/jit_kernel/csrc/gemm/nvfp4/nvfp4_scaled_mm_sm120.cuh new file mode 100644 index 000000000000..cdb159061eb9 --- /dev/null +++ b/python/sglang/jit_kernel/csrc/gemm/nvfp4/nvfp4_scaled_mm_sm120.cuh @@ -0,0 +1,228 @@ +/* Copyright 2026 SGLang Team. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#pragma once + +#include "nvfp4_scaled_mm_common.cuh" + +#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED) + +struct sm120_fp4_config_small_m { + using ClusterShape = Shape<_1, _1, _1>; + using MmaTileShape = Shape<_128, _128, _256>; + using PerSmTileShape_MNK = Shape<_128, _128, _256>; +}; + +struct sm120_fp4_config_M256 { + using ClusterShape = Shape<_1, _1, _1>; + using MmaTileShape = Shape<_128, _128, _128>; + using PerSmTileShape_MNK = Shape<_128, _128, _128>; +}; + +struct sm120_fp4_config_default { + using ClusterShape = Shape<_1, _1, _1>; + using MmaTileShape = Shape<_256, _128, _128>; + using PerSmTileShape_MNK = Shape<_256, _128, _128>; +}; + +template +struct Fp4GemmSm120 { + using ElementA = cutlass::nv_float4_t; + using LayoutATag = cutlass::layout::RowMajor; + static constexpr int AlignmentA = 32; + + using ElementB = cutlass::nv_float4_t; + using LayoutBTag = cutlass::layout::ColumnMajor; + static constexpr int AlignmentB = 32; + + using ElementD = OutType; + using ElementC = OutType; + using LayoutCTag = cutlass::layout::RowMajor; + using LayoutDTag = cutlass::layout::RowMajor; + static constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value; + static constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; + + using ElementAccumulator = float; + using ArchTag = cutlass::arch::Sm120; + using OperatorClass = cutlass::arch::OpClassBlockScaledTensorOp; + + using MmaTileShape = typename Config::MmaTileShape; + using ClusterShape = typename Config::ClusterShape; + using PerSmTileShape_MNK = typename Config::PerSmTileShape_MNK; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, + OperatorClass, + PerSmTileShape_MNK, + ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, + ElementAccumulator, + void, + LayoutCTag, + AlignmentC, + ElementD, + LayoutDTag, + AlignmentD, + cutlass::epilogue::collective::EpilogueScheduleAuto>::CollectiveOp; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, + OperatorClass, + ElementA, + LayoutATag, + AlignmentA, + ElementB, + LayoutBTag, + AlignmentB, + ElementAccumulator, + MmaTileShape, + ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout( + sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::collective::KernelScheduleAuto>::CollectiveOp; + + using GemmKernel = + cutlass::gemm::kernel::GemmUniversal, CollectiveMainloop, CollectiveEpilogue, void>; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; +}; + +template +typename Gemm::Arguments args_from_options_sm120( + tvm::ffi::TensorView D, + tvm::ffi::TensorView A, + tvm::ffi::TensorView B, + tvm::ffi::TensorView A_sf, + tvm::ffi::TensorView B_sf, + tvm::ffi::TensorView alpha, + int M, + int N, + int K) { + using ElementA = typename Gemm::ElementA; + using ElementB = typename Gemm::ElementB; + using ElementD = typename Gemm::ElementD; + using ElementSFA = cutlass::float_ue4m3_t; + using ElementSFB = cutlass::float_ue4m3_t; + using ElementCompute = float; + + using StrideA = typename Gemm::GemmKernel::StrideA; + using StrideB = typename Gemm::GemmKernel::StrideB; + using StrideC = typename Gemm::GemmKernel::StrideC; + using StrideD = typename Gemm::GemmKernel::StrideD; + + using Sm1xxBlkScaledConfig = typename Gemm::GemmKernel::CollectiveMainloop::Sm1xxBlkScaledConfig; + + auto stride_A = cutlass::make_cute_packed_stride(StrideA{}, {M, K, 1}); + auto stride_B = cutlass::make_cute_packed_stride(StrideB{}, {N, K, 1}); + auto stride_D = cutlass::make_cute_packed_stride(StrideD{}, {M, N, 1}); + + auto layout_SFA = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA(cute::make_shape(M, N, K, 1)); + auto layout_SFB = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFB(cute::make_shape(M, N, K, 1)); + + typename Gemm::Arguments arguments{ + cutlass::gemm::GemmUniversalMode::kGemm, + {M, N, K, 1}, + {static_cast(A.data_ptr()), + stride_A, + static_cast(B.data_ptr()), + stride_B, + static_cast(A_sf.data_ptr()), + layout_SFA, + static_cast(B_sf.data_ptr()), + layout_SFB}, + {{}, nullptr, stride_D, static_cast(D.data_ptr()), stride_D}}; + auto& fusion_args = arguments.epilogue.thread; + fusion_args.alpha_ptr = static_cast(alpha.data_ptr()); + + return arguments; +} + +template +void runGemmSm120( + tvm::ffi::TensorView D, + tvm::ffi::TensorView A, + tvm::ffi::TensorView B, + tvm::ffi::TensorView A_sf, + tvm::ffi::TensorView B_sf, + tvm::ffi::TensorView alpha, + int M, + int N, + int K, + cudaStream_t stream) { + Gemm gemm; + + auto arguments = args_from_options_sm120(D, A, B, A_sf, B_sf, alpha, M, N, K); + + size_t workspace_size = Gemm::get_workspace_size(arguments); + auto workspace_tensor = alloc_workspace_tensor(workspace_size, A.device()); + void* workspace = (workspace_size == 0) ? nullptr : workspace_tensor.data_ptr(); + + CUTLASS_CHECK(gemm.can_implement(arguments)); + + CUTLASS_CHECK(gemm.initialize(arguments, workspace, stream)); + + CUTLASS_CHECK(gemm.run(arguments, workspace, stream)); +} + +void cutlass_fp4_bf16_gemm_dispatch_sm120( + tvm::ffi::TensorView D, + tvm::ffi::TensorView A, + tvm::ffi::TensorView B, + tvm::ffi::TensorView A_sf, + tvm::ffi::TensorView B_sf, + tvm::ffi::TensorView alpha, + int m, + int n, + int k, + cudaStream_t stream) { + uint32_t const mp2 = std::max(static_cast(16), next_pow_2(m)); + if (mp2 <= 32) { + runGemmSm120::Gemm>( + D, A, B, A_sf, B_sf, alpha, m, n, k, stream); + } else if (mp2 <= 256) { + runGemmSm120::Gemm>( + D, A, B, A_sf, B_sf, alpha, m, n, k, stream); + } else { + runGemmSm120::Gemm>( + D, A, B, A_sf, B_sf, alpha, m, n, k, stream); + } +} + +void cutlass_fp4_f16_gemm_dispatch_sm120( + tvm::ffi::TensorView D, + tvm::ffi::TensorView A, + tvm::ffi::TensorView B, + tvm::ffi::TensorView A_sf, + tvm::ffi::TensorView B_sf, + tvm::ffi::TensorView alpha, + int m, + int n, + int k, + cudaStream_t stream) { + uint32_t const mp2 = std::max(static_cast(16), next_pow_2(m)); + if (mp2 <= 32) { + runGemmSm120::Gemm>( + D, A, B, A_sf, B_sf, alpha, m, n, k, stream); + } else if (mp2 <= 256) { + runGemmSm120::Gemm>( + D, A, B, A_sf, B_sf, alpha, m, n, k, stream); + } else { + runGemmSm120::Gemm>( + D, A, B, A_sf, B_sf, alpha, m, n, k, stream); + } +} + +#endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED) diff --git a/python/sglang/jit_kernel/csrc/hicache.cuh b/python/sglang/jit_kernel/csrc/hicache.cuh index 04f093a02ba6..ae297061136e 100644 --- a/python/sglang/jit_kernel/csrc/hicache.cuh +++ b/python/sglang/jit_kernel/csrc/hicache.cuh @@ -14,6 +14,11 @@ namespace device { namespace details { +template +struct LocalStorage { + T data[N]; +}; + template inline constexpr auto get_mem_package() { if constexpr (kUnit == 16) { @@ -78,7 +83,7 @@ SGL_DEVICE auto load_vec(const void* __restrict__ src) { static_assert(128 % kNumThreads == 0, "kNumThreads must divide 128 bytes"); constexpr uint32_t kLoopCount = kBytes / 128; using Package = details::PackageType<128 / kNumThreads>; - using Storage = AlignedStorage; + using Storage = details::LocalStorage; const auto src_packed = static_cast(src); const auto lane_id = threadIdx.x % kNumThreads; @@ -129,7 +134,13 @@ struct HicacheKernelParams { uint32_t num_layers = 0; // only used in all_layer transfer }; -template +template < + typename T, + int64_t kElementSize, + uint32_t kUnroll, + uint32_t kBlockQuota, + uint32_t kBlockSize, + bool kIsMLA = false> SGL_HICACHE_KERNEL void hicache_transfer_per_layer(const __grid_constant__ HicacheKernelParams params) { using namespace device; static_assert(kBlockSize % kWarpThreads == 0); @@ -151,16 +162,24 @@ SGL_HICACHE_KERNEL void hicache_transfer_per_layer(const __grid_constant__ Hicac const auto pos_dst = static_cast(indices_dst)[i]; const auto src_k = pointer::offset(k_cache_src, pos_src * kv_cache_src_stride); const auto dst_k = pointer::offset(k_cache_dst, pos_dst * kv_cache_dst_stride); - const auto src_v = pointer::offset(v_cache_src, pos_src * kv_cache_src_stride); - const auto dst_v = pointer::offset(v_cache_dst, pos_dst * kv_cache_dst_stride); const auto vec_k = load_vec(src_k); - const auto vec_v = load_vec(src_v); store_vec(dst_k, vec_k); - store_vec(dst_v, vec_v); + if constexpr (!kIsMLA) { + const auto src_v = pointer::offset(v_cache_src, pos_src * kv_cache_src_stride); + const auto dst_v = pointer::offset(v_cache_dst, pos_dst * kv_cache_dst_stride); + const auto vec_v = load_vec(src_v); + store_vec(dst_v, vec_v); + } } } -template +template < + typename T, + int64_t kElementSize, + uint32_t kUnroll, + uint32_t kBlockQuota, + uint32_t kBlockSize, + bool kIsMLA = false> SGL_HICACHE_KERNEL void hicache_transfer_all_layer(const __grid_constant__ HicacheKernelParams params) { using namespace device; using src_ptr_t = const void*; @@ -185,17 +204,19 @@ SGL_HICACHE_KERNEL void hicache_transfer_all_layer(const __grid_constant__ Hicac const auto pos_dst = static_cast(indices_dst)[i]; for (uint32_t layer = 0; layer < num_layers; ++layer) { const auto k_cache_src = static_cast(k_ptr_src)[layer]; - const auto v_cache_src = static_cast(v_ptr_src)[layer]; const auto k_cache_dst = static_cast(k_ptr_dst)[layer]; - const auto v_cache_dst = static_cast(v_ptr_dst)[layer]; const auto src_k = pointer::offset(k_cache_src, pos_src * kv_cache_src_stride); const auto dst_k = pointer::offset(k_cache_dst, pos_dst * kv_cache_dst_stride); - const auto src_v = pointer::offset(v_cache_src, pos_src * kv_cache_src_stride); - const auto dst_v = pointer::offset(v_cache_dst, pos_dst * kv_cache_dst_stride); const auto vec_k = load_vec(src_k); - const auto vec_v = load_vec(src_v); store_vec(dst_k, vec_k); - store_vec(dst_v, vec_v); + if constexpr (!kIsMLA) { + const auto v_cache_src = static_cast(v_ptr_src)[layer]; + const auto v_cache_dst = static_cast(v_ptr_dst)[layer]; + const auto src_v = pointer::offset(v_cache_src, pos_src * kv_cache_src_stride); + const auto dst_v = pointer::offset(v_cache_dst, pos_dst * kv_cache_dst_stride); + const auto vec_v = load_vec(src_v); + store_vec(dst_v, vec_v); + } } } } @@ -206,6 +227,12 @@ struct HiCacheKernel { static constexpr auto kernel_one = hicache_transfer_per_layer; template static constexpr auto kernel_all = hicache_transfer_all_layer; + template + static constexpr auto kernel_one_mla = + hicache_transfer_per_layer; + template + static constexpr auto kernel_all_mla = + hicache_transfer_all_layer; static void run_one( const tvm::ffi::TensorView k_cache_dst, @@ -333,6 +360,119 @@ struct HiCacheKernel { const auto kernel = use_int32 ? kernel_all : kernel_all; LaunchKernel(num_blocks, kBlockSize, device)(kernel, params); } + + static void run_one_mla( + const tvm::ffi::TensorView cache_dst, + const tvm::ffi::TensorView indices_dst, + const tvm::ffi::TensorView cache_src, + const tvm::ffi::TensorView indices_src) { + using namespace host; + + auto D = SymbolicSize{"head dimension"}; + auto N = SymbolicSize{"src stride"}; + auto M = SymbolicSize{"dst stride"}; + auto L = SymbolicSize{"indices length"}; + auto cache_dtype = SymbolicDType{}; + auto indices_dtype = SymbolicDType{}; + auto indices_device = SymbolicDevice{}; + + TensorMatcher({-1, D}) // + .with_strides({N, 1}) + .with_dtype(cache_dtype) + .with_device() + .verify(cache_src); + TensorMatcher({-1, D}) // + .with_strides({M, 1}) + .with_dtype(cache_dtype) + .with_device() + .verify(cache_dst); + TensorMatcher({L}) // + .with_dtype(indices_dtype) + .with_device(indices_device) + .verify(indices_src) + .verify(indices_dst); + + const auto dtype_size = dtype_bytes(cache_dtype.unwrap()); + const auto element_bytes = D.unwrap() * dtype_size; + RuntimeCheck(kElementSize == element_bytes, "HicacheKernel MLA: cache dimension mismatch."); + + const auto cache_dst_ptr = cache_dst.data_ptr(); + const auto cache_src_ptr = cache_src.data_ptr(); + const auto indices_dst_ptr = indices_dst.data_ptr(); + const auto indices_src_ptr = indices_src.data_ptr(); + const auto length = static_cast(L.unwrap()); + const auto cache_src_stride = static_cast(N.unwrap() * dtype_size); + const auto cache_dst_stride = static_cast(M.unwrap() * dtype_size); + const auto use_int32 = indices_dtype.unwrap().bits == 32; + const auto device = indices_device.unwrap(); + + constexpr auto kWorkersPerBlock = kBlockSize / (device::kWarpThreads / kUnroll); + const auto num_blocks = std::min(div_ceil(length, kWorkersPerBlock), kBlockQuota); + const auto params = HicacheKernelParams{ + .k_cache_dst = cache_dst_ptr, + .v_cache_dst = nullptr, + .indices_dst = indices_dst_ptr, + .k_cache_src = cache_src_ptr, + .v_cache_src = nullptr, + .indices_src = indices_src_ptr, + .kv_cache_src_stride = cache_src_stride, + .kv_cache_dst_stride = cache_dst_stride, + .length = length, + }; + const auto kernel = use_int32 ? kernel_one_mla : kernel_one_mla; + LaunchKernel(num_blocks, kBlockSize, device)(kernel, params); + } + + static void run_all_mla( + const tvm::ffi::TensorView ptr_dst, + const tvm::ffi::TensorView indices_dst, + const tvm::ffi::TensorView ptr_src, + const tvm::ffi::TensorView indices_src, + const int64_t src_stride_bytes, + const int64_t dst_stride_bytes) { + using namespace host; + + auto N = SymbolicSize{"num_layers"}; + auto L = SymbolicSize{"indices length"}; + auto dtype_ = SymbolicDType{}; + auto device_ = SymbolicDevice{}; + + TensorMatcher({N}) // + .with_dtype() + .with_device(device_) + .verify(ptr_src) + .verify(ptr_dst); + TensorMatcher({L}) // + .with_dtype(dtype_) + .with_device(device_) + .verify(indices_src) + .verify(indices_dst); + + const auto cache_dst_ptr = ptr_dst.data_ptr(); + const auto cache_src_ptr = ptr_src.data_ptr(); + const auto indices_dst_ptr = indices_dst.data_ptr(); + const auto indices_src_ptr = indices_src.data_ptr(); + const auto length = static_cast(L.unwrap()); + const auto use_int32 = dtype_.unwrap().bits == 32; + const auto device = device_.unwrap(); + + constexpr auto kWorkersPerBlock = kBlockSize / (device::kWarpThreads / kUnroll); + const auto num_blocks = std::min(div_ceil(length, kWorkersPerBlock), kBlockQuota); + const auto params = HicacheKernelParams{ + .k_cache_dst = cache_dst_ptr, + .v_cache_dst = nullptr, + .indices_dst = indices_dst_ptr, + .k_cache_src = cache_src_ptr, + .v_cache_src = nullptr, + .indices_src = indices_src_ptr, + .kv_cache_src_stride = src_stride_bytes, + .kv_cache_dst_stride = dst_stride_bytes, + .length = length, + .num_layers = static_cast(N.unwrap()), + }; + const auto kernel = use_int32 ? kernel_all_mla : kernel_all_mla; + LaunchKernel(num_blocks, kBlockSize, device)(kernel, params); + } }; #undef SGL_HICACHE_KERNEL diff --git a/python/sglang/jit_kernel/csrc/hisparse.cuh b/python/sglang/jit_kernel/csrc/hisparse.cuh index 3cf12178f243..2919b59ba6a4 100644 --- a/python/sglang/jit_kernel/csrc/hisparse.cuh +++ b/python/sglang/jit_kernel/csrc/hisparse.cuh @@ -53,11 +53,24 @@ __device__ __forceinline__ int warp_inclusive_scan(int* s_data, int lane_id, int return accumulator; } +// Shared memory size calculation for dynamic allocation. +// Layout: int32_t region (4-byte aligned) followed by int16_t region (2-byte aligned). +template +struct SmemLayout { + static constexpr int HASH_SIZE = NUM_TOP_K * 2; + static constexpr int NUM_BUFFER_CHUNKS = (HOT_BUFFER_SIZE + WARP_SIZE - 1) / WARP_SIZE; + // int32_t region: top_k_tokens + chunk_offset + evict_chunk_offset + hash_keys + total_hits + newest_hit + static constexpr int TOTAL_INT32 = NUM_TOP_K + (NUM_BUFFER_CHUNKS + 1) + (NUM_BUFFER_CHUNKS + 1) + HASH_SIZE + 2; + // int16_t region: lru_slots_out + hash_vals + static constexpr int TOTAL_INT16 = HOT_BUFFER_SIZE + HASH_SIZE; + static constexpr size_t BYTES = TOTAL_INT32 * sizeof(int32_t) + TOTAL_INT16 * sizeof(int16_t); +}; + // Each block processes one request -// req_pool_indices are int64_t (pool indices can be large), seq_lens are int32_t +// req_pool_indices are int64_t (pool indices can be large), seq_lens can be int32_t or int64_t // Layout: [HOT_BUFFER_SIZE slots for LRU] + [page_size slots for newest token] // newest_slot is at HOT_BUFFER_SIZE (first position of extra page) -template +template __global__ void load_cache_to_device_buffer_kernel( const int32_t* __restrict__ top_k_tokens, int32_t* __restrict__ device_buffer_tokens, @@ -69,7 +82,7 @@ __global__ void load_cache_to_device_buffer_kernel( void* __restrict__ device_buffer_v, int32_t* __restrict__ top_k_device_locs, const int64_t* __restrict__ req_pool_indices, - const int32_t* __restrict__ seq_lens, + const SeqLensT* __restrict__ seq_lens, int16_t* __restrict__ lru_slots, const int32_t* __restrict__ num_real_reqs, int64_t buffer_stride_0, @@ -118,21 +131,29 @@ __global__ void load_cache_to_device_buffer_kernel( return; } + // Dynamic shared memory layout: int32_t arrays first, then int16_t arrays. + extern __shared__ char smem_raw[]; + using Layout = SmemLayout; + constexpr int HASH_SIZE = Layout::HASH_SIZE; + + int32_t* smem_i32 = reinterpret_cast(smem_raw); // Top-k token positions; reused as miss-token scratch in the copy phase - __shared__ int32_t s_top_k_tokens[NUM_TOP_K]; + int32_t* s_top_k_tokens = smem_i32; // Prefix-sum offsets for hit counting and miss counting - __shared__ int32_t s_chunk_offset[NUM_BUFFER_CHUNKS + 1]; + int32_t* s_chunk_offset = s_top_k_tokens + NUM_TOP_K; // Prefix-sum offsets for evictable counting - __shared__ int32_t s_evict_chunk_offset[NUM_BUFFER_CHUNKS + 1]; + int32_t* s_evict_chunk_offset = s_chunk_offset + (NUM_BUFFER_CHUNKS + 1); + // Open-addressing hash table: top-k token_id → top-k index (keys) + int32_t* s_hash_keys = s_evict_chunk_offset + (NUM_BUFFER_CHUNKS + 1); + // Scalar counters + int32_t& s_total_hits = s_hash_keys[HASH_SIZE]; + int32_t& s_newest_hit = s_hash_keys[HASH_SIZE + 1]; + + int16_t* smem_i16 = reinterpret_cast(smem_i32 + Layout::TOTAL_INT32); // Compacted slot ordering: [hits fwd→ ... ←evictables bwd] - __shared__ int16_t s_lru_slots_out[HOT_BUFFER_SIZE]; - // Open-addressing hash table: top-k token_id → top-k index - constexpr int HASH_SIZE = NUM_TOP_K * 2; - __shared__ int32_t s_hash_keys[HASH_SIZE]; - __shared__ int16_t s_hash_vals[HASH_SIZE]; - - __shared__ int32_t s_total_hits; - __shared__ int32_t s_newest_hit; + int16_t* s_lru_slots_out = smem_i16; + // Open-addressing hash table: top-k token_id → top-k index (values) + int16_t* s_hash_vals = s_lru_slots_out + HOT_BUFFER_SIZE; // Initialize shared memory: counters, hash table, prefix-sum offsets. if (tid == 0) { @@ -363,28 +384,47 @@ void load_cache_to_device_buffer( const int64_t top_k_device_locs_stride = top_k_device_locs.strides()[0]; const auto device = LaunchKernel::resolve_device(top_k_tokens.device()); - LaunchKernel(bs, BLOCK_SIZE, device)( - load_cache_to_device_buffer_kernel, - static_cast(top_k_tokens.data_ptr()), - static_cast(device_buffer_tokens.data_ptr()), - static_cast(host_cache_locs.data_ptr()), - static_cast(device_buffer_locs.data_ptr()), - host_cache_k.data_ptr(), - (IsMLA || host_cache_v.ndim() == 0) ? (const void*)nullptr : host_cache_v.data_ptr(), - device_buffer_k.data_ptr(), - (IsMLA || device_buffer_v.ndim() == 0) ? (void*)nullptr : device_buffer_v.data_ptr(), - static_cast(top_k_device_locs.data_ptr()), - static_cast(req_pool_indices.data_ptr()), - static_cast(seq_lens.data_ptr()), - static_cast(lru_slots.data_ptr()), - static_cast(num_real_reqs.data_ptr()), - buffer_stride_0, - host_stride, - lru_slot_stride_0, - top_k_tokens_stride, - top_k_device_locs_stride, - page_size, - item_size_bytes); + // Generic lambda: both int32 and int64 kernel variants are compiled; + // the correct one is selected at runtime based on seq_lens dtype. + auto launch = [&](auto kernel_fn, const auto* seq_lens_ptr) { + constexpr size_t smem_bytes = SmemLayout::BYTES; + if constexpr (smem_bytes > 48u * 1024u) { + cudaFuncSetAttribute(kernel_fn, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_bytes); + } + LaunchKernel(bs, BLOCK_SIZE, device, smem_bytes)( + kernel_fn, + static_cast(top_k_tokens.data_ptr()), + static_cast(device_buffer_tokens.data_ptr()), + static_cast(host_cache_locs.data_ptr()), + static_cast(device_buffer_locs.data_ptr()), + host_cache_k.data_ptr(), + (IsMLA || host_cache_v.ndim() == 0) ? (const void*)nullptr : host_cache_v.data_ptr(), + device_buffer_k.data_ptr(), + (IsMLA || device_buffer_v.ndim() == 0) ? (void*)nullptr : device_buffer_v.data_ptr(), + static_cast(top_k_device_locs.data_ptr()), + static_cast(req_pool_indices.data_ptr()), + seq_lens_ptr, + static_cast(lru_slots.data_ptr()), + static_cast(num_real_reqs.data_ptr()), + buffer_stride_0, + host_stride, + lru_slot_stride_0, + top_k_tokens_stride, + top_k_device_locs_stride, + page_size, + item_size_bytes); + }; + + const auto dtype = seq_lens.dtype(); + if (dtype.code == kDLInt && dtype.bits == 64) { + launch( + load_cache_to_device_buffer_kernel, + static_cast(seq_lens.data_ptr())); + } else { + launch( + load_cache_to_device_buffer_kernel, + static_cast(seq_lens.data_ptr())); + } } } // namespace diff --git a/python/sglang/srt/speculative/cpp_ngram/ngram.cpp b/python/sglang/jit_kernel/csrc/ngram_corpus/ngram.cpp similarity index 67% rename from python/sglang/srt/speculative/cpp_ngram/ngram.cpp rename to python/sglang/jit_kernel/csrc/ngram_corpus/ngram.cpp index b1d54b964400..0b90fa812f34 100644 --- a/python/sglang/srt/speculative/cpp_ngram/ngram.cpp +++ b/python/sglang/jit_kernel/csrc/ngram_corpus/ngram.cpp @@ -1,11 +1,10 @@ #include "ngram.h" +#include "trie.h" #include #include #include -#include "trie.h" - namespace ngram { Ngram::Ngram(size_t capacity, const Param& param) : param_(param) { @@ -13,23 +12,6 @@ Ngram::Ngram(size_t capacity, const Param& param) : param_(param) { throw std::runtime_error( "param_.max_trie_depth must be greater than 1, current value: " + std::to_string(param_.max_trie_depth)); } - if (!(param_.min_match_window_size > 0)) { - throw std::runtime_error( - "min_match_window_size must be greater than 0, current value: " + std::to_string(param_.min_match_window_size)); - } - if (!(param_.min_match_window_size <= param_.max_match_window_size)) { - throw std::runtime_error( - "min_match_window_size must be less than or equal to " - "max_match_window_size, current min_match_window_size: " + - std::to_string(param_.min_match_window_size) + - ", max_match_window_size: " + std::to_string(param_.max_match_window_size)); - } - if (!(param_.max_match_window_size < param_.max_trie_depth)) { - throw std::runtime_error( - "max_match_window_size must be less than max_trie_depth, current " - "max_match_window_size: " + - std::to_string(param_.max_match_window_size) + ", max_trie_depth: " + std::to_string(param_.max_trie_depth)); - } if (!(param_.min_bfs_breadth > 0)) { throw std::runtime_error( "min_bfs_breadth must be greater than 0, current value: " + std::to_string(param_.min_bfs_breadth)); @@ -53,20 +35,6 @@ Ngram::Ngram(size_t capacity, const Param& param) : param_(param) { } } } - for (auto config : param_.batch_min_match_window_size) { - if (config != std::numeric_limits::max()) { - if (!(config >= param_.min_match_window_size)) { - throw std::runtime_error( - "batch_min_match_window_size config value " + std::to_string(config) + - " must be greater than or equal to min_match_window_size: " + std::to_string(param_.min_match_window_size)); - } - if (!(config <= param_.max_match_window_size)) { - throw std::runtime_error( - "batch_min_match_window_size config value " + std::to_string(config) + - " must be less than or equal to max_match_window_size: " + std::to_string(param_.max_match_window_size)); - } - } - } trie_ = std::make_unique(capacity, param_); diff --git a/python/sglang/srt/speculative/cpp_ngram/ngram.h b/python/sglang/jit_kernel/csrc/ngram_corpus/ngram.h similarity index 99% rename from python/sglang/srt/speculative/cpp_ngram/ngram.h rename to python/sglang/jit_kernel/csrc/ngram_corpus/ngram.h index 377b481ae3fe..fb1461d9ac92 100644 --- a/python/sglang/srt/speculative/cpp_ngram/ngram.h +++ b/python/sglang/jit_kernel/csrc/ngram_corpus/ngram.h @@ -1,5 +1,9 @@ #pragma once +#include "param.h" +#include "queue.h" +#include "result.h" +#include "trie.h" #include #include #include @@ -8,11 +12,6 @@ #include #include -#include "param.h" -#include "queue.h" -#include "result.h" -#include "trie.h" - namespace ngram { class Ngram { diff --git a/python/sglang/jit_kernel/csrc/ngram_corpus/ngram_corpus_ffi.cpp b/python/sglang/jit_kernel/csrc/ngram_corpus/ngram_corpus_ffi.cpp new file mode 100644 index 000000000000..e1797e1fc0f3 --- /dev/null +++ b/python/sglang/jit_kernel/csrc/ngram_corpus/ngram_corpus_ffi.cpp @@ -0,0 +1,104 @@ +#pragma once + +#include +#include + +#include + +#include "ngram.h" +#include +#include +#include +#include +#include + +struct NgramCorpusObj : public tvm::ffi::Object { + public: + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("sgl.NgramCorpus", NgramCorpusObj, tvm::ffi::Object); + static constexpr bool _type_mutable = true; + + NgramCorpusObj( + int64_t capacity, + int64_t max_trie_depth, + int64_t min_bfs_breadth, + int64_t max_bfs_breadth, + int64_t draft_token_num, + int64_t match_type) { + ngram::Param param; + param.enable = true; + param.enable_router_mode = false; + param.max_trie_depth = static_cast(max_trie_depth); + param.min_bfs_breadth = static_cast(min_bfs_breadth); + param.max_bfs_breadth = static_cast(max_bfs_breadth); + param.draft_token_num = static_cast(draft_token_num); + param.match_type = (match_type == 0) ? "BFS" : "PROB"; + ngram_ = std::make_unique(static_cast(capacity), param); + } + + void async_insert(const tvm::ffi::TensorView tokens_flat, const tvm::ffi::TensorView offsets) { + auto* data = static_cast(tokens_flat.data_ptr()); + auto* offs = static_cast(offsets.data_ptr()); + int64_t batch_size = offsets.size(0) - 1; + + std::vector> tokens(batch_size); + for (int64_t i = 0; i < batch_size; ++i) { + tokens[i].assign(data + offs[i], data + offs[i + 1]); + } + ngram_->asyncInsert(std::move(tokens)); + } + + void batch_match( + const tvm::ffi::TensorView tokens_flat, + const tvm::ffi::TensorView offsets, + const tvm::ffi::TensorView out_tokens, + const tvm::ffi::TensorView out_mask) { + auto* data = static_cast(tokens_flat.data_ptr()); + auto* offs = static_cast(offsets.data_ptr()); + int64_t batch_size = offsets.size(0) - 1; + + std::vector> tokens(batch_size); + for (int64_t i = 0; i < batch_size; ++i) { + tokens[i].assign(data + offs[i], data + offs[i + 1]); + } + + auto result = ngram_->batchMatch(tokens); + + auto* out_tok = static_cast(out_tokens.data_ptr()); + auto* out_msk = static_cast(out_mask.data_ptr()); + if (result.token.size() > static_cast(out_tokens.size(0))) { + throw std::runtime_error( + "out_tokens buffer too small: " + std::to_string(out_tokens.size(0)) + " < " + + std::to_string(result.token.size())); + } + if (result.mask.size() > static_cast(out_mask.size(0))) { + throw std::runtime_error( + "out_mask buffer too small: " + std::to_string(out_mask.size(0)) + " < " + + std::to_string(result.mask.size())); + } + std::memcpy(out_tok, result.token.data(), result.token.size() * sizeof(int32_t)); + std::memcpy(out_msk, result.mask.data(), result.mask.size() * sizeof(uint8_t)); + } + + void synchronize() { + ngram_->synchronize(); + } + + void reset() { + ngram_->reset(); + } + + private: + std::unique_ptr ngram_; +}; + +void register_ngram_corpus() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def(refl::init(), "__init__") + .def("async_insert", &NgramCorpusObj::async_insert) + .def("batch_match", &NgramCorpusObj::batch_match) + .def("synchronize", &NgramCorpusObj::synchronize) + .def("reset", &NgramCorpusObj::reset); +} + +TVM_FFI_DLL_EXPORT_TYPED_FUNC(register_once, register_ngram_corpus); diff --git a/python/sglang/srt/speculative/cpp_ngram/param.h b/python/sglang/jit_kernel/csrc/ngram_corpus/param.h similarity index 78% rename from python/sglang/srt/speculative/cpp_ngram/param.h rename to python/sglang/jit_kernel/csrc/ngram_corpus/param.h index d31af64ba5b9..725f635db8cd 100644 --- a/python/sglang/srt/speculative/cpp_ngram/param.h +++ b/python/sglang/jit_kernel/csrc/ngram_corpus/param.h @@ -17,13 +17,10 @@ struct Param { bool enable_router_mode; size_t min_bfs_breadth; size_t max_bfs_breadth; - size_t min_match_window_size; - size_t max_match_window_size; size_t max_trie_depth; size_t draft_token_num; std::string match_type; - std::vector batch_min_match_window_size; std::vector batch_draft_token_num; size_t get_draft_token_num(size_t batch_size) const { @@ -36,16 +33,6 @@ struct Param { return draft_token_num - 1; } - size_t get_min_match_window_size(size_t batch_size) const { - if (batch_size < batch_min_match_window_size.size()) { - if (batch_min_match_window_size[batch_size] != - std::numeric_limits::max()) { - return batch_min_match_window_size[batch_size]; - } - } - return min_match_window_size; - } - std::vector parse(const std::string& value) { // 0-1|10,2-3|20, std::vector result; @@ -96,10 +83,6 @@ struct Param { return result; } - void resetBatchMinMatchWindowSize(const std::string& value) { - batch_min_match_window_size = parse(value); - } - void resetBatchReturnTokenNum(const std::string& value) { batch_draft_token_num = parse(value); } @@ -108,13 +91,8 @@ struct Param { std::stringstream ss; ss << "enable = " << enable << ", enable_router_mode = " << enable_router_mode << ", min_bfs_breadth = " << min_bfs_breadth << ", max_bfs_breadth = " << max_bfs_breadth - << ", min_match_window_size = " << min_match_window_size << ", max_match_window_size = " << max_match_window_size << ", max_trie_depth = " << max_trie_depth << ", draft_token_num = " << draft_token_num << ", match_type = " << match_type; - ss << ", batch_min_match_window_size(" << batch_min_match_window_size.size() << ") = "; - for (int i = 0; i < batch_min_match_window_size.size(); ++i) { - ss << i << "|" << batch_min_match_window_size[i] << ","; - } ss << ", batch_draft_token_num(" << batch_draft_token_num.size() << ") = "; for (int i = 0; i < batch_draft_token_num.size(); ++i) { ss << i << "|" << batch_draft_token_num[i] << ","; diff --git a/python/sglang/srt/speculative/cpp_ngram/queue.h b/python/sglang/jit_kernel/csrc/ngram_corpus/queue.h similarity index 100% rename from python/sglang/srt/speculative/cpp_ngram/queue.h rename to python/sglang/jit_kernel/csrc/ngram_corpus/queue.h diff --git a/python/sglang/srt/speculative/cpp_ngram/result.cpp b/python/sglang/jit_kernel/csrc/ngram_corpus/result.cpp similarity index 100% rename from python/sglang/srt/speculative/cpp_ngram/result.cpp rename to python/sglang/jit_kernel/csrc/ngram_corpus/result.cpp diff --git a/python/sglang/srt/speculative/cpp_ngram/result.h b/python/sglang/jit_kernel/csrc/ngram_corpus/result.h similarity index 100% rename from python/sglang/srt/speculative/cpp_ngram/result.h rename to python/sglang/jit_kernel/csrc/ngram_corpus/result.h diff --git a/python/sglang/srt/speculative/cpp_ngram/trie.cpp b/python/sglang/jit_kernel/csrc/ngram_corpus/trie.cpp similarity index 86% rename from python/sglang/srt/speculative/cpp_ngram/trie.cpp rename to python/sglang/jit_kernel/csrc/ngram_corpus/trie.cpp index 8d9eec82b97e..67058eccb589 100644 --- a/python/sglang/srt/speculative/cpp_ngram/trie.cpp +++ b/python/sglang/jit_kernel/csrc/ngram_corpus/trie.cpp @@ -19,7 +19,7 @@ Trie::Trie(size_t capacity, const Param& param) : param_(param) { } void Trie::insert(const int32_t* tokens, size_t len) { - for (size_t i = 0; i + param_.min_match_window_size < len; ++i) { + for (size_t i = 0; i < len; ++i) { auto start = tokens + i; auto end = start + std::min(len - i, param_.max_trie_depth); @@ -100,14 +100,13 @@ void Trie::reset() { root_ = getNode(); } -std::vector> -Trie::match(const int32_t* context, size_t len, size_t min_window, size_t max_window) const { +std::vector> Trie::match(const int32_t* context, size_t len) const { std::vector> result; - result.reserve(max_window - min_window); - for (int32_t match_window_size = std::min(len, max_window); match_window_size >= static_cast(min_window); - --match_window_size) { - auto start = context + len - match_window_size; - auto end = start + match_window_size; + const auto max_match_depth = std::min(len, param_.max_trie_depth); + result.reserve(max_match_depth); + for (size_t match_depth = max_match_depth; match_depth > 0; --match_depth) { + auto start = context + len - match_depth; + auto end = start + match_depth; auto cursor = root_; while (start != end) { auto iter = cursor->child.find(*start); @@ -118,8 +117,8 @@ Trie::match(const int32_t* context, size_t len, size_t min_window, size_t max_wi ++start; cursor = iter->second; } - if (cursor) { - result.emplace_back(std::make_pair(cursor, match_window_size)); + if (cursor != nullptr && !cursor->child.empty()) { + result.emplace_back(cursor, static_cast(match_depth)); } } return result; @@ -127,10 +126,10 @@ Trie::match(const int32_t* context, size_t len, size_t min_window, size_t max_wi Result Trie::buildRecency( const int32_t* context, size_t len, int32_t last_token, size_t draft_token_num, const Param& param) const { - auto anchors = match(context, len, param.min_match_window_size, param.max_match_window_size); + auto anchors = match(context, len); - double bfs_breadth_scale = double(param.max_bfs_breadth - param.min_bfs_breadth) / - (param.max_match_window_size - param.min_match_window_size + 1); + const auto max_match_depth = std::max(1, static_cast(param.max_trie_depth - 1)); + double bfs_breadth_scale = double(param.max_bfs_breadth - param.min_bfs_breadth) / max_match_depth; std::vector tree(draft_token_num + 1); int root = 0; @@ -138,7 +137,7 @@ Result Trie::buildRecency( for (auto [node, depth] : anchors) { std::queue> queue; - queue.push({root, (param.max_match_window_size - depth) * bfs_breadth_scale + param.min_bfs_breadth, node}); + queue.push({root, (max_match_depth - depth) * bfs_breadth_scale + param.min_bfs_breadth, node}); while (queue.size() && cursor <= static_cast(draft_token_num)) { auto front = queue.front(); queue.pop(); @@ -168,7 +167,7 @@ Result Trie::buildRecency( Result Trie::buildFrequency( const int32_t* context, size_t len, int32_t last_token, size_t draft_token_num, const Param& param) const { - auto anchors = match(context, len, param.min_match_window_size, param.max_match_window_size); + auto anchors = match(context, len); struct CompareByLastDouble { bool operator()( diff --git a/python/sglang/srt/speculative/cpp_ngram/trie.h b/python/sglang/jit_kernel/csrc/ngram_corpus/trie.h similarity index 92% rename from python/sglang/srt/speculative/cpp_ngram/trie.h rename to python/sglang/jit_kernel/csrc/ngram_corpus/trie.h index 30db5b29400c..bd555597dd46 100644 --- a/python/sglang/srt/speculative/cpp_ngram/trie.h +++ b/python/sglang/jit_kernel/csrc/ngram_corpus/trie.h @@ -1,5 +1,7 @@ #pragma once +#include "param.h" +#include "result.h" #include #include #include @@ -10,9 +12,6 @@ #include #include -#include "param.h" -#include "result.h" - namespace ngram { struct TrieNode { @@ -49,8 +48,7 @@ class Trie { void reset(); private: - std::vector> - match(const int32_t* context, size_t len, size_t min_window, size_t max_window) const; + std::vector> match(const int32_t* context, size_t len) const; TrieNode* getNode() { auto node = node_pool_[--free_node_count_]; diff --git a/python/sglang/jit_kernel/diffusion/triton/npu_fallback.py b/python/sglang/jit_kernel/diffusion/triton/npu_fallback.py index 9cb65a9d0b2f..534ed47eac67 100644 --- a/python/sglang/jit_kernel/diffusion/triton/npu_fallback.py +++ b/python/sglang/jit_kernel/diffusion/triton/npu_fallback.py @@ -1,4 +1,8 @@ import torch +import torch_npu + +NPU_ROTARY_MUL_MAX_NUM_HEADS = 1000 +NPU_ROTARY_MUL_MAX_HEAD_SIZE = 896 # TODO: remove this when triton ascend bug is fixed @@ -18,6 +22,23 @@ def apply_rotary_embedding_native( ) -> torch.Tensor: cos = cos.unsqueeze(-2).to(x.dtype) sin = sin.unsqueeze(-2).to(x.dtype) + + if ( + cos.dim() == 3 + and x.dim() == 3 + and x.shape[1] < NPU_ROTARY_MUL_MAX_NUM_HEADS + and x.shape[2] < NPU_ROTARY_MUL_MAX_HEAD_SIZE + ): + if cos.size(-1) * 2 == x.size(-1): + cos = torch.cat([cos, cos], dim=-1) + sin = torch.cat([sin, sin], dim=-1) + cos = cos.unsqueeze(0) + sin = sin.unsqueeze(0) + x = x.unsqueeze(0) + x_embed = torch_npu.npu_rotary_mul(x, cos, sin) + x_embed = x_embed.squeeze(0) + return x_embed + x1 = x[..., ::2] x2 = x[..., 1::2] o1 = x1 * cos - x2 * sin diff --git a/python/sglang/jit_kernel/diffusion/triton/scale_shift.py b/python/sglang/jit_kernel/diffusion/triton/scale_shift.py index 9768b06c34db..1c9ca007d1ec 100644 --- a/python/sglang/jit_kernel/diffusion/triton/scale_shift.py +++ b/python/sglang/jit_kernel/diffusion/triton/scale_shift.py @@ -79,13 +79,20 @@ def _fused_layernorm_scale_shift_gate_select01_kernel( shift1_ptrs = shift1_ptr + batch_idx * stride_sh1_b + cols * stride_sh1_c gate1_ptrs = gate1_ptr + batch_idx * stride_g1_b + cols * stride_g1_c - scale_ptrs = tl.where(idx, scale1_ptrs, scale0_ptrs) - shift_ptrs = tl.where(idx, shift1_ptrs, shift0_ptrs) - gate_ptrs = tl.where(idx, gate1_ptrs, gate0_ptrs) - - scale = tl.load(scale_ptrs, mask=mask, other=0.0).to(tl.float32) - shift = tl.load(shift_ptrs, mask=mask, other=0.0).to(tl.float32) - gate = tl.load(gate_ptrs, mask=mask, other=0.0) + # Branch on scalar idx instead of using tl.where on pointers. + # tl.where on pointers triggers an assertion in AMD Triton's + # CanonicalizePointers pass (ConvertArithSelectOp) on gfx950. + # This keeps it at 3 loads (not 6), avoids the pointer-level + # tl.where entirely, and since idx is uniform across all threads + # the branch has no divergence cost. + if idx: + scale = tl.load(scale1_ptrs, mask=mask, other=0.0).to(tl.float32) + shift = tl.load(shift1_ptrs, mask=mask, other=0.0).to(tl.float32) + gate = tl.load(gate1_ptrs, mask=mask, other=0.0) + else: + scale = tl.load(scale0_ptrs, mask=mask, other=0.0).to(tl.float32) + shift = tl.load(shift0_ptrs, mask=mask, other=0.0).to(tl.float32) + gate = tl.load(gate0_ptrs, mask=mask, other=0.0) y = x_hat * (1.0 + scale) + shift tl.store(out_row_ptr + cols, y, mask=mask) @@ -180,13 +187,20 @@ def _fused_residual_layernorm_scale_shift_gate_select01_kernel( shift1_ptrs = shift1_ptr + batch_idx * stride_sh1_b + cols * stride_sh1_c gate1_ptrs = gate1_ptr + batch_idx * stride_g1_b + cols * stride_g1_c - scale_ptrs = tl.where(idx, scale1_ptrs, scale0_ptrs) - shift_ptrs = tl.where(idx, shift1_ptrs, shift0_ptrs) - gate_ptrs = tl.where(idx, gate1_ptrs, gate0_ptrs) - - scale = tl.load(scale_ptrs, mask=mask, other=0.0).to(tl.float32) - shift = tl.load(shift_ptrs, mask=mask, other=0.0).to(tl.float32) - gate = tl.load(gate_ptrs, mask=mask, other=0.0) + # Branch on scalar idx instead of using tl.where on pointers. + # tl.where on pointers triggers an assertion in AMD Triton's + # CanonicalizePointers pass (ConvertArithSelectOp) on gfx950. + # This keeps it at 3 loads (not 6), avoids the pointer-level + # tl.where entirely, and since idx is uniform across all threads + # the branch has no divergence cost. + if idx: + scale = tl.load(scale1_ptrs, mask=mask, other=0.0).to(tl.float32) + shift = tl.load(shift1_ptrs, mask=mask, other=0.0).to(tl.float32) + gate = tl.load(gate1_ptrs, mask=mask, other=0.0) + else: + scale = tl.load(scale0_ptrs, mask=mask, other=0.0).to(tl.float32) + shift = tl.load(shift0_ptrs, mask=mask, other=0.0).to(tl.float32) + gate = tl.load(gate0_ptrs, mask=mask, other=0.0) y = x_hat * (1.0 + scale) + shift tl.store(out_row_ptr + cols, y, mask=mask) diff --git a/python/sglang/jit_kernel/fused_qknorm_rope.py b/python/sglang/jit_kernel/fused_qknorm_rope.py index 92ea1f4350ad..00e872020709 100644 --- a/python/sglang/jit_kernel/fused_qknorm_rope.py +++ b/python/sglang/jit_kernel/fused_qknorm_rope.py @@ -13,17 +13,20 @@ @cache_once -def _jit_fused_qknorm_rope_module(head_dim: int, is_neox: bool) -> Module: - interleave = "false" if is_neox else "true" +def _jit_fused_qknorm_rope_module(head_dim: int, is_neox: bool, yarn: bool) -> Module: return load_jit( "fused_qknorm_rope", head_dim, int(is_neox), + int(yarn), cuda_files=["elementwise/fused_qknorm_rope.cuh"], - cuda_wrappers=[ - ("fused_qk_norm_rope", f"fused_qk_norm_rope<{head_dim}, {interleave}>") + cuda_wrappers=[("fused_qk_norm_rope", "fused_qk_norm_rope")], + extra_cuda_cflags=[ + "--use_fast_math", + f"-DJIT_HEAD_DIM={head_dim}", + f"-DJIT_INTERLEAVE={0 if is_neox else 1}", + f"-DJIT_YARN={1 if yarn else 0}", ], - extra_cuda_cflags=["--use_fast_math"], ) @@ -55,9 +58,9 @@ def fused_qk_norm_rope_out( Matches the call signature of ``sgl_kernel.fused_qk_norm_rope``. Args: - qkv: [num_tokens, (nq+nk+nv)*head_dim] bfloat16 -modified in-place - q_weight: [head_dim] bfloat16 -RMSNorm weights for Q - k_weight: [head_dim] bfloat16 -RMSNorm weights for K + qkv: [num_tokens, (nq+nk+nv)*head_dim] bfloat16 — modified in-place + q_weight: [head_dim] bfloat16 — RMSNorm weights for Q + k_weight: [head_dim] bfloat16 — RMSNorm weights for K position_ids: [num_tokens] int32 num_heads_q: number of query heads num_heads_k: number of key heads @@ -65,14 +68,15 @@ def fused_qk_norm_rope_out( head_dim: head dimension; must be 64, 128, or 256 eps: epsilon for RMSNorm base: RoPE base frequency - is_neox: True ->NeoX style, False ->interleave (GPT-J) style + is_neox: True → NeoX style, False → interleave (GPT-J) style factor: YaRN scaling factor (1.0 = standard RoPE) low: YaRN low-frequency threshold high: YaRN high-frequency threshold attention_factor: scale applied to the rotary component rotary_dim: number of elements per head to apply RoPE to """ - module = _jit_fused_qknorm_rope_module(head_dim, is_neox) + yarn = factor != 1.0 + module = _jit_fused_qknorm_rope_module(head_dim, is_neox, yarn) module.fused_qk_norm_rope( qkv, q_weight, @@ -81,8 +85,10 @@ def fused_qk_norm_rope_out( num_heads_q, num_heads_k, num_heads_v, + head_dim, eps, base, + 1 if is_neox else 0, factor, low, high, @@ -93,13 +99,16 @@ def fused_qk_norm_rope_out( @cache_once def can_use_fused_qk_norm_rope( - head_dim: int, is_neox: bool, dtype: torch.dtype + head_dim: int, is_neox: bool, dtype: torch.dtype, yarn: bool = False ) -> bool: """Return True if the JIT fused QK-Norm + RoPE kernel can be used. Args: head_dim: head dimension; supported values are 64, 128, 256 dtype: tensor dtype; only bfloat16 is supported + yarn: whether YaRN scaling is active (factor != 1.0); prebuilds the + correct kernel variant so no extra JIT compile occurs on the + first real call. """ logger = logging.getLogger(__name__) if head_dim not in (64, 128, 256): @@ -111,7 +120,7 @@ def can_use_fused_qk_norm_rope( logger.warning(f"Unsupported dtype={dtype} for JIT fused_qk_norm_rope kernel") return False try: - _jit_fused_qknorm_rope_module(head_dim, is_neox) + _jit_fused_qknorm_rope_module(head_dim, is_neox, yarn) return True except Exception as e: logger.warning(f"Failed to load JIT fused_qk_norm_rope kernel: {e}") @@ -142,16 +151,16 @@ def fused_qk_norm_rope( Matches the call signature of ``sgl_kernel.fused_qk_norm_rope``. Args: - qkv: [num_tokens, (nq+nk+nv)*head_dim] bfloat16 -modified in-place + qkv: [num_tokens, (nq+nk+nv)*head_dim] bfloat16 — modified in-place num_heads_q: number of query heads num_heads_k: number of key heads num_heads_v: number of value heads head_dim: head dimension; must be 64, 128, or 256 eps: epsilon for RMSNorm - q_weight: [head_dim] bfloat16 -RMSNorm weights for Q - k_weight: [head_dim] bfloat16 -RMSNorm weights for K + q_weight: [head_dim] bfloat16 — RMSNorm weights for Q + k_weight: [head_dim] bfloat16 — RMSNorm weights for K base: RoPE base frequency - is_neox: True ->NeoX style, False ->interleave (GPT-J) style + is_neox: True → NeoX style, False → interleave (GPT-J) style position_ids: [num_tokens] int32 factor: YaRN scaling factor (1.0 = standard RoPE) low: YaRN low-frequency threshold diff --git a/python/sglang/jit_kernel/hicache.py b/python/sglang/jit_kernel/hicache.py index 7a357790147b..0e5ed5802fc2 100644 --- a/python/sglang/jit_kernel/hicache.py +++ b/python/sglang/jit_kernel/hicache.py @@ -28,6 +28,8 @@ def _jit_hicache_module(*, element_size: int, unroll: int, block_quota: int) -> cuda_wrappers=[ ("launch_one", f"&HiCacheKernel<{args}>::run_one"), ("launch_all", f"&HiCacheKernel<{args}>::run_all"), + ("launch_one_mla", f"&HiCacheKernel<{args}>::run_one_mla"), + ("launch_all_mla", f"&HiCacheKernel<{args}>::run_all_mla"), ], ) @@ -139,3 +141,65 @@ def transfer_hicache_all_layer( kv_cache_src_stride_bytes, kv_cache_dst_stride_bytes, ) + + +def transfer_hicache_one_layer_mla( + cache_dst: torch.Tensor, + indices_dst: torch.Tensor, + cache_src: torch.Tensor, + indices_src: torch.Tensor, + *, + element_dim: int | None = None, + unroll: int | None = None, + block_quota: int | None = None, +) -> None: + element_dim = element_dim or cache_dst.size(-1) + cache_src = cache_src.view(-1, element_dim) + cache_dst = cache_dst.view(-1, element_dim) + element_size = element_dim * cache_dst.element_size() + block_quota = block_quota or DEFAULT_BLOCK_QUOTA + unroll = unroll or _default_unroll(element_size) + module = _jit_hicache_module( + element_size=element_size, + unroll=unroll, + block_quota=block_quota, + ) + module.launch_one_mla( + cache_dst, + indices_dst, + cache_src, + indices_src, + ) + + +def transfer_hicache_all_layer_mla( + ptr_dst: torch.Tensor, + indices_dst: torch.Tensor, + ptr_src: torch.Tensor, + indices_src: torch.Tensor, + *, + cache_src_stride_bytes: int, + cache_dst_stride_bytes: int, + element_size: int | None = None, + unroll: int | None = None, + block_quota: int | None = None, +) -> None: + if element_size is None: + assert cache_dst_stride_bytes == cache_src_stride_bytes + element_size = cache_dst_stride_bytes + + block_quota = block_quota or DEFAULT_BLOCK_QUOTA + unroll = unroll or _default_unroll(element_size) + module = _jit_hicache_module( + element_size=element_size, + unroll=unroll, + block_quota=block_quota, + ) + module.launch_all_mla( + ptr_dst, + indices_dst, + ptr_src, + indices_src, + cache_src_stride_bytes, + cache_dst_stride_bytes, + ) diff --git a/python/sglang/jit_kernel/moe_wna16_marlin.py b/python/sglang/jit_kernel/moe_wna16_marlin.py index e9a8cd25372b..0ddd6ef717d5 100644 --- a/python/sglang/jit_kernel/moe_wna16_marlin.py +++ b/python/sglang/jit_kernel/moe_wna16_marlin.py @@ -31,6 +31,24 @@ def _jit_moe_wna16_marlin_module(dtype: torch.dtype) -> Module: ) +@cache_once +def _jit_moe_wna16_marlin_fp4_module(dtype: torch.dtype) -> Module: + """Separate JIT module with NVFP4 (kFE2M1f) kernel instantiations enabled.""" + args = make_cpp_args(dtype) + return load_jit( + "moe_wna16_marlin_fp4", + *args, + cuda_files=["gemm/marlin_moe/moe_wna16_marlin.cuh"], + extra_cuda_cflags=["-DSGL_MOE_MARLIN_FP4"], + cuda_wrappers=[ + ( + "moe_wna16_marlin_gemm", + f"moe_wna16_marlin_gemm<{args}>", + ) + ], + ) + + def _or_empty( t: Optional[torch.Tensor], device: torch.device, dtype: torch.dtype ) -> torch.Tensor: @@ -134,7 +152,11 @@ def moe_wna16_marlin_gemm( b_bias_t = _or_empty(b_bias_or_none, device, a.dtype) global_scale_t = _or_empty(global_scale_or_none, device, a.dtype) - module = _jit_moe_wna16_marlin_module(a.dtype) + is_fp4 = global_scale_or_none is not None and global_scale_or_none.numel() > 0 + if is_fp4: + module = _jit_moe_wna16_marlin_fp4_module(a.dtype) + else: + module = _jit_moe_wna16_marlin_module(a.dtype) module.moe_wna16_marlin_gemm( a, c, diff --git a/python/sglang/jit_kernel/ngram_corpus.py b/python/sglang/jit_kernel/ngram_corpus.py new file mode 100644 index 000000000000..42b6babab0dd --- /dev/null +++ b/python/sglang/jit_kernel/ngram_corpus.py @@ -0,0 +1,88 @@ +from __future__ import annotations + +from typing import List, Tuple + +import numpy as np +import torch +import tvm_ffi + +from sglang.jit_kernel.utils import cache_once, load_jit + +_MATCH_TYPE_MAP = {"BFS": 0, "PROB": 1} + + +def _to_csr(batch_tokens: List[List[int]]) -> Tuple[torch.Tensor, torch.Tensor]: + flat = [] + offsets = [0] + for seq in batch_tokens: + flat.extend(seq) + offsets.append(len(flat)) + tokens_flat = torch.tensor(flat, dtype=torch.int32) + offsets_t = torch.tensor(offsets, dtype=torch.int64) + return tokens_flat, offsets_t + + +@cache_once +def get_ngram_corpus_cls(): + module = load_jit( + "ngram_corpus", + cpp_files=[ + "ngram_corpus/result.cpp", + "ngram_corpus/trie.cpp", + "ngram_corpus/ngram.cpp", + "ngram_corpus/ngram_corpus_ffi.cpp", + ], + header_only=False, + ) + module.register_once() + + @tvm_ffi.register_object("sgl.NgramCorpus") + class NgramCorpusFFI(tvm_ffi.Object): + __slots__ = ("__dict__",) + + def __init__( + self, + capacity: int, + max_trie_depth: int, + min_bfs_breadth: int, + max_bfs_breadth: int, + draft_token_num: int, + match_type: str, + ) -> None: + mt = _MATCH_TYPE_MAP.get(match_type) + if mt is None: + raise ValueError( + f"Unknown match_type: '{match_type}'. Must be 'BFS' or 'PROB'." + ) + self.__ffi_init__( + capacity, + max_trie_depth, + min_bfs_breadth, + max_bfs_breadth, + draft_token_num, + mt, + ) + self._draft_token_num = draft_token_num + + def insert(self, batch_tokens: List[List[int]]) -> None: + tokens_flat, offsets = _to_csr(batch_tokens) + self.async_insert(tokens_flat, offsets) # type: ignore + + def match( + self, + batch_tokens: List[List[int]], + ) -> Tuple[np.ndarray, np.ndarray]: + tokens_flat, offsets = _to_csr(batch_tokens) + batch_size = len(batch_tokens) + d = self._draft_token_num + + out_tokens = torch.zeros(batch_size * d, dtype=torch.int32) + out_mask = torch.zeros(batch_size * d * d, dtype=torch.uint8) + + self.batch_match(tokens_flat, offsets, out_tokens, out_mask) # type: ignore + + return out_tokens.numpy().astype(np.int64), out_mask.numpy().astype( + np.int64 + ) + + return NgramCorpusFFI diff --git a/python/sglang/jit_kernel/norm.py b/python/sglang/jit_kernel/norm.py index 606358dd1d97..25b4a5f2c1b2 100644 --- a/python/sglang/jit_kernel/norm.py +++ b/python/sglang/jit_kernel/norm.py @@ -32,20 +32,23 @@ def _jit_qknorm_module(head_dim: int, dtype: torch.dtype) -> Module: _RMSNORM_WARP_SIZES = frozenset({64, 128, 256}) -_RMSNORM_MAX_HIDDEN_SIZE = 8192 +_RMSNORM_MAX_HIDDEN_SIZE = 16384 +_RMSNORM_HALF_BLOCK_MIN_SIZE = 2048 -def _is_supported_rmsnorm_hidden_size(hidden_size: int) -> bool: - return hidden_size in _RMSNORM_WARP_SIZES or ( - hidden_size > 256 - and hidden_size % 256 == 0 - and hidden_size <= _RMSNORM_MAX_HIDDEN_SIZE +def _is_supported_rmsnorm_hidden_size(d: int) -> bool: + return d in _RMSNORM_WARP_SIZES or ( + (d > 256 and d % 256 == 0 and d <= 8192) + or (d >= 8192 and d % 512 == 0 and d <= 16384) ) def _rmsnorm_kernel_class(hidden_size: int) -> str: if hidden_size in _RMSNORM_WARP_SIZES: return "RMSNormWarpKernel" + if hidden_size >= _RMSNORM_HALF_BLOCK_MIN_SIZE: + if hidden_size % 512 == 0: + return "RMSNormHalfKernel" return "RMSNormKernel" @@ -118,10 +121,10 @@ def fused_inplace_qknorm( def rmsnorm( input: torch.Tensor, weight: torch.Tensor, - output: Optional[torch.Tensor] = None, + out: Optional[torch.Tensor] = None, eps: float = 1e-6, ) -> None: - output = output if output is not None else input + out = out if out is not None else input hidden_size = input.size(-1) if not _is_supported_rmsnorm_hidden_size(hidden_size): raise RuntimeError( @@ -130,7 +133,7 @@ def rmsnorm( f"(256, {_RMSNORM_MAX_HIDDEN_SIZE}]." ) module = _jit_rmsnorm_module(hidden_size, input.dtype) - module.rmsnorm(input, weight, output, eps) + module.rmsnorm(input, weight, out, eps) @debug_kernel_api diff --git a/python/sglang/jit_kernel/tests/test_cast.py b/python/sglang/jit_kernel/tests/test_cast.py index 6a71dc194214..a63b4023c45f 100644 --- a/python/sglang/jit_kernel/tests/test_cast.py +++ b/python/sglang/jit_kernel/tests/test_cast.py @@ -2,6 +2,10 @@ import torch from sglang.jit_kernel.cast import downcast_fp8 +from sglang.test.ci.ci_register import register_cuda_ci + +register_cuda_ci(est_time=15, suite="stage-b-kernel-unit-1-gpu-large") +register_cuda_ci(est_time=120, suite="nightly-kernel-1-gpu", nightly=True) DTYPES = [torch.bfloat16, torch.float16] diff --git a/python/sglang/jit_kernel/tests/test_custom_all_reduce.py b/python/sglang/jit_kernel/tests/test_custom_all_reduce.py index 365761ddfe88..2d7e0253eb43 100644 --- a/python/sglang/jit_kernel/tests/test_custom_all_reduce.py +++ b/python/sglang/jit_kernel/tests/test_custom_all_reduce.py @@ -16,29 +16,33 @@ import itertools import logging +import multiprocessing as mp import os import subprocess import sys -from typing import Optional +from typing import Dict, Optional, Tuple import pytest import torch import torch.distributed as dist -from tqdm import tqdm import sglang.srt.distributed.parallel_state as ps -from sglang.jit_kernel.all_reduce import AllReduceAlgo +from sglang.jit_kernel.all_reduce import ( + AllReduceAlgo, + _jit_custom_all_reduce_pull_module, + _jit_custom_all_reduce_push_module, +) from sglang.srt.distributed.device_communicators.custom_all_reduce_v2 import ( CustomAllReduceV2, ) from sglang.test.ci.ci_register import register_cuda_ci register_cuda_ci( - est_time=500, + est_time=300, suite="stage-b-kernel-unit-8-gpu-h200", ) register_cuda_ci( - est_time=500, + est_time=300, suite="nightly-kernel-8-gpu-h200", nightly=True, ) @@ -67,7 +71,7 @@ ] USE_GRAPH_OPTIONS = [True, False] TEST_CONFIG = itertools.product(TEST_SIZES, TEST_DTYPES, SHOTS, USE_GRAPH_OPTIONS) -TEST_LAYERS = 2 +TEST_LAYERS = 4 TEST_LOOP = 16 # --------------------------------------------------------------------------- @@ -75,14 +79,13 @@ # --------------------------------------------------------------------------- -def run_torchrun(nproc: int, timeout: int = 300) -> None: +def _run_torchrun(nproc: int, timeout: int = 300) -> None: """Launch this script as a torchrun worker and assert success.""" cmd = [ "torchrun", f"--nproc_per_node={nproc}", __file__, ] - os.environ["DISABLE_PBAR"] = "1" result = subprocess.run( cmd, stdout=subprocess.PIPE, @@ -96,14 +99,37 @@ def run_torchrun(nproc: int, timeout: int = 300) -> None: ) -@pytest.mark.parametrize("nproc", [2, 3, 4, 5, 6, 7, 8]) +def _compile_one(dtype: torch.dtype, world_size: int): + _jit_custom_all_reduce_push_module(dtype, world_size) + _jit_custom_all_reduce_pull_module(dtype, world_size) + + +def _precompile_kernels() -> None: + # NOTE: even when device count < 8, we should be able to compile all + process_map: Dict[Tuple[torch.dtype, int], mp.Process] = {} + COMPILE_SPACE = itertools.product(TEST_DTYPES, [2, 3, 4, 5, 6, 7, 8]) + mp.set_start_method("spawn") + for config in COMPILE_SPACE: + process_map[config] = mp.Process(target=_compile_one, args=config) + for process in process_map.values(): + process.start() + for (dtype, world_size), process in process_map.items(): + process.join() + if process.exitcode != 0: + raise RuntimeError(f"Custom All Reduce {world_size=} {dtype=} failed") + + +@pytest.mark.parametrize("nproc", [1, 2, 3, 4, 5, 6, 7, 8]) def test_custom_allreduce(nproc: int) -> None: + if nproc == 1: # NOTE: special case to speed up tests + return _precompile_kernels() + device_count = torch.cuda.device_count() if device_count < nproc: pytest.skip( f"Requires at least {nproc} GPUs, but only {device_count} available" ) - run_torchrun(nproc) + _run_torchrun(nproc) # --------------------------------------------------------------------------- @@ -192,8 +218,6 @@ def run_eager(x: torch.Tensor) -> torch.Tensor: dist.all_reduce(out_ref, group=nccl_group) out_jit = run_fn(inp) num_errors += not torch.all(out_jit == out_ref) - torch.cuda.synchronize() - nccl_group.barrier().wait() if num_errors > 0: return RuntimeError( f"Test failed for {size=}, {dtype=}, {algo=}, " @@ -211,9 +235,7 @@ def worker_main() -> None: logging.disable(logging.INFO) # Suppress internal logging for cleaner test output items = list(enumerate(TEST_CONFIG)) - disable_pbar = os.environ.get("DISABLE_PBAR", "0") == "1" or rank != 0 - pbar = tqdm(items, desc=f"Testing {world_size} GPUs", disable=disable_pbar) - for i, (size, dtype, algo, use_graph) in pbar: + for i, (size, dtype, algo, use_graph) in items: error = worker_test(device, nccl_group, comm, size, dtype, use_graph, algo) if error is not None: print( @@ -222,7 +244,7 @@ def worker_main() -> None: f"Error: {error}" ) # communicate the result to rank 0 for logging - result = torch.tensor([int(error is not None)], device=device) + result = torch.tensor([int(error is not None)]) dist.all_reduce(result, group=cpu_group) failed = bool(result.item()) if failed: @@ -239,4 +261,4 @@ def worker_main() -> None: if "LOCAL_RANK" in os.environ: worker_main() else: - sys.exit(pytest.main([__file__, "-v", "-s"])) + sys.exit(pytest.main([__file__, "-x", "-vv", "-s"])) diff --git a/python/sglang/jit_kernel/tests/test_fused_qknorm_rope.py b/python/sglang/jit_kernel/tests/test_fused_qknorm_rope.py index 10c6572900d3..0843db13f217 100644 --- a/python/sglang/jit_kernel/tests/test_fused_qknorm_rope.py +++ b/python/sglang/jit_kernel/tests/test_fused_qknorm_rope.py @@ -9,6 +9,10 @@ import torch from sglang.jit_kernel.fused_qknorm_rope import fused_qk_norm_rope +from sglang.test.ci.ci_register import register_cuda_ci + +register_cuda_ci(est_time=35, suite="stage-b-kernel-unit-1-gpu-large") +register_cuda_ci(est_time=256, suite="nightly-kernel-1-gpu", nightly=True) try: from sgl_kernel import fused_qk_norm_rope as fused_qk_norm_rope_aot @@ -118,7 +122,7 @@ def apply_interleave(x): q = apply_interleave(q) k = apply_interleave(k) else: - # NeoX style: first half * cos - second half * sin (and vice versa) + # NeoX style: first half Ɨ cos āˆ’ second half Ɨ sin (and vice versa) def apply_neox(x): # x: [num_tokens, n_heads, head_dim] x1 = x[:, :, : rotary_dim // 2] @@ -227,7 +231,7 @@ def test_fused_qknorm_rope_partial_rotary(head_dim, is_neox): # NeoX requires half_rotary_lanes to be power of 2. # half_rotary_lanes = rotary_dim / (head_dim / 32) / 2 = (head_dim//2) / (head_dim/32) / 2 - # = 16 / 2 = 8 -> power of 2, OK for all supported head_dims. + # = 16 / 2 = 8 → power of 2, OK for all supported head_dims. qkv = torch.randn( (num_tokens, total_heads * head_dim), dtype=torch.bfloat16, device=device diff --git a/python/sglang/jit_kernel/tests/test_hicache.py b/python/sglang/jit_kernel/tests/test_hicache.py new file mode 100644 index 000000000000..b6059b2c1d50 --- /dev/null +++ b/python/sglang/jit_kernel/tests/test_hicache.py @@ -0,0 +1,247 @@ +import sys + +import pytest +import torch + +from sglang.srt.mem_cache.memory_pool import MHATokenToKVPool, MLATokenToKVPool +from sglang.srt.mem_cache.memory_pool_host import ( + ALLOC_MEMORY_FUNCS, + MHATokenToKVPoolHost, + MLATokenToKVPoolHost, + alloc_with_pin_memory, +) +from sglang.srt.utils import is_cuda, is_hip, is_npu, is_xpu +from sglang.test.ci.ci_register import register_cuda_ci + +register_cuda_ci(est_time=10, suite="stage-b-kernel-unit-1-gpu-large") +register_cuda_ci(est_time=120, suite="nightly-kernel-1-gpu", nightly=True) + +pytestmark = pytest.mark.skipif( + not torch.cuda.is_available() + or is_npu() + or is_xpu() + or not (is_cuda() or is_hip()), + reason="HiCache JIT tests require CUDA/ROCm.", +) + +DEVICE = "cuda" +PAGE_SIZE = 1 if is_hip() else 16 +NUM_LAYERS = 2 +POOL_SIZE = PAGE_SIZE * 8 +MHA_ELEMENT_DIMS = [128, 256, 512, 1024] +MLA_ELEMENT_DIMS = [576] +LAYOUTS = ["layer_first", "page_first"] + + +def _token_indices_for_pages( + pages: torch.Tensor, page_size: int = PAGE_SIZE, device: str = DEVICE +) -> torch.Tensor: + parts = [ + torch.arange( + int(page) * page_size, + (int(page) + 1) * page_size, + device=device, + dtype=torch.int64, + ) + for page in pages.tolist() + ] + return torch.cat(parts, dim=0) + + +def _pinned_host_pool(host_pool_cls, **kwargs): + original_alloc = ALLOC_MEMORY_FUNCS[DEVICE] + ALLOC_MEMORY_FUNCS[DEVICE] = alloc_with_pin_memory + try: + return host_pool_cls( + host_to_device_ratio=2.0, + host_size=0, + page_size=PAGE_SIZE, + pin_memory=True, + device="cpu", + **kwargs, + ) + finally: + ALLOC_MEMORY_FUNCS[DEVICE] = original_alloc + + +def _copy_tensor_with_offset(tensor: torch.Tensor, offset: int) -> None: + data = torch.arange( + tensor.numel(), device=tensor.device, dtype=tensor.dtype + ).view_as(tensor) + tensor.copy_(data + offset) + + +def _run_transfer_roundtrip_mha(layout: str, element_dim: int) -> None: + device_pool = MHATokenToKVPool( + size=POOL_SIZE, + page_size=PAGE_SIZE, + head_num=element_dim // 128, + head_dim=128, + dtype=torch.bfloat16, + layer_num=NUM_LAYERS, + device=DEVICE, + enable_memory_saver=False, + ) + host_pool = _pinned_host_pool( + MHATokenToKVPoolHost, + device_pool=device_pool, + layout=layout, + ) + assert ( + host_pool.can_use_jit + ), f"Expected JIT HiCache kernel for MHA dim={element_dim}" + + for layer_id in range(NUM_LAYERS): + _copy_tensor_with_offset(device_pool.k_buffer[layer_id], layer_id) + _copy_tensor_with_offset(device_pool.v_buffer[layer_id], layer_id + 100) + + device_pages = torch.tensor([1, 2, 3], device=DEVICE, dtype=torch.int64) + host_pages = torch.tensor([0, 1, 2], device=DEVICE, dtype=torch.int64) + device_indices = _token_indices_for_pages(device_pages) + host_indices = _token_indices_for_pages(host_pages) + + host_pool.backup_from_device_all_layer( + device_pool, host_indices, device_indices, "kernel" + ) + torch.cuda.synchronize() + + for layer_id in range(NUM_LAYERS): + for host_page, device_page in zip(host_pages.tolist(), device_pages.tolist()): + host_start = host_page * PAGE_SIZE + device_start = device_page * PAGE_SIZE + assert torch.equal( + host_pool.k_data_refs[layer_id][ + host_start : host_start + PAGE_SIZE + ].cpu(), + device_pool.k_buffer[layer_id][ + device_start : device_start + PAGE_SIZE + ].cpu(), + ) + assert torch.equal( + host_pool.v_data_refs[layer_id][ + host_start : host_start + PAGE_SIZE + ].cpu(), + device_pool.v_buffer[layer_id][ + device_start : device_start + PAGE_SIZE + ].cpu(), + ) + + for layer_id in range(NUM_LAYERS): + device_pool.k_buffer[layer_id].zero_() + device_pool.v_buffer[layer_id].zero_() + + load_pages = torch.tensor([4, 5, 6], device=DEVICE, dtype=torch.int64) + load_indices = _token_indices_for_pages(load_pages) + for layer_id in range(NUM_LAYERS): + host_pool.load_to_device_per_layer( + device_pool, host_indices, load_indices, layer_id, "kernel" + ) + torch.cuda.synchronize() + + for layer_id in range(NUM_LAYERS): + for host_page, device_page in zip(host_pages.tolist(), load_pages.tolist()): + host_start = host_page * PAGE_SIZE + device_start = device_page * PAGE_SIZE + assert torch.equal( + device_pool.k_buffer[layer_id][ + device_start : device_start + PAGE_SIZE + ].cpu(), + host_pool.k_data_refs[layer_id][ + host_start : host_start + PAGE_SIZE + ].cpu(), + ) + assert torch.equal( + device_pool.v_buffer[layer_id][ + device_start : device_start + PAGE_SIZE + ].cpu(), + host_pool.v_data_refs[layer_id][ + host_start : host_start + PAGE_SIZE + ].cpu(), + ) + + +def _run_transfer_roundtrip_mla(layout: str, element_dim: int) -> None: + device_pool = MLATokenToKVPool( + size=POOL_SIZE, + page_size=PAGE_SIZE, + kv_lora_rank=element_dim - 64, + qk_rope_head_dim=64, + dtype=torch.bfloat16, + layer_num=NUM_LAYERS, + device=DEVICE, + enable_memory_saver=False, + ) + host_pool = _pinned_host_pool( + MLATokenToKVPoolHost, + device_pool=device_pool, + layout=layout, + ) + assert ( + host_pool.can_use_jit + ), f"Expected JIT HiCache kernel for MLA dim={element_dim}" + + for layer_id in range(NUM_LAYERS): + _copy_tensor_with_offset(device_pool.kv_buffer[layer_id], layer_id) + + device_pages = torch.tensor([1, 2, 3], device=DEVICE, dtype=torch.int64) + host_pages = torch.tensor([0, 1, 2], device=DEVICE, dtype=torch.int64) + device_indices = _token_indices_for_pages(device_pages) + host_indices = _token_indices_for_pages(host_pages) + + host_pool.backup_from_device_all_layer( + device_pool, host_indices, device_indices, "kernel" + ) + torch.cuda.synchronize() + + for layer_id in range(NUM_LAYERS): + for host_page, device_page in zip(host_pages.tolist(), device_pages.tolist()): + host_start = host_page * PAGE_SIZE + device_start = device_page * PAGE_SIZE + assert torch.equal( + host_pool.data_refs[layer_id][ + host_start : host_start + PAGE_SIZE + ].cpu(), + device_pool.kv_buffer[layer_id][ + device_start : device_start + PAGE_SIZE + ].cpu(), + ) + + for layer_id in range(NUM_LAYERS): + device_pool.kv_buffer[layer_id].zero_() + + load_pages = torch.tensor([4, 5, 6], device=DEVICE, dtype=torch.int64) + load_indices = _token_indices_for_pages(load_pages) + for layer_id in range(NUM_LAYERS): + host_pool.load_to_device_per_layer( + device_pool, host_indices, load_indices, layer_id, "kernel" + ) + torch.cuda.synchronize() + + for layer_id in range(NUM_LAYERS): + for host_page, device_page in zip(host_pages.tolist(), load_pages.tolist()): + host_start = host_page * PAGE_SIZE + device_start = device_page * PAGE_SIZE + assert torch.equal( + device_pool.kv_buffer[layer_id][ + device_start : device_start + PAGE_SIZE + ].cpu(), + host_pool.data_refs[layer_id][ + host_start : host_start + PAGE_SIZE + ].cpu(), + ) + + +@pytest.mark.parametrize("layout", LAYOUTS) +@pytest.mark.parametrize("element_dim", MHA_ELEMENT_DIMS) +def test_hicache_transfer_mha(layout: str, element_dim: int) -> None: + _run_transfer_roundtrip_mha(layout, element_dim) + + +@pytest.mark.parametrize("layout", LAYOUTS) +@pytest.mark.parametrize("element_dim", MLA_ELEMENT_DIMS) +def test_hicache_transfer_mla(layout: str, element_dim: int) -> None: + _run_transfer_roundtrip_mla(layout, element_dim) + + +if __name__ == "__main__": + sys.exit(pytest.main([__file__, "-v", "-s"])) diff --git a/python/sglang/jit_kernel/tests/test_norm_jit.py b/python/sglang/jit_kernel/tests/test_norm_jit.py deleted file mode 100644 index ebd0d3034cd9..000000000000 --- a/python/sglang/jit_kernel/tests/test_norm_jit.py +++ /dev/null @@ -1,145 +0,0 @@ -# Adapted from sgl-kernel/tests/test_norm.py - -import sys - -import pytest -import torch - -from sglang.test.ci.ci_register import register_cuda_ci - -register_cuda_ci(est_time=125, suite="stage-b-kernel-unit-1-gpu-large") -register_cuda_ci(est_time=500, suite="nightly-kernel-1-gpu", nightly=True) - -# JIT rmsnorm: fp16/bf16 only -# - Warp norm path (one warp per token): hidden_size in {64, 128, 256} -# - CTA norm path (multi-warp per token): hidden_size is a multiple of 256, > 256, and <=8192 -RMSNORM_HIDDEN_SIZES = [64, 128, 256, 512, 1024, 3072, 3584, 4096, 8192] - -# JIT fused_add_rmsnorm: fp16/bf16 only; hidden_size % 8 == 0, <=8192 -FUSED_ADD_RMSNORM_HIDDEN_SIZES = [1024, 3072, 3584, 4096, 8192] - -BS_LIST = [ - 1, - 19, - 99, - 989, - 8192, -] # 8192 ensures num_tokens > max_occupancy * kNumSM on any GPU - - -def _jit_rmsnorm(input, weight, output, eps): - from sglang.jit_kernel.norm import rmsnorm - - rmsnorm(input, weight, output=output, eps=eps) - - -def _fi_rmsnorm(input, weight, out, eps): - from flashinfer.norm import rmsnorm - - rmsnorm(input, weight, out=out, eps=eps) - - -def _jit_fused_add_rmsnorm(input, residual, weight, eps): - from sglang.jit_kernel.norm import fused_add_rmsnorm - - fused_add_rmsnorm(input, residual, weight, eps) - - -def _fi_fused_add_rmsnorm(input, residual, weight, eps): - from flashinfer.norm import fused_add_rmsnorm - - fused_add_rmsnorm(input, residual, weight, eps=eps) - - -@pytest.mark.parametrize("batch_size", BS_LIST) -@pytest.mark.parametrize("hidden_size", RMSNORM_HIDDEN_SIZES) -@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) -@pytest.mark.parametrize("specify_out", [True, False]) -def test_rmsnorm_jit(batch_size, hidden_size, dtype, specify_out): - eps = 1e-6 - x = torch.randn(batch_size, hidden_size, device="cuda", dtype=dtype) - w = torch.randn(hidden_size, device="cuda", dtype=dtype) - - # flashinfer reference - x_ref = x.clone() - _fi_rmsnorm(x_ref, w, out=x_ref, eps=eps) - - if specify_out: - y = torch.empty_like(x) - _jit_rmsnorm(x, w, output=y, eps=eps) - else: - y = x.clone() - _jit_rmsnorm(y, w, output=y, eps=eps) - - torch.testing.assert_close(y, x_ref, rtol=1e-2, atol=1e-2) - - -@pytest.mark.parametrize("batch_size", BS_LIST) -@pytest.mark.parametrize("hidden_size", FUSED_ADD_RMSNORM_HIDDEN_SIZES) -@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) -def test_fused_add_rmsnorm_jit(batch_size, hidden_size, dtype): - eps = 1e-6 - x = torch.randn(batch_size, hidden_size, dtype=dtype, device="cuda") - residual = torch.randn_like(x) - weight = torch.randn(hidden_size, dtype=dtype, device="cuda") - - # flashinfer reference - x_ref = x.clone() - r_ref = residual.clone() - _fi_fused_add_rmsnorm(x_ref, r_ref, weight, eps=eps) - - x_jit = x.clone() - r_jit = residual.clone() - _jit_fused_add_rmsnorm(x_jit, r_jit, weight, eps) - - torch.testing.assert_close(x_jit, x_ref, rtol=1e-2, atol=1e-2) - torch.testing.assert_close(r_jit, r_ref, rtol=1e-2, atol=1e-2) - - -@pytest.mark.parametrize( - ("hidden_size", "expected"), - [ - (0, False), - (64, True), - (128, True), - (256, True), - (512, True), - (8192, True), - (16384, False), - ], -) -def test_rmsnorm_hidden_size_support(hidden_size, expected): - from sglang.jit_kernel.norm import _is_supported_rmsnorm_hidden_size - - assert _is_supported_rmsnorm_hidden_size(hidden_size) is expected - - -@pytest.mark.parametrize( - ("hidden_size", "expected"), - [ - (64, "RMSNormWarpKernel"), - (128, "RMSNormWarpKernel"), - (256, "RMSNormWarpKernel"), - (512, "RMSNormKernel"), - (8192, "RMSNormKernel"), - ], -) -def test_rmsnorm_kernel_dispatch(hidden_size, expected): - from sglang.jit_kernel.norm import _rmsnorm_kernel_class - - assert _rmsnorm_kernel_class(hidden_size) == expected - - -@pytest.mark.parametrize("hidden_size", [0, 16384]) -def test_rmsnorm_rejects_unsupported_hidden_size(hidden_size): - from sglang.jit_kernel.norm import rmsnorm - - x = torch.randn(1, hidden_size) - w = torch.randn(hidden_size) - - with pytest.raises(RuntimeError, match=f"unsupported hidden_size={hidden_size}"): - rmsnorm(x, w) - - -if __name__ == "__main__": - sys.exit(pytest.main([__file__])) diff --git a/python/sglang/jit_kernel/tests/test_renorm.py b/python/sglang/jit_kernel/tests/test_renorm.py index 4def31326749..d3ef6ce196bb 100644 --- a/python/sglang/jit_kernel/tests/test_renorm.py +++ b/python/sglang/jit_kernel/tests/test_renorm.py @@ -82,44 +82,5 @@ def test_top_p_renorm_probs(batch_size, vocab_size, p): ) -@pytest.mark.parametrize("batch_size", [1, 99, 989]) -@pytest.mark.parametrize("vocab_size", [111, 32000, 128256]) -@pytest.mark.parametrize("k", [10, 100, 500]) -@pytest.mark.parametrize("neginf_input", [False, True]) -def test_top_k_mask_logits(batch_size, vocab_size, k, neginf_input): - """Test top_k_mask_logits kernel for correctness. - - This test validates that the kernel correctly: - 1. Identifies the top-k logits - 2. Masks non-top-k values to -inf - 3. Preserves the top-k values - 4. Handles negative infinity inputs gracefully - - The test verifies correctness by comparing softmax(top_k_mask_logits(logits)) - with top_k_renorm_prob(probs), which should be equivalent. - """ - if k > vocab_size: - pytest.skip("k should be less than vocab_size") - torch.manual_seed(42) - logits = torch.randn(batch_size, vocab_size, device="cuda:0") * 5 - if neginf_input: - # Randomly assign some logits to -inf to test edge cases - num_neginf = torch.randint(1, vocab_size * batch_size, (1,)).item() - idxs = torch.randperm(batch_size * vocab_size, device="cuda:0")[:num_neginf] - logits[idxs // vocab_size, idxs % vocab_size] = -float("inf") - - probs = torch.softmax(logits, dim=-1) - masked_logits = sgl_kernel.top_k_mask_logits(logits, k) - renormed_probs = torch.softmax(masked_logits, dim=-1) - renormed_probs_ref = sgl_kernel.top_k_renorm_prob(probs, k) - - torch.testing.assert_close( - renormed_probs, - renormed_probs_ref, - rtol=1e-3, - atol=1e-3, - ) - - if __name__ == "__main__": sys.exit(pytest.main([__file__])) diff --git a/python/sglang/jit_kernel/tests/test_rmsnorm.py b/python/sglang/jit_kernel/tests/test_rmsnorm.py index ac31a792747d..59ce90f299f2 100644 --- a/python/sglang/jit_kernel/tests/test_rmsnorm.py +++ b/python/sglang/jit_kernel/tests/test_rmsnorm.py @@ -3,49 +3,104 @@ import pytest import torch -import triton from sglang.jit_kernel.utils import get_ci_test_range from sglang.test.ci.ci_register import register_cuda_ci -register_cuda_ci(est_time=18, suite="stage-b-kernel-unit-1-gpu-large") -register_cuda_ci(est_time=120, suite="nightly-kernel-1-gpu", nightly=True) +register_cuda_ci(est_time=45, suite="stage-b-kernel-unit-1-gpu-large") +register_cuda_ci(est_time=240, suite="nightly-kernel-1-gpu", nightly=True) -def sglang_jit_rmsnorm(input: torch.Tensor, weight: torch.Tensor) -> None: +EPS = 1e-6 +DEVICE = "cuda" +DTYPES = [torch.float16, torch.bfloat16] + + +def sglang_jit_rmsnorm( + input: torch.Tensor, + weight: torch.Tensor, + *, + output: torch.Tensor | None = None, + eps: float = EPS, +) -> None: from sglang.jit_kernel.norm import rmsnorm - rmsnorm(input, weight, output=input) + rmsnorm(input, weight, out=output, eps=eps) -def flashinfer_rmsnorm(input: torch.Tensor, weight: torch.Tensor) -> None: +def flashinfer_rmsnorm( + input: torch.Tensor, + weight: torch.Tensor, + *, + output: torch.Tensor, + eps: float = EPS, +) -> None: from flashinfer.norm import rmsnorm - rmsnorm(input, weight, out=input) + rmsnorm(input, weight, out=output, eps=eps) BS_LIST = [2**n for n in range(0, 14)] BS_LIST += [x + 1 + i for i, x in enumerate(BS_LIST)] BS_LIST = get_ci_test_range(BS_LIST, [1, 9, 256, 4109]) -HIDDEN_SIZE_LIST = get_ci_test_range( - [512, 1024, 1536, 2048, 3072, 4096, 5120, 6144, 7168, 8192], - [512, 2048, 8192], +SUPPORTED_HIDDEN_SIZE_LIST = get_ci_test_range( + [64, 128, 256, 512, *range(1024, 8192 + 1, 1024), 2304, 2560, 12288, 16384], + [256, 1024, 16384], ) -DEVICE = "cuda" -DTYPE = torch.bfloat16 @pytest.mark.parametrize( - "batch_size,hidden_size", list(itertools.product(BS_LIST, HIDDEN_SIZE_LIST)) + "batch_size,hidden_size", + list(itertools.product(BS_LIST, SUPPORTED_HIDDEN_SIZE_LIST)), ) -def test_rmsnorm(batch_size: int, hidden_size: int) -> None: - input = torch.randn(batch_size, hidden_size, device=DEVICE, dtype=DTYPE) - weight = torch.randn(hidden_size, device=DEVICE, dtype=DTYPE) - input_sglang = input.clone() +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("specify_out", [True, False]) +def test_rmsnorm( + batch_size: int, hidden_size: int, dtype: torch.dtype, specify_out: bool +) -> None: + input = torch.randn(batch_size, hidden_size, device=DEVICE, dtype=dtype) + weight = torch.randn(hidden_size, device=DEVICE, dtype=dtype) + input_flashinfer = input.clone() - sglang_jit_rmsnorm(input_sglang, weight) - flashinfer_rmsnorm(input_flashinfer, weight) - triton.testing.assert_close(input_sglang, input_flashinfer, atol=1e-2, rtol=1e-2) + output_flashinfer = torch.empty_like(input) + flashinfer_rmsnorm(input_flashinfer, weight, output=output_flashinfer) + + if specify_out: + output_sglang = torch.empty_like(input) + sglang_jit_rmsnorm(input, weight, output=output_sglang) + else: + output_sglang = input.clone() + sglang_jit_rmsnorm(output_sglang, weight, output=output_sglang) + + torch.testing.assert_close(output_sglang, output_flashinfer, atol=1e-2, rtol=1e-2) + + +@pytest.mark.parametrize("hidden_size", [64, 128, 256, 512, 8192, 8704, 16384]) +def test_rmsnorm_hidden_size_support(hidden_size: int) -> None: + from sglang.jit_kernel.norm import _is_supported_rmsnorm_hidden_size + + assert _is_supported_rmsnorm_hidden_size(hidden_size) + + +@pytest.mark.parametrize( + ("hidden_size", "expected"), + [ + (64, "RMSNormWarpKernel"), + (128, "RMSNormWarpKernel"), + (256, "RMSNormWarpKernel"), + (512, "RMSNormKernel"), + (1536, "RMSNormKernel"), + (2048, "RMSNormHalfKernel"), + (2304, "RMSNormKernel"), # NOTE: not 512 aligned + (8192, "RMSNormHalfKernel"), + (8704, "RMSNormHalfKernel"), + (16384, "RMSNormHalfKernel"), + ], +) +def test_rmsnorm_kernel_dispatch(hidden_size: int, expected: str) -> None: + from sglang.jit_kernel.norm import _rmsnorm_kernel_class + + assert _rmsnorm_kernel_class(hidden_size) == expected if __name__ == "__main__": diff --git a/python/sglang/jit_kernel/utils.py b/python/sglang/jit_kernel/utils.py index 63a0ba99d041..ec3ebd5abc80 100644 --- a/python/sglang/jit_kernel/utils.py +++ b/python/sglang/jit_kernel/utils.py @@ -140,6 +140,7 @@ def load_jit( extra_include_paths: List[str] | None = None, extra_dependencies: List[str] | None = None, build_directory: str | None = None, + header_only: bool = True, ) -> Module: """ Loading a JIT module from C++/CUDA source files. @@ -169,47 +170,64 @@ def load_jit( :type extra_dependencies: List[str] | None :param build_directory: The build directory for JIT compilation. :type build_directory: str | None + :param header_only: Whether the module is header-only. + If true, apply the wrappers to export given class/functions. + Otherwise, we must export from C++/CUDA side. :return: A just-in-time(JIT) compiled module. :rtype: Module """ - from tvm_ffi.cpp import load_inline + from tvm_ffi.cpp import load, load_inline cpp_files = cpp_files or [] cuda_files = cuda_files or [] - cpp_wrappers = cpp_wrappers or [] - cuda_wrappers = cuda_wrappers or [] extra_cflags = extra_cflags or [] extra_cuda_cflags = extra_cuda_cflags or [] extra_ldflags = extra_ldflags or [] extra_include_paths = extra_include_paths or [] + cpp_files = [str((KERNEL_PATH / "csrc" / f).resolve()) for f in cpp_files] + cuda_files = [str((KERNEL_PATH / "csrc" / f).resolve()) for f in cuda_files] + for dep in set(extra_dependencies or []): if dep not in _REGISTERED_DEPENDENCIES: raise ValueError(f"Dependency {dep} is not registered.") extra_include_paths += _REGISTERED_DEPENDENCIES[dep]() - # include cpp files - cpp_paths = [(KERNEL_PATH / "csrc" / f).resolve() for f in cpp_files] - cpp_sources = [f'#include "{path}"' for path in cpp_paths] - cpp_sources += [_make_wrapper(tup) for tup in cpp_wrappers] - - # include cuda files - cuda_paths = [(KERNEL_PATH / "csrc" / f).resolve() for f in cuda_files] - cuda_sources = [f'#include "{path}"' for path in cuda_paths] - cuda_sources += [_make_wrapper(tup) for tup in cuda_wrappers] - - with _jit_compile_context(): - return load_inline( - "sgl_kernel_jit_" + "_".join(str(arg) for arg in args), - cpp_sources=cpp_sources, - cuda_sources=cuda_sources, - extra_cflags=DEFAULT_CFLAGS + extra_cflags, - extra_cuda_cflags=_get_default_target_flags() + extra_cuda_cflags, - extra_ldflags=DEFAULT_LDFLAGS + extra_ldflags, - extra_include_paths=DEFAULT_INCLUDE + extra_include_paths, - build_directory=build_directory, - ) + module_name = "sgl_kernel_jit_" + "_".join(str(arg) for arg in args) + if header_only: + cpp_wrappers = cpp_wrappers or [] + cuda_wrappers = cuda_wrappers or [] + cpp_sources = [f'#include "{path}"' for path in cpp_files] + cpp_sources += [_make_wrapper(tup) for tup in cpp_wrappers] + + # include cuda files + cuda_sources = [f'#include "{path}"' for path in cuda_files] + cuda_sources += [_make_wrapper(tup) for tup in cuda_wrappers] + with _jit_compile_context(): + return load_inline( + module_name, + cpp_sources=cpp_sources, + cuda_sources=cuda_sources, + extra_cflags=DEFAULT_CFLAGS + extra_cflags, + extra_cuda_cflags=_get_default_target_flags() + extra_cuda_cflags, + extra_ldflags=DEFAULT_LDFLAGS + extra_ldflags, + extra_include_paths=DEFAULT_INCLUDE + extra_include_paths, + build_directory=build_directory, + ) + else: + assert cpp_wrappers is None and cuda_wrappers is None + with _jit_compile_context(): + return load( + module_name, + cpp_files=cpp_files, + cuda_files=cuda_files, + extra_cflags=DEFAULT_CFLAGS + extra_cflags, + extra_cuda_cflags=_get_default_target_flags() + extra_cuda_cflags, + extra_ldflags=DEFAULT_LDFLAGS + extra_ldflags, + extra_include_paths=DEFAULT_INCLUDE + extra_include_paths, + build_directory=build_directory, + ) @dataclass diff --git a/python/sglang/lang/backend/runtime_endpoint.py b/python/sglang/lang/backend/runtime_endpoint.py index 41a48103123f..8732e401fe71 100644 --- a/python/sglang/lang/backend/runtime_endpoint.py +++ b/python/sglang/lang/backend/runtime_endpoint.py @@ -67,7 +67,7 @@ def flush_cache(self): def get_server_info(self): res = http_request( - self.base_url + "/get_server_info", + self.base_url + "/server_info", api_key=self.api_key, verify=self.verify, ) @@ -531,7 +531,7 @@ def encode( async def get_server_info(self): async with aiohttp.ClientSession() as session: - async with session.get(f"{self.url}/get_server_info") as response: + async with session.get(f"{self.url}/server_info") as response: if response.status == 200: return await response.json() else: diff --git a/python/sglang/multimodal_gen/.claude/skills/sglang-diffusion-benchmark-profile/benchmark-and-profile.md b/python/sglang/multimodal_gen/.claude/skills/sglang-diffusion-benchmark-profile/benchmark-and-profile.md index a48bca3edc99..5d54535cc7b5 100644 --- a/python/sglang/multimodal_gen/.claude/skills/sglang-diffusion-benchmark-profile/benchmark-and-profile.md +++ b/python/sglang/multimodal_gen/.claude/skills/sglang-diffusion-benchmark-profile/benchmark-and-profile.md @@ -63,8 +63,6 @@ Download input images required by some models: ```bash wget -O "${ASSET_DIR}/cat.png" \ https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/cat.png -wget -O "${ASSET_DIR}/astronaut.jpg" \ - https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg wget -O "${ASSET_DIR}/mova_single_person.jpg" \ https://github.com/OpenMOSS/MOVA/raw/main/assets/single_person.jpg ``` @@ -75,7 +73,27 @@ wget -O "${ASSET_DIR}/mova_single_person.jpg" \ All commands include `--warmup` and `--enable-torch-compile` for real production performance. Add `--perf-dump-path .json` for machine-readable output. -If you want a checked-in preset runner instead of copying commands manually, use `scripts/bench_diffusion_denoise.py --model --label `. It writes the same perf dump JSONs used by `compare_perf.py`. +Nightly diffusion comparison is server/API based (`sglang serve` + OpenAI-compatible requests). The commands below stay on `sglang generate` for local profiling, but the first 8 presets are aligned to nightly on model, prompt, reference image, steps, guidance scale, GPU count, and parallelism flags. + +If you want a checked-in preset runner instead of copying commands manually, use `scripts/bench_diffusion_denoise.py --model --label ` or `--list-models`. It writes the same perf dump JSONs used by `compare_perf.py`. + +### Preset Catalog + +Nightly-aligned presets come first; skill-only presets stay available after them. + +| Preset | Model | Nightly | Notes | +| --- | --- | --- | --- | +| `flux` | `black-forest-labs/FLUX.1-dev` | Yes: `flux1_dev_t2i_1024` | Aligned to nightly prompt + `--dit-layerwise-offload false` | +| `flux2` | `black-forest-labs/FLUX.2-dev` | Yes: `flux2_dev_t2i_1024` | Aligned to nightly prompt, 50 steps, guidance 4.0 | +| `qwen` | `Qwen/Qwen-Image-2512` | Yes: `qwen_image_2512_t2i_1024` | Aligned to nightly prompt/steps; no extra offload overrides | +| `qwen-edit` | `Qwen/Qwen-Image-Edit-2511` | Yes: `qwen_image_edit_2511` | Uses nightly cat image + edit prompt | +| `zimage` | `Tongyi-MAI/Z-Image-Turbo` | Yes: `zimage_turbo_t2i_1024` | Aligned to nightly prompt + guidance 4.0 | +| `wan-t2v` | `Wan-AI/Wan2.2-T2V-A14B-Diffusers` | Yes: `wan22_t2v_a14b_720p` | Aligned to nightly CFG-parallel 4-GPU launch | +| `wan-ti2v` | `Wan-AI/Wan2.2-TI2V-5B-Diffusers` | Yes: `wan22_ti2v_5b_720p` | Uses nightly cat image + motion prompt | +| `wan-i2v` | `Wan-AI/Wan2.2-I2V-A14B-Diffusers` | Yes: `wan22_i2v_a14b_720p` | Added to match nightly; aligned to CFG-parallel 4-GPU launch | +| `hunyuanvideo` | `hunyuanvideo-community/HunyuanVideo` | No | Skill-only extra preset | +| `mova-720p` | `OpenMOSS-Team/MOVA-720p` | No | Skill-only extra preset | +| `helios` | `BestWishYsh/Helios-Base` | No | Skill-only extra preset | ### Perf dump & before/after compare @@ -98,86 +116,95 @@ python3 python/sglang/multimodal_gen/benchmarks/compare_perf.py \ "${BENCH_DIR}/baseline.json" "${BENCH_DIR}/new.json" ``` -### Qwen-Image-2512 (1024Ɨ1024, 50 steps) +### FLUX.1-dev (Nightly: `flux1_dev_t2i_1024`) ```bash sglang generate \ - --model-path=Qwen/Qwen-Image-2512 \ - --prompt="A futuristic cyberpunk city at night, neon lights reflecting on wet streets, highly detailed, 8k" \ - '--negative-prompt= ' \ + --model-path=black-forest-labs/FLUX.1-dev \ + --prompt="A futuristic cyberpunk city at night, neon lights reflecting on wet streets" \ --width=1024 --height=1024 --num-inference-steps=50 --guidance-scale=4.0 \ --seed=42 --save-output --enable-torch-compile --warmup \ - --dit-cpu-offload false --text-encoder-cpu-offload false + --dit-layerwise-offload false ``` -### Qwen-Image-Edit-2511 (image editing, 1024Ɨ1024, 50 steps) +### FLUX.2-dev (Nightly: `flux2_dev_t2i_1024`) ```bash sglang generate \ - --model-path=Qwen/Qwen-Image-Edit-2511 \ - '--prompt=Transform into anime style' '--negative-prompt= ' \ - --image-path="${ASSET_DIR}/cat.png" \ + --model-path=black-forest-labs/FLUX.2-dev \ + --prompt="A futuristic cyberpunk city at night, neon lights reflecting on wet streets" \ --width=1024 --height=1024 --num-inference-steps=50 --guidance-scale=4.0 \ --seed=42 --save-output --enable-torch-compile --warmup \ - --dit-cpu-offload false --text-encoder-cpu-offload false + --dit-layerwise-offload false ``` -### FLUX.1-dev (1024Ɨ1024, 50 steps) +### Qwen-Image-2512 (Nightly: `qwen_image_2512_t2i_1024`) ```bash sglang generate \ - --model-path=black-forest-labs/FLUX.1-dev \ - --prompt="A futuristic cyberpunk city at night, neon lights reflecting on wet streets, highly detailed, 8k" \ + --model-path=Qwen/Qwen-Image-2512 \ + --prompt="A futuristic cyberpunk city at night, neon lights reflecting on wet streets" \ --width=1024 --height=1024 --num-inference-steps=50 --guidance-scale=4.0 \ --seed=42 --save-output --enable-torch-compile --warmup ``` -### FLUX.2-dev (1024Ɨ1024) +### Qwen-Image-Edit-2511 (Nightly: `qwen_image_edit_2511`) ```bash sglang generate \ - --model-path black-forest-labs/FLUX.2-dev \ - --prompt "A Logo With Bold Large Text: SGL Diffusion" \ - --width=1024 --height=1024 \ - --dit-layerwise-offload false --enable-torch-compile --warmup \ - --dit-cpu-offload false --text-encoder-cpu-offload true --vae-cpu-offload false + --model-path=Qwen/Qwen-Image-Edit-2511 \ + --prompt="Make the cat wear a red hat" \ + --image-path="${ASSET_DIR}/cat.png" \ + --width=1024 --height=1024 --num-inference-steps=50 --guidance-scale=4.0 \ + --seed=42 --save-output --enable-torch-compile --warmup ``` -### Z-Image-Turbo (1024Ɨ1024, 9 steps) +### Z-Image-Turbo (Nightly: `zimage_turbo_t2i_1024`) ```bash sglang generate \ --model-path=Tongyi-MAI/Z-Image-Turbo \ - --prompt='A fantasy landscape with mountains and a river, detailed, vibrant colors' \ - --width=1024 --height=1024 --num-inference-steps=9 --guidance-scale=0.0 \ - --seed=42 --save-output --enable-torch-compile --warmup \ - --dit-cpu-offload false --text-encoder-cpu-offload false + --prompt="A futuristic cyberpunk city at night, neon lights reflecting on wet streets" \ + --width=1024 --height=1024 --num-inference-steps=9 --guidance-scale=4.0 \ + --seed=42 --save-output --enable-torch-compile --warmup ``` -### Wan2.2-T2V-A14B 720P (4 GPUs, 81 frames, 2 steps) +### Wan2.2-T2V-A14B 720P (Nightly: `wan22_t2v_a14b_720p`) ```bash # Select four idle GPUs first: # export CUDA_VISIBLE_DEVICES=$(python3 "$ENV_PY" print-idle-gpus --count 4) sglang generate \ --model-path=Wan-AI/Wan2.2-T2V-A14B-Diffusers \ - --prompt="A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon." \ - --negative-prompt=" " --720p --num-inference-steps=2 --num-frames=81 \ + --prompt="A cat and a dog baking a cake together in a kitchen." \ + --720p --num-inference-steps=2 --num-frames=81 \ --guidance-scale=5.0 --seed=42 --save-output \ - --num-gpus=4 --ulysses-degree=4 \ + --num-gpus=4 --enable-cfg-parallel --ulysses-degree=2 \ --text-encoder-cpu-offload --pin-cpu-memory \ --warmup --enable-torch-compile ``` -### Wan2.2-TI2V-5B 720P (single GPU, 81 frames, 50 steps) +### Wan2.2-TI2V-5B 720P (Nightly: `wan22_ti2v_5b_720p`) ```bash sglang generate \ - --model-path Wan-AI/Wan2.2-TI2V-5B-Diffusers \ - --prompt "An astronaut hatching from an egg, on the surface of the moon..." \ - --negative-prompt "Bright tones, overexposed, static, blurred details..." \ - --image-path="${ASSET_DIR}/astronaut.jpg" \ - --num-frames 81 --720p --num-inference-steps 50 --guidance-scale 5.0 \ - --seed 42 --save-output \ - --dit-layerwise-offload false --dit-cpu-offload false \ - --vae-cpu-offload false --text-encoder-cpu-offload false \ + --model-path=Wan-AI/Wan2.2-TI2V-5B-Diffusers \ + --prompt="The cat starts walking slowly towards the camera." \ + --image-path="${ASSET_DIR}/cat.png" \ + --num-frames=81 --720p --num-inference-steps=50 --guidance-scale=5.0 \ + --seed=42 --save-output \ --enable-torch-compile --warmup ``` -### HunyuanVideo (848Ɨ480, 65 frames, 30 steps) +### Wan2.2-I2V-A14B 720P (Nightly: `wan22_i2v_a14b_720p`) +```bash +# Select four idle GPUs first: +# export CUDA_VISIBLE_DEVICES=$(python3 "$ENV_PY" print-idle-gpus --count 4) +sglang generate \ + --model-path=Wan-AI/Wan2.2-I2V-A14B-Diffusers \ + --prompt="The cat starts walking slowly towards the camera." \ + --image-path="${ASSET_DIR}/cat.png" \ + --720p --num-inference-steps=2 --num-frames=81 \ + --guidance-scale=5.0 --seed=42 --save-output \ + --num-gpus=4 --enable-cfg-parallel --ulysses-degree=2 \ + --text-encoder-cpu-offload --pin-cpu-memory \ + --warmup --enable-torch-compile +``` + +### HunyuanVideo (Skill-only, not nightly) ```bash sglang generate \ --model-path=hunyuanvideo-community/HunyuanVideo \ @@ -188,7 +215,7 @@ sglang generate \ --warmup --enable-torch-compile ``` -### MOVA-720p (4 GPUs, 193 frames, 2 steps) +### MOVA-720p (Skill-only, not nightly) ```bash # Select four idle GPUs first: # export CUDA_VISIBLE_DEVICES=$(python3 "$ENV_PY" print-idle-gpus --count 4) @@ -203,6 +230,17 @@ sglang generate \ --enable-torch-compile --save-output --warmup ``` +### Helios-Base (Skill-only, not nightly) +```bash +sglang generate \ + --model-path=BestWishYsh/Helios-Base \ + --prompt="A curious raccoon" \ + --width=640 --height=384 --num-frames=33 \ + --dit-layerwise-offload false --dit-cpu-offload false \ + --text-encoder-cpu-offload false --vae-cpu-offload false \ + --seed=42 --save-output --enable-torch-compile --warmup +``` + **Key metrics** (all models): denoise latency ā˜…, end-to-end latency, peak GPU memory. --- diff --git a/python/sglang/multimodal_gen/.claude/skills/sglang-diffusion-benchmark-profile/scripts/bench_diffusion_denoise.py b/python/sglang/multimodal_gen/.claude/skills/sglang-diffusion-benchmark-profile/scripts/bench_diffusion_denoise.py index 86381cd1d557..bc52f4a43c97 100755 --- a/python/sglang/multimodal_gen/.claude/skills/sglang-diffusion-benchmark-profile/scripts/bench_diffusion_denoise.py +++ b/python/sglang/multimodal_gen/.claude/skills/sglang-diffusion-benchmark-profile/scripts/bench_diffusion_denoise.py @@ -12,9 +12,12 @@ # Tag the run for later compare_perf.py usage python3 python/sglang/multimodal_gen/.claude/skills/sglang-diffusion-benchmark-profile/scripts/bench_diffusion_denoise.py --model flux --label tuned - # All 10 preset models + # All 11 preset models python3 python/sglang/multimodal_gen/.claude/skills/sglang-diffusion-benchmark-profile/scripts/bench_diffusion_denoise.py --all + # Show preset order, model path, and nightly mapping + python3 python/sglang/multimodal_gen/.claude/skills/sglang-diffusion-benchmark-profile/scripts/bench_diffusion_denoise.py --list-models + For gated Hugging Face repos such as FLUX, export HF_TOKEN first: export HF_TOKEN= @@ -22,8 +25,6 @@ ASSET_DIR=$(python3 python/sglang/multimodal_gen/.claude/skills/sglang-diffusion-benchmark-profile/scripts/diffusion_skill_env.py print-assets-dir --mkdir) wget -O "${ASSET_DIR}/cat.png" \ https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/cat.png - wget -O "${ASSET_DIR}/astronaut.jpg" \ - https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg wget -O "${ASSET_DIR}/mova_single_person.jpg" \ https://github.com/OpenMOSS/MOVA/raw/main/assets/single_person.jpg """ @@ -55,47 +56,43 @@ # --------------------------------------------------------------------------- # Model configs — kept in exact sync with benchmark-and-profile.md +# Nightly-aligned presets come first, followed by skill-only extras. # Each entry produces the same `sglang generate` command as shown in that doc. # --------------------------------------------------------------------------- MODELS = { - # 1. Qwen/Qwen-Image-2512 — Text-to-Image, 1024Ɨ1024, 50 steps - "qwen": { - "path": "Qwen/Qwen-Image-2512", - "prompt": "A futuristic cyberpunk city at night, neon lights reflecting on wet streets, highly detailed, 8k", - "negative_prompt": " ", + # 1. Nightly: flux1_dev_t2i_1024 + "flux": { + "nightly_case_id": "flux1_dev_t2i_1024", + "path": "black-forest-labs/FLUX.1-dev", + "prompt": "A futuristic cyberpunk city at night, neon lights reflecting on wet streets", "extra_args": [ "--width=1024", "--height=1024", "--num-inference-steps=50", "--guidance-scale=4.0", - "--dit-cpu-offload", - "false", - "--text-encoder-cpu-offload", + "--dit-layerwise-offload", "false", ], }, - # 2. Qwen/Qwen-Image-Edit-2511 — Image Editing, 1024Ɨ1024, 50 steps - # Requires: /inputs/diffusion_benchmark/figs/cat.png - "qwen-edit": { - "path": "Qwen/Qwen-Image-Edit-2511", - "prompt": "Transform into anime style", - "negative_prompt": " ", - "image_path": str(ASSET_DIR / "cat.png"), + # 2. Nightly: flux2_dev_t2i_1024 + "flux2": { + "nightly_case_id": "flux2_dev_t2i_1024", + "path": "black-forest-labs/FLUX.2-dev", + "prompt": "A futuristic cyberpunk city at night, neon lights reflecting on wet streets", "extra_args": [ "--width=1024", "--height=1024", "--num-inference-steps=50", "--guidance-scale=4.0", - "--dit-cpu-offload", - "false", - "--text-encoder-cpu-offload", + "--dit-layerwise-offload", "false", ], }, - # 3. black-forest-labs/FLUX.1-dev — Text-to-Image, 1024Ɨ1024, 50 steps - "flux": { - "path": "black-forest-labs/FLUX.1-dev", - "prompt": "A futuristic cyberpunk city at night, neon lights reflecting on wet streets, highly detailed, 8k", + # 3. Nightly: qwen_image_2512_t2i_1024 + "qwen": { + "nightly_case_id": "qwen_image_2512_t2i_1024", + "path": "Qwen/Qwen-Image-2512", + "prompt": "A futuristic cyberpunk city at night, neon lights reflecting on wet streets", "extra_args": [ "--width=1024", "--height=1024", @@ -103,80 +100,83 @@ "--guidance-scale=4.0", ], }, - # 4. black-forest-labs/FLUX.2-dev — Text-to-Image, 1024Ɨ1024 - "flux2": { - "path": "black-forest-labs/FLUX.2-dev", - "prompt": "A Logo With Bold Large Text: SGL Diffusion", + # 4. Nightly: qwen_image_edit_2511 + # Requires: /inputs/diffusion_benchmark/figs/cat.png + "qwen-edit": { + "nightly_case_id": "qwen_image_edit_2511", + "path": "Qwen/Qwen-Image-Edit-2511", + "prompt": "Make the cat wear a red hat", + "image_path": str(ASSET_DIR / "cat.png"), "extra_args": [ "--width=1024", "--height=1024", - "--dit-layerwise-offload", - "false", - "--dit-cpu-offload", - "false", - "--text-encoder-cpu-offload", - "true", - "--vae-cpu-offload", - "false", + "--num-inference-steps=50", + "--guidance-scale=4.0", ], }, - # 5. Tongyi-MAI/Z-Image-Turbo — Turbo Text-to-Image, 1024Ɨ1024, 9 steps + # 5. Nightly: zimage_turbo_t2i_1024 "zimage": { + "nightly_case_id": "zimage_turbo_t2i_1024", "path": "Tongyi-MAI/Z-Image-Turbo", - "prompt": "A fantasy landscape with mountains and a river, detailed, vibrant colors", + "prompt": "A futuristic cyberpunk city at night, neon lights reflecting on wet streets", "extra_args": [ "--width=1024", "--height=1024", "--num-inference-steps=9", - "--guidance-scale=0.0", - "--dit-cpu-offload", - "false", - "--text-encoder-cpu-offload", - "false", + "--guidance-scale=4.0", ], }, - # 6. Wan-AI/Wan2.2-T2V-A14B-Diffusers — Text-to-Video, 720P, 4 GPUs, 81 frames, 2 steps + # 6. Nightly: wan22_t2v_a14b_720p "wan-t2v": { + "nightly_case_id": "wan22_t2v_a14b_720p", "path": "Wan-AI/Wan2.2-T2V-A14B-Diffusers", - "prompt": "A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon.", - "negative_prompt": " ", + "prompt": "A cat and a dog baking a cake together in a kitchen.", "extra_args": [ "--720p", "--num-inference-steps=2", "--num-frames=81", "--guidance-scale=5.0", "--num-gpus=4", - "--ulysses-degree=4", + "--enable-cfg-parallel", + "--ulysses-degree=2", "--text-encoder-cpu-offload", "--pin-cpu-memory", ], }, - # 7. Wan-AI/Wan2.2-TI2V-5B-Diffusers — Text-Image-to-Video, 720P, 1 GPU, 81 frames, 50 steps - # Requires: /inputs/diffusion_benchmark/figs/astronaut.jpg + # 7. Nightly: wan22_ti2v_5b_720p + # Requires: /inputs/diffusion_benchmark/figs/cat.png "wan-ti2v": { + "nightly_case_id": "wan22_ti2v_5b_720p", "path": "Wan-AI/Wan2.2-TI2V-5B-Diffusers", - "prompt": "An astronaut hatching from an egg, on the surface of the moon, the darkness and depth of space realised in the background. High quality, ultrarealistic detail and breath-taking movie-like camera shot.", - "negative_prompt": "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards", - "image_path": str(ASSET_DIR / "astronaut.jpg"), + "prompt": "The cat starts walking slowly towards the camera.", + "image_path": str(ASSET_DIR / "cat.png"), "extra_args": [ - "--num-frames", - "81", "--720p", - "--num-inference-steps", - "50", - "--guidance-scale", - "5.0", - "--dit-layerwise-offload", - "false", - "--dit-cpu-offload", - "false", - "--vae-cpu-offload", - "false", + "--num-frames=81", + "--num-inference-steps=50", + "--guidance-scale=5.0", + ], + }, + # 8. Nightly: wan22_i2v_a14b_720p + # Requires: /inputs/diffusion_benchmark/figs/cat.png + "wan-i2v": { + "nightly_case_id": "wan22_i2v_a14b_720p", + "path": "Wan-AI/Wan2.2-I2V-A14B-Diffusers", + "prompt": "The cat starts walking slowly towards the camera.", + "image_path": str(ASSET_DIR / "cat.png"), + "extra_args": [ + "--720p", + "--num-inference-steps=2", + "--num-frames=81", + "--guidance-scale=5.0", + "--num-gpus=4", + "--enable-cfg-parallel", + "--ulysses-degree=2", "--text-encoder-cpu-offload", - "false", + "--pin-cpu-memory", ], }, - # 8. hunyuanvideo-community/HunyuanVideo — Text-to-Video, 848Ɨ480, 65 frames, 30 steps + # 9. Skill-only extra preset "hunyuanvideo": { "path": "hunyuanvideo-community/HunyuanVideo", "prompt": "A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window.", @@ -189,7 +189,7 @@ "--num-inference-steps=30", ], }, - # 9. OpenMOSS-Team/MOVA-720p — Image-to-Video, 4 GPUs, 193 frames, 2 steps + # 10. Skill-only extra preset # Requires: /inputs/diffusion_benchmark/figs/mova_single_person.jpg "mova-720p": { "path": "OpenMOSS-Team/MOVA-720p", @@ -205,7 +205,7 @@ "--num-inference-steps=2", ], }, - # 10. BestWishYsh/Helios-Base — Text-to-Video, 640Ɨ384, 33 frames + # 11. Skill-only extra preset "helios": { "path": "BestWishYsh/Helios-Base", "prompt": "A curious raccoon", @@ -227,13 +227,35 @@ def required_gpus_for_model(model_key: str) -> int: - if model_key == "wan-t2v": + if model_key in {"wan-t2v", "wan-i2v"}: return 4 if model_key == "mova-720p": return 4 return 1 +def model_nightly_case_id(model_key: str) -> str: + return MODELS[model_key].get("nightly_case_id", "-") + + +def print_model_catalog(): + """Print preset order, model path, and whether each preset maps to nightly.""" + print() + print("=" * 95) + print("MODEL PRESETS — Nightly-aligned first, skill-only extras after") + print("=" * 95) + print(f"{'Preset':<14} {'Nightly':<28} {'Model Path':<46} {'GPUs':>4}") + print("-" * 95) + for model_key, cfg in MODELS.items(): + print( + f"{model_key:<14} {model_nightly_case_id(model_key):<28} {cfg['path']:<46} {required_gpus_for_model(model_key):>4}" + ) + print("-" * 112) + print( + "Nightly column shows the comparison_configs.json case id; '-' means skill-only." + ) + + def build_sglang_cmd( model_key: str, perf_dump_path: Optional[str] = None, @@ -393,9 +415,9 @@ def print_results_table(results: list[dict]): print("=" * 80) print( - f"{'Model':<16} {'Label':<12} {'Denoise(s)':>12} {'E2E(s)':>10} {'Peak Mem(GB)':>14}" + f"{'Model':<14} {'Nightly':<24} {'Label':<12} {'Denoise(s)':>12} {'E2E(s)':>10} {'Peak Mem(GB)':>14}" ) - print("-" * 64) + print("-" * 92) for result in results: denoise_s = result.get("denoise_latency_s") @@ -405,10 +427,10 @@ def print_results_table(results: list[dict]): e2e_text = f"{e2e_s:.2f}" if isinstance(e2e_s, float) else "n/a" mem_text = f"{peak_mem:.1f}" if isinstance(peak_mem, float) else "n/a" print( - f"{result['model']:<16} {result['label']:<12} {denoise_text:>12} {e2e_text:>10} {mem_text:>14}" + f"{result['model']:<14} {model_nightly_case_id(result['model']):<24} {result['label']:<12} {denoise_text:>12} {e2e_text:>10} {mem_text:>14}" ) - print("-" * 64) + print("-" * 92) print() print("ā˜… Denoise latency = total DiT forward pass time across all inference steps.") print( @@ -425,7 +447,12 @@ def main(): choices=list(MODELS.keys()), help="Model to benchmark (default: flux)", ) - parser.add_argument("--all", action="store_true", help="Benchmark all 10 models") + parser.add_argument("--all", action="store_true", help="Benchmark all 11 models") + parser.add_argument( + "--list-models", + action="store_true", + help="List preset order, nightly mapping, and exit", + ) parser.add_argument( "--label", type=str, @@ -442,6 +469,10 @@ def main(): args = parser.parse_args() + if args.list_models: + print_model_catalog() + return + output_dir = Path(args.output_dir) output_dir.mkdir(parents=True, exist_ok=True) warmup = not args.no_warmup diff --git a/python/sglang/multimodal_gen/benchmarks/bench_serving.py b/python/sglang/multimodal_gen/benchmarks/bench_serving.py index d4910473d77f..7baa0e129dd0 100644 --- a/python/sglang/multimodal_gen/benchmarks/bench_serving.py +++ b/python/sglang/multimodal_gen/benchmarks/bench_serving.py @@ -39,6 +39,7 @@ init_logger, ) from sglang.multimodal_gen.test.test_utils import print_divider, print_value_formatted +from sglang.srt.utils.network import NetworkAddress logger = init_logger(__name__) @@ -457,7 +458,7 @@ async def benchmark(args): # Construct base_url if not provided if args.base_url is None: - args.base_url = f"http://{args.host}:{args.port}" + args.base_url = NetworkAddress(args.host, args.port).to_url() # Wait for service wait_for_service(args.base_url) diff --git a/python/sglang/multimodal_gen/configs/models/vaes/qwenimage.py b/python/sglang/multimodal_gen/configs/models/vaes/qwenimage.py index af9fa9d2a0d8..fae449fe75e7 100644 --- a/python/sglang/multimodal_gen/configs/models/vaes/qwenimage.py +++ b/python/sglang/multimodal_gen/configs/models/vaes/qwenimage.py @@ -38,6 +38,8 @@ class QwenImageVAEConfig(VAEConfig): use_temporal_tiling: bool = False use_parallel_tiling: bool = False + use_parallel_decode: bool = False + def get_vae_scale_factor(self): return 2 ** len(self.arch_config.temperal_downsample) diff --git a/python/sglang/multimodal_gen/configs/pipeline_configs/base.py b/python/sglang/multimodal_gen/configs/pipeline_configs/base.py index 80b7758b885a..d155068992ce 100644 --- a/python/sglang/multimodal_gen/configs/pipeline_configs/base.py +++ b/python/sglang/multimodal_gen/configs/pipeline_configs/base.py @@ -97,10 +97,6 @@ class STA_Mode(str, Enum): NONE = None -def preprocess_text(prompt: str) -> str: - return prompt - - def postprocess_text(output: BaseEncoderOutput, _text_inputs) -> torch.tensor: raise NotImplementedError @@ -206,8 +202,8 @@ class PipelineConfig: def postprocess_image(self, image): return image.last_hidden_state - preprocess_text_funcs: tuple[Callable[[str], str], ...] = field( - default_factory=lambda: (preprocess_text,) + preprocess_text_funcs: tuple[Callable[[str], str] | None, ...] = field( + default_factory=lambda: (None,) ) # get prompt_embeds from encoder output diff --git a/python/sglang/multimodal_gen/configs/pipeline_configs/flux.py b/python/sglang/multimodal_gen/configs/pipeline_configs/flux.py index 71d0c0128372..b24822154450 100644 --- a/python/sglang/multimodal_gen/configs/pipeline_configs/flux.py +++ b/python/sglang/multimodal_gen/configs/pipeline_configs/flux.py @@ -23,12 +23,10 @@ from sglang.multimodal_gen.configs.pipeline_configs.base import ( ImagePipelineConfig, ModelTaskType, - preprocess_text, shard_rotary_emb_for_sp, ) from sglang.multimodal_gen.configs.pipeline_configs.hunyuan import ( clip_postprocess_text, - clip_preprocess_text, ) from sglang.multimodal_gen.configs.pipeline_configs.qwen_image import _pack_latents from sglang.multimodal_gen.runtime.distributed import get_local_torch_device @@ -65,8 +63,8 @@ class FluxPipelineConfig(ImagePipelineConfig): default_factory=lambda: ("bf16", "bf16") ) - preprocess_text_funcs: tuple[Callable[[str], str], ...] = field( - default_factory=lambda: (clip_preprocess_text, preprocess_text), + preprocess_text_funcs: tuple[Callable[[str], str] | None, ...] = field( + default_factory=lambda: (None, None), ) postprocess_text_funcs: tuple[Callable[[str], str], ...] = field( @@ -437,6 +435,17 @@ class Flux2PipelineConfig(FluxPipelineConfig): default_factory=lambda: (flux2_postprocess_text,) ) vae_config: VAEConfig = field(default_factory=Flux2VAEConfig) + text_encoder_extra_args: list[dict] = field( + default_factory=lambda: [ + dict( + max_length=512, + padding="max_length", + truncation=True, + return_overflowing_tokens=False, + return_length=False, + ) + ] + ) def tokenize_prompt(self, prompts: list[str], tokenizer, tok_kwargs) -> dict: # flatten to 1-d list @@ -650,8 +659,8 @@ class Flux2KleinPipelineConfig(Flux2PipelineConfig): default_factory=lambda: (Qwen3TextConfig(),) ) - preprocess_text_funcs: tuple[Callable[[str], str], ...] = field( - default_factory=lambda: (preprocess_text,), + preprocess_text_funcs: tuple[Callable[[str], str] | None, ...] = field( + default_factory=lambda: (None,), ) postprocess_text_funcs: tuple[Callable[[str], str], ...] = field( @@ -681,7 +690,9 @@ def _apply_chat_template(prompt: str) -> str: texts = [_apply_chat_template(prompt) for prompt in prompts] tok_kwargs = dict(tok_kwargs or {}) - max_length = tok_kwargs.pop("max_length", 512) + tok_kwargs.pop("max_length", None) + # Flux2 Klein uses max_length 512. + max_length = 512 padding = tok_kwargs.pop("padding", "max_length") truncation = tok_kwargs.pop("truncation", True) return_tensors = tok_kwargs.pop("return_tensors", "pt") diff --git a/python/sglang/multimodal_gen/configs/pipeline_configs/hunyuan.py b/python/sglang/multimodal_gen/configs/pipeline_configs/hunyuan.py index d45dfadb2582..abe421c890b6 100644 --- a/python/sglang/multimodal_gen/configs/pipeline_configs/hunyuan.py +++ b/python/sglang/multimodal_gen/configs/pipeline_configs/hunyuan.py @@ -56,10 +56,6 @@ def llama_postprocess_text(outputs: BaseEncoderOutput, _text_inputs) -> torch.te return last_hidden_state -def clip_preprocess_text(prompt: str) -> str: - return prompt - - def clip_postprocess_text(outputs: BaseEncoderOutput, _text_inputs) -> torch.tensor: pooler_output: torch.tensor = outputs.pooler_output return pooler_output @@ -84,8 +80,8 @@ class HunyuanConfig(PipelineConfig): text_encoder_configs: tuple[EncoderConfig, ...] = field( default_factory=lambda: (LlamaConfig(), CLIPTextConfig()) ) - preprocess_text_funcs: tuple[Callable[[str], str], ...] = field( - default_factory=lambda: (llama_preprocess_text, clip_preprocess_text) + preprocess_text_funcs: tuple[Callable[[str], str] | None, ...] = field( + default_factory=lambda: (llama_preprocess_text, None) ) postprocess_text_funcs: tuple[Callable[[BaseEncoderOutput], torch.tensor], ...] = ( field(default_factory=lambda: (llama_postprocess_text, clip_postprocess_text)) diff --git a/python/sglang/multimodal_gen/configs/pipeline_configs/ltx_2.py b/python/sglang/multimodal_gen/configs/pipeline_configs/ltx_2.py index 301423753c66..54fafce0db2c 100644 --- a/python/sglang/multimodal_gen/configs/pipeline_configs/ltx_2.py +++ b/python/sglang/multimodal_gen/configs/pipeline_configs/ltx_2.py @@ -14,7 +14,6 @@ from sglang.multimodal_gen.configs.pipeline_configs.base import ( ModelTaskType, PipelineConfig, - preprocess_text, ) from sglang.multimodal_gen.runtime.distributed import ( get_sp_parallel_rank, @@ -189,8 +188,8 @@ def prepare_audio_latent_shape(self, batch, batch_size, num_frames): text_encoder_precisions: tuple[str, ...] = field(default_factory=lambda: ("bf16",)) text_encoder_extra_args: list[dict] = field(default_factory=lambda: [{}]) - preprocess_text_funcs: tuple[Callable[[str], str], ...] = field( - default_factory=lambda: (preprocess_text,) + preprocess_text_funcs: tuple[Callable[[str], str] | None, ...] = field( + default_factory=lambda: (None,) ) postprocess_text_funcs: tuple[ Callable[[BaseEncoderOutput, dict], torch.Tensor], ... diff --git a/python/sglang/multimodal_gen/configs/pipeline_configs/sana.py b/python/sglang/multimodal_gen/configs/pipeline_configs/sana.py index 73f62fa00401..d008c0c62f70 100644 --- a/python/sglang/multimodal_gen/configs/pipeline_configs/sana.py +++ b/python/sglang/multimodal_gen/configs/pipeline_configs/sana.py @@ -30,7 +30,6 @@ from sglang.multimodal_gen.configs.pipeline_configs.base import ( ModelTaskType, SpatialImagePipelineConfig, - preprocess_text, ) @@ -65,8 +64,8 @@ class SanaPipelineConfig(SpatialImagePipelineConfig): text_encoder_precisions: tuple[str, ...] = field(default_factory=lambda: ("bf16",)) - preprocess_text_funcs: tuple[Callable[[str], str], ...] = field( - default_factory=lambda: (preprocess_text,), + preprocess_text_funcs: tuple[Callable[[str], str] | None, ...] = field( + default_factory=lambda: (None,), ) postprocess_text_funcs: tuple[Callable[[str], str], ...] = field( diff --git a/python/sglang/multimodal_gen/configs/pipeline_configs/wan.py b/python/sglang/multimodal_gen/configs/pipeline_configs/wan.py index 6a824e67881f..c08851e93fe6 100644 --- a/python/sglang/multimodal_gen/configs/pipeline_configs/wan.py +++ b/python/sglang/multimodal_gen/configs/pipeline_configs/wan.py @@ -214,7 +214,7 @@ def __post_init__(self) -> None: @dataclass -class Wan2_2_I2V_A14B_Config(WanI2V480PConfig): +class Wan2_2_I2V_A14B_Config(WanI2V720PConfig): flow_shift: float | None = 5.0 boundary_ratio: float | None = 0.900 diff --git a/python/sglang/multimodal_gen/configs/sample/hunyuan.py b/python/sglang/multimodal_gen/configs/sample/hunyuan.py index ae69dbd62ccd..c60b856630f0 100644 --- a/python/sglang/multimodal_gen/configs/sample/hunyuan.py +++ b/python/sglang/multimodal_gen/configs/sample/hunyuan.py @@ -39,6 +39,7 @@ class HunyuanSamplingParams(SamplingParams): teacache_params: TeaCacheParams = field( default_factory=lambda: TeaCacheParams( teacache_thresh=0.15, + # from https://github.com/ali-vilab/TeaCache/blob/7c10efc4702c6b619f47805f7abe4a7a08085aa0/TeaCache4HunyuanVideo/teacache_sample_video.py#L222 coefficients=[ 7.33226126e02, -4.01131952e02, diff --git a/python/sglang/multimodal_gen/configs/sample/sampling_params.py b/python/sglang/multimodal_gen/configs/sample/sampling_params.py index 7dcf9bf1dc88..f4d353874b44 100644 --- a/python/sglang/multimodal_gen/configs/sample/sampling_params.py +++ b/python/sglang/multimodal_gen/configs/sample/sampling_params.py @@ -29,11 +29,16 @@ def _json_safe(obj: Any): """ Recursively convert objects to JSON-serializable forms. - Enums -> their name + - Callables -> stable module-qualified name - Sets/Tuples -> lists - Dicts/Lists -> recursively processed """ if isinstance(obj, Enum): return obj.name + if callable(obj): + module = getattr(obj, "__module__", None) + qualname = getattr(obj, "__qualname__", getattr(obj, "__name__", repr(obj))) + return f"{module}.{qualname}" if module else qualname if isinstance(obj, dict): return {k: _json_safe(v) for k, v in obj.items()} if isinstance(obj, (list, tuple, set)): diff --git a/python/sglang/multimodal_gen/configs/sample/teacache.py b/python/sglang/multimodal_gen/configs/sample/teacache.py index ada71d0b3618..e20df6e67b93 100644 --- a/python/sglang/multimodal_gen/configs/sample/teacache.py +++ b/python/sglang/multimodal_gen/configs/sample/teacache.py @@ -1,43 +1,82 @@ # Copied and adapted from: https://github.com/hao-ai-lab/FastVideo # SPDX-License-Identifier: Apache-2.0 +from __future__ import annotations + from dataclasses import dataclass, field +from typing import Callable from sglang.multimodal_gen.configs.sample.sampling_params import CacheParams @dataclass class TeaCacheParams(CacheParams): + """ + Parameters for [TeaCache](https://arxiv.org/abs/2411.14324). + + Attributes: + cache_type: (`str`, defaults to `teacache`): + A string labeling these parameters as belonging to teacache. + teacache_thresh (`float`, defaults to `0.0`): + Threshold for accumulated relative L1 distance. When below this threshold, the + forward pass is skipped. Recommended values: 0.25 for ~1.5x speedup, 0.4 for ~1.8x, + 0.6 for ~2.0x. + start_skipping (`int` or `float`, defaults to `5`): + The number of timesteps after which we may skip a forward pass. These early + steps define the global structure and are too critical to not skip. + int: The number of timesteps after which we can skip. If negative, + this is an offset from the end of the schedule. + float (0.0 - 1.0): A percentage of the total steps (e.g., 0.1 + computes the first 10%). + end_skipping (`int` or `float`, defaults to `-1`): + The number of timesteps after which we are no longer able to skip + forward passes. The last steps refine fine textures and details. + int: The number of timesteps after which skipping ends. If negative, + this is an offset from the total number of steps. + float (0.0 - 1.0): A percentage of the total steps (e.g., 0.1 + computes the first 10%). + coefficients (`List[float]`, defaults to `[]`): + Polynomial coefficients for rescaling the raw relative L1 distance, + evaluated as `c[0]*x**4 + c[1]*x**3 + c[2]*x**2 + c[3]*x + c[4]`. + coefficients_callback (`Callable[[TeaCacheParams], List[float]]`, *optional*): + A function that receives this `TeaCacheParams` instance and returns + the polynomial coefficients to use. When set, it takes precedence over + the `coefficients` field, allowing dynamic coefficient selection based + on any property of the params (e.g., `use_ret_steps` for Wan models). + use_ret_steps: (`bool`, `None`, defaults to `None`): + Used exclusively for wanvideo models to select different modulated inputs. + """ + cache_type: str = "teacache" teacache_thresh: float = 0.0 + start_skipping: int | float = 5 + end_skipping: int | float = -1 coefficients: list[float] = field(default_factory=list) + coefficients_callback: Callable[[TeaCacheParams], list[float]] | None = field( + default=None, repr=False + ) + use_ret_steps: bool | None = None + def get_coefficients(self) -> list[float]: + if self.coefficients_callback is not None: + return self.coefficients_callback(self) + return self.coefficients -@dataclass -class WanTeaCacheParams(CacheParams): - # Unfortunately, TeaCache is very different for Wan than other models - cache_type: str = "teacache" - teacache_thresh: float = 0.0 - use_ret_steps: bool = True - ret_steps_coeffs: list[float] = field(default_factory=list) - non_ret_steps_coeffs: list[float] = field(default_factory=list) - - @property - def coefficients(self) -> list[float]: - if self.use_ret_steps: - return self.ret_steps_coeffs - else: - return self.non_ret_steps_coeffs - - @property - def ret_steps(self) -> int: - if self.use_ret_steps: - return 5 * 2 - else: - return 1 * 2 - - def get_cutoff_steps(self, num_inference_steps: int) -> int: - if self.use_ret_steps: - return num_inference_steps * 2 - else: - return num_inference_steps * 2 - 2 + def get_skip_boundaries( + self, num_inference_steps: int, do_cfg: bool + ) -> tuple[int, int]: + def _resolve_boundary(value: int | float) -> int: + if isinstance(value, float): + return int(num_inference_steps * value) + if value < 0: + return num_inference_steps + value + return value + + start_skipping = _resolve_boundary(self.start_skipping) + end_skipping = _resolve_boundary(self.end_skipping) + + if do_cfg: + start_skipping *= 2 + end_skipping *= 2 + + return start_skipping, end_skipping diff --git a/python/sglang/multimodal_gen/configs/sample/wan.py b/python/sglang/multimodal_gen/configs/sample/wan.py index a5faf50214f0..0f147a9dcede 100644 --- a/python/sglang/multimodal_gen/configs/sample/wan.py +++ b/python/sglang/multimodal_gen/configs/sample/wan.py @@ -4,7 +4,41 @@ from dataclasses import dataclass, field from sglang.multimodal_gen.configs.sample.sampling_params import SamplingParams -from sglang.multimodal_gen.configs.sample.teacache import WanTeaCacheParams +from sglang.multimodal_gen.configs.sample.teacache import TeaCacheParams + + +def _wan_1_3b_coefficients(p: TeaCacheParams) -> list[float]: + if p.use_ret_steps: + # from https://github.com/ali-vilab/TeaCache/blob/7c10efc4702c6b619f47805f7abe4a7a08085aa0/TeaCache4Wan2.1/teacache_generate.py#L883 + return [ + -5.21862437e04, + 9.23041404e03, + -5.28275948e02, + 1.36987616e01, + -4.99875664e-02, + ] + # from https://github.com/ali-vilab/TeaCache/blob/7c10efc4702c6b619f47805f7abe4a7a08085aa0/TeaCache4Wan2.1/teacache_generate.py#L890 + return [ + 2.39676752e03, + -1.31110545e03, + 2.01331979e02, + -8.29855975e00, + 1.37887774e-01, + ] + + +def _wan_14b_coefficients(p: TeaCacheParams) -> list[float]: + if p.use_ret_steps: + # from https://github.com/ali-vilab/TeaCache/blob/7c10efc4702c6b619f47805f7abe4a7a08085aa0/TeaCache4Wan2.1/teacache_generate.py#L885 + return [ + -3.03318725e05, + 4.90537029e04, + -2.65530556e03, + 5.87365115e01, + -3.15583525e-01, + ] + # from https://github.com/ali-vilab/TeaCache/blob/7c10efc4702c6b619f47805f7abe4a7a08085aa0/TeaCache4Wan2.1/teacache_generate.py#L892 + return [-5784.54975374, 5449.50911966, -1811.16591783, 256.27178429, -13.02252404] @dataclass @@ -30,23 +64,13 @@ class WanT2V_1_3B_SamplingParams(SamplingParams): ] ) - teacache_params: WanTeaCacheParams = field( - default_factory=lambda: WanTeaCacheParams( + teacache_params: TeaCacheParams = field( + default_factory=lambda: TeaCacheParams( teacache_thresh=0.08, - ret_steps_coeffs=[ - -5.21862437e04, - 9.23041404e03, - -5.28275948e02, - 1.36987616e01, - -4.99875664e-02, - ], - non_ret_steps_coeffs=[ - 2.39676752e03, - -1.31110545e03, - 2.01331979e02, - -8.29855975e00, - 1.37887774e-01, - ], + use_ret_steps=True, + coefficients_callback=_wan_1_3b_coefficients, + start_skipping=5, + end_skipping=1.0, ) ) @@ -76,24 +100,13 @@ class WanT2V_14B_SamplingParams(SamplingParams): ] ) - teacache_params: WanTeaCacheParams = field( - default_factory=lambda: WanTeaCacheParams( + teacache_params: TeaCacheParams = field( + default_factory=lambda: TeaCacheParams( teacache_thresh=0.20, use_ret_steps=False, - ret_steps_coeffs=[ - -3.03318725e05, - 4.90537029e04, - -2.65530556e03, - 5.87365115e01, - -3.15583525e-01, - ], - non_ret_steps_coeffs=[ - -5784.54975374, - 5449.50911966, - -1811.16591783, - 256.27178429, - -13.02252404, - ], + coefficients_callback=_wan_14b_coefficients, + start_skipping=1, + end_skipping=-1, ) ) @@ -113,23 +126,13 @@ class WanI2V_14B_480P_SamplingParam(WanT2V_1_3B_SamplingParams): ] ) - teacache_params: WanTeaCacheParams = field( - default_factory=lambda: WanTeaCacheParams( + teacache_params: TeaCacheParams = field( + default_factory=lambda: TeaCacheParams( teacache_thresh=0.26, - ret_steps_coeffs=[ - -3.03318725e05, - 4.90537029e04, - -2.65530556e03, - 5.87365115e01, - -3.15583525e-01, - ], - non_ret_steps_coeffs=[ - -5784.54975374, - 5449.50911966, - -1811.16591783, - 256.27178429, - -13.02252404, - ], + use_ret_steps=True, + coefficients_callback=_wan_14b_coefficients, + start_skipping=5, + end_skipping=1.0, ) ) @@ -151,23 +154,13 @@ class WanI2V_14B_720P_SamplingParam(WanT2V_14B_SamplingParams): ] ) - teacache_params: WanTeaCacheParams = field( - default_factory=lambda: WanTeaCacheParams( + teacache_params: TeaCacheParams = field( + default_factory=lambda: TeaCacheParams( teacache_thresh=0.3, - ret_steps_coeffs=[ - -3.03318725e05, - 4.90537029e04, - -2.65530556e03, - 5.87365115e01, - -3.15583525e-01, - ], - non_ret_steps_coeffs=[ - -5784.54975374, - 5449.50911966, - -1811.16591783, - 256.27178429, - -13.02252404, - ], + use_ret_steps=True, + coefficients_callback=_wan_14b_coefficients, + start_skipping=5, + end_skipping=1.0, ) ) diff --git a/python/sglang/multimodal_gen/registry.py b/python/sglang/multimodal_gen/registry.py index 493e519bcad9..b9e03c3d93df 100644 --- a/python/sglang/multimodal_gen/registry.py +++ b/python/sglang/multimodal_gen/registry.py @@ -122,9 +122,9 @@ ) from sglang.multimodal_gen.runtime.utils.hf_diffusers_utils import ( maybe_download_model_index, - verify_model_config_and_directory, ) from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger +from sglang.utils import KNOWN_NON_DIFFUSERS_DIFFUSION_MODEL_PATTERNS logger = init_logger(__name__) @@ -327,10 +327,7 @@ def _get_config_info( return _CONFIG_REGISTRY.get(model_id) # 3. Use detectors - if os.path.exists(model_path): - config = verify_model_config_and_directory(model_path) - else: - config = maybe_download_model_index(model_path) + config = maybe_download_model_index(model_path) pipeline_name = config.get("_class_name", "").lower() matched_model_names = [] @@ -349,10 +346,11 @@ def _get_config_info( model_id = matched_model_names[0] return _CONFIG_REGISTRY.get(model_id) else: - raise RuntimeError( + logger.debug( f"No model info found for model path: {model_path}. " f"Please check the model path or specify the model_id explicitly." ) + return None # --- Part 3: Main Resolver --- @@ -498,10 +496,7 @@ def get_model_info( else: # Try to get from model_index.json try: - if os.path.exists(model_path): - config = verify_model_config_and_directory(model_path) - else: - config = maybe_download_model_index(model_path) + config = maybe_download_model_index(model_path) except Exception as e: logger.error(f"Could not read model config for '{model_path}': {e}") if backend == Backend.AUTO: @@ -875,25 +870,18 @@ def _register_configs(): _register_configs() -# Known non-diffusers multimodal model patterns -# Maps pattern -> pipeline_name for models that don't have model_index.json -_NON_DIFFUSERS_MULTIMODAL_PATTERNS: Dict[str, str] = { - "hunyuan3d": "Hunyuan3D2Pipeline", - "flux.2-dev-nvfp4": "Flux2NvfpPipeline", -} - - def is_known_non_diffusers_multimodal_model(model_path: str) -> bool: model_path_lower = model_path.lower() return any( - pattern in model_path_lower for pattern in _NON_DIFFUSERS_MULTIMODAL_PATTERNS + pattern in model_path_lower + for pattern in KNOWN_NON_DIFFUSERS_DIFFUSION_MODEL_PATTERNS ) def get_non_diffusers_pipeline_name(model_path: str) -> Optional[str]: """Get the pipeline name for a known non-diffusers model.""" model_path_lower = model_path.lower() - for pattern, pipeline_name in _NON_DIFFUSERS_MULTIMODAL_PATTERNS.items(): + for pattern, pipeline_name in KNOWN_NON_DIFFUSERS_DIFFUSION_MODEL_PATTERNS.items(): if pattern in model_path_lower: return pipeline_name return None diff --git a/python/sglang/multimodal_gen/runtime/cache/teacache.py b/python/sglang/multimodal_gen/runtime/cache/teacache.py index 5cdafd08bc04..8830f7ec20c4 100644 --- a/python/sglang/multimodal_gen/runtime/cache/teacache.py +++ b/python/sglang/multimodal_gen/runtime/cache/teacache.py @@ -297,7 +297,7 @@ def _get_teacache_context(self) -> TeaCacheContext | None: do_cfg=do_cfg, is_cfg_negative=is_cfg_negative, teacache_thresh=teacache_params.teacache_thresh, - coefficients=teacache_params.coefficients, + coefficients=teacache_params.get_coefficients(), teacache_params=teacache_params, ) diff --git a/python/sglang/multimodal_gen/runtime/entrypoints/cli/generate.py b/python/sglang/multimodal_gen/runtime/entrypoints/cli/generate.py index a38a9cfc76ba..d47fa93dbb6c 100644 --- a/python/sglang/multimodal_gen/runtime/entrypoints/cli/generate.py +++ b/python/sglang/multimodal_gen/runtime/entrypoints/cli/generate.py @@ -112,7 +112,23 @@ def generate_cmd(args: argparse.Namespace, unknown_args: list[str] | None = None server_args = ServerArgs.from_cli_args(args, unknown_args) - sampling_params_kwargs = SamplingParams.get_cli_args(args) + sampling_params_kwargs = {} + config_file = getattr(args, "config", None) + # respect config file by overriding args with args parsed from it + if config_file: + config_args = ServerArgs.load_config_file(config_file) or {} + sampling_param_fields = { + field.name for field in dataclasses.fields(SamplingParams) + } + sampling_params_kwargs.update( + { + key: value + for key, value in config_args.items() + if key in sampling_param_fields and value is not None + } + ) + + sampling_params_kwargs.update(SamplingParams.get_cli_args(args)) sampling_params_kwargs["request_id"] = generate_request_id() # Handle diffusers-specific kwargs passed via CLI diff --git a/python/sglang/multimodal_gen/runtime/entrypoints/diffusion_generator.py b/python/sglang/multimodal_gen/runtime/entrypoints/diffusion_generator.py index 2f71ad1d1fd5..6e21f13b4291 100644 --- a/python/sglang/multimodal_gen/runtime/entrypoints/diffusion_generator.py +++ b/python/sglang/multimodal_gen/runtime/entrypoints/diffusion_generator.py @@ -183,7 +183,10 @@ def generate( multiple prompts, or None when every request failed. """ # 1. prepare requests - prompts = self._resolve_prompts(sampling_params_kwargs.get("prompt")) + prompts = self._resolve_prompts( + sampling_params_kwargs.get("prompt"), + sampling_params_kwargs.get("prompt_path"), + ) user_output_file_name = sampling_params_kwargs.get("output_file_name") if len(prompts) > 1 and user_output_file_name is not None: @@ -334,10 +337,14 @@ def generate( return None return results[0] if len(results) == 1 else results - def _resolve_prompts(self, prompt: str | list[str] | None) -> list[str]: + def _resolve_prompts( + self, + prompt: str | list[str] | None, + prompt_path: str | None = None, + ) -> list[str]: """Collect prompts from the argument or from a prompt file.""" - if self.server_args.prompt_file_path is not None: - path = self.server_args.prompt_file_path + path = prompt_path or self.server_args.prompt_file_path + if path is not None: if not os.path.exists(path): raise FileNotFoundError(f"Prompt text file not found: {path}") with open(path, encoding="utf-8") as f: diff --git a/python/sglang/multimodal_gen/runtime/entrypoints/http_server.py b/python/sglang/multimodal_gen/runtime/entrypoints/http_server.py index d07febdb12f0..303cea786a2a 100644 --- a/python/sglang/multimodal_gen/runtime/entrypoints/http_server.py +++ b/python/sglang/multimodal_gen/runtime/entrypoints/http_server.py @@ -9,7 +9,6 @@ import torch from fastapi import APIRouter, FastAPI, Request -from fastapi.responses import ORJSONResponse from sglang.multimodal_gen.configs.sample.sampling_params import SamplingParams from sglang.multimodal_gen.runtime.entrypoints.openai import image_api, video_api @@ -25,6 +24,7 @@ from sglang.multimodal_gen.runtime.scheduler_client import async_scheduler_client from sglang.multimodal_gen.runtime.server_args import ServerArgs, get_global_server_args from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger +from sglang.srt.utils.json_response import orjson_response from sglang.version import __version__ if TYPE_CHECKING: @@ -235,7 +235,7 @@ async def forward_to_scheduler( @vertex_router.post(VERTEX_ROUTE) async def vertex_generate(vertex_req: VertexGenerateReqInput): if not vertex_req.instances: - return ORJSONResponse({"predictions": []}) + return orjson_response({"predictions": []}) server_args = get_global_server_args() params = vertex_req.parameters or {} @@ -263,7 +263,7 @@ async def vertex_generate(vertex_req: VertexGenerateReqInput): results = await asyncio.gather(*futures) - return ORJSONResponse({"predictions": results}) + return orjson_response({"predictions": results}) def create_app(server_args: ServerArgs): diff --git a/python/sglang/multimodal_gen/runtime/entrypoints/openai/common_api.py b/python/sglang/multimodal_gen/runtime/entrypoints/openai/common_api.py index 921f64410a62..328d9f6f1dd2 100644 --- a/python/sglang/multimodal_gen/runtime/entrypoints/openai/common_api.py +++ b/python/sglang/multimodal_gen/runtime/entrypoints/openai/common_api.py @@ -2,7 +2,6 @@ from typing import Any, List, Optional, Union from fastapi import APIRouter, Body, HTTPException -from fastapi.responses import ORJSONResponse from pydantic import BaseModel, Field from sglang.multimodal_gen.registry import get_model_info @@ -17,6 +16,7 @@ from sglang.multimodal_gen.runtime.scheduler_client import async_scheduler_client from sglang.multimodal_gen.runtime.server_args import get_global_server_args from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger +from sglang.srt.utils.json_response import orjson_response router = APIRouter(prefix="/v1") logger = init_logger(__name__) @@ -173,7 +173,7 @@ async def list_loras(): raise HTTPException(status_code=500, detail=str(e)) -@router.get("/models", response_class=ORJSONResponse) +@router.get("/models") async def available_models(): """Show available models. OpenAI-compatible endpoint with extended diffusion info.""" server_args = get_global_server_args() @@ -206,7 +206,7 @@ async def available_models(): return {"object": "list", "data": [model_card.model_dump()]} -@router.get("/models/{model:path}", response_class=ORJSONResponse) +@router.get("/models/{model:path}") async def retrieve_model(model: str): """Retrieve a model instance. OpenAI-compatible endpoint with extended diffusion info.""" server_args = get_global_server_args() @@ -214,9 +214,8 @@ async def retrieve_model(model: str): raise HTTPException(status_code=500, detail="Server args not initialized") if model != server_args.model_path: - return ORJSONResponse( - status_code=404, - content={ + return orjson_response( + { "error": { "message": f"The model '{model}' does not exist", "type": "invalid_request_error", @@ -224,6 +223,7 @@ async def retrieve_model(model: str): "code": "model_not_found", } }, + status_code=404, ) model_info = get_model_info( diff --git a/python/sglang/multimodal_gen/runtime/entrypoints/openai/protocol.py b/python/sglang/multimodal_gen/runtime/entrypoints/openai/protocol.py index 9831d368a028..b326a295097f 100644 --- a/python/sglang/multimodal_gen/runtime/entrypoints/openai/protocol.py +++ b/python/sglang/multimodal_gen/runtime/entrypoints/openai/protocol.py @@ -90,6 +90,8 @@ class VideoGenerationsRequest(BaseModel): seed: Optional[int] = 1024 generator_device: Optional[str] = "cuda" # SGLang extensions + width: Optional[int] = None + height: Optional[int] = None num_inference_steps: Optional[int] = None guidance_scale: Optional[float] = None guidance_scale_2: Optional[float] = None diff --git a/python/sglang/multimodal_gen/runtime/entrypoints/openai/video_api.py b/python/sglang/multimodal_gen/runtime/entrypoints/openai/video_api.py index 2a444111d75b..abccf31bb2d6 100644 --- a/python/sglang/multimodal_gen/runtime/entrypoints/openai/video_api.py +++ b/python/sglang/multimodal_gen/runtime/entrypoints/openai/video_api.py @@ -60,6 +60,8 @@ def _build_video_sampling_params(request_id: str, request: VideoGenerationsReque request_id, prompt=request.prompt, size=request.size, + width=request.width, + height=request.height, num_frames=num_frames, fps=fps, image_path=request.input_reference, diff --git a/python/sglang/multimodal_gen/runtime/entrypoints/post_training/weights_api.py b/python/sglang/multimodal_gen/runtime/entrypoints/post_training/weights_api.py index 1b9312d8ea0f..7bc0054f7cb5 100644 --- a/python/sglang/multimodal_gen/runtime/entrypoints/post_training/weights_api.py +++ b/python/sglang/multimodal_gen/runtime/entrypoints/post_training/weights_api.py @@ -1,13 +1,13 @@ """Weight update API for the diffusion engine.""" from fastapi import APIRouter, Request -from fastapi.responses import ORJSONResponse from sglang.multimodal_gen.runtime.entrypoints.post_training.io_struct import ( GetWeightsChecksumReqInput, UpdateWeightFromDiskReqInput, ) from sglang.multimodal_gen.runtime.scheduler_client import async_scheduler_client +from sglang.srt.utils.json_response import orjson_response router = APIRouter() @@ -18,7 +18,7 @@ async def update_weights_from_disk(request: Request): body = await request.json() model_path = body.get("model_path") if not model_path: - return ORJSONResponse( + return orjson_response( {"success": False, "message": "model_path is required"}, status_code=400, ) @@ -32,7 +32,7 @@ async def update_weights_from_disk(request: Request): try: response = await async_scheduler_client.forward(req) except Exception as e: - return ORJSONResponse( + return orjson_response( {"success": False, "message": str(e)}, status_code=500, ) @@ -40,7 +40,7 @@ async def update_weights_from_disk(request: Request): result = response.output success = result.get("success", False) message = result.get("message", "Unknown status") - return ORJSONResponse( + return orjson_response( {"success": success, "message": message}, status_code=200 if success else 400, ) @@ -57,6 +57,6 @@ async def get_weights_checksum(request: Request): try: response = await async_scheduler_client.forward(req) except Exception as e: - return ORJSONResponse({"error": str(e)}, status_code=500) + return orjson_response({"error": str(e)}, status_code=500) - return ORJSONResponse(response.output, status_code=200) + return orjson_response(response.output, status_code=200) diff --git a/python/sglang/multimodal_gen/runtime/layers/attention/backends/ascend_fa.py b/python/sglang/multimodal_gen/runtime/layers/attention/backends/ascend_fa.py new file mode 100644 index 000000000000..0fc3db29b531 --- /dev/null +++ b/python/sglang/multimodal_gen/runtime/layers/attention/backends/ascend_fa.py @@ -0,0 +1,105 @@ +from dataclasses import dataclass +from typing import Any + +import torch + +from sglang.multimodal_gen.runtime.layers.attention.backends.attention_backend import ( + AttentionBackend, + AttentionImpl, + AttentionMetadata, + AttentionMetadataBuilder, +) +from sglang.multimodal_gen.runtime.platforms import AttentionBackendEnum +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger + +logger = init_logger(__name__) + + +@dataclass +class AscendFAMetadata: + pass + + +class AscendFAMetadataBuilder(AttentionMetadataBuilder): + def __init__(self) -> None: + pass + + def prepare(self) -> None: + pass + + def build( + self, + **kwargs: dict[str, Any], + ) -> AttentionMetadata: + return AscendFAMetadata() + + +class AscendFABackend(AttentionBackend): + + @staticmethod + def get_enum() -> AttentionBackendEnum: + return AttentionBackendEnum.FA + + @staticmethod + def get_impl_cls() -> type["AscendFAImpl"]: + return AscendFAImpl + + @staticmethod + def get_metadata_cls() -> type["AttentionMetadata"]: + raise NotImplementedError + + @staticmethod + def get_builder_cls() -> type["AttentionMetadataBuilder"]: + return AscendFAMetadataBuilder + + +class AscendFAImpl(AttentionImpl): + + def __init__( + self, + num_heads: int, + head_size: int, + causal: bool, + softmax_scale: float, + num_kv_heads: int | None = None, + prefix: str = "", + **extra_impl_args, + ) -> None: + self.causal = causal + self.softmax_scale = softmax_scale + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads or num_heads + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_metadata: AttentionMetadata, + return_softmax_lse: bool = False, + ) -> torch.Tensor: + mask = None + if self.causal: + seq_len = query.shape[1] + mask = torch.triu( + torch.ones(seq_len, seq_len, device=query.device), diagonal=1 + ).bool() + # transpose to bs, heads, seq_len, head_dim + query = query.transpose(1, 2) + key = key.transpose(1, 2) + value = value.transpose(1, 2) + output, lse = torch.ops.npu.npu_fused_infer_attention_score( + query, + key, + value, + num_heads=self.num_heads, + num_key_value_heads=self.num_kv_heads, + scale=self.softmax_scale, + input_layout="BNSD", + softmax_lse_flag=return_softmax_lse, + atten_mask=mask, + ) + output = output.transpose(1, 2) + if return_softmax_lse: + return output, lse + return output diff --git a/python/sglang/multimodal_gen/runtime/layers/usp.py b/python/sglang/multimodal_gen/runtime/layers/usp.py index e822350091ae..3794605fde53 100644 --- a/python/sglang/multimodal_gen/runtime/layers/usp.py +++ b/python/sglang/multimodal_gen/runtime/layers/usp.py @@ -210,7 +210,7 @@ def attn_callable_adapter(q, k, v, *args, **kwargs): q = torch.permute(q, [0, 2, 1, 3]) k = torch.permute(k, [0, 2, 1, 3]) v = torch.permute(v, [0, 2, 1, 3]) - # logger.warning(f"Warning: return_sĀ·oftmax_lse is only supported for FlashAttentionImpl") + # logger.warning(f"Warning: return_softmax_lse is only supported for FlashAttentionImpl") output, softmax_lse, *rest = attn_impl.forward( q, k, diff --git a/python/sglang/multimodal_gen/runtime/loader/component_loaders/text_encoder_loader.py b/python/sglang/multimodal_gen/runtime/loader/component_loaders/text_encoder_loader.py index 2e304daf0990..5bd93acb0cad 100644 --- a/python/sglang/multimodal_gen/runtime/loader/component_loaders/text_encoder_loader.py +++ b/python/sglang/multimodal_gen/runtime/loader/component_loaders/text_encoder_loader.py @@ -156,7 +156,9 @@ def _prepare_weights( return hf_folder, hf_weights_files, use_safetensors def _get_weights_iterator( - self, source: "Source", to_cpu: bool + self, + source: "Source", + to_cpu: bool, ) -> Generator[tuple[str, torch.Tensor], None, None]: """get an iterator for the model weights based on the load format.""" hf_folder, hf_weights_files, use_safetensors = self._prepare_weights( @@ -166,7 +168,8 @@ def _get_weights_iterator( ) if use_safetensors: weights_iterator = safetensors_weights_iterator( - hf_weights_files, to_cpu=to_cpu + hf_weights_files, + to_cpu=to_cpu, ) else: weights_iterator = pt_weights_iterator(hf_weights_files, to_cpu=to_cpu) @@ -186,17 +189,27 @@ def _get_all_weights( fall_back_to_pt=getattr(model, "fall_back_to_pt_during_load", True), allow_patterns_overrides=getattr(model, "allow_patterns_overrides", None), ) - yield from self._get_weights_iterator(primary_weights, to_cpu) + yield from self._get_weights_iterator( + primary_weights, + to_cpu, + ) secondary_weights = cast( Iterable[TextEncoderLoader.Source], getattr(model, "secondary_weights", ()), ) for source in secondary_weights: - yield from self._get_weights_iterator(source, to_cpu) + yield from self._get_weights_iterator( + source, + to_cpu, + ) def load_customized( - self, component_model_path: str, server_args: ServerArgs, component_name: str + self, + component_model_path: str, + server_args: ServerArgs, + component_name: str, + cpu_offload_flag: bool | None = None, ): """Load the text encoders based on the model path, and inference args.""" diffusers_pretrained_config = get_config( @@ -227,6 +240,7 @@ def is_not_first_encoder(module_name): encoder_config, server_args, encoder_dtype, + cpu_offload_flag=cpu_offload_flag, ) def load_model( @@ -240,7 +254,10 @@ def load_model( # Determine CPU offload behavior and target device local_torch_device = get_local_torch_device() - should_offload = self.should_offload(server_args, model_config) + fsdp_cpu_offload = self.should_offload(server_args, model_config) + should_offload = ( + cpu_offload_flag if cpu_offload_flag is not None else fsdp_cpu_offload + ) if should_offload and not current_platform.is_mps(): model_device = torch.device("cpu") @@ -263,13 +280,13 @@ def load_model( weights_to_load = {name for name, _ in model.named_parameters()} loaded_weights = model.load_weights( - self._get_all_weights(model, model_path, to_cpu=should_offload) + self._get_all_weights( + model, + model_path, + to_cpu=should_offload, + ) ) - # Explicitly move model to target device after loading weights - if not should_offload: - model = model.to(local_torch_device) - if should_offload: # Disable FSDP for MPS as it's not compatible if current_platform.is_mps(): @@ -277,7 +294,7 @@ def load_model( "Disabling FSDP sharding for MPS platform as it's not compatible" ) model = model.to(local_torch_device) - else: + elif fsdp_cpu_offload: mesh = init_device_mesh( current_platform.device_type, mesh_shape=(1, dist.get_world_size()), @@ -292,6 +309,8 @@ def load_model( or getattr(model, "_fsdp_shard_conditions", None), pin_cpu_memory=server_args.pin_cpu_memory, ) + else: + model = model.to("cpu") else: model = model.to(local_torch_device) # We only enable strict check for non-quantized models diff --git a/python/sglang/multimodal_gen/runtime/loader/fsdp_load.py b/python/sglang/multimodal_gen/runtime/loader/fsdp_load.py index f2f7c7709783..385931b19e02 100644 --- a/python/sglang/multimodal_gen/runtime/loader/fsdp_load.py +++ b/python/sglang/multimodal_gen/runtime/loader/fsdp_load.py @@ -6,6 +6,7 @@ # Copyright 2024 The TorchTune Authors. # Copyright 2025 The sglang-diffusion Authors. +from collections import Counter, defaultdict from collections.abc import Callable, Generator from itertools import chain from typing import Any @@ -40,6 +41,28 @@ logger = init_logger(__name__) +_QUANTIZED_DTYPES = ( + torch.uint8, + torch.float8_e4m3fn, + torch.float8_e5m2, + torch.int8, +) +_DTYPE_MISMATCH_EXAMPLE_LIMIT = 3 + + +def _format_dtype_mismatch_summary( + mismatch_counts: Counter[tuple[torch.dtype, torch.dtype]], + mismatch_examples: dict[tuple[torch.dtype, torch.dtype], list[str]], +) -> str: + parts: list[str] = [] + for (checkpoint_dtype, target_dtype), count in mismatch_counts.items(): + examples = mismatch_examples[(checkpoint_dtype, target_dtype)] + part = f"{checkpoint_dtype}->{target_dtype} x{count}" + if examples: + part += f" (e.g. {', '.join(examples)})" + parts.append(part) + return "; ".join(parts) + def _make_param_like( actual_param: torch.nn.Parameter, tensor: torch.Tensor @@ -272,6 +295,18 @@ def load_model_from_full_model_state_dict( sharded_sd = {} skipped_checkpoint_keys: list[str] = [] + non_quantized_dtype_mismatch_counts: Counter[tuple[torch.dtype, torch.dtype]] = ( + Counter() + ) + non_quantized_dtype_mismatch_examples: dict[ + tuple[torch.dtype, torch.dtype], list[str] + ] = defaultdict(list) + quantized_dtype_mismatch_counts: Counter[tuple[torch.dtype, torch.dtype]] = ( + Counter() + ) + quantized_dtype_mismatch_examples: dict[ + tuple[torch.dtype, torch.dtype], list[str] + ] = defaultdict(list) # shard from loaded state_dict, custom_param_sd -> sharded_sd for target_param_name in sorted_param_names: @@ -296,32 +331,29 @@ def load_model_from_full_model_state_dict( else: target_dtype = meta_sharded_param.dtype - _QUANTIZED_DTYPES = ( - torch.uint8, - torch.float8_e4m3fn, - torch.float8_e5m2, - torch.int8, - ) if full_tensor.dtype != target_dtype: + mismatch_key = (full_tensor.dtype, target_dtype) if ( full_tensor.dtype in _QUANTIZED_DTYPES or target_dtype in _QUANTIZED_DTYPES ): - logger.warning( - "Dtype mismatch for quantized parameter %s: " - "checkpoint has %s, model expects %s", - target_param_name, - full_tensor.dtype, - target_dtype, - ) + quantized_dtype_mismatch_counts[mismatch_key] += 1 + if ( + len(quantized_dtype_mismatch_examples[mismatch_key]) + < _DTYPE_MISMATCH_EXAMPLE_LIMIT + ): + quantized_dtype_mismatch_examples[mismatch_key].append( + target_param_name + ) else: - logger.warning( - "Dtype mismatch for %s: checkpoint has %s, model expects %s. " - "Casting checkpoint tensor to the target dtype during load.", - target_param_name, - full_tensor.dtype, - target_dtype, - ) + non_quantized_dtype_mismatch_counts[mismatch_key] += 1 + if ( + len(non_quantized_dtype_mismatch_examples[mismatch_key]) + < _DTYPE_MISMATCH_EXAMPLE_LIMIT + ): + non_quantized_dtype_mismatch_examples[mismatch_key].append( + target_param_name + ) if not hasattr(meta_sharded_param, "device_mesh"): full_tensor = full_tensor.to(device=device, dtype=target_dtype) @@ -378,6 +410,28 @@ def load_model_from_full_model_state_dict( model.reverse_param_names_mapping = reverse_param_names_mapping + if non_quantized_dtype_mismatch_counts: + logger.debug( + "Casting checkpoint tensors to target dtype during load: %s", + _format_dtype_mismatch_summary( + non_quantized_dtype_mismatch_counts, + non_quantized_dtype_mismatch_examples, + ), + main_process_only=True, + local_main_process_only=True, + ) + + if quantized_dtype_mismatch_counts: + logger.warning( + "Dtype mismatches detected for quantized parameters during load: %s", + _format_dtype_mismatch_summary( + quantized_dtype_mismatch_counts, + quantized_dtype_mismatch_examples, + ), + main_process_only=True, + local_main_process_only=True, + ) + if skipped_checkpoint_keys: logger.warning( "Checkpoint keys not loaded (no matching model parameter) %s", diff --git a/python/sglang/multimodal_gen/runtime/loader/weight_utils.py b/python/sglang/multimodal_gen/runtime/loader/weight_utils.py index bdeb2ac45c28..64ec3b052e2a 100644 --- a/python/sglang/multimodal_gen/runtime/loader/weight_utils.py +++ b/python/sglang/multimodal_gen/runtime/loader/weight_utils.py @@ -24,6 +24,7 @@ except ImportError: HAS_RUNAI_MODEL_STREAMER = False +from sglang.multimodal_gen import envs from sglang.multimodal_gen.runtime.distributed import get_local_torch_device from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger @@ -135,13 +136,17 @@ def _validate_safetensors_file(file_path: str) -> bool: def safetensors_weights_iterator( hf_weights_files: list[str], to_cpu: bool = True, - use_runai_model_streamer: bool = HAS_RUNAI_MODEL_STREAMER, + use_runai_model_streamer: bool | None = None, ) -> Generator[tuple[str, torch.Tensor], None, None]: """Iterate over the weights in the model safetensor files.""" enable_tqdm = ( not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0 ) device = "cpu" if to_cpu else str(get_local_torch_device()) + if use_runai_model_streamer is None: + use_runai_model_streamer = ( + HAS_RUNAI_MODEL_STREAMER and envs.SGLANG_USE_RUNAI_MODEL_STREAMER + ) # Validate files before loading corrupted_files = [ diff --git a/python/sglang/multimodal_gen/runtime/models/dits/flux_2.py b/python/sglang/multimodal_gen/runtime/models/dits/flux_2.py index 1651cdba1eb7..5b2a69a32a1b 100644 --- a/python/sglang/multimodal_gen/runtime/models/dits/flux_2.py +++ b/python/sglang/multimodal_gen/runtime/models/dits/flux_2.py @@ -438,7 +438,10 @@ def __init__( self.norm_q = RMSNorm(dim_head, eps=eps) self.norm_k = RMSNorm(dim_head, eps=eps) - # Fused attention output projection + MLP output projection + # Fused attention output + MLP output projection. + # Input is [attn_shard | mlp_shard] (independently sharded by + # MergedColumnParallelLinear), so patch weight loader to pick the + # correct non-contiguous columns per rank. self.to_out = RowParallelLinear( self.inner_dim + self.mlp_hidden_dim, self.out_dim, @@ -447,6 +450,8 @@ def __init__( quant_config=quant_config, prefix=f"{prefix}.to_out" if prefix else "to_out", ) + if self.tp_size > 1: + self._patch_to_out_weight_loader() self.attn = USPAttention( num_heads=self.local_heads, @@ -456,6 +461,24 @@ def __init__( causal=False, ) + def _patch_to_out_weight_loader(self) -> None: + inner_dim, mlp_dim = self.inner_dim, self.mlp_hidden_dim + tp_size, tp_rank = self.tp_size, self.to_out.tp_rank + + def _loader(param, loaded_weight): + input_dim = getattr(param, "input_dim", None) + if input_dim is not None: + a = inner_dim // tp_size + m = mlp_dim // tp_size + attn_cols = loaded_weight.narrow(input_dim, tp_rank * a, a) + mlp_cols = loaded_weight.narrow(input_dim, inner_dim + tp_rank * m, m) + param.data.copy_(torch.cat([attn_cols, mlp_cols], dim=input_dim)) + else: + param.data.copy_(loaded_weight) + + self.to_out.weight_loader = _loader + self.to_out.weight.weight_loader = _loader + def forward( self, hidden_states: torch.Tensor, diff --git a/python/sglang/multimodal_gen/runtime/models/dits/mova_video_dit.py b/python/sglang/multimodal_gen/runtime/models/dits/mova_video_dit.py index 21de5b9b37bb..f6f520690e85 100644 --- a/python/sglang/multimodal_gen/runtime/models/dits/mova_video_dit.py +++ b/python/sglang/multimodal_gen/runtime/models/dits/mova_video_dit.py @@ -34,6 +34,7 @@ QuantizationConfig, ) from sglang.multimodal_gen.runtime.models.dits.base import CachableDiT +from sglang.multimodal_gen.runtime.platforms import current_platform from sglang.multimodal_gen.runtime.utils.layerwise_offload import OffloadableDiTMixin from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger @@ -521,8 +522,12 @@ def _init_freqs(self): def patchify( self, x: torch.Tensor, control_camera_latents_input: torch.Tensor | None = None ): - # NOTE(dhyu): avoid slow_conv - x = x.contiguous(memory_format=torch.channels_last_3d) + if current_platform.is_npu: + # torch.channels_last_3d is not supported on NPU + x = x.contiguous() + else: + # NOTE(dhyu): avoid slow_conv + x = x.contiguous(memory_format=torch.channels_last_3d) x = self.patch_embedding(x) grid_size = x.shape[2:] x = rearrange(x, "b c f h w -> b (f h w) c").contiguous() diff --git a/python/sglang/multimodal_gen/runtime/models/dits/qwen_image.py b/python/sglang/multimodal_gen/runtime/models/dits/qwen_image.py index ebf2daa5acb5..1b3bc8a4ae4d 100644 --- a/python/sglang/multimodal_gen/runtime/models/dits/qwen_image.py +++ b/python/sglang/multimodal_gen/runtime/models/dits/qwen_image.py @@ -47,12 +47,16 @@ apply_flashinfer_rope_qk_inplace, ) from sglang.multimodal_gen.runtime.models.dits.base import CachableDiT -from sglang.multimodal_gen.runtime.platforms import AttentionBackendEnum +from sglang.multimodal_gen.runtime.platforms import ( + AttentionBackendEnum, + current_platform, +) from sglang.multimodal_gen.runtime.utils.layerwise_offload import OffloadableDiTMixin from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger logger = init_logger(__name__) # pylint: disable=invalid-name + try: from nunchaku.models.attention import NunchakuFeedForward # type: ignore[import] except Exception: @@ -826,30 +830,52 @@ def _modulate( shift, scale, gate = mod_params.chunk(3, dim=-1) if index is not None: - actual_batch = x.shape[0] - shift0, shift1 = ( - shift[:actual_batch], - shift[actual_batch : 2 * actual_batch], - ) - scale0, scale1 = ( - scale[:actual_batch], - scale[actual_batch : 2 * actual_batch], - ) - gate0, gate1 = gate[:actual_batch], gate[actual_batch : 2 * actual_batch] - if not x.is_contiguous(): - x = x.contiguous() - if not index.is_contiguous(): - index = index.contiguous() - if is_scale_residual: - if not residual_x.is_contiguous(): - residual_x = residual_x.contiguous() - if not gate_x.is_contiguous(): - gate_x = gate_x.contiguous() - x, residual_out, gate_result = ( - fuse_residual_layernorm_scale_shift_gate_select01_kernel( + # ROCm currently fails to compile the select01 Triton kernel, so + # keep using the torch.where fallback there. + if x.is_cuda and not current_platform.is_hip(): + actual_batch = x.shape[0] + shift0, shift1 = ( + shift[:actual_batch], + shift[actual_batch : 2 * actual_batch], + ) + scale0, scale1 = ( + scale[:actual_batch], + scale[actual_batch : 2 * actual_batch], + ) + gate0, gate1 = ( + gate[:actual_batch], + gate[actual_batch : 2 * actual_batch], + ) + if not x.is_contiguous(): + x = x.contiguous() + if not index.is_contiguous(): + index = index.contiguous() + if is_scale_residual: + if not residual_x.is_contiguous(): + residual_x = residual_x.contiguous() + if not gate_x.is_contiguous(): + gate_x = gate_x.contiguous() + x, residual_out, gate_result = ( + fuse_residual_layernorm_scale_shift_gate_select01_kernel( + x, + residual=residual_x, + residual_gate=gate_x, + weight=getattr(norm_module.norm, "weight", None), + bias=getattr(norm_module.norm, "bias", None), + scale0=scale0.contiguous(), + shift0=shift0.contiguous(), + gate0=gate0.contiguous(), + scale1=scale1.contiguous(), + shift1=shift1.contiguous(), + gate1=gate1.contiguous(), + index=index, + eps=norm_module.eps, + ) + ) + return x, residual_out, gate_result + else: + x, gate_result = fuse_layernorm_scale_shift_gate_select01_kernel( x, - residual=residual_x, - residual_gate=gate_x, weight=getattr(norm_module.norm, "weight", None), bias=getattr(norm_module.norm, "bias", None), scale0=scale0.contiguous(), @@ -861,39 +887,45 @@ def _modulate( index=index, eps=norm_module.eps, ) - ) - return x, residual_out, gate_result + return x, gate_result else: - x, gate_result = fuse_layernorm_scale_shift_gate_select01_kernel( - x, - weight=getattr(norm_module.norm, "weight", None), - bias=getattr(norm_module.norm, "bias", None), - scale0=scale0.contiguous(), - shift0=shift0.contiguous(), - gate0=gate0.contiguous(), - scale1=scale1.contiguous(), - shift1=shift1.contiguous(), - gate1=gate1.contiguous(), - index=index, - eps=norm_module.eps, + actual_batch = x.shape[0] + shift0, shift1 = ( + shift[:actual_batch], + shift[actual_batch : 2 * actual_batch], ) - return x, gate_result + scale0, scale1 = ( + scale[:actual_batch], + scale[actual_batch : 2 * actual_batch], + ) + gate0, gate1 = ( + gate[:actual_batch], + gate[actual_batch : 2 * actual_batch], + ) + index = index.to(dtype=torch.bool).unsqueeze(-1) + shift_result = torch.where( + index, shift1.unsqueeze(1), shift0.unsqueeze(1) + ) + scale_result = torch.where( + index, scale1.unsqueeze(1), scale0.unsqueeze(1) + ) + gate_result = torch.where(index, gate1.unsqueeze(1), gate0.unsqueeze(1)) else: shift_result = shift.unsqueeze(1) scale_result = scale.unsqueeze(1) gate_result = gate.unsqueeze(1) - if is_scale_residual: - modulated, residual_out = norm_module( - residual=residual_x, - x=x, - gate=gate_x, - shift=shift_result, - scale=scale_result, - ) - return modulated, residual_out, gate_result - else: - modulated = norm_module(x=x, shift=shift_result, scale=scale_result) - return modulated, gate_result + if is_scale_residual: + modulated, residual_out = norm_module( + residual=residual_x, + x=x, + gate=gate_x, + shift=shift_result, + scale=scale_result, + ) + return modulated, residual_out, gate_result + else: + modulated = norm_module(x=x, shift=shift_result, scale=scale_result) + return modulated, gate_result def forward( self, @@ -1127,8 +1159,8 @@ def build_modulate_index(self, img_shapes: tuple[int, int, int], device): first_size = sample[0][0] * sample[0][1] * sample[0][2] total_size = sum(s[0] * s[1] * s[2] for s in sample) if sp_world_size > 1: - first_local_size = _local_seq_len(first_size) - tail_local_size = _local_seq_len(total_size - first_size) + first_local_size = _local_seq_len(first_size, sp_world_size) + tail_local_size = _local_seq_len(total_size - first_size, sp_world_size) idx = torch.cat( [ torch.zeros(first_local_size, device=device, dtype=torch.int), diff --git a/python/sglang/multimodal_gen/runtime/models/dits/wanvideo.py b/python/sglang/multimodal_gen/runtime/models/dits/wanvideo.py index 4a2798a4a934..b193bf808324 100644 --- a/python/sglang/multimodal_gen/runtime/models/dits/wanvideo.py +++ b/python/sglang/multimodal_gen/runtime/models/dits/wanvideo.py @@ -10,7 +10,6 @@ import torch.nn as nn from sglang.multimodal_gen.configs.models.dits import WanVideoConfig -from sglang.multimodal_gen.configs.sample.wan import WanTeaCacheParams from sglang.multimodal_gen.runtime.distributed import ( divide, get_sp_group, @@ -1170,22 +1169,15 @@ def should_skip_forward_for_cached_states(self, **kwargs) -> bool: if ctx is None: return False - # Wan uses WanTeaCacheParams with additional fields - teacache_params = ctx.teacache_params - assert isinstance( - teacache_params, WanTeaCacheParams - ), "teacache_params is not a WanTeaCacheParams" - # Initialize Wan-specific parameters + teacache_params = ctx.teacache_params use_ret_steps = teacache_params.use_ret_steps - cutoff_steps = teacache_params.get_cutoff_steps(ctx.num_inference_steps) - ret_steps = teacache_params.ret_steps + start_skipping, end_skipping = teacache_params.get_skip_boundaries( + ctx.num_inference_steps, ctx.do_cfg + ) - # Adjust ret_steps and cutoff_steps for non-CFG mode - # (WanTeaCacheParams uses *2 factor assuming CFG) - if not ctx.do_cfg: - ret_steps = ret_steps // 2 - cutoff_steps = cutoff_steps // 2 + # Determine boundary step + is_boundary_step = self.cnt < start_skipping or self.cnt >= end_skipping timestep_proj = kwargs["timestep_proj"] temb = kwargs["temb"] @@ -1193,9 +1185,6 @@ def should_skip_forward_for_cached_states(self, **kwargs) -> bool: self.is_cfg_negative = ctx.is_cfg_negative - # Wan uses ret_steps/cutoff_steps for boundary detection - is_boundary_step = self.cnt < ret_steps or self.cnt >= cutoff_steps - # Use shared helper to compute cache decision should_calc = self._compute_teacache_decision( modulated_inp=modulated_inp, diff --git a/python/sglang/multimodal_gen/runtime/models/vaes/autoencoder_kl_qwenimage.py b/python/sglang/multimodal_gen/runtime/models/vaes/autoencoder_kl_qwenimage.py index 42c5426d7434..3178783c4362 100644 --- a/python/sglang/multimodal_gen/runtime/models/vaes/autoencoder_kl_qwenimage.py +++ b/python/sglang/multimodal_gen/runtime/models/vaes/autoencoder_kl_qwenimage.py @@ -3,6 +3,7 @@ from typing import Optional, Tuple, Union import torch +import torch.distributed as dist import torch.nn as nn import torch.nn.functional as F from diffusers.models.activations import get_activation @@ -13,7 +14,10 @@ from diffusers.models.modeling_outputs import AutoencoderKLOutput from sglang.multimodal_gen.configs.models.vaes.qwenimage import QwenImageVAEConfig -from sglang.multimodal_gen.runtime.distributed import get_local_torch_device +from sglang.multimodal_gen.runtime.distributed import ( + get_local_torch_device, + get_sp_world_size, +) from sglang.multimodal_gen.runtime.models.vaes.common import ParallelTiledVAE from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger @@ -789,6 +793,7 @@ def __init__( self.input_channels = config.arch_config.input_channels self.latents_mean = config.arch_config.latents_mean self.config = config.arch_config + self.use_parallel_decode = config.use_parallel_decode self.encoder = QwenImageEncoder3d( base_dim, z_dim * 2, dim_mult, num_res_blocks, attn_scales, self.temperal_downsample, dropout, @@ -841,6 +846,8 @@ def __init__( .to(cuda_device, dtype) ) + + def enable_tiling( self, tile_sample_min_height: Optional[int] = None, @@ -956,30 +963,43 @@ def encode( return posterior - def _decode(self, z: torch.Tensor, return_dict: bool = True): + def _decode_with_parallel_dispatch(self, z: torch.Tensor) -> DecoderOutput: + if self.use_parallel_decode and get_sp_world_size() > 1: + num_frame = z.shape[2] + num_sample_frames = (num_frame - 1) * self.temporal_compression_ratio + 1 + decoded = super().parallel_tiled_decode(z)[:, :, :num_sample_frames] + return DecoderOutput(sample=decoded) + + return DecoderOutput(sample=self._decode(z)) + + def _decode(self, z: torch.Tensor) -> torch.Tensor: _, _, num_frame, height, width = z.shape tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio if self.use_tiling and (width > tile_latent_min_width or height > tile_latent_min_height): - return self.tiled_decode(z, return_dict=return_dict) + return self.tiled_decode(z).sample self.clear_cache() x = self.post_quant_conv(z) for i in range(num_frame): self._conv_idx = [0] if i == 0: - out = self.decoder(x[:, :, i: i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx) + out = self.decoder( + x[:, :, i : i + 1, :, :], + feat_cache=self._feat_map, + feat_idx=self._conv_idx, + ) else: - out_ = self.decoder(x[:, :, i: i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx) + out_ = self.decoder( + x[:, :, i : i + 1, :, :], + feat_cache=self._feat_map, + feat_idx=self._conv_idx, + ) out = torch.cat([out, out_], 2) - out = torch.clamp(out, min=-1.0, max=1.0) self.clear_cache() - if not return_dict: - return (out,) - - return DecoderOutput(sample=out) + return out def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]: r""" @@ -996,29 +1016,121 @@ def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutp returned. """ if self.use_slicing and z.shape[0] > 1: - decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)] + decoded_slices = [ + self._decode_with_parallel_dispatch(z_slice).sample + for z_slice in z.split(1) + ] decoded = torch.cat(decoded_slices) else: - decoded = self._decode(z).sample + decoded = self._decode_with_parallel_dispatch(z).sample return decoded - def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: + def blend_v( + self, a: torch.Tensor, b: torch.Tensor, blend_extent: int + ) -> torch.Tensor: blend_extent = min(a.shape[-2], b.shape[-2], blend_extent) - for y in range(blend_extent): - b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * ( - y / blend_extent - ) + if blend_extent <= 0: + return b + weight = ( + torch.arange(blend_extent, device=b.device, dtype=b.dtype) / blend_extent + ).view(1, 1, 1, blend_extent, 1) + b[:, :, :, :blend_extent, :] = ( + a[:, :, :, -blend_extent:, :] * (1 - weight) + + b[:, :, :, :blend_extent, :] * weight + ) return b - def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: + def blend_h( + self, a: torch.Tensor, b: torch.Tensor, blend_extent: int + ) -> torch.Tensor: blend_extent = min(a.shape[-1], b.shape[-1], blend_extent) - for x in range(blend_extent): - b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * ( - x / blend_extent - ) + if blend_extent <= 0: + return b + weight = ( + torch.arange(blend_extent, device=b.device, dtype=b.dtype) / blend_extent + ).view(1, 1, 1, 1, blend_extent) + b[:, :, :, :, :blend_extent] = ( + a[:, :, :, :, -blend_extent:] * (1 - weight) + + b[:, :, :, :, :blend_extent] * weight + ) return b + def _process_parallel_tiled_outputs( + self, + results: torch.Tensor, + local_dim_metadata: list[torch.Size], + z: torch.Tensor, + world_size: int, + rank: int, + num_t_tiles: int, + num_h_tiles: int, + num_w_tiles: int, + total_spatial_tiles: int, + blend_height: int, + blend_width: int, + ) -> torch.Tensor: + local_size = torch.tensor( + [results.size(0)], device=results.device, dtype=torch.int64 + ) + if rank == 0: + gathered_sizes = [ + torch.zeros(1, device=results.device, dtype=torch.int64) + for _ in range(world_size) + ] + else: + gathered_sizes = None + dist.gather(local_size, gather_list=gathered_sizes, dst=0) + + max_size = 0 + if rank == 0: + max_size = max(size.item() for size in gathered_sizes) + + max_size_tensor = torch.tensor( + [max_size], device=results.device, dtype=torch.int64 + ) + dist.broadcast(max_size_tensor, src=0) + max_size = int(max_size_tensor.item()) + + padded_results = torch.zeros( + max_size, device=results.device, dtype=results.dtype + ) + padded_results[: results.size(0)] = results + + gathered_dim_metadata = [None] * world_size + dist.all_gather_object(gathered_dim_metadata, local_dim_metadata) + + if rank == 0: + gathered_results = [ + torch.empty_like(padded_results) for _ in range(world_size) + ] + else: + gathered_results = None + dist.gather(padded_results, gather_list=gathered_results, dst=0) + + if rank == 0: + gathered_results = torch.stack(gathered_results, dim=0).contiguous() + dec = super()._merge_parallel_tiled_results( + gathered_results, + gathered_dim_metadata, + num_t_tiles, + num_h_tiles, + num_w_tiles, + total_spatial_tiles, + blend_height, + blend_width, + ) + shape_tensor = torch.tensor(dec.shape, device=dec.device, dtype=torch.int64) + else: + dec = None + shape_tensor = torch.zeros(5, device=z.device, dtype=torch.int64) + + dist.broadcast(shape_tensor, src=0) + if rank != 0: + dec = z.new_empty(tuple(shape_tensor.tolist())) + dist.broadcast(dec, src=0) + return dec + def tiled_encode(self, x: torch.Tensor) -> AutoencoderKLOutput: r"""Encode a batch of images using a tiled encoder. diff --git a/python/sglang/multimodal_gen/runtime/models/vaes/common.py b/python/sglang/multimodal_gen/runtime/models/vaes/common.py index 095ce49574f5..7a55c8c4a8a4 100644 --- a/python/sglang/multimodal_gen/runtime/models/vaes/common.py +++ b/python/sglang/multimodal_gen/runtime/models/vaes/common.py @@ -220,14 +220,110 @@ def _parallel_data_generator( _start_shape += mul_shape global_idx += 1 + def _merge_parallel_tiled_results( + self, + gathered_results: torch.Tensor, + gathered_dim_metadata: list[list[torch.Size]], + num_t_tiles: int, + num_h_tiles: int, + num_w_tiles: int, + total_spatial_tiles: int, + blend_height: int, + blend_width: int, + ) -> torch.Tensor: + data: list = [ + [[[] for _ in range(num_w_tiles)] for _ in range(num_h_tiles)] + for _ in range(num_t_tiles) + ] + for current_data, global_idx in self._parallel_data_generator( + gathered_results, gathered_dim_metadata + ): + t_idx = global_idx // total_spatial_tiles + spatial_idx = global_idx % total_spatial_tiles + h_idx = spatial_idx // num_w_tiles + w_idx = spatial_idx % num_w_tiles + data[t_idx][h_idx][w_idx] = current_data + + result_slices = [] + last_slice_data = None + for i, tem_data in enumerate(data): + slice_data = self._merge_spatial_tiles( + tem_data, + blend_height, + blend_width, + self.tile_sample_stride_height, + self.tile_sample_stride_width, + ) + if i > 0: + slice_data = self.blend_t( + last_slice_data, slice_data, self.blend_num_frames + ) + result_slices.append( + slice_data[:, :, : self.tile_sample_stride_num_frames, :, :] + ) + else: + result_slices.append( + slice_data[:, :, : self.tile_sample_stride_num_frames + 1, :, :] + ) + last_slice_data = slice_data + return torch.cat(result_slices, dim=2) + + def _process_parallel_tiled_outputs( + self, + results: torch.Tensor, + local_dim_metadata: list[torch.Size], + z: torch.Tensor, + world_size: int, + rank: int, + num_t_tiles: int, + num_h_tiles: int, + num_w_tiles: int, + total_spatial_tiles: int, + blend_height: int, + blend_width: int, + ) -> torch.Tensor: + local_size = torch.tensor( + [results.size(0)], device=results.device, dtype=torch.int64 + ) + all_sizes = [ + torch.zeros(1, device=results.device, dtype=torch.int64) + for _ in range(world_size) + ] + dist.all_gather(all_sizes, local_size) + max_size = max(size.item() for size in all_sizes) + + padded_results = torch.zeros( + max_size, device=results.device, dtype=results.dtype + ) + padded_results[: results.size(0)] = results + + gathered_dim_metadata = [None] * world_size + gathered_results = ( + torch.zeros_like(padded_results) + .repeat(world_size, *[1] * len(padded_results.shape)) + .contiguous() + ) + dist.all_gather_into_tensor(gathered_results, padded_results) + dist.all_gather_object(gathered_dim_metadata, local_dim_metadata) + gathered_dim_metadata = cast(list[list[torch.Size]], gathered_dim_metadata) + return self._merge_parallel_tiled_results( + gathered_results, + gathered_dim_metadata, + num_t_tiles, + num_h_tiles, + num_w_tiles, + total_spatial_tiles, + blend_height, + blend_width, + ) + def parallel_tiled_decode(self, z: torch.FloatTensor) -> torch.FloatTensor: """ Parallel version of tiled_decode that distributes both temporal and spatial computation across GPUs """ world_size, rank = get_sp_world_size(), get_sp_parallel_rank() - B, C, T, H, W = z.shape + _, _, T, H, W = z.shape - # Calculate parameters tile_latent_min_height = ( self.tile_sample_min_height // self.spatial_compression_ratio ) @@ -259,26 +355,22 @@ def parallel_tiled_decode(self, z: torch.FloatTensor) -> torch.FloatTensor: total_spatial_tiles = num_h_tiles * num_w_tiles total_tiles = num_t_tiles * total_spatial_tiles - # Calculate tiles per rank and padding tiles_per_rank = (total_tiles + world_size - 1) // world_size start_tile_idx = rank * tiles_per_rank end_tile_idx = min((rank + 1) * tiles_per_rank, total_tiles) local_results = [] local_dim_metadata = [] - # Process assigned tiles - for local_idx, global_idx in enumerate(range(start_tile_idx, end_tile_idx)): + for global_idx in range(start_tile_idx, end_tile_idx): t_idx = global_idx // total_spatial_tiles spatial_idx = global_idx % total_spatial_tiles h_idx = spatial_idx // num_w_tiles w_idx = spatial_idx % num_w_tiles - # Calculate positions t_start = t_idx * tile_latent_stride_num_frames h_start = h_idx * tile_latent_stride_height w_start = w_idx * tile_latent_stride_width - # Extract and process tile tile = z[ :, :, @@ -286,84 +378,31 @@ def parallel_tiled_decode(self, z: torch.FloatTensor) -> torch.FloatTensor: h_start : h_start + tile_latent_min_height, w_start : w_start + tile_latent_min_width, ] - - # Process tile - tile = self._decode(tile) - + decoded_tile = self._decode(tile) if t_start > 0: - tile = tile[:, :, 1:, :, :] - - # Store metadata - shape = tile.shape - # Store decoded data (flattened) - decoded_flat = tile.reshape(-1) - local_results.append(decoded_flat) - local_dim_metadata.append(shape) + decoded_tile = decoded_tile[:, :, 1:, :, :] + local_results.append(decoded_tile.reshape(-1)) + local_dim_metadata.append(decoded_tile.shape) - results = torch.cat(local_results, dim=0).contiguous() + if local_results: + results = torch.cat(local_results, dim=0).contiguous() + else: + results = z.new_empty((0,), dtype=z.dtype) del local_results - # first gather size to pad the results - local_size = torch.tensor( - [results.size(0)], device=results.device, dtype=torch.int64 - ) - all_sizes = [ - torch.zeros(1, device=results.device, dtype=torch.int64) - for _ in range(world_size) - ] - dist.all_gather(all_sizes, local_size) - max_size = max(size.item() for size in all_sizes) - padded_results = torch.zeros(max_size, device=results.device) - padded_results[: results.size(0)] = results - del results - - # Gather all results - gathered_dim_metadata = [None] * world_size - gathered_results = ( - torch.zeros_like(padded_results) - .repeat(world_size, *[1] * len(padded_results.shape)) - .contiguous() - ) # use contiguous to make sure it won't copy data in the following operations - # TODO (PY): use sgl_diffusion distributed methods - dist.all_gather_into_tensor(gathered_results, padded_results) - dist.all_gather_object(gathered_dim_metadata, local_dim_metadata) - # Process gathered results - data: list = [ - [[[] for _ in range(num_w_tiles)] for _ in range(num_h_tiles)] - for _ in range(num_t_tiles) - ] - for current_data, global_idx in self._parallel_data_generator( - gathered_results, gathered_dim_metadata - ): - t_idx = global_idx // total_spatial_tiles - spatial_idx = global_idx % total_spatial_tiles - h_idx = spatial_idx // num_w_tiles - w_idx = spatial_idx % num_w_tiles - data[t_idx][h_idx][w_idx] = current_data - # Merge results - result_slices = [] - last_slice_data = None - for i, tem_data in enumerate(data): - slice_data = self._merge_spatial_tiles( - tem_data, - blend_height, - blend_width, - self.tile_sample_stride_height, - self.tile_sample_stride_width, - ) - if i > 0: - slice_data = self.blend_t( - last_slice_data, slice_data, self.blend_num_frames - ) - result_slices.append( - slice_data[:, :, : self.tile_sample_stride_num_frames, :, :] - ) - else: - result_slices.append( - slice_data[:, :, : self.tile_sample_stride_num_frames + 1, :, :] - ) - last_slice_data = slice_data - dec = torch.cat(result_slices, dim=2) + dec = self._process_parallel_tiled_outputs( + results, + local_dim_metadata, + z, + world_size, + rank, + num_t_tiles, + num_h_tiles, + num_w_tiles, + total_spatial_tiles, + blend_height, + blend_width, + ) return dec def _merge_spatial_tiles( diff --git a/python/sglang/multimodal_gen/runtime/pipelines_core/stages/denoising.py b/python/sglang/multimodal_gen/runtime/pipelines_core/stages/denoising.py index 2611c2733615..2f348d68cfb8 100644 --- a/python/sglang/multimodal_gen/runtime/pipelines_core/stages/denoising.py +++ b/python/sglang/multimodal_gen/runtime/pipelines_core/stages/denoising.py @@ -1025,6 +1025,7 @@ def forward( logger=logger, metrics=batch.metrics, perf_dump_path_provided=batch.perf_dump_path is not None, + record_as_step=True, ): t_int = int(t_host.item()) t_device = timesteps[i] diff --git a/python/sglang/multimodal_gen/runtime/pipelines_core/stages/denoising_av.py b/python/sglang/multimodal_gen/runtime/pipelines_core/stages/denoising_av.py index 55deb10f41ab..7bd5a7e95d74 100644 --- a/python/sglang/multimodal_gen/runtime/pipelines_core/stages/denoising_av.py +++ b/python/sglang/multimodal_gen/runtime/pipelines_core/stages/denoising_av.py @@ -407,6 +407,7 @@ def forward(self, batch: Req, server_args: ServerArgs) -> Req: logger=logger, metrics=batch.metrics, perf_dump_path_provided=batch.perf_dump_path is not None, + record_as_step=True, ): t_int = int(t_host.item()) t_device = timesteps[i] diff --git a/python/sglang/multimodal_gen/runtime/pipelines_core/stages/denoising_dmd.py b/python/sglang/multimodal_gen/runtime/pipelines_core/stages/denoising_dmd.py index 504fc429e03b..1b0223d513ec 100644 --- a/python/sglang/multimodal_gen/runtime/pipelines_core/stages/denoising_dmd.py +++ b/python/sglang/multimodal_gen/runtime/pipelines_core/stages/denoising_dmd.py @@ -104,6 +104,7 @@ def forward( logger=logger, metrics=batch.metrics, perf_dump_path_provided=batch.perf_dump_path is not None, + record_as_step=True, ): t_int = int(t.item()) if self.transformer_2 is not None: diff --git a/python/sglang/multimodal_gen/runtime/pipelines_core/stages/input_validation.py b/python/sglang/multimodal_gen/runtime/pipelines_core/stages/input_validation.py index aa0b686b828a..fc27f657f424 100644 --- a/python/sglang/multimodal_gen/runtime/pipelines_core/stages/input_validation.py +++ b/python/sglang/multimodal_gen/runtime/pipelines_core/stages/input_validation.py @@ -191,8 +191,30 @@ def preprocess_condition_image( server_args.pipeline_config.vae_config.arch_config.scale_factor_spatial * server_args.pipeline_config.dit_config.arch_config.patch_size[1] ) + + # User-specified width/height controls the target area (scale), + # capped by max_area. Aspect ratio always comes from the + # condition image for I2V. + if batch.width is not None or batch.height is not None: + # If one dimension is provided, calculate the other based on the image's aspect ratio. + if batch.width is None: + batch.width = round(batch.height / aspect_ratio) + elif batch.height is None: + batch.height = round(batch.width * aspect_ratio) + + target_area = min(batch.width * batch.height, max_area) + if batch.width * batch.height > max_area: + logger.warning( + "Requested resolution %dx%d exceeds max_area %d, " + "clamping to max_area", + batch.width, + batch.height, + max_area, + ) + else: + target_area = max_area width, height = self._calculate_dimensions_from_area( - max_area, aspect_ratio, mod_value + target_area, aspect_ratio, mod_value ) batch.condition_image = batch.condition_image.resize((width, height)) diff --git a/python/sglang/multimodal_gen/runtime/pipelines_core/stages/model_specific_stages/mova.py b/python/sglang/multimodal_gen/runtime/pipelines_core/stages/model_specific_stages/mova.py index afd43238e20a..b8b0cfb62b02 100644 --- a/python/sglang/multimodal_gen/runtime/pipelines_core/stages/model_specific_stages/mova.py +++ b/python/sglang/multimodal_gen/runtime/pipelines_core/stages/model_specific_stages/mova.py @@ -69,6 +69,7 @@ from sglang.multimodal_gen.utils import PRECISION_TO_TYPE from sglang.srt.utils.common import get_compiler_backend +_is_npu = current_platform.is_npu() logger = init_logger(__name__) @@ -430,6 +431,7 @@ def forward(self, batch: Req, server_args: ServerArgs) -> Req: logger=logger, metrics=metrics, perf_dump_path_provided=perf_dump_path_provided, + record_as_step=True, ): pair_t = paired_timesteps[idx_step] if getattr(pair_t, "shape", None) == (2,): @@ -714,7 +716,14 @@ def inference_single_step( # Build visual freqs for full sequence visual_dit._init_freqs() - visual_freqs = tuple(freq.to(visual_x.device) for freq in visual_dit.freqs) + if _is_npu: + # TODO: remove this when torch.complex128 is supported for torch.cat on NPU + visual_freqs = tuple( + freq.to(device=visual_x.device, dtype=torch.complex64) + for freq in visual_dit.freqs + ) + else: + visual_freqs = tuple(freq.to(visual_x.device) for freq in visual_dit.freqs) visual_freqs = ( torch.cat( [ @@ -734,18 +743,24 @@ def inference_single_step( # Build audio freqs for full sequence self.audio_dit._init_freqs() - audio_freqs = ( - torch.cat( - [ - self.audio_dit.freqs[0][:f].view(f, -1).expand(f, -1), - self.audio_dit.freqs[1][:f].view(f, -1).expand(f, -1), - self.audio_dit.freqs[2][:f].view(f, -1).expand(f, -1), - ], - dim=-1, + if _is_npu: + # TODO: remove this when torch.complex128 is supported for torch.cat on NPU + audio_freqs = tuple( + freq.to(device=audio_x.device, dtype=torch.complex64) + for freq in self.audio_dit.freqs ) - .reshape(full_audio_seq_len, 1, -1) - .to(audio_x.device) - ) + else: + audio_freqs = tuple( + freq.to(audio_x.device) for freq in self.audio_dit.freqs + ) + audio_freqs = torch.cat( + [ + audio_freqs[0][:f].view(f, -1).expand(f, -1), + audio_freqs[1][:f].view(f, -1).expand(f, -1), + audio_freqs[2][:f].view(f, -1).expand(f, -1), + ], + dim=-1, + ).reshape(full_audio_seq_len, 1, -1) # Shard sequences for SP visual_x, visual_pad_len = self._shard_sequence_for_sp(visual_x, dim=1) diff --git a/python/sglang/multimodal_gen/runtime/pipelines_core/stages/text_encoding.py b/python/sglang/multimodal_gen/runtime/pipelines_core/stages/text_encoding.py index 7b51cff5f61d..e1184cbc0c38 100644 --- a/python/sglang/multimodal_gen/runtime/pipelines_core/stages/text_encoding.py +++ b/python/sglang/multimodal_gen/runtime/pipelines_core/stages/text_encoding.py @@ -236,10 +236,12 @@ def encode_text( else {} ) - processed_text_list: list[str] = [] - for prompt_str in texts: - preprocessed = preprocess_func(prompt_str) - processed_text_list.append(preprocessed) + if preprocess_func is not None: + processed_text_list: list[str] = [ + preprocess_func(prompt_str) for prompt_str in texts + ] + else: + processed_text_list = texts # Prepare tokenizer args tok_kwargs = self.prepare_tokenizer_kwargs( diff --git a/python/sglang/multimodal_gen/runtime/platforms/musa.py b/python/sglang/multimodal_gen/runtime/platforms/musa.py index 7d443be6b542..a368be696a75 100644 --- a/python/sglang/multimodal_gen/runtime/platforms/musa.py +++ b/python/sglang/multimodal_gen/runtime/platforms/musa.py @@ -150,10 +150,61 @@ def get_attn_backend_cls_str( head_size: int, dtype: torch.dtype, ) -> str: - logger.info("Using Torch SDPA backend.") - return ( - "sglang.multimodal_gen.runtime.layers.attention.backends.sdpa.SDPABackend" - ) + target_backend: AttentionBackendEnum | None = None + + if selected_backend == AttentionBackendEnum.TORCH_SDPA: + logger.info("Using Torch SDPA backend") + return "sglang.multimodal_gen.runtime.layers.attention.backends.sdpa.SDPABackend" + elif selected_backend in [ + AttentionBackendEnum.FA, + ]: + target_backend = AttentionBackendEnum.FA + elif selected_backend: + raise ValueError(f"Invalid attention backend for {cls.device_name}") + else: + target_backend = AttentionBackendEnum.FA + + # Ensure we have a target backend selected before validation/fallback. + if target_backend is None: + target_backend = AttentionBackendEnum.FA + + if dtype not in (torch.float16, torch.bfloat16): + logger.info( + "Cannot use FlashAttention backend for dtype other than " + "torch.float16 or torch.bfloat16." + ) + target_backend = AttentionBackendEnum.TORCH_SDPA + + # FlashAttn is valid for the model, checking if the package is + # installed. + if target_backend == AttentionBackendEnum.FA: + try: + from sglang.multimodal_gen.runtime.layers.attention.backends.flash_attn import ( # noqa: F401 + FlashAttentionBackend, + ) + + supported_sizes = FlashAttentionBackend.get_supported_head_sizes() + if head_size not in supported_sizes: + logger.info( + "Cannot use FlashAttention backend for head size %d.", + head_size, + ) + target_backend = AttentionBackendEnum.TORCH_SDPA + except ImportError: + logger.info( + "Cannot use FlashAttention backend because the " + "flash_attn package is not found. " + "Make sure that flash_attn was built and installed " + "(on by default)." + ) + target_backend = AttentionBackendEnum.TORCH_SDPA + + if target_backend == AttentionBackendEnum.TORCH_SDPA: + logger.info("Using Torch SDPA backend") + return "sglang.multimodal_gen.runtime.layers.attention.backends.sdpa.SDPABackend" + + logger.info("Using FlashAttention (FA3) backend on MUSA") + return "sglang.multimodal_gen.runtime.layers.attention.backends.flash_attn.FlashAttentionBackend" @classmethod def get_device_communicator_cls(cls) -> str: diff --git a/python/sglang/multimodal_gen/runtime/platforms/npu.py b/python/sglang/multimodal_gen/runtime/platforms/npu.py index c73733409b80..3a3303b61d2f 100644 --- a/python/sglang/multimodal_gen/runtime/platforms/npu.py +++ b/python/sglang/multimodal_gen/runtime/platforms/npu.py @@ -116,6 +116,10 @@ def get_attn_backend_cls_str( head_size: int, dtype: torch.dtype, ) -> str: + if selected_backend == AttentionBackendEnum.FA: + logger.info("Using Ascend Flash Attention backend.") + return "sglang.multimodal_gen.runtime.layers.attention.backends.ascend_fa.AscendFABackend" + logger.info("Using Torch SDPA backend.") return ( "sglang.multimodal_gen.runtime.layers.attention.backends.sdpa.SDPABackend" diff --git a/python/sglang/multimodal_gen/runtime/server_args.py b/python/sglang/multimodal_gen/runtime/server_args.py index 4fc3e2964ea7..62edb3dcaf1e 100644 --- a/python/sglang/multimodal_gen/runtime/server_args.py +++ b/python/sglang/multimodal_gen/runtime/server_args.py @@ -233,6 +233,7 @@ class ServerArgs: # Logging log_level: str = "info" + uvicorn_access_log_exclude_prefixes: list[str] = field(default_factory=list) @property def broker_port(self) -> int: @@ -860,6 +861,15 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: default=ServerArgs.log_level, help="The logging level of all loggers.", ) + parser.add_argument( + "--uvicorn-access-log-exclude-prefixes", + type=str, + nargs="*", + default=[], + help="Exclude uvicorn access logs whose request path starts with any of these prefixes. " + "Defaults to empty (disabled). " + "Example: --uvicorn-access-log-exclude-prefixes /metrics /health", + ) parser.add_argument( "--backend", type=str, diff --git a/python/sglang/multimodal_gen/runtime/utils/hf_diffusers_utils.py b/python/sglang/multimodal_gen/runtime/utils/hf_diffusers_utils.py index 47c1a7b85815..f12f304110a0 100644 --- a/python/sglang/multimodal_gen/runtime/utils/hf_diffusers_utils.py +++ b/python/sglang/multimodal_gen/runtime/utils/hf_diffusers_utils.py @@ -44,6 +44,10 @@ from sglang.multimodal_gen.runtime.loader.weight_utils import get_lock from sglang.multimodal_gen.runtime.platforms import current_platform from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger +from sglang.multimodal_gen.runtime.utils.model_overlay import ( + maybe_load_overlay_model_index, + maybe_resolve_overlay_model_path, +) from sglang.srt.environ import envs from sglang.utils import is_in_ci @@ -485,7 +489,15 @@ def maybe_download_model_index(model_name_or_path: str) -> dict[str, Any]: from huggingface_hub.errors import EntryNotFoundError - # If it's a local path, verify it directly + overlay_config = maybe_load_overlay_model_index( + model_name_or_path, + snapshot_download_fn=snapshot_download, + hf_hub_download_fn=hf_hub_download, + ) + if overlay_config is not None: + return overlay_config + + # If it's a local path, verify it directly. if os.path.exists(model_name_or_path): try: return verify_model_config_and_directory(model_name_or_path) @@ -560,6 +572,7 @@ def maybe_download_model( is_lora: bool = False, allow_patterns: list[str] | None = None, force_diffusers_model: bool = False, + skip_overlay_resolution: bool = False, ) -> str: """ Check if the model path is a Hugging Face Hub model ID and download it if needed. @@ -573,6 +586,20 @@ def maybe_download_model( Returns: Local path to the model """ + if force_diffusers_model and not skip_overlay_resolution: + # return overlay model path if applicable + overlay_model_path = maybe_resolve_overlay_model_path( + model_name_or_path, + local_dir=local_dir, + download=download, + allow_patterns=allow_patterns, + snapshot_download_fn=snapshot_download, + hf_hub_download_fn=hf_hub_download, + verify_diffusers_model_complete_fn=_verify_diffusers_model_complete, + base_model_download_fn=maybe_download_model, + ) + if overlay_model_path is not None: + return overlay_model_path # 1. Local path check: if path exists locally, verify it's complete (skip for LoRA) if os.path.exists(model_name_or_path): diff --git a/python/sglang/multimodal_gen/runtime/utils/logging_utils.py b/python/sglang/multimodal_gen/runtime/utils/logging_utils.py index 6214bacd91a3..1b13c05c6d70 100644 --- a/python/sglang/multimodal_gen/runtime/utils/logging_utils.py +++ b/python/sglang/multimodal_gen/runtime/utils/logging_utils.py @@ -451,7 +451,7 @@ def enable_trace_function_call(log_file_path: str, root_dir: str | None = None): sys.settrace(partial(_trace_calls, log_file_path, root_dir)) -def set_uvicorn_logging_configs(): +def set_uvicorn_logging_configs(server_args=None): from uvicorn.config import LOGGING_CONFIG LOGGING_CONFIG["formatters"]["default"][ @@ -463,6 +463,59 @@ def set_uvicorn_logging_configs(): ] = '[%(asctime)s] %(levelprefix)s %(client_addr)s - "%(request_line)s" %(status_code)s' LOGGING_CONFIG["formatters"]["access"]["datefmt"] = "%Y-%m-%d %H:%M:%S" + # Install access log path filter into LOGGING_CONFIG so it survives + # uvicorn's internal dictConfig() call during startup. + prefixes = getattr(server_args, "uvicorn_access_log_exclude_prefixes", None) + if prefixes: + _install_access_log_filter(LOGGING_CONFIG, prefixes) + + +def _install_access_log_filter(config: dict, prefixes: list[str]): + """Register a path-based access log filter into uvicorn's LOGGING_CONFIG dict. + + Only attaches to the ``access`` handler (not the ``uvicorn.access`` logger) + to avoid filtering the same record twice. + """ + # Sanitize: drop empty strings (would match all paths) and deduplicate. + prefixes = [str(p) for p in prefixes if p] + prefixes = list(dict.fromkeys(prefixes)) + if not prefixes: + return + + name = "sglang_diffusion_path_filter" + config.setdefault("filters", {})[name] = { + "()": "sglang.multimodal_gen.runtime.utils.logging_utils._UvicornAccessLogFilter", + "prefixes": prefixes, + } + + handler_cfg = config.get("handlers", {}).get("access") + if handler_cfg is not None: + fl = handler_cfg.setdefault("filters", []) + if name not in fl: + fl.append(name) + + +class _UvicornAccessLogFilter(logging.Filter): + """Suppress uvicorn access logs whose path starts with an excluded prefix. + + uvicorn's ``AccessFormatter`` injects ``request_line`` during ``format()``, + which runs *after* filters. We therefore extract the path from + ``record.args`` which uvicorn populates as:: + + (client_addr, method, full_path, http_version, status_code) + """ + + def __init__(self, prefixes: list[str] | None = None): + super().__init__() + self.prefixes = tuple(str(p) for p in (prefixes or ()) if p) + + def filter(self, record: logging.LogRecord) -> bool: + args = record.args + if isinstance(args, tuple) and len(args) >= 3: + path = str(args[2]).split("?", 1)[0] + return not path.startswith(self.prefixes) + return True + def configure_logger(server_args, prefix: str = ""): log_format = f"[%(asctime)s{prefix}] %(message)s" @@ -477,7 +530,7 @@ def configure_logger(server_args, prefix: str = ""): root.addHandler(handler) root.setLevel(getattr(logging, server_args.log_level.upper())) - set_uvicorn_logging_configs() + set_uvicorn_logging_configs(server_args) @lru_cache(maxsize=1) diff --git a/python/sglang/multimodal_gen/runtime/utils/model_overlay.py b/python/sglang/multimodal_gen/runtime/utils/model_overlay.py new file mode 100644 index 000000000000..eabc2e9dd3a7 --- /dev/null +++ b/python/sglang/multimodal_gen/runtime/utils/model_overlay.py @@ -0,0 +1,650 @@ +# SPDX-License-Identifier: Apache-2.0 +from __future__ import annotations + +import glob +import hashlib +import importlib.util +import json +import os +import shutil +from typing import Any, Callable, cast + +from huggingface_hub.errors import ( + LocalEntryNotFoundError, + RepositoryNotFoundError, + RevisionNotFoundError, +) +from requests.exceptions import ConnectionError as RequestsConnectionError +from requests.exceptions import RequestException + +from sglang.multimodal_gen.runtime.loader.weight_utils import get_lock +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger +from sglang.utils import load_diffusion_overlay_registry_from_env + +logger = init_logger(__name__) + +# Built-in diffusion model overlay registry. +# Keep this empty until concrete overlay repos are ready to ship. +BUILTIN_MODEL_OVERLAY_REGISTRY: dict[str, dict[str, Any]] = {} + + +MODEL_OVERLAY_METADATA_PATTERNS = [ + "*.json", + "*.md", + "*.py", + "*.txt", + "**/*.json", + "**/*.md", + "**/*.py", + "**/*.txt", +] + +_MODEL_OVERLAY_REGISTRY_CACHE: dict[str, dict[str, Any]] | None = None + + +def get_diffusion_cache_root() -> str: + return os.path.expanduser( + os.getenv("SGLANG_DIFFUSION_CACHE_ROOT", "~/.cache/sgl_diffusion") + ) + + +def clear_model_overlay_registry_cache() -> None: + global _MODEL_OVERLAY_REGISTRY_CACHE + _MODEL_OVERLAY_REGISTRY_CACHE = None + + +def _load_model_overlay_registry() -> dict[str, dict[str, Any]]: + global _MODEL_OVERLAY_REGISTRY_CACHE + if _MODEL_OVERLAY_REGISTRY_CACHE is not None: + return _MODEL_OVERLAY_REGISTRY_CACHE + + # Built-in registry is the stable default path; env only overrides it. + normalized = _normalize_model_overlay_registry(BUILTIN_MODEL_OVERLAY_REGISTRY) + + env_registry = load_diffusion_overlay_registry_from_env() + if not env_registry: + _MODEL_OVERLAY_REGISTRY_CACHE = normalized + return _MODEL_OVERLAY_REGISTRY_CACHE + + normalized.update(_normalize_model_overlay_registry(env_registry)) + _MODEL_OVERLAY_REGISTRY_CACHE = normalized + return _MODEL_OVERLAY_REGISTRY_CACHE + + +def _normalize_model_overlay_registry( + payload: dict[str, Any], +) -> dict[str, dict[str, Any]]: + normalized: dict[str, dict[str, Any]] = {} + for source_model_id, spec in payload.items(): + if isinstance(spec, str): + normalized[source_model_id] = {"overlay_repo_id": spec} + continue + if not isinstance(spec, dict): + raise ValueError( + "Overlay registry values must be either strings or JSON objects" + ) + overlay_repo_id = spec.get("overlay_repo_id") + if not overlay_repo_id: + raise ValueError( + f"Overlay registry entry for {source_model_id!r} is missing overlay_repo_id" + ) + normalized[source_model_id] = dict(spec) + return normalized + + +def resolve_model_overlay(model_name_or_path: str) -> dict[str, Any] | None: + registry = _load_model_overlay_registry() + return registry.get(model_name_or_path) + + +def resolve_model_overlay_target( + model_name_or_path: str, +) -> tuple[str, dict[str, Any]] | None: + registry = _load_model_overlay_registry() + + exact = registry.get(model_name_or_path) + if exact is not None: + return model_name_or_path, exact + + if os.path.exists(model_name_or_path): + # Local source dirs do not have a repo id, so match them by basename. + base_name = os.path.basename(os.path.normpath(model_name_or_path)) + for source_model_id, spec in registry.items(): + if base_name == source_model_id.rsplit("/", 1)[-1]: + return source_model_id, spec + + return None + + +def load_overlay_manifest_if_present(overlay_dir: str) -> dict[str, Any] | None: + overlay_manifest_path = os.path.join( + overlay_dir, "_overlay", "overlay_manifest.json" + ) + if not os.path.exists(overlay_manifest_path): + return None + with open(overlay_manifest_path, encoding="utf-8") as f: + manifest = cast(dict[str, Any], json.load(f)) + return manifest + + +def load_model_index_from_dir(model_dir: str) -> dict[str, Any]: + model_index_path = os.path.join(model_dir, "model_index.json") + if not os.path.exists(model_index_path): + raise ValueError(f"model_index.json not found under {model_dir}") + with open(model_index_path, encoding="utf-8") as f: + config = cast(dict[str, Any], json.load(f)) + if "_class_name" not in config or "_diffusers_version" not in config: + raise ValueError(f"Invalid model_index.json under {model_dir}") + config["pipeline_name"] = config["_class_name"] + return config + + +def _ensure_dir(path: str) -> None: + os.makedirs(path, exist_ok=True) + + +def _find_missing_required_paths( + root_dir: str, required_paths: list[str] | tuple[str, ...] +) -> list[str]: + missing: list[str] = [] + for rel_path in required_paths: + if not os.path.exists(os.path.join(root_dir, rel_path)): + missing.append(rel_path) + return missing + + +def _link_or_copy_file(src: str, dst: str) -> None: + src = os.path.realpath(src) + _ensure_dir(os.path.dirname(dst)) + if os.path.lexists(dst): + os.remove(dst) + try: + os.link(src, dst) + return + except OSError: + pass + try: + os.symlink(src, dst) + return + except OSError: + pass + shutil.copy2(src, dst) + + +def _copytree_link_or_copy(src_dir: str, dst_dir: str) -> None: + for root, _, files in os.walk(src_dir): + rel_root = os.path.relpath(root, src_dir) + target_root = dst_dir if rel_root == "." else os.path.join(dst_dir, rel_root) + _ensure_dir(target_root) + for file_name in files: + src_file = os.path.join(root, file_name) + dst_file = os.path.join(target_root, file_name) + _link_or_copy_file(src_file, dst_file) + + +def ensure_overlay_source_dir_complete( + *, + source_model_id: str, + source_dir: str, + manifest: dict[str, Any], + local_dir: str | None, + allow_patterns: list[str] | None, + download: bool, + snapshot_download_fn: Callable[..., str], +) -> str: + required_source_files = cast( + list[str], list(manifest.get("required_source_files", [])) + ) + if not required_source_files: + return source_dir + + # Metadata-only overlays often need a partial source snapshot. Re-download + # only when the current source dir is missing required files. + missing_paths = _find_missing_required_paths(source_dir, required_source_files) + if not missing_paths: + return source_dir + + if not download: + raise ValueError( + f"Overlay source model {source_model_id} is missing required files " + f"{missing_paths} and download=False." + ) + + logger.warning( + "Overlay source model %s is missing required files %s. " + "Re-downloading source snapshot.", + source_model_id, + missing_paths, + ) + source_allow_patterns = manifest.get("source_allow_patterns") + effective_allow_patterns = ( + cast(list[str] | None, source_allow_patterns) + if source_allow_patterns is not None + else allow_patterns + ) + with get_lock(source_model_id).acquire(poll_interval=2): + source_dir = snapshot_download_fn( + repo_id=source_model_id, + ignore_patterns=["*.onnx", "*.msgpack"], + allow_patterns=effective_allow_patterns, + local_dir=local_dir, + max_workers=8, + force_download=True, + ) + missing_after_redownload = _find_missing_required_paths( + source_dir, required_source_files + ) + if missing_after_redownload: + raise ValueError( + f"Overlay source model {source_model_id} is still missing required files " + f"{missing_after_redownload} after re-download." + ) + return str(source_dir) + + +def resolve_direct_overlay_repo( + model_name_or_path: str, + *, + hf_hub_download_fn: Callable[..., str], +) -> tuple[dict[str, Any], str, dict[str, Any]] | None: + if os.path.exists(model_name_or_path): + manifest = load_overlay_manifest_if_present(model_name_or_path) + if manifest is None: + return None + source_model_id = manifest.get("source_model_id") + if not source_model_id: + raise ValueError( + f"Overlay repo {model_name_or_path} is missing source_model_id in _overlay/overlay_manifest.json" + ) + overlay_spec = { + "overlay_repo_id": model_name_or_path, + "overlay_revision": "local", + } + return overlay_spec, model_name_or_path, manifest + + try: + manifest_path = hf_hub_download_fn( + repo_id=model_name_or_path, + filename="_overlay/overlay_manifest.json", + ) + overlay_dir = os.path.dirname(os.path.dirname(manifest_path)) + except ( + RepositoryNotFoundError, + RevisionNotFoundError, + LocalEntryNotFoundError, + RequestsConnectionError, + RequestException, + ): + return None + except Exception: + return None + + manifest = load_overlay_manifest_if_present(overlay_dir) + if manifest is None: + return None + source_model_id = manifest.get("source_model_id") + if not source_model_id: + raise ValueError( + f"Overlay repo {model_name_or_path} is missing source_model_id in _overlay/overlay_manifest.json" + ) + overlay_spec = { + "overlay_repo_id": model_name_or_path, + "overlay_revision": "main", + } + return overlay_spec, overlay_dir, manifest + + +def download_overlay_metadata( + source_model_id: str, + overlay_spec: dict[str, Any], + *, + snapshot_download_fn: Callable[..., str], +) -> str: + overlay_repo_id = str(overlay_spec["overlay_repo_id"]) + if os.path.exists(overlay_repo_id): + logger.info( + "Using local overlay metadata for %s from %s", + source_model_id, + overlay_repo_id, + ) + return overlay_repo_id + revision = overlay_spec.get("overlay_revision") + logger.info( + "Downloading overlay metadata for %s from %s", + source_model_id, + overlay_repo_id, + ) + return str( + snapshot_download_fn( + repo_id=overlay_repo_id, + allow_patterns=MODEL_OVERLAY_METADATA_PATTERNS, + revision=revision, + max_workers=4, + ) + ) + + +def _apply_overlay_file_mappings( + *, + source_dir: str, + output_dir: str, + file_mappings: list[dict[str, Any]], +) -> None: + for mapping in file_mappings: + mapping_type = mapping.get("type", "file") + src_rel = mapping.get("src") + if not src_rel: + raise ValueError(f"Overlay file mapping is missing src: {mapping}") + src_path = os.path.join(source_dir, src_rel) + if mapping_type == "tree": + if not os.path.isdir(src_path): + raise ValueError(f"Tree mapping source does not exist: {src_path}") + dst_dir = os.path.join(output_dir, str(mapping.get("dst_dir", src_rel))) + _copytree_link_or_copy(src_path, dst_dir) + continue + if mapping_type == "glob": + matched = glob.glob(src_path, recursive=True) + if not matched: + raise ValueError(f"Glob mapping matched no files: {src_path}") + for matched_path in matched: + if os.path.isdir(matched_path): + continue + rel_path = os.path.relpath(matched_path, source_dir) + dst_path = os.path.join(output_dir, rel_path) + _link_or_copy_file(matched_path, dst_path) + continue + + if not os.path.isfile(src_path): + raise ValueError(f"File mapping source does not exist: {src_path}") + dst_rel = str(mapping.get("dst", os.path.basename(src_rel))) + dst_path = os.path.join(output_dir, dst_rel) + _link_or_copy_file(src_path, dst_path) + + +def _run_overlay_custom_materializer( + *, + overlay_dir: str, + source_dir: str, + output_dir: str, + manifest: dict[str, Any], +) -> None: + custom_materializer = manifest.get("custom_materializer") + if not custom_materializer: + return + script_path = os.path.join(overlay_dir, str(custom_materializer)) + if not os.path.exists(script_path): + raise ValueError(f"Custom materializer script not found: {script_path}") + + spec = importlib.util.spec_from_file_location( + "_sglang_overlay_materializer", script_path + ) + if spec is None or spec.loader is None: + raise ValueError(f"Failed to import custom materializer: {script_path}") + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + + materialize_fn = getattr(module, "materialize", None) + if materialize_fn is None: + raise ValueError( + f"Custom materializer {script_path} must define materialize(...)" + ) + + materialize_fn( + overlay_dir=overlay_dir, + source_dir=source_dir, + output_dir=output_dir, + manifest=manifest, + ) + + +def materialize_overlay_model( + *, + source_model_id: str, + overlay_spec: dict[str, Any], + overlay_dir: str, + source_dir: str, + verify_diffusers_model_complete_fn: Callable[[str], bool], +) -> str: + overlay_manifest_path = os.path.join( + overlay_dir, "_overlay", "overlay_manifest.json" + ) + if not os.path.exists(overlay_manifest_path): + raise ValueError( + f"Overlay repo for {source_model_id} is missing _overlay/overlay_manifest.json" + ) + + with open(overlay_manifest_path, encoding="utf-8") as f: + manifest = cast(dict[str, Any], json.load(f)) + + materializer_version = str(manifest.get("materializer_version", "v1")) + overlay_repo_id = str(overlay_spec["overlay_repo_id"]) + overlay_revision = str(overlay_spec.get("overlay_revision", "main")) + cache_key = hashlib.sha256( + json.dumps( + { + "source_model_id": source_model_id, + "overlay_repo_id": overlay_repo_id, + "overlay_revision": overlay_revision, + "materializer_version": materializer_version, + }, + sort_keys=True, + ).encode("utf-8") + ).hexdigest()[:16] + cache_root = os.path.join(get_diffusion_cache_root(), "materialized_models") + _ensure_dir(cache_root) + safe_name = source_model_id.replace("/", "__") + final_dir = os.path.join(cache_root, f"{safe_name}-{cache_key}") + marker_path = os.path.join(final_dir, ".sglang_overlay_materialized.json") + if verify_diffusers_model_complete_fn(final_dir) and os.path.exists(marker_path): + return final_dir + + lock_name = ( + f"overlay-materialize::{source_model_id}::{overlay_repo_id}::{overlay_revision}" + ) + with get_lock(lock_name).acquire(poll_interval=2): + if verify_diffusers_model_complete_fn(final_dir) and os.path.exists( + marker_path + ): + return final_dir + + logger.info( + "Materializing overlay model for %s into %s", + source_model_id, + final_dir, + ) + logger.info( + "Overlay source repo: %s, overlay repo: %s@%s", + source_model_id, + overlay_repo_id, + overlay_revision, + ) + tmp_dir = final_dir + ".tmp" + if os.path.exists(tmp_dir): + shutil.rmtree(tmp_dir) + if os.path.exists(final_dir): + shutil.rmtree(final_dir) + logger.info("Copying overlay metadata into temporary materialized directory") + shutil.copytree( + overlay_dir, + tmp_dir, + ignore=shutil.ignore_patterns("*.safetensors", "*.bin", "*.pth", "*.pt"), + ) + + overlay_hidden_dir = os.path.join(tmp_dir, "_overlay") + if os.path.isdir(overlay_hidden_dir): + shutil.rmtree(overlay_hidden_dir) + + file_mappings = manifest.get("file_mappings", []) + if file_mappings: + logger.info("Applying %d overlay file mappings", len(file_mappings)) + _apply_overlay_file_mappings( + source_dir=source_dir, + output_dir=tmp_dir, + file_mappings=cast(list[dict[str, Any]], file_mappings), + ) + if manifest.get("custom_materializer"): + logger.info( + "Running custom overlay materializer: %s", + manifest["custom_materializer"], + ) + _run_overlay_custom_materializer( + overlay_dir=overlay_dir, + source_dir=source_dir, + output_dir=tmp_dir, + manifest=manifest, + ) + + with open(marker_path.replace(final_dir, tmp_dir), "w", encoding="utf-8") as f: + json.dump( + { + "source_model_id": source_model_id, + "source_dir": source_dir, + "overlay_repo_id": overlay_repo_id, + "overlay_revision": overlay_revision, + "materializer_version": materializer_version, + }, + f, + indent=2, + sort_keys=True, + ) + + os.replace(tmp_dir, final_dir) + logger.info("Overlay materialization finished: %s", final_dir) + + return final_dir + + +def maybe_load_overlay_model_index( + model_name_or_path: str, + *, + snapshot_download_fn: Callable[..., str], + hf_hub_download_fn: Callable[..., str], +) -> dict[str, Any] | None: + if os.path.exists(model_name_or_path): + # A local overlay repo already contains the model_index we need. + if load_overlay_manifest_if_present(model_name_or_path) is not None: + return load_model_index_from_dir(model_name_or_path) + return None + + overlay_target = resolve_model_overlay_target(model_name_or_path) + if overlay_target is not None: + # Registry-mapped source model ids first resolve to overlay metadata. + source_model_id, overlay_spec = overlay_target + overlay_dir = download_overlay_metadata( + source_model_id, + overlay_spec, + snapshot_download_fn=snapshot_download_fn, + ) + return load_model_index_from_dir(overlay_dir) + + direct_overlay = resolve_direct_overlay_repo( + model_name_or_path, hf_hub_download_fn=hf_hub_download_fn + ) + if direct_overlay is None: + return None + + _, overlay_dir, _ = direct_overlay + return load_model_index_from_dir(overlay_dir) + + +def maybe_resolve_overlay_model_path( + model_name_or_path: str, + *, + local_dir: str | None, + download: bool, + allow_patterns: list[str] | None, + snapshot_download_fn: Callable[..., str], + hf_hub_download_fn: Callable[..., str], + verify_diffusers_model_complete_fn: Callable[[str], bool], + base_model_download_fn: Callable[..., str], +) -> str | None: + overlay_target = resolve_model_overlay_target(model_name_or_path) + if overlay_target is not None: + source_model_id, overlay_spec = overlay_target + overlay_dir = download_overlay_metadata( + source_model_id, + overlay_spec, + snapshot_download_fn=snapshot_download_fn, + ) + manifest = load_overlay_manifest_if_present(overlay_dir) + if manifest is None: + # Full diffusers overlays do not need materialization. + return base_model_download_fn( + str(overlay_spec["overlay_repo_id"]), + local_dir=local_dir, + download=download, + allow_patterns=allow_patterns, + force_diffusers_model=True, + skip_overlay_resolution=True, + ) + source_allow_patterns = cast( + list[str] | None, manifest.get("source_allow_patterns") + ) + # For local source paths, reuse the directory directly instead of + # round-tripping through snapshot_download. + source_dir = ( + model_name_or_path + if os.path.exists(model_name_or_path) + else base_model_download_fn( + source_model_id, + local_dir=local_dir, + download=download, + allow_patterns=source_allow_patterns or allow_patterns, + force_diffusers_model=False, + skip_overlay_resolution=True, + ) + ) + source_dir = ensure_overlay_source_dir_complete( + source_model_id=source_model_id, + source_dir=source_dir, + manifest=manifest, + local_dir=local_dir, + allow_patterns=allow_patterns, + download=download, + snapshot_download_fn=snapshot_download_fn, + ) + return materialize_overlay_model( + source_model_id=source_model_id, + overlay_spec=overlay_spec, + overlay_dir=overlay_dir, + source_dir=source_dir, + verify_diffusers_model_complete_fn=verify_diffusers_model_complete_fn, + ) + + direct_overlay = resolve_direct_overlay_repo( + model_name_or_path, hf_hub_download_fn=hf_hub_download_fn + ) + if direct_overlay is None: + return None + + overlay_spec, overlay_dir, manifest = direct_overlay + source_model_id = str(manifest["source_model_id"]) + # Direct overlay repos are always metadata-only; they need the original + # source weights before they can be materialized into a diffusers-like dir. + source_allow_patterns = cast( + list[str] | None, manifest.get("source_allow_patterns") + ) + source_dir = base_model_download_fn( + source_model_id, + local_dir=local_dir, + download=download, + allow_patterns=source_allow_patterns or allow_patterns, + force_diffusers_model=False, + skip_overlay_resolution=True, + ) + source_dir = ensure_overlay_source_dir_complete( + source_model_id=source_model_id, + source_dir=source_dir, + manifest=manifest, + local_dir=local_dir, + allow_patterns=allow_patterns, + download=download, + snapshot_download_fn=snapshot_download_fn, + ) + return materialize_overlay_model( + source_model_id=source_model_id, + overlay_spec=overlay_spec, + overlay_dir=overlay_dir, + source_dir=source_dir, + verify_diffusers_model_complete_fn=verify_diffusers_model_complete_fn, + ) diff --git a/python/sglang/multimodal_gen/runtime/utils/perf_logger.py b/python/sglang/multimodal_gen/runtime/utils/perf_logger.py index d44a50d37f19..9b1eb68cf4d3 100644 --- a/python/sglang/multimodal_gen/runtime/utils/perf_logger.py +++ b/python/sglang/multimodal_gen/runtime/utils/perf_logger.py @@ -64,9 +64,8 @@ def record_stage(self, stage_name: str, duration_s: float): """Records the duration of a pipeline stage""" self.stages[stage_name] = duration_s * 1000 # Store as milliseconds - def record_steps(self, index: int, duration_s: float): - """Records the duration of a denoising step""" - assert index == len(self.steps) + def record_step(self, duration_s: float): + """Records the duration of a denoising step in execution order.""" self.steps.append(duration_s * 1000) def record_memory_snapshot(self, checkpoint_name: str, snapshot: MemorySnapshot): @@ -192,6 +191,7 @@ def __init__( log_stage_start_end: bool = False, perf_dump_path_provided: bool = False, capture_memory: bool = False, + record_as_step: bool = False, ): self.stage_name = stage_name self.metrics = metrics @@ -200,6 +200,10 @@ def __init__( self.log_timing = perf_dump_path_provided or envs.SGLANG_DIFFUSION_STAGE_LOGGING self.log_stage_start_end = log_stage_start_end self.capture_memory = capture_memory + self.record_as_step = record_as_step + + def _should_record_as_step(self) -> bool: + return self.record_as_step or self.stage_name.startswith("denoising_step_") def __enter__(self): if self.log_stage_start_end: @@ -211,7 +215,7 @@ def __enter__(self): if (self.log_timing and self.metrics) or self.log_stage_start_end: if ( os.environ.get("SGLANG_DIFFUSION_SYNC_STAGE_PROFILING", "0") == "1" - and self.stage_name.startswith("denoising_step_") + and self._should_record_as_step() and torch.get_device_module().is_available() ): torch.get_device_module().synchronize() @@ -225,7 +229,7 @@ def __exit__(self, exc_type, exc_val, exc_tb): if ( os.environ.get("SGLANG_DIFFUSION_SYNC_STAGE_PROFILING", "0") == "1" - and self.stage_name.startswith("denoising_step_") + and self._should_record_as_step() and torch.get_device_module().is_available() ): torch.get_device_module().synchronize() @@ -247,9 +251,8 @@ def __exit__(self, exc_type, exc_val, exc_tb): ) if self.log_timing and self.metrics: - if "denoising_step_" in self.stage_name: - index = int(self.stage_name[len("denoising_step_") :]) - self.metrics.record_steps(index, execution_time_s) + if self._should_record_as_step(): + self.metrics.record_step(execution_time_s) else: self.metrics.record_stage(self.stage_name, execution_time_s) diff --git a/python/sglang/multimodal_gen/test/run_suite.py b/python/sglang/multimodal_gen/test/run_suite.py index f7182b9bb7f9..4178a779754e 100644 --- a/python/sglang/multimodal_gen/test/run_suite.py +++ b/python/sglang/multimodal_gen/test/run_suite.py @@ -35,6 +35,8 @@ "../unit/test_storage.py", "../unit/test_lora_format_adapter.py", "../unit/test_server_args.py", + "../unit/test_input_validation.py", + "../unit/test_resolve_prompts.py", # add new unit tests here ], "1-gpu": [ @@ -50,6 +52,9 @@ "test_server_2_gpu_b.py", # add new 2-gpu test files here ], + "1-gpu-b200": [ + "test_server_c.py", + ], } suites_ascend = { @@ -78,7 +83,7 @@ def parse_args(): type=str, required=True, choices=list(SUITES.keys()), - help="The test suite to run (e.g., 1-gpu, 2-gpu)", + help="The test suite to run (valid names are defined in SUITES)", ) parser.add_argument( "--partition-id", @@ -233,7 +238,9 @@ def run_pytest(files, filter_expr=None): ) is_flaky_ci_assertion = ( - "SafetensorError" in full_output or "FileNotFoundError" in full_output + "SafetensorError" in full_output + or "FileNotFoundError" in full_output + or "TimeoutError" in full_output ) is_oom_error = ( diff --git a/python/sglang/multimodal_gen/test/scripts/gen_diffusion_ci_outputs.py b/python/sglang/multimodal_gen/test/scripts/gen_diffusion_ci_outputs.py index bcc82fab9ce6..645a9cac5486 100755 --- a/python/sglang/multimodal_gen/test/scripts/gen_diffusion_ci_outputs.py +++ b/python/sglang/multimodal_gen/test/scripts/gen_diffusion_ci_outputs.py @@ -8,6 +8,7 @@ Usage: python gen_diffusion_ci_outputs.py --suite 1-gpu --partition-id 0 --total-partitions 2 --out-dir ./output python gen_diffusion_ci_outputs.py --suite 1-gpu --case-ids qwen_image_t2i flux_image_t2i --out-dir ./output + python gen_diffusion_ci_outputs.py --suite 1-gpu-b200 --out-dir ./output """ import argparse @@ -27,9 +28,9 @@ def main(): parser.add_argument( "--suite", type=str, - choices=["1-gpu", "2-gpu"], + choices=list(SUITES.keys()), required=True, - help="Test suite to run (1-gpu or 2-gpu)", + help="Test suite to run (choices: " + ", ".join(list(SUITES.keys())) + ")", ) parser.add_argument( "--partition-id", diff --git a/python/sglang/multimodal_gen/test/server/accuracy_config.py b/python/sglang/multimodal_gen/test/server/accuracy_config.py new file mode 100644 index 000000000000..d125e289d28a --- /dev/null +++ b/python/sglang/multimodal_gen/test/server/accuracy_config.py @@ -0,0 +1,386 @@ +from __future__ import annotations + +from dataclasses import dataclass +from enum import Enum +from typing import Dict, Optional + +from sglang.multimodal_gen.test.server.testcase_configs import DiffusionTestCase + + +class ComponentType(str, Enum): + VAE = "vae" + TRANSFORMER = "transformer" + TEXT_ENCODER = "text_encoder" + + +@dataclass(frozen=True) +class ComponentSkip: + reason: str + + +DEFAULT_TIMESTEP = 500.0 +TIMESTEP_NORMALIZATION_FACTOR = 1000.0 +I2V_IMAGE_DIM = 1280 +I2V_TEXT_ENCODER_DIM = 5120 + +DEFAULT_TEXT_ENCODER_VOCAB_SIZE = 32000 +TEXT_ENCODER_INPUT_SEED = 42 +TEXT_ENCODER_TOKEN_MIN = 100 +TEXT_ENCODER_TOKEN_MAX = 30000 +TEXT_ENCODER_TOKEN_LENGTH = 32 + +# Default thresholds by component. Override per component/case if needed. +DEFAULT_THRESHOLDS = { + ComponentType.VAE: 0.999, + ComponentType.TRANSFORMER: 0.995, + ComponentType.TEXT_ENCODER: 0.98, +} + +# Optional per-case overrides: {case_id: {ComponentType: threshold}} +CASE_THRESHOLDS: Dict[str, Dict[ComponentType, float]] = { + # Add overrides here when a specific model/component needs a different threshold. + "flux_2_image_t2i": {ComponentType.TRANSFORMER: 0.99}, + "flux_2_image_t2i_layerwise_offload": {ComponentType.TRANSFORMER: 0.99}, + "flux_2_image_t2i_2_gpus": {ComponentType.TRANSFORMER: 0.99}, + "flux_2_klein_ti2i_2_gpus": {ComponentType.TRANSFORMER: 0.975}, + "flux_2_ti2i": {ComponentType.TRANSFORMER: 0.99}, + "flux_2_t2i_customized_vae_path": {ComponentType.TRANSFORMER: 0.99}, + "fast_hunyuan_video": {ComponentType.TRANSFORMER: 0.99}, + "fsdp-inference": {ComponentType.TRANSFORMER: 0.9935}, + "wan2_2_i2v_a14b_2gpu": {ComponentType.TRANSFORMER: 0.99}, + "wan2_2_t2v_a14b_2gpu": {ComponentType.TRANSFORMER: 0.99}, + "wan2_2_t2v_a14b_teacache_2gpu": {ComponentType.TRANSFORMER: 0.99}, + "wan2_2_t2v_a14b_lora_2gpu": {ComponentType.TRANSFORMER: 0.99}, + "zimage_image_t2i_2_gpus": {ComponentType.TRANSFORMER: 0.9935}, + "zimage_image_t2i_2_gpus_non_square": {ComponentType.TRANSFORMER: 0.9935}, +} + +# Active skip policy. Keep this limited to cases with current, concrete evidence +# of real divergence or unsupported reference loading in the harness. +SKIP_COMPONENTS: Dict[str, Dict[ComponentType, ComponentSkip]] = { + "flux_image_t2i": { + ComponentType.TEXT_ENCODER: ComponentSkip( + "Text encoder diverges from HF baseline despite 100% matched weights (CosSim ~0.47)" + ) + }, + "sana_image_t2i": { + ComponentType.VAE: ComponentSkip( + "HF AutoencoderDC checkpoint leaves required to_qkv_multiscale weights missing, so VAE transfer would compare against partially initialized reference weights" + ) + }, + "mova_360p_1gpu": { + ComponentType.TRANSFORMER: ComponentSkip( + "HF reference transformer cannot be materialized from the video_dit repo layout" + ) + }, + "qwen_image_t2i_cache_dit_enabled": { + ComponentType.VAE: ComponentSkip( + "Representative VAE accuracy is already covered by qwen_image_t2i for the same source component and topology" + ), + ComponentType.TRANSFORMER: ComponentSkip( + "Representative transformer accuracy is already covered by qwen_image_t2i for the same source component and topology" + ), + ComponentType.TEXT_ENCODER: ComponentSkip( + "Representative text encoder accuracy is already covered by qwen_image_t2i for the same source component and topology" + ), + }, + "flux_2_image_t2i_upscaling_4x": { + ComponentType.VAE: ComponentSkip( + "Representative VAE accuracy is already covered by flux_2_image_t2i for the same source component and topology" + ), + ComponentType.TRANSFORMER: ComponentSkip( + "Representative transformer accuracy is already covered by flux_2_image_t2i for the same source component and topology" + ), + ComponentType.TEXT_ENCODER: ComponentSkip( + "Representative text encoder accuracy is already covered by flux_2_image_t2i for the same source component and topology" + ), + }, + "layerwise_offload": { + ComponentType.VAE: ComponentSkip( + "Representative VAE accuracy is already covered by zimage_image_t2i for the same source component and topology" + ), + ComponentType.TRANSFORMER: ComponentSkip( + "Representative transformer accuracy is already covered by zimage_image_t2i for the same source component and topology" + ), + ComponentType.TEXT_ENCODER: ComponentSkip( + "Representative text encoder accuracy is already covered by zimage_image_t2i for the same source component and topology" + ), + }, + "zimage_image_t2i_fp8": { + ComponentType.VAE: ComponentSkip( + "Representative VAE accuracy is already covered by zimage_image_t2i for the same source component and topology" + ), + ComponentType.TEXT_ENCODER: ComponentSkip( + "Representative text encoder accuracy is already covered by zimage_image_t2i for the same source component and topology" + ), + }, + "zimage_image_t2i_multi_lora": { + ComponentType.VAE: ComponentSkip( + "Representative VAE accuracy is already covered by zimage_image_t2i for the same source component and topology" + ), + ComponentType.TRANSFORMER: ComponentSkip( + "Representative transformer accuracy is already covered by zimage_image_t2i for the same source component and topology" + ), + ComponentType.TEXT_ENCODER: ComponentSkip( + "Representative text encoder accuracy is already covered by zimage_image_t2i for the same source component and topology" + ), + }, + "flux_2_ti2i": { + ComponentType.VAE: ComponentSkip( + "Representative VAE accuracy is already covered by flux_2_image_t2i for the same source component and topology" + ), + ComponentType.TRANSFORMER: ComponentSkip( + "Representative transformer accuracy is already covered by flux_2_image_t2i for the same source component and topology" + ), + ComponentType.TEXT_ENCODER: ComponentSkip( + "Representative text encoder accuracy is already covered by flux_2_image_t2i for the same source component and topology" + ), + }, + "flux_2_t2i_customized_vae_path": { + ComponentType.VAE: ComponentSkip( + "Customized VAE override points to FLUX.2 Tiny AutoEncoder, but the HF reference loader does not yet materialize a trustworthy matching VAE baseline" + ), + ComponentType.TRANSFORMER: ComponentSkip( + "Representative transformer accuracy is already covered by flux_2_image_t2i for the same source component and topology" + ), + ComponentType.TEXT_ENCODER: ComponentSkip( + "Representative text encoder accuracy is already covered by flux_2_image_t2i for the same source component and topology" + ), + }, + "wan2_1_t2v_1.3b_text_encoder_cpu_offload": { + ComponentType.VAE: ComponentSkip( + "Representative VAE accuracy is already covered by wan2_1_t2v_1.3b for the same source component and topology" + ), + ComponentType.TRANSFORMER: ComponentSkip( + "Representative transformer accuracy is already covered by wan2_1_t2v_1.3b for the same source component and topology" + ), + ComponentType.TEXT_ENCODER: ComponentSkip( + "Representative text encoder accuracy is already covered by wan2_1_t2v_1.3b for the same source component and topology" + ), + }, + "wan2_1_t2v_1.3b_teacache_enabled": { + ComponentType.VAE: ComponentSkip( + "Representative VAE accuracy is already covered by wan2_1_t2v_1.3b for the same source component and topology" + ), + ComponentType.TRANSFORMER: ComponentSkip( + "Representative transformer accuracy is already covered by wan2_1_t2v_1.3b for the same source component and topology" + ), + ComponentType.TEXT_ENCODER: ComponentSkip( + "Representative text encoder accuracy is already covered by wan2_1_t2v_1.3b for the same source component and topology" + ), + }, + "wan2_1_t2v_1.3b_frame_interp_2x": { + ComponentType.VAE: ComponentSkip( + "Representative VAE accuracy is already covered by wan2_1_t2v_1.3b for the same source component and topology" + ), + ComponentType.TRANSFORMER: ComponentSkip( + "Representative transformer accuracy is already covered by wan2_1_t2v_1.3b for the same source component and topology" + ), + ComponentType.TEXT_ENCODER: ComponentSkip( + "Representative text encoder accuracy is already covered by wan2_1_t2v_1.3b for the same source component and topology" + ), + }, + "wan2_1_t2v_1.3b_upscaling_4x": { + ComponentType.VAE: ComponentSkip( + "Representative VAE accuracy is already covered by wan2_1_t2v_1.3b for the same source component and topology" + ), + ComponentType.TRANSFORMER: ComponentSkip( + "Representative transformer accuracy is already covered by wan2_1_t2v_1.3b for the same source component and topology" + ), + ComponentType.TEXT_ENCODER: ComponentSkip( + "Representative text encoder accuracy is already covered by wan2_1_t2v_1.3b for the same source component and topology" + ), + }, + "wan2_1_t2v_1.3b_frame_interp_2x_upscaling_4x": { + ComponentType.VAE: ComponentSkip( + "Representative VAE accuracy is already covered by wan2_1_t2v_1.3b for the same source component and topology" + ), + ComponentType.TRANSFORMER: ComponentSkip( + "Representative transformer accuracy is already covered by wan2_1_t2v_1.3b for the same source component and topology" + ), + ComponentType.TEXT_ENCODER: ComponentSkip( + "Representative text encoder accuracy is already covered by wan2_1_t2v_1.3b for the same source component and topology" + ), + }, + "wan2_1_t2v_1_3b_lora_1gpu": { + ComponentType.VAE: ComponentSkip( + "Representative VAE accuracy is already covered by wan2_1_t2v_1.3b for the same source component and topology" + ), + ComponentType.TRANSFORMER: ComponentSkip( + "Representative transformer accuracy is already covered by wan2_1_t2v_1.3b for the same source component and topology" + ), + ComponentType.TEXT_ENCODER: ComponentSkip( + "Representative text encoder accuracy is already covered by wan2_1_t2v_1.3b for the same source component and topology" + ), + }, + "flux_2_ti2i_multi_image_cache_dit": { + ComponentType.VAE: ComponentSkip( + "Representative VAE accuracy is already covered by flux_2_image_t2i for the same source component and topology" + ), + ComponentType.TRANSFORMER: ComponentSkip( + "Representative transformer accuracy is already covered by flux_2_image_t2i for the same source component and topology" + ), + ComponentType.TEXT_ENCODER: ComponentSkip( + "Representative text encoder accuracy is already covered by flux_2_image_t2i for the same source component and topology" + ), + }, + "wan2_2_ti2v_5b": { + ComponentType.TRANSFORMER: ComponentSkip( + "SGLang transformer loader rejects new parameters in HF checkpoint" + ) + }, + "fastwan2_2_ti2v_5b": { + ComponentType.TRANSFORMER: ComponentSkip( + "SGLang transformer loader rejects new parameters in HF checkpoint" + ) + }, + "turbo_wan2_1_t2v_1.3b": { + ComponentType.TRANSFORMER: ComponentSkip( + "Weight transfer match ratio too low for reliable comparison" + ) + }, + "wan2_1_i2v_14b_480P_2gpu": { + ComponentType.TRANSFORMER: ComponentSkip( + "Transformer diverges from Diffusers baseline in 2-GPU accuracy run (CosSim ~0.71) after full weight transfer and matching output shape" + ), + ComponentType.TEXT_ENCODER: ComponentSkip( + "Text encoder diverges from HF baseline in 2-GPU SP-folded accuracy run (CosSim ~0.31) after 100% matched weight transfer" + ), + }, + "wan2_1_i2v_14b_lora_2gpu": { + ComponentType.VAE: ComponentSkip( + "Representative VAE accuracy is already covered by wan2_1_i2v_14b_720P_2gpu for the same source component and topology" + ), + ComponentType.TRANSFORMER: ComponentSkip( + "Transformer diverges from Diffusers baseline in 2-GPU accuracy run (CosSim ~0.68) after full weight transfer and matching output shape" + ), + ComponentType.TEXT_ENCODER: ComponentSkip( + "Text encoder diverges from HF baseline in 2-GPU SP-folded accuracy run (CosSim ~0.31) after 100% matched weight transfer" + ), + }, + "wan2_1_i2v_14b_720P_2gpu": { + ComponentType.TRANSFORMER: ComponentSkip( + "Transformer diverges from Diffusers baseline in 2-GPU accuracy run (CosSim ~0.68) after full weight transfer and matching output shape" + ), + ComponentType.TEXT_ENCODER: ComponentSkip( + "Text encoder diverges from HF baseline in 2-GPU SP-folded accuracy run (CosSim ~0.31) after 100% matched weight transfer" + ), + }, + "wan2_2_i2v_a14b_2gpu": { + ComponentType.TEXT_ENCODER: ComponentSkip( + "Text encoder diverges from HF baseline in 2-GPU SP-folded accuracy run (CosSim ~0.31) after 100% matched weight transfer" + ) + }, + "wan2_2_t2v_a14b_2gpu": { + ComponentType.TEXT_ENCODER: ComponentSkip( + "Text encoder diverges from HF baseline in 2-GPU SP-folded accuracy run (CosSim ~0.31) after 100% matched weight transfer" + ) + }, + "wan2_2_t2v_a14b_teacache_2gpu": { + ComponentType.VAE: ComponentSkip( + "Representative VAE accuracy is already covered by wan2_2_t2v_a14b_2gpu for the same source component and topology" + ), + ComponentType.TRANSFORMER: ComponentSkip( + "Representative transformer accuracy is already covered by wan2_2_t2v_a14b_2gpu for the same source component and topology" + ), + ComponentType.TEXT_ENCODER: ComponentSkip( + "Text encoder diverges from HF baseline in 2-GPU SP-folded accuracy run (CosSim ~0.31) after 100% matched weight transfer" + ), + }, + "wan2_2_t2v_a14b_lora_2gpu": { + ComponentType.VAE: ComponentSkip( + "Representative VAE accuracy is already covered by wan2_2_t2v_a14b_2gpu for the same source component and topology" + ), + ComponentType.TRANSFORMER: ComponentSkip( + "Representative transformer accuracy is already covered by wan2_2_t2v_a14b_2gpu for the same source component and topology" + ), + ComponentType.TEXT_ENCODER: ComponentSkip( + "Text encoder diverges from HF baseline in 2-GPU SP-folded accuracy run (CosSim ~0.31) after 100% matched weight transfer" + ), + }, + "wan2_1_t2v_14b_2gpu": { + ComponentType.TEXT_ENCODER: ComponentSkip( + "Text encoder diverges from HF baseline in 2-GPU SP-folded accuracy run (CosSim ~0.31) after 100% matched weight transfer" + ) + }, + "wan2_1_t2v_1.3b_cfg_parallel": { + ComponentType.TEXT_ENCODER: ComponentSkip( + "Text encoder diverges from HF baseline in 2-GPU accuracy run (CosSim ~0.31) after 100% matched weight transfer" + ) + }, + "mova_360p_tp2": { + ComponentType.TRANSFORMER: ComponentSkip( + "HF reference transformer cannot be materialized from the MOVA video_dit repo layout" + ), + ComponentType.TEXT_ENCODER: ComponentSkip( + "Text encoder diverges from HF baseline in 2-GPU accuracy run (CosSim ~0.31) after 100% matched weight transfer" + ), + }, + "mova_360p_ring1_uly2": { + ComponentType.TRANSFORMER: ComponentSkip( + "HF reference transformer cannot be materialized from the MOVA video_dit repo layout" + ), + ComponentType.TEXT_ENCODER: ComponentSkip( + "Text encoder diverges from HF baseline in 2-GPU accuracy run (CosSim ~0.31) after 100% matched weight transfer" + ), + }, + "mova_360p_ring2_uly1": { + ComponentType.TRANSFORMER: ComponentSkip( + "HF reference transformer cannot be materialized from the MOVA video_dit repo layout" + ), + ComponentType.TEXT_ENCODER: ComponentSkip( + "Text encoder diverges from HF baseline in 2-GPU accuracy run (CosSim ~0.31) after 100% matched weight transfer" + ), + }, + "flux_image_t2i_2_gpus": { + ComponentType.TEXT_ENCODER: ComponentSkip( + "Text encoder diverges from HF baseline in 2-GPU accuracy run (CosSim ~0.47) after 100% matched weight transfer" + ) + }, + "zimage_image_t2i_2_gpus_non_square": { + ComponentType.VAE: ComponentSkip( + "Representative VAE accuracy is already covered by zimage_image_t2i_2_gpus for the same source component and topology" + ), + ComponentType.TRANSFORMER: ComponentSkip( + "Representative transformer accuracy is already covered by zimage_image_t2i_2_gpus for the same source component and topology" + ), + ComponentType.TEXT_ENCODER: ComponentSkip( + "Representative text encoder accuracy is already covered by zimage_image_t2i_2_gpus for the same source component and topology" + ), + }, + "flux_2_image_t2i_2_gpus": { + ComponentType.TRANSFORMER: ComponentSkip( + "2-GPU FLUX.2 transformer diverges strongly from Diffusers baseline (CosSim ~0.54) despite full weight transfer" + ) + }, + "hunyuan3d_shape_gen": { + ComponentType.VAE: ComponentSkip( + "HF config cannot be parsed as valid JSON for component reference loading" + ), + ComponentType.TRANSFORMER: ComponentSkip( + "HF config cannot be parsed as valid JSON for component reference loading" + ), + ComponentType.TEXT_ENCODER: ComponentSkip( + "HF config cannot be parsed as valid JSON for component reference loading" + ), + }, +} + +# TODO: If a model needs extra compatibility logic, prefer adding a skip or an +# explicit override here instead of adding more ad-hoc hacks in the engine. + + +def get_threshold(case_id: str, component: ComponentType) -> float: + overrides = CASE_THRESHOLDS.get(case_id, {}) + return overrides.get(component, DEFAULT_THRESHOLDS[component]) + + +def get_skip_reason(case: DiffusionTestCase, component: ComponentType) -> Optional[str]: + skip_entry = SKIP_COMPONENTS.get(case.id, {}).get(component) + if skip_entry is None: + return None + return skip_entry.reason + + +def should_skip_component(case: DiffusionTestCase, component: ComponentType) -> bool: + return get_skip_reason(case, component) is not None diff --git a/python/sglang/multimodal_gen/test/server/accuracy_hooks.py b/python/sglang/multimodal_gen/test/server/accuracy_hooks.py new file mode 100644 index 000000000000..0ba46437d929 --- /dev/null +++ b/python/sglang/multimodal_gen/test/server/accuracy_hooks.py @@ -0,0 +1,579 @@ +from __future__ import annotations + +import inspect +from dataclasses import dataclass, field +from typing import Any, Callable, Dict, Optional + +import torch +import torch.nn as nn + +from sglang.multimodal_gen.test.server.accuracy_config import ( + DEFAULT_TIMESTEP, + I2V_IMAGE_DIM, + TIMESTEP_NORMALIZATION_FACTOR, + ComponentType, +) +from sglang.multimodal_gen.test.server.accuracy_utils import ( + extract_output_tensor, + seed_and_broadcast, +) + +Inputs = Dict[str, Any] +BuildInputsFn = Callable[[Any, nn.Module, str, Optional[nn.Module]], Inputs] +PrepareCallFn = Callable[[nn.Module, Inputs], "HookCall"] +NormalizeFn = Callable[[Any], torch.Tensor] + +# These are harness defaults for synthetic accuracy inputs. +# They are not checkpoint truth. We use them only when the model config or +# forward signature does not expose a more specific shape or channel count. +DEFAULT_TEXT_SEQ_LEN = 64 +DEFAULT_TOKEN_LAYOUT_SIZE = 32 +REDUCED_TOKEN_LAYOUT_SIZE = 16 +DEFAULT_VIDEO_FRAME_COUNT = 4 +DEFAULT_IMAGE_TOKEN_COUNT = 257 +ALIAS_ROTARY_TEXT_PAD_MULTIPLE = 32 +DEFAULT_TRANSFORMER_IN_CHANNELS = 16 +DEFAULT_TRANSFORMER_TEXT_CHANNELS = 4096 +DEFAULT_TRANSFORMER_POOLED_CHANNELS = 768 +DEFAULT_VAE_LATENT_CHANNELS = 16 +DEFAULT_VAE_LATENT_SPATIAL_SIZE = 32 +LARGE_CHANNEL_LAYOUT_THRESHOLD = 128 + + +@dataclass(frozen=True) +class TransformerHookCompat: + normalize_reference_timestep: bool = False + negate_reference_output: bool = False + omit_reference_guidance: bool = False + use_2d_hidden_states: bool = False + + +def _resolve_transformer_hook_compat(case: Any) -> TransformerHookCompat: + model_path = case.server_args.model_path.lower() + if "z-image" in model_path: + return TransformerHookCompat( + normalize_reference_timestep=True, + negate_reference_output=True, + ) + if "qwen" in model_path: + return TransformerHookCompat( + normalize_reference_timestep=True, + omit_reference_guidance=True, + ) + if "sana" in model_path: + return TransformerHookCompat( + omit_reference_guidance=True, + use_2d_hidden_states=True, + ) + if "flux" in model_path: + return TransformerHookCompat(normalize_reference_timestep=True) + return TransformerHookCompat() + + +@dataclass +class HookCall: + module: nn.Module + args: tuple[Any, ...] = () + kwargs: Dict[str, Any] = field(default_factory=dict) + negate_output: bool = False + + +@dataclass(frozen=True) +class NativeHookProfile: + build_inputs: BuildInputsFn + prepare_sglang_call: PrepareCallFn + prepare_reference_call: PrepareCallFn + normalize_sglang_output: NormalizeFn = extract_output_tensor + normalize_reference_output: NormalizeFn = extract_output_tensor + + +class _DeterministicRNG: + def __init__(self, seed: int = 42) -> None: + self._seed = seed + + def randn( + self, shape: tuple[int, ...], device: str, dtype: torch.dtype + ) -> torch.Tensor: + torch.manual_seed(self._seed) + tensor = torch.randn(shape, device="cpu", dtype=dtype).to(device) + seed_and_broadcast(self._seed, tensor) + self._seed += 1 + return tensor + + +def _resolve_nested_attr(obj: Any, path: str) -> Any: + current = obj + for name in path.split("."): + if current is None or not hasattr(current, name): + return None + current = getattr(current, name) + return current + + +def _read_config_value(model: nn.Module, keys: list[str], default: int) -> int: + config = getattr(model, "config", None) + for key in keys: + for root in (model, config): + value = _resolve_nested_attr(root, key) if root is not None else None + if isinstance(value, int) and value > 0: + return value + return default + + +def _forward_parameter_names(module: nn.Module) -> set[str]: + return set(inspect.signature(module.forward).parameters.keys()) + + +def _infer_transformer_layout(param_names: set[str]) -> str: + if "img_shapes" in param_names or "txt_seq_lens" in param_names: + return "token_shapes" + if "img_ids" in param_names or "txt_ids" in param_names: + return "token_ids" + if "x" in param_names or "cap_feats" in param_names: + return "alias" + return "video" + + +def _build_position_ids( + height: int, width: int, dims: int, device: str +) -> tuple[torch.Tensor, torch.Tensor]: + img_len = height * width + txt_len = DEFAULT_TEXT_SEQ_LEN + if dims == 4: + img_ids = torch.zeros(img_len, 4, device=device, dtype=torch.bfloat16) + img_ids[:, 1] = torch.arange(height).repeat_interleave(width) + img_ids[:, 2] = torch.arange(width).repeat(height) + txt_ids = torch.zeros(txt_len, 4, device=device, dtype=torch.bfloat16) + else: + img_ids = torch.zeros(img_len, 3, device=device, dtype=torch.bfloat16) + img_ids[:, 0] = torch.arange(height).repeat_interleave(width) + img_ids[:, 1] = torch.arange(width).repeat(height) + txt_ids = torch.zeros(txt_len, 3, device=device, dtype=torch.bfloat16) + return img_ids, txt_ids + + +def _build_alias_rotary_freqs( + model: nn.Module, device: str, height: int, width: int +) -> tuple[tuple[torch.Tensor, torch.Tensor], tuple[torch.Tensor, torch.Tensor]]: + cap_len = DEFAULT_TEXT_SEQ_LEN + cap_pad_len = (-cap_len) % ALIAS_ROTARY_TEXT_PAD_MULTIPLE + cap_ids = ( + torch.stack( + torch.meshgrid( + torch.arange(cap_len + cap_pad_len), + torch.arange(1), + torch.arange(1), + indexing="ij", + ), + dim=-1, + ) + .flatten(0, 2) + .to(device) + ) + img_ids = ( + torch.stack( + torch.meshgrid( + torch.arange(1), + torch.arange(height // 2), + torch.arange(width // 2), + indexing="ij", + ), + dim=-1, + ) + .flatten(0, 2) + .to(device) + ) + cos_cap, sin_cap = model.rotary_emb(cap_ids) + cos_img, sin_img = model.rotary_emb(img_ids) + return ((cos_cap, sin_cap), (cos_img, sin_img)) + + +def _supports_image_conditioning(module: nn.Module) -> bool: + image_embedder = _resolve_nested_attr(module, "condition_embedder.image_embedder") + if image_embedder is not None: + return True + image_dim = _read_config_value( + module, ["arch_config.image_dim", "image_dim"], default=0 + ) + return image_dim > 0 + + +def _build_transformer_hook_inputs( + case: Any, model: nn.Module, device: str, ref_model: Optional[nn.Module] = None +) -> Inputs: + """Build one synthetic input bundle that both transformer variants can consume.""" + compat = _resolve_transformer_hook_compat(case) + param_names = _forward_parameter_names(model) + if ref_model is not None: + # The input bundle has to satisfy both call signatures. + param_names.update(_forward_parameter_names(ref_model)) + + rng = _DeterministicRNG() + layout = _infer_transformer_layout(param_names) + in_channels = _read_config_value( + model, + [ + "arch_config.in_channels", + "in_channels", + "transformer_config.in_channels", + ], + default=DEFAULT_TRANSFORMER_IN_CHANNELS, + ) + text_channels = _read_config_value( + model, + [ + "text_states_dim", + "arch_config.cap_feat_dim", + "cap_feat_dim", + "caption_channels", + "arch_config.text_dim", + "text_dim", + "arch_config.text_embed_dim", + "text_embed_dim", + "arch_config.joint_attention_dim", + "joint_attention_dim", + "cross_attention_dim", + "hidden_size", + "dim", + ], + default=DEFAULT_TRANSFORMER_TEXT_CHANNELS, + ) + pooled_channels = _read_config_value( + model, + [ + "text_states_dim_2", + "arch_config.pooled_projection_dim", + "pooled_projection_dim", + "pooled_embed_dim", + "text_embed_dim", + "projection_dim", + ], + default=DEFAULT_TRANSFORMER_POOLED_CHANNELS, + ) + image_channels = _read_config_value( + model, + ["arch_config.image_dim", "image_dim", "cross_attention_dim"], + default=I2V_IMAGE_DIM, + ) + + if layout == "token_shapes": + height, width = DEFAULT_TOKEN_LAYOUT_SIZE, DEFAULT_TOKEN_LAYOUT_SIZE + seq_len = (height // 2) * (width // 2) + hidden_states = rng.randn((1, seq_len, in_channels), device, torch.bfloat16) + elif layout == "token_ids": + height, width = REDUCED_TOKEN_LAYOUT_SIZE, REDUCED_TOKEN_LAYOUT_SIZE + seq_len = height * width + hidden_states = rng.randn((1, seq_len, in_channels), device, torch.bfloat16) + elif layout == "alias": + height, width = DEFAULT_TOKEN_LAYOUT_SIZE, DEFAULT_TOKEN_LAYOUT_SIZE + hidden_states = rng.randn( + (1, in_channels, 1, height, width), device, torch.bfloat16 + ) + elif compat.use_2d_hidden_states: + spatial_size = ( + REDUCED_TOKEN_LAYOUT_SIZE + if "encoder_attention_mask" in param_names + or "encoder_hidden_states_mask" in param_names + else DEFAULT_TOKEN_LAYOUT_SIZE + ) + height, width = spatial_size, spatial_size + hidden_states = rng.randn( + (1, in_channels, height, width), + device, + torch.bfloat16, + ) + else: + spatial_size = ( + REDUCED_TOKEN_LAYOUT_SIZE + if "encoder_attention_mask" in param_names + or "encoder_hidden_states_mask" in param_names + else DEFAULT_TOKEN_LAYOUT_SIZE + ) + height, width = spatial_size, spatial_size + hidden_states = rng.randn( + (1, in_channels, DEFAULT_VIDEO_FRAME_COUNT, height, width), + device, + torch.bfloat16, + ) + + inputs: Inputs = { + "hidden_states": hidden_states, + "encoder_hidden_states": rng.randn( + (1, DEFAULT_TEXT_SEQ_LEN, text_channels), device, torch.bfloat16 + ), + "timestep": torch.tensor( + [DEFAULT_TIMESTEP], device=device, dtype=torch.bfloat16 + ), + "guidance": torch.tensor([1.0], device=device, dtype=torch.bfloat16), + } + + if "pooled_projections" in param_names: + inputs["pooled_projections"] = rng.randn( + (1, pooled_channels), device, torch.bfloat16 + ) + if ( + "encoder_attention_mask" in param_names + or "encoder_hidden_states_mask" in param_names + ): + attention_mask = torch.ones( + 1, DEFAULT_TEXT_SEQ_LEN, device=device, dtype=torch.bool + ) + inputs["encoder_attention_mask"] = attention_mask + inputs["encoder_hidden_states_mask"] = attention_mask + if "encoder_hidden_states_image" in param_names and _supports_image_conditioning( + model + ): + inputs["encoder_hidden_states_image"] = rng.randn( + (1, DEFAULT_IMAGE_TOKEN_COUNT, image_channels), device, torch.bfloat16 + ) + if "additional_t_cond" in param_names: + inputs["additional_t_cond"] = torch.zeros((1,), device=device, dtype=torch.long) + if "img_shapes" in param_names: + inputs["img_shapes"] = [[(1, height // 2, width // 2)]] + if "txt_seq_lens" in param_names: + inputs["txt_seq_lens"] = [DEFAULT_TEXT_SEQ_LEN] + if "img_ids" in param_names or "txt_ids" in param_names: + id_dims = 4 if in_channels >= LARGE_CHANNEL_LAYOUT_THRESHOLD else 3 + img_ids, txt_ids = _build_position_ids(height, width, id_dims, device) + inputs["img_ids"] = img_ids + inputs["txt_ids"] = txt_ids + + if "freqs_cis" in param_names and hasattr(model, "rotary_emb"): + if "img_shapes" in inputs and "txt_seq_lens" in inputs: + img_freqs, txt_freqs = model.rotary_emb( + inputs["img_shapes"], + inputs["txt_seq_lens"], + device=hidden_states.device, + ) + if torch.is_complex(img_freqs) and torch.is_complex(txt_freqs): + inputs["freqs_cis"] = ( + torch.cat([img_freqs.real.float(), img_freqs.imag.float()], dim=-1), + torch.cat([txt_freqs.real.float(), txt_freqs.imag.float()], dim=-1), + ) + else: + inputs["freqs_cis"] = (img_freqs, txt_freqs) + elif "img_ids" in inputs and "txt_ids" in inputs: + ids = torch.cat([inputs["txt_ids"], inputs["img_ids"]], dim=0) + inputs["freqs_cis"] = model.rotary_emb(ids) + elif inputs["hidden_states"].ndim == 5: + inputs["freqs_cis"] = _build_alias_rotary_freqs( + model, device, height, width + ) + + inputs["hook_compat"] = compat + return inputs + + +def _get_transformer_hook_compat(inputs: Inputs) -> TransformerHookCompat: + compat = inputs.get("hook_compat") + assert isinstance(compat, TransformerHookCompat) + return compat + + +def _supports_guidance_embedding(module: nn.Module) -> bool: + time_text_embed = getattr(module, "time_text_embed", None) + if time_text_embed is None: + return True + + parameters = list(inspect.signature(time_text_embed.forward).parameters.values()) + + if any(param.kind is inspect.Parameter.VAR_POSITIONAL for param in parameters): + return True + + accepted_args = [ + param + for param in parameters + if param.name != "self" + and param.kind + in ( + inspect.Parameter.POSITIONAL_ONLY, + inspect.Parameter.POSITIONAL_OR_KEYWORD, + inspect.Parameter.KEYWORD_ONLY, + ) + ] + return len(accepted_args) >= 3 + + +def _prepare_transformer_hook_call( + module: nn.Module, inputs: Inputs, side: str +) -> HookCall: + param_names = _forward_parameter_names(module) + signature = inspect.signature(module.forward) + compat = _get_transformer_hook_compat(inputs) + kwargs: Dict[str, Any] = {} + negate_output = side == "reference" and compat.negate_reference_output + + if "hidden_states" in param_names: + kwargs["hidden_states"] = inputs["hidden_states"] + if "x" in param_names: + kwargs["x"] = [inputs["hidden_states"].squeeze(0)] + if "encoder_hidden_states" in param_names: + encoder_value: Any = inputs["encoder_hidden_states"] + if ( + side == "sglang" + and "pooled_projections" in inputs + and "pooled_projections" not in param_names + and "encoder_attention_mask" not in param_names + ): + encoder_value = [ + inputs["encoder_hidden_states"], + inputs["pooled_projections"], + ] + kwargs["encoder_hidden_states"] = encoder_value + if "cap_feats" in param_names: + kwargs["cap_feats"] = [inputs["encoder_hidden_states"].squeeze(0)] + + if "timestep" in param_names: + timestep = inputs["timestep"] + if side == "reference" and compat.normalize_reference_timestep: + timestep = timestep / TIMESTEP_NORMALIZATION_FACTOR + kwargs["timestep"] = timestep + if "t" in param_names: + timestep = inputs["timestep"] + if side == "reference" and compat.normalize_reference_timestep: + timestep = timestep / TIMESTEP_NORMALIZATION_FACTOR + kwargs["t"] = timestep + + if "guidance" in param_names and "guidance" in inputs: + if side == "reference" and compat.omit_reference_guidance: + pass + else: + skip_guidance_for_image_context = ( + "encoder_hidden_states_image" in param_names + and "img_ids" not in param_names + and "img_shapes" not in param_names + ) + supports_guidance_embedding = _supports_guidance_embedding(module) + requires_guidance_arg = ( + signature.parameters["guidance"].default is inspect._empty + ) + should_include_guidance = ( + not skip_guidance_for_image_context and supports_guidance_embedding + ) + if should_include_guidance or requires_guidance_arg: + guidance_value = inputs["guidance"] + if side == "sglang": + guidance_value = guidance_value * TIMESTEP_NORMALIZATION_FACTOR + kwargs["guidance"] = guidance_value + + if ( + "encoder_hidden_states_image" in param_names + and "encoder_hidden_states_image" in inputs + ): + value = inputs["encoder_hidden_states_image"] + kwargs["encoder_hidden_states_image"] = [value] if side == "sglang" else value + + for key in ( + "pooled_projections", + "img_ids", + "txt_ids", + "img_shapes", + "txt_seq_lens", + "freqs_cis", + "additional_t_cond", + "encoder_attention_mask", + "encoder_hidden_states_mask", + ): + if key in param_names and key in inputs: + kwargs[key] = inputs[key] + + if "return_dict" in param_names: + kwargs["return_dict"] = True + + return HookCall(module=module, kwargs=kwargs, negate_output=negate_output) + + +def _prepare_transformer_sglang_call(module: nn.Module, inputs: Inputs) -> HookCall: + return _prepare_transformer_hook_call(module, inputs, side="sglang") + + +def _prepare_transformer_reference_call(module: nn.Module, inputs: Inputs) -> HookCall: + return _prepare_transformer_hook_call(module, inputs, side="reference") + + +class _VAEDecodeModule(nn.Module): + def __init__(self, vae: nn.Module): + super().__init__() + self.vae = vae + + def forward(self, z: torch.Tensor) -> torch.Tensor: + if ( + any( + isinstance(module, (nn.Conv3d, nn.ConvTranspose3d)) + for module in self.vae.modules() + ) + and z.ndim == 4 + ): + z = z.unsqueeze(2) + output = self.vae.decode(z) + tensor = output.sample if hasattr(output, "sample") else output + if isinstance(tensor, (list, tuple)): + tensor = tensor[0] + return tensor.squeeze(2) if tensor.ndim == 5 else tensor + + +def _infer_vae_latent_channels(model: nn.Module) -> int: + for path in ("post_quant_conv.in_channels", "post_quant_conv.conv.in_channels"): + value = _resolve_nested_attr(model, path) + if isinstance(value, int) and value > 0: + return value + return _read_config_value( + model, + [ + "z_dim", + "arch_config.z_dim", + "latent_channels", + "arch_config.latent_channels", + "num_channels_latents", + "arch_config.num_channels_latents", + "latent_dim", + "z_channels", + "arch_config.z_channels", + ], + default=DEFAULT_VAE_LATENT_CHANNELS, + ) + + +def _build_vae_hook_inputs( + case: Any, model: nn.Module, device: str, ref_model: Optional[nn.Module] = None +) -> Inputs: + del case, ref_model + latent_channels = _infer_vae_latent_channels(model) + rng = _DeterministicRNG() + return { + "z": rng.randn( + ( + 1, + latent_channels, + DEFAULT_VAE_LATENT_SPATIAL_SIZE, + DEFAULT_VAE_LATENT_SPATIAL_SIZE, + ), + device, + torch.bfloat16, + ) + } + + +def _prepare_vae_decode_call(module: nn.Module, inputs: Inputs) -> HookCall: + return HookCall(module=_VAEDecodeModule(module), args=(inputs["z"],)) + + +TRANSFORMER_NATIVE_PROFILE = NativeHookProfile( + build_inputs=_build_transformer_hook_inputs, + prepare_sglang_call=_prepare_transformer_sglang_call, + prepare_reference_call=_prepare_transformer_reference_call, +) + +VAE_NATIVE_PROFILE = NativeHookProfile( + build_inputs=_build_vae_hook_inputs, + prepare_sglang_call=_prepare_vae_decode_call, + prepare_reference_call=_prepare_vae_decode_call, +) + + +def resolve_component_native_profile(component: ComponentType) -> NativeHookProfile: + if component == ComponentType.TRANSFORMER: + return TRANSFORMER_NATIVE_PROFILE + if component == ComponentType.VAE: + return VAE_NATIVE_PROFILE + raise KeyError(f"Unsupported native accuracy component: {component.value}") diff --git a/python/sglang/multimodal_gen/test/server/accuracy_utils.py b/python/sglang/multimodal_gen/test/server/accuracy_utils.py new file mode 100644 index 000000000000..48ecace6f7aa --- /dev/null +++ b/python/sglang/multimodal_gen/test/server/accuracy_utils.py @@ -0,0 +1,880 @@ +from __future__ import annotations + +import json +import os +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple + +import torch +import torch.nn as nn +from safetensors.torch import load_file as safetensors_load_file +from torch.distributed.tensor import distribute_tensor + +from sglang.multimodal_gen.runtime.distributed.parallel_state import ( + destroy_model_parallel, + get_data_parallel_world_size, + get_sequence_parallel_world_size, + get_tensor_model_parallel_world_size, + maybe_init_distributed_environment_and_model_parallel, + model_parallel_is_initialized, +) +from sglang.multimodal_gen.runtime.layers.utils import get_group_rank, get_group_size +from sglang.multimodal_gen.runtime.server_args import ServerArgs +from sglang.multimodal_gen.runtime.utils.hf_diffusers_utils import maybe_download_model +from sglang.multimodal_gen.test.server.accuracy_config import ( + DEFAULT_TEXT_ENCODER_VOCAB_SIZE, + I2V_TEXT_ENCODER_DIM, + TEXT_ENCODER_INPUT_SEED, + TEXT_ENCODER_TOKEN_LENGTH, + TEXT_ENCODER_TOKEN_MAX, + TEXT_ENCODER_TOKEN_MIN, + ComponentType, + get_threshold, +) + +STAGED_1GPU_NATIVE_CASE_IDS = { + "flux_2_image_t2i", + "qwen_image_layered_i2i", + "flux_2_image_t2i_upscaling_4x", + "flux_2_ti2i", + "flux_2_t2i_customized_vae_path", + "flux_2_ti2i_multi_image_cache_dit", +} + +# These case allowlists are accuracy-runner policy. They select the few 1-GPU +# cases that need sequential SGLang/reference execution to stay within memory +# limits during CI and local correctness runs. +STAGED_1GPU_TEXT_ENCODER_CASE_IDS = { + "flux_2_image_t2i", + "flux_2_image_t2i_upscaling_4x", + "mova_360p_1gpu", + "flux_2_ti2i", + "flux_2_t2i_customized_vae_path", + "flux_2_ti2i_multi_image_cache_dit", +} + +SOURCE_PREFIXES = ( + "module.", + "model.", + "transformer.", + "text_encoder.", + "image_encoder.", + "encoder.", + "decoder.", + "model.language_model.", + "model.visual.", +) + +TARGET_PREFIXES = ( + "module.", + "model.", + "transformer.", + "text_encoder.", + "image_encoder.", + "encoder.", + "decoder.", +) + + +@dataclass(frozen=True) +class ComponentSelection: + base_model_id: str + base_model_root: str + component_paths: Dict[str, str] + source_root: str + source_path: str + source_subfolder: str + + +@dataclass(frozen=True) +class ParameterShardContext: + world_size: int + rank: int + + +def seed_and_broadcast(seed: int, tensor: torch.Tensor) -> torch.Tensor: + """Seed and broadcast tensor across ranks for determinism.""" + torch.manual_seed(seed) + if torch.distributed.is_initialized() and torch.distributed.get_world_size() > 1: + torch.distributed.broadcast(tensor, src=0) + return tensor + + +def read_json_file(path: str) -> Dict[str, Any]: + if not os.path.exists(path): + return {} + with open(path) as f: + return json.load(f) + + +def has_component_files(path: str) -> bool: + if not os.path.isdir(path): + return False + if os.path.exists(os.path.join(path, "config.json")): + return True + for ext in (".safetensors", ".bin", ".pth"): + if any(name.endswith(ext) for name in os.listdir(path)): + return True + return False + + +def list_safetensor_files(path: str) -> List[str]: + if not os.path.isdir(path): + return [] + return sorted( + os.path.join(path, name) + for name in os.listdir(path) + if name.endswith(".safetensors") + ) + + +def is_text_encoder_config(path: str) -> bool: + cfg_path = os.path.join(path, "config.json") + if not os.path.exists(cfg_path): + return False + cfg = read_json_file(cfg_path) + if cfg.get("model_type") == "i2v" or cfg.get("dim") == I2V_TEXT_ENCODER_DIM: + return False + return True + + +def _resolve_component_subfolder( + model_index: Dict[str, Any], key: str +) -> Optional[str]: + entry = model_index.get(key) + if isinstance(entry, dict): + return entry.get("path") or entry.get("subfolder") + if isinstance(entry, str): + return entry + if entry is not None: + return key + return None + + +def resolve_component_path( + local_root: str, component: ComponentType, model_index_keys: Tuple[str, ...] +) -> Tuple[str, str]: + model_index_path = os.path.join(local_root, "model_index.json") + model_index = read_json_file(model_index_path) + + if model_index: + for key in model_index_keys: + subfolder = _resolve_component_subfolder(model_index, key) + if not subfolder: + continue + candidate = os.path.join(local_root, subfolder) + if not has_component_files(candidate): + continue + if component == ComponentType.TEXT_ENCODER and not is_text_encoder_config( + candidate + ): + continue + return candidate, subfolder + + if has_component_files(local_root): + if component != ComponentType.TEXT_ENCODER or is_text_encoder_config( + local_root + ): + return local_root, "" + + raise FileNotFoundError( + f"Could not resolve {component.value} from model_index.json under {local_root}" + ) + + +def extract_component_path_overrides(extra_args: List[str]) -> Dict[str, str]: + component_paths: Dict[str, str] = {} + index = 0 + while index < len(extra_args): + arg = extra_args[index] + key_part = arg.split("=", 1)[0] if "=" in arg else arg + if key_part.startswith("--") and key_part.endswith("-path"): + component = key_part[2:-5].replace("-", "_") + if "=" in arg: + component_paths[component] = arg.split("=", 1)[1] + elif index + 1 < len(extra_args) and not extra_args[index + 1].startswith( + "-" + ): + index += 1 + component_paths[component] = extra_args[index] + index += 1 + + for component, path in component_paths.items(): + component_paths[component] = os.path.expanduser(path) + return component_paths + + +def load_checkpoint_weights( + module: nn.Module, model_path: str +) -> tuple[list[str], list[str]]: + safetensors_files = list_safetensor_files(model_path) + assert safetensors_files, f"Found no safetensors files in {model_path}" + + loaded_state: Dict[str, torch.Tensor] = {} + for safetensor_path in safetensors_files: + loaded_state.update(safetensors_load_file(safetensor_path)) + + module.load_state_dict(loaded_state, strict=False) + + state_keys = set(module.state_dict().keys()) + loaded_keys = set(loaded_state.keys()) + missing_keys = sorted(state_keys - loaded_keys) + unexpected_keys = sorted(loaded_keys - state_keys) + return missing_keys, unexpected_keys + + +def select_component_source( + model_id: str, + extra_args: List[str], + component: ComponentType, + model_index_keys: Tuple[str, ...], +) -> ComponentSelection: + component_paths = extract_component_path_overrides(extra_args) + base_model_root = maybe_download_model(model_id) + search_keys = [component.value] + for key in model_index_keys: + if key not in search_keys: + search_keys.append(key) + + source_root = base_model_root + component_key = component.value + for key in search_keys: + override_path = component_paths.get(key) + if override_path: + source_root = maybe_download_model(override_path) + component_key = key + break + + ordered_keys = [component_key] + for key in search_keys: + if key not in ordered_keys: + ordered_keys.append(key) + source_path, source_subfolder = resolve_component_path( + source_root, + component, + tuple(ordered_keys), + ) + return ComponentSelection( + base_model_id=model_id, + base_model_root=base_model_root, + component_paths=component_paths, + source_root=source_root, + source_path=source_path, + source_subfolder=source_subfolder, + ) + + +def ensure_distributed_env_defaults() -> None: + if "WORLD_SIZE" in os.environ: + return + os.environ.update( + { + "MASTER_ADDR": os.getenv("MASTER_ADDR", "127.0.0.1"), + "MASTER_PORT": os.getenv("MASTER_PORT", "29505"), + "RANK": "0", + "LOCAL_RANK": "0", + "WORLD_SIZE": "1", + } + ) + + +def initialize_parallel_runtime(sgl_args: ServerArgs) -> None: + tp_size = sgl_args.tp_size + sp_degree = sgl_args.sp_degree + ulysses_degree = sgl_args.ulysses_degree + ring_degree = sgl_args.ring_degree + dp_size = sgl_args.dp_size + enable_cfg_parallel = bool(sgl_args.enable_cfg_parallel) + + if ( + tp_size is None + or sp_degree is None + or ulysses_degree is None + or ring_degree is None + ): + raise RuntimeError( + "ServerArgs must have tp_size, sp_degree, ulysses_degree, and ring_degree before init" + ) + + if not model_parallel_is_initialized() and torch.distributed.is_initialized(): + # A prior case may have failed while distributed groups were only partially + # initialized. Clear any stale group objects before re-initializing. + destroy_model_parallel() + + if model_parallel_is_initialized(): + current_tp = get_tensor_model_parallel_world_size() + current_sp = get_sequence_parallel_world_size() + current_dp = get_data_parallel_world_size() + if current_tp == tp_size and current_sp == sp_degree and current_dp == dp_size: + return + if torch.distributed.is_initialized(): + torch.distributed.barrier() + destroy_model_parallel() + + ensure_distributed_env_defaults() + + maybe_init_distributed_environment_and_model_parallel( + tp_size=tp_size, + sp_size=sp_degree, + enable_cfg_parallel=enable_cfg_parallel, + ulysses_degree=ulysses_degree, + ring_degree=ring_degree, + dp_size=dp_size, + ) + if torch.distributed.is_initialized(): + torch.distributed.barrier() + + +def build_accuracy_server_args( + base_model_id: str, + base_model_root: str, + case: Any, + component: ComponentType, + num_gpus: int, + component_paths: Dict[str, str], +) -> ServerArgs: + cfg_parallel = bool(case.server_args.cfg_parallel) + kwargs = { + "model_path": base_model_root, + "model_id": base_model_id, + "num_gpus": num_gpus, + "trust_remote_code": True, + "component_paths": component_paths, + "enable_cfg_parallel": cfg_parallel, + } + + if case.server_args.tp_size is not None: + kwargs["tp_size"] = case.server_args.tp_size + if case.server_args.ulysses_degree is not None: + kwargs["ulysses_degree"] = case.server_args.ulysses_degree + if case.server_args.ring_degree is not None: + kwargs["ring_degree"] = case.server_args.ring_degree + + if component == ComponentType.TEXT_ENCODER: + kwargs["enable_cfg_parallel"] = False + + sgl_args = ServerArgs.from_kwargs(**kwargs) + sgl_args.text_encoder_cpu_offload = False + sgl_args.dit_cpu_offload = False + sgl_args.vae_cpu_offload = False + sgl_args.image_encoder_cpu_offload = False + sgl_args.enable_cache_dit = case.server_args.enable_cache_dit + sgl_args.dit_layerwise_offload = case.server_args.dit_layerwise_offload + sgl_args.dit_offload_prefetch_size = case.server_args.dit_offload_prefetch_size + return sgl_args + + +def set_module_attr(module: nn.Module, name: str, value: Any) -> None: + """Assign to a nested parameter/buffer path such as `blocks.0.attn.to_q.weight`.""" + attrs = name.split(".") + parent = module + for attr in attrs[:-1]: + if hasattr(parent, attr): + parent = getattr(parent, attr) + elif isinstance(parent, (nn.ModuleList, nn.Sequential)): + parent = parent[int(attr)] + elif isinstance(parent, nn.ModuleDict): + parent = parent[attr] + else: + raise AttributeError( + f"Cannot resolve {name} on {module.__class__.__name__}" + ) + setattr(parent, attrs[-1], value) + + +def materialize_module( + module: nn.Module, device: torch.device, dtype: torch.dtype +) -> None: + """Materialize meta tensors and cast floating tensors onto one target device/dtype.""" + for name, param in module.named_parameters(): + if param.device.type == "meta": + new_data = torch.zeros(param.shape, device=device, dtype=dtype) + if hasattr(param, "device_mesh") and param.device_mesh is not None: + new_data = distribute_tensor( + new_data, param.device_mesh, param.placements + ) + set_module_attr( + module, name, nn.Parameter(new_data, requires_grad=param.requires_grad) + ) + elif torch.is_floating_point(param): + param.data = param.data.to(device=device, dtype=dtype) + + for name, buf in module.named_buffers(): + if buf.device.type == "meta": + new_buf = torch.zeros(buf.shape, device=device, dtype=buf.dtype) + if hasattr(buf, "device_mesh") and buf.device_mesh is not None: + new_buf = distribute_tensor(new_buf, buf.device_mesh, buf.placements) + set_module_attr(module, name, new_buf) + elif torch.is_floating_point(buf): + buf.data = buf.data.to(device=device, dtype=dtype) + + +def build_parameter_shard_contexts( + module: nn.Module, +) -> Dict[str, ParameterShardContext]: + """Record TP shard world/rank for each parameter owned by a TP-aware submodule.""" + shard_contexts: Dict[str, ParameterShardContext] = {} + for module_name, submodule in module.named_modules(): + tp_group = getattr(submodule, "tp_group", None) + if tp_group is None: + continue + + context = ParameterShardContext( + world_size=get_group_size(tp_group), + rank=get_group_rank(tp_group), + ) + if context.world_size <= 1: + continue + + for name, _ in submodule.named_parameters(recurse=False): + qualified_name = f"{module_name}.{name}" if module_name else name + shard_contexts[qualified_name] = context + for name, _ in submodule.named_buffers(recurse=False): + qualified_name = f"{module_name}.{name}" if module_name else name + shard_contexts[qualified_name] = context + + return shard_contexts + + +def build_state_lookup(state: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + """Index a source state dict under both original and prefix-stripped names.""" + lookup: Dict[str, torch.Tensor] = {} + for key, val in state.items(): + lookup[key] = val + for prefix in SOURCE_PREFIXES: + if key.startswith(prefix): + lookup[key[len(prefix) :]] = val + return lookup + + +def normalize_state_key(name: str) -> str: + """Normalize common naming differences between source and target state dicts.""" + return ( + name.replace("_fsdp_wrapped_module.", "") + .replace("_orig_mod.", "") + .replace("gamma", "weight") + .replace("beta", "bias") + .replace("scale", "weight") + .replace("shift", "bias") + ) + + +def fuse_qkv(lookup: Dict[str, torch.Tensor], name: str) -> Optional[torch.Tensor]: + if "qkv_proj" not in name: + return None + variants = ["q_proj", "q"] + for repl in variants: + q_name = name.replace("qkv_proj", repl) + k_name = q_name.replace(".q_proj", ".k_proj").replace(".q", ".k") + v_name = q_name.replace(".q_proj", ".v_proj").replace(".q", ".v") + if q_name in lookup and k_name in lookup and v_name in lookup: + return torch.cat([lookup[q_name], lookup[k_name], lookup[v_name]], dim=0) + return None + + +def fuse_gate_up_proj( + lookup: Dict[str, torch.Tensor], name: str +) -> Optional[torch.Tensor]: + if "gate_up_proj" not in name: + return None + + for gate_token, up_token in (("gate_proj", "up_proj"), ("wi_0", "wi_1")): + gate_name = name.replace("gate_up_proj", gate_token) + up_name = name.replace("gate_up_proj", up_token) + if gate_name in lookup and up_name in lookup: + return torch.cat([lookup[gate_name], lookup[up_name]], dim=0) + return None + + +def generate_name_candidates( + name: str, reverse_mapping: Optional[Dict[str, Tuple[str, Any, Any]]] +) -> List[str]: + candidates: List[str] = [] + clean = normalize_state_key(name) + + for cand in (name, clean): + if cand not in candidates: + candidates.append(cand) + + if reverse_mapping: + for key in (name, clean): + entry = reverse_mapping.get(key) + if entry and entry[0] not in candidates: + candidates.append(entry[0]) + + for prefix in TARGET_PREFIXES: + if clean.startswith(prefix): + stripped = clean[len(prefix) :] + if stripped and stripped not in candidates: + candidates.append(stripped) + + parts = clean.split(".") + for i in range(1, len(parts)): + cand = ".".join(parts[i:]) + if cand not in candidates: + candidates.append(cand) + + return candidates + + +def copy_tensor( + dest: torch.Tensor, + src: torch.Tensor, + tp_world: int, + rank: int, +) -> bool: + if src.numel() == 0: + return False + src = src.to(device=dest.device, dtype=dest.dtype) + + if hasattr(dest, "device_mesh") and dest.device_mesh is not None: + if src.numel() == dest.numel(): + with torch.no_grad(): + dt = distribute_tensor( + src.view(dest.shape), dest.device_mesh, dest.placements + ) + dest.copy_(dt) + return True + + if src.numel() == dest.numel(): + with torch.no_grad(): + dest.copy_(src.view(dest.shape)) + return True + + if tp_world > 1 and src.numel() == dest.numel() * tp_world: + if ( + src.ndim == dest.ndim + and src.shape[0] == dest.shape[0] * tp_world + and src.shape[1:] == dest.shape[1:] + ): + with torch.no_grad(): + dest.copy_(src[rank * dest.shape[0] : (rank + 1) * dest.shape[0], ...]) + return True + if ( + src.ndim >= 2 + and dest.ndim >= 2 + and src.ndim == dest.ndim + and src.shape[0] == dest.shape[0] + and src.shape[1] == dest.shape[1] * tp_world + and src.shape[2:] == dest.shape[2:] + ): + with torch.no_grad(): + dest.copy_(src[:, rank * dest.shape[1] : (rank + 1) * dest.shape[1]]) + return True + + if src.ndim == 4 and dest.ndim == 5 and dest.numel() == src.numel() * dest.shape[2]: + with torch.no_grad(): + dest.copy_(src.unsqueeze(2).repeat(1, 1, dest.shape[2], 1, 1)) + return True + + return False + + +def _config_to_dict(config: Any) -> Dict[str, Any]: + to_dict = getattr(config, "to_dict", None) + if not callable(to_dict): + return {} + config_dict = to_dict() + return config_dict if isinstance(config_dict, dict) else {} + + +def resolve_text_encoder_vocab_size(config: Any) -> int: + config_dict = _config_to_dict(config) + vocab_size = config_dict.get("vocab_size") + if isinstance(vocab_size, int) and vocab_size > 0: + return vocab_size + + text_config = config_dict.get("text_config") + if isinstance(text_config, dict): + nested_vocab_size = text_config.get("vocab_size") + if isinstance(nested_vocab_size, int) and nested_vocab_size > 0: + return nested_vocab_size + + return DEFAULT_TEXT_ENCODER_VOCAB_SIZE + + +def build_deterministic_text_encoder_inputs( + config: Any, device: str +) -> tuple[torch.Tensor, torch.Tensor]: + """Build one stable token batch that works across text-encoder implementations.""" + vocab_size = resolve_text_encoder_vocab_size(config) + max_token_id = max( + TEXT_ENCODER_TOKEN_MIN + 1, min(vocab_size, TEXT_ENCODER_TOKEN_MAX) + ) + + torch.manual_seed(TEXT_ENCODER_INPUT_SEED) + input_ids = torch.randint( + TEXT_ENCODER_TOKEN_MIN, + max_token_id, + (1, TEXT_ENCODER_TOKEN_LENGTH), + device="cpu", + dtype=torch.long, + ).to(device) + attention_mask = torch.ones_like(input_ids) + return input_ids, attention_mask + + +def resolve_text_encoder_forward_module(model: nn.Module) -> nn.Module: + get_encoder = getattr(model, "get_encoder", None) + return get_encoder() if callable(get_encoder) else model + + +def _module_device(module: nn.Module) -> torch.device: + param = next(module.parameters(), None) + if param is not None: + return param.device + + buf = next(module.buffers(), None) + if buf is not None: + return buf.device + + return torch.device("cpu") + + +def extract_output_tensor(output: Any) -> torch.Tensor: + """Best-effort extraction of a tensor from model outputs.""" + if isinstance(output, torch.Tensor): + return output + + sample = getattr(output, "sample", None) + if sample is not None: + if isinstance(sample, (list, tuple)): + sample = sample[0] + if isinstance(sample, torch.Tensor): + return sample + + last_hidden_state = getattr(output, "last_hidden_state", None) + if last_hidden_state is not None: + return last_hidden_state + + hidden_states = getattr(output, "hidden_states", None) + if hidden_states: + return hidden_states[-1] + + pooler_output = getattr(output, "pooler_output", None) + if pooler_output is not None: + return pooler_output + + logits = getattr(output, "logits", None) + if logits is not None: + return logits + + if ( + isinstance(output, (list, tuple)) + and output + and isinstance(output[0], torch.Tensor) + ): + return output[0] + raise ValueError(f"Could not extract tensor from output of type {type(output)}") + + +def run_text_encoder_accuracy_pair( + sgl: nn.Module, ref: nn.Module +) -> tuple[torch.Tensor, torch.Tensor]: + input_ids, attention_mask = build_deterministic_text_encoder_inputs( + ref.config, "cpu" + ) + return ( + _run_single_text_encoder_forward(sgl, input_ids, attention_mask), + _run_single_text_encoder_forward(ref, input_ids, attention_mask), + ) + + +def _should_stage_case(case: Any, component: ComponentType, num_gpus: int) -> bool: + if num_gpus == 2: + return True + if num_gpus != 1: + return False + if component == ComponentType.TEXT_ENCODER: + return case.id in STAGED_1GPU_TEXT_ENCODER_CASE_IDS + return case.id in STAGED_1GPU_NATIVE_CASE_IDS + + +def _run_single_text_encoder_forward( + model: nn.Module, input_ids: torch.Tensor, attention_mask: torch.Tensor +) -> torch.Tensor: + """Run one encoder forward and normalize its output into a tensor.""" + with torch.no_grad(): + forward_model = resolve_text_encoder_forward_module(model) + model_device = _module_device(forward_model) + output = forward_model( + input_ids.to(device=model_device), + attention_mask=attention_mask.to(device=model_device), + output_hidden_states=True, + ) + return extract_output_tensor(output) + + +def _run_staged_native_component_accuracy_case( + engine_cls: Any, + case: Any, + component: ComponentType, + library: str, + num_gpus: int, +) -> None: + from sglang.multimodal_gen.test.server.accuracy_hooks import ( + resolve_component_native_profile, + ) + + sgl = None + ref = None + try: + sgl, ref, device = engine_cls.load_component_pair( + case, + component, + library, + num_gpus, + materialize_ref_on_device=False, + ) + profile = resolve_component_native_profile(component) + inputs = profile.build_inputs(case, sgl, device, ref) + + sgl_call = profile.prepare_sglang_call(sgl, inputs) + with torch.no_grad(): + sgl_raw = engine_cls._execute_with_native_hook(sgl_call) + sgl_out = profile.normalize_sglang_output(sgl_raw) + sgl_out = engine_cls._apply_output_transforms(sgl_out, sgl_call).detach().cpu() + + del sgl_call + del sgl_raw + del sgl + sgl = None + engine_cls.clear_memory() + + ref = ref.to(device=device, dtype=torch.bfloat16).eval() + ref_call = profile.prepare_reference_call(ref, inputs) + with torch.no_grad(): + ref_raw = engine_cls._execute_with_native_hook(ref_call) + ref_out = profile.normalize_reference_output(ref_raw) + ref_out = engine_cls._apply_output_transforms(ref_out, ref_call).detach().cpu() + del ref_call + del ref_raw + + engine_cls.check_accuracy( + sgl_out, + ref_out, + f"{case.id}_{component.value}", + get_threshold(case.id, component), + ) + finally: + if sgl is not None: + del sgl + if ref is not None: + del ref + engine_cls.reset_parallel_runtime() + engine_cls.clear_memory() + + +def _run_staged_text_encoder_accuracy_case( + engine_cls: Any, case: Any, num_gpus: int +) -> None: + sgl = None + ref = None + try: + sgl, ref, device = engine_cls.load_component_pair( + case, + ComponentType.TEXT_ENCODER, + "transformers", + num_gpus, + materialize_sgl_on_device=False, + materialize_ref_on_device=False, + ) + input_ids, attention_mask = build_deterministic_text_encoder_inputs( + ref.config, "cpu" + ) + + sgl = sgl.to(device=device, dtype=torch.bfloat16).eval() + sgl_out = ( + _run_single_text_encoder_forward(sgl, input_ids, attention_mask) + .detach() + .cpu() + ) + + del sgl + sgl = None + engine_cls.clear_memory() + + ref = ref.to(device=device, dtype=torch.bfloat16).eval() + ref_out = ( + _run_single_text_encoder_forward(ref, input_ids, attention_mask) + .detach() + .cpu() + ) + + engine_cls.check_accuracy( + sgl_out, + ref_out, + f"{case.id}_encoder", + get_threshold(case.id, ComponentType.TEXT_ENCODER), + ) + finally: + if sgl is not None: + del sgl + if ref is not None: + del ref + engine_cls.reset_parallel_runtime() + engine_cls.clear_memory() + + +def run_native_component_accuracy_case( + engine_cls: Any, + case: Any, + component: ComponentType, + library: str, + num_gpus: int, +) -> None: + if _should_stage_case(case, component, num_gpus): + _run_staged_native_component_accuracy_case( + engine_cls, case, component, library, num_gpus + ) + return + engine_cls.clear_memory() + sgl = None + ref = None + try: + sgl, ref, device = engine_cls.load_component_pair( + case, component, library, num_gpus + ) + sgl_out, ref_out = engine_cls.run_component_pair_native( + case, component, sgl, ref, device + ) + engine_cls.check_accuracy( + sgl_out, + ref_out, + f"{case.id}_{component.value}", + get_threshold(case.id, component), + ) + finally: + if sgl is not None: + del sgl + if ref is not None: + del ref + engine_cls.reset_parallel_runtime() + engine_cls.clear_memory() + + +def run_text_encoder_accuracy_case(engine_cls: Any, case: Any, num_gpus: int) -> None: + if _should_stage_case(case, ComponentType.TEXT_ENCODER, num_gpus): + _run_staged_text_encoder_accuracy_case(engine_cls, case, num_gpus) + return + engine_cls.clear_memory() + sgl = None + ref = None + try: + sgl, ref, _device = engine_cls.load_component_pair( + case, ComponentType.TEXT_ENCODER, "transformers", num_gpus + ) + sgl_out, ref_out = run_text_encoder_accuracy_pair(sgl, ref) + engine_cls.check_accuracy( + sgl_out, + ref_out, + f"{case.id}_encoder", + get_threshold(case.id, ComponentType.TEXT_ENCODER), + ) + finally: + if sgl is not None: + del sgl + if ref is not None: + del ref + engine_cls.reset_parallel_runtime() + engine_cls.clear_memory() diff --git a/python/sglang/multimodal_gen/test/server/ascend/perf_baselines_npu.json b/python/sglang/multimodal_gen/test/server/ascend/perf_baselines_npu.json index e1e721de52c0..b4703c95aabf 100644 --- a/python/sglang/multimodal_gen/test/server/ascend/perf_baselines_npu.json +++ b/python/sglang/multimodal_gen/test/server/ascend/perf_baselines_npu.json @@ -257,6 +257,71 @@ "expected_e2e_ms": 91733.92, "expected_avg_denoise_ms": 2091.33, "expected_median_denoise_ms": 2090.72 + }, + "qwen_image_t2i_2npu": { + "stages_ms": { + "InputValidationStage": 0.07, + "TextEncodingStage": 629.24, + "LatentPreparationStage": 0.69, + "TimestepPreparationStage": 35.29, + "DenoisingStage": 30529.83, + "DecodingStage": 74.25 + }, + "denoise_step_ms": { + "0": 477.43, + "1": 511.96, + "2": 607.78, + "3": 615.12, + "4": 616.29, + "5": 614.61, + "6": 623.04, + "7": 607.12, + "8": 615.32, + "9": 615.47, + "10": 616.93, + "11": 623.26, + "12": 607.12, + "13": 615.48, + "14": 615.07, + "15": 614.83, + "16": 623.18, + "17": 609.0, + "18": 614.8, + "19": 623.08, + "20": 607.64, + "21": 614.2, + "22": 615.58, + "23": 615.43, + "24": 623.59, + "25": 606.57, + "26": 616.02, + "27": 615.48, + "28": 615.76, + "29": 623.13, + "30": 608.73, + "31": 615.04, + "32": 616.08, + "33": 616.59, + "34": 623.77, + "35": 608.0, + "36": 616.1, + "37": 615.79, + "38": 615.34, + "39": 617.43, + "40": 610.99, + "41": 614.22, + "42": 623.27, + "43": 606.98, + "44": 615.87, + "45": 615.99, + "46": 614.66, + "47": 622.93, + "48": 607.97, + "49": 614.69 + }, + "expected_e2e_ms": 34362.34, + "expected_avg_denoise_ms": 610.41, + "expected_median_denoise_ms": 615.39 } } } diff --git a/python/sglang/multimodal_gen/test/server/ascend/testcase_configs_npu.py b/python/sglang/multimodal_gen/test/server/ascend/testcase_configs_npu.py index 3c65bf894d9f..0a78eb35fa15 100644 --- a/python/sglang/multimodal_gen/test/server/ascend/testcase_configs_npu.py +++ b/python/sglang/multimodal_gen/test/server/ascend/testcase_configs_npu.py @@ -42,6 +42,18 @@ ), T2I_sampling_params, ), + DiffusionTestCase( + "qwen_image_t2i_2npu", + DiffusionServerArgs( + model_path="/root/.cache/modelscope/hub/models/Qwen/Qwen-Image", + modality="image", + num_gpus=2, + # test ring attn + ulysses_degree=1, + ring_degree=2, + ), + T2I_sampling_params, + ), ] EIGHT_NPU_CASES: list[DiffusionTestCase] = [ diff --git a/python/sglang/multimodal_gen/test/server/component_accuracy.py b/python/sglang/multimodal_gen/test/server/component_accuracy.py new file mode 100644 index 000000000000..ba59063c081e --- /dev/null +++ b/python/sglang/multimodal_gen/test/server/component_accuracy.py @@ -0,0 +1,534 @@ +from __future__ import annotations + +import gc +import os +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple + +import diffusers +import torch +import torch.nn as nn +from transformers import ( + AutoConfig, + AutoModel, + AutoModelForCausalLM, + T5EncoderModel, + UMT5EncoderModel, +) + +try: + from transformers import AutoModelForImageTextToText as AutoVisionTextModel +except ImportError: + try: + from transformers import AutoModelForVision2Seq as AutoVisionTextModel + except ImportError: + AutoVisionTextModel = None + +import sglang.multimodal_gen.runtime.managers.forward_context as fc_mod +from sglang.multimodal_gen.runtime.distributed.parallel_state import ( + destroy_model_parallel, + get_local_torch_device, + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + model_parallel_is_initialized, +) +from sglang.multimodal_gen.runtime.loader.component_loaders.component_loader import ( + ComponentLoader, +) +from sglang.multimodal_gen.runtime.loader.utils import ( + get_param_names_mapping, + hf_to_custom_state_dict, +) +from sglang.multimodal_gen.runtime.managers.forward_context import ForwardContext +from sglang.multimodal_gen.runtime.models.vaes.wanvae import AutoencoderKLWan +from sglang.multimodal_gen.runtime.server_args import ServerArgs, set_global_server_args +from sglang.multimodal_gen.runtime.utils.hf_diffusers_utils import ( + get_diffusers_component_config, +) +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger +from sglang.multimodal_gen.test.server.accuracy_config import ( + DEFAULT_TIMESTEP, + ComponentType, +) +from sglang.multimodal_gen.test.server.accuracy_hooks import ( + resolve_component_native_profile, +) +from sglang.multimodal_gen.test.server.accuracy_utils import ( + build_accuracy_server_args, + build_parameter_shard_contexts, + build_state_lookup, + copy_tensor, + fuse_gate_up_proj, + fuse_qkv, + generate_name_candidates, + initialize_parallel_runtime, + load_checkpoint_weights, + materialize_module, + read_json_file, + resolve_text_encoder_forward_module, + select_component_source, +) +from sglang.multimodal_gen.test.server.testcase_configs import DiffusionTestCase + +logger = init_logger(__name__) + +MIN_MATCH_RATIO = float(os.getenv("SGLANG_DIFFUSION_WEIGHT_MATCH_RATIO", "0.98")) + + +@dataclass(frozen=True) +class ComponentSpec: + model_index_keys: Tuple[str, ...] + reference_library: str + + +COMPONENT_SPECS: Dict[ComponentType, ComponentSpec] = { + ComponentType.VAE: ComponentSpec( + model_index_keys=( + "vae", + "vae_model", + "autoencoder", + "autoencoder_kl", + "video_vae", + "audio_vae", + ), + reference_library="diffusers", + ), + ComponentType.TRANSFORMER: ComponentSpec( + model_index_keys=("transformer", "unet", "dit", "video_dit", "audio_dit"), + reference_library="diffusers", + ), + ComponentType.TEXT_ENCODER: ComponentSpec( + model_index_keys=( + "text_encoder", + "text_encoder_2", + "text_encoder_3", + "image_encoder", + ), + reference_library="transformers", + ), +} + + +# Component loading helpers +def _load_sglang_component( + comp_path: str, + sgl_args: ServerArgs, + component: ComponentType, + library: str, + text_encoder_cpu_offload: bool | None = None, +) -> nn.Module: + loader = ComponentLoader.for_component_type(component.value, library) + if component == ComponentType.TEXT_ENCODER: + component_model = loader.load_customized( + comp_path, + sgl_args, + component.value, + cpu_offload_flag=text_encoder_cpu_offload, + ) + else: + component_model = loader.load_customized(comp_path, sgl_args, component.value) + if component_model is None: + raise RuntimeError(f"Failed to load customized {component.value}") + return component_model + + +def _load_wan_reference_vae(comp_path: str, pipeline_config) -> nn.Module: + vae_config = pipeline_config.vae_config + vae_config.update_model_arch( + get_diffusers_component_config(component_path=comp_path) + ) + if hasattr(vae_config, "post_init"): + vae_config.post_init() + + vae = AutoencoderKLWan(vae_config) + missing_keys, unexpected_keys = load_checkpoint_weights(vae, comp_path) + if missing_keys: + logger.warning("WAN VAE missing keys: %s", missing_keys) + if unexpected_keys: + logger.warning("WAN VAE unexpected keys: %s", unexpected_keys) + return vae + + +def _load_reference_component( + comp_path: str, + source_root: str, + component: ComponentType, + hub_id: str, + pipeline_config, + subfolder: str, +) -> nn.Module: + # WAN VAE does not have a clean generic diffusers auto-load path here, and we + # explicitly need checkpoint-loaded weights for reference-side transfer/parity. + if component == ComponentType.VAE and "wan" in hub_id.lower(): + return _load_wan_reference_vae(comp_path, pipeline_config) + + if component == ComponentType.VAE: + cfg = read_json_file(os.path.join(comp_path, "config.json")) + class_name = cfg.get("_class_name") if cfg else None + cls = getattr(diffusers, str(class_name), None) if class_name else None + if cls is None: + cls = diffusers.AutoencoderKL + return cls.from_pretrained( + source_root, + subfolder=subfolder, + torch_dtype=torch.bfloat16, + trust_remote_code=True, + ) + + if component == ComponentType.TRANSFORMER: + cfg = read_json_file(os.path.join(comp_path, "config.json")) + class_name = cfg.get("_class_name") if cfg else None + load_kwargs: Dict[str, Any] = { + "torch_dtype": torch.bfloat16, + "trust_remote_code": True, + } + if cfg: + for k, out_k in [ + ("in_dim", "in_channels"), + ("dim", "hidden_size"), + ("num_heads", "num_attention_heads"), + ("out_dim", "out_channels"), + ]: + if k in cfg: + load_kwargs[out_k] = cfg[k] + candidates = [diffusers.AutoModel] + if class_name: + maybe_cls = getattr(diffusers, str(class_name), None) + if maybe_cls is not None: + candidates.insert(0, maybe_cls) + last_error: Optional[Exception] = None + for cls in candidates: + try: + return cls.from_pretrained(comp_path, **load_kwargs) + except Exception as exc: + last_error = exc + raise RuntimeError(f"Failed to load transformer from {comp_path}: {last_error}") + + if component == ComponentType.TEXT_ENCODER: + config = AutoConfig.from_pretrained(comp_path, trust_remote_code=True) + kwargs = { + "torch_dtype": torch.bfloat16, + "trust_remote_code": True, + "config": config, + } + architectures = tuple(getattr(config, "architectures", ()) or ()) + if ( + "UMT5EncoderModel" in architectures + or getattr(config, "model_type", None) == "umt5" + ): + class_order = [UMT5EncoderModel, AutoModel, AutoModelForCausalLM] + else: + class_order = [ + AutoModel, + AutoModelForCausalLM, + UMT5EncoderModel, + T5EncoderModel, + ] + if AutoVisionTextModel is not None: + class_order.append(AutoVisionTextModel) + last_error: Optional[Exception] = None + for cls in class_order: + try: + return cls.from_pretrained(comp_path, **kwargs) + except Exception as exc: + last_error = exc + raise RuntimeError( + f"Failed to load text encoder from {comp_path}: {last_error}" + ) + + raise RuntimeError(f"Unsupported component {component.value}") + + +# Public accuracy engine +class AccuracyEngine: + @staticmethod + def reset_parallel_runtime() -> None: + if torch.distributed.is_initialized(): + torch.distributed.barrier() + if model_parallel_is_initialized(): + destroy_model_parallel() + gc.collect() + if torch.cuda.is_available(): + torch.cuda.synchronize() + torch.cuda.empty_cache() + torch.cuda.ipc_collect() + + @staticmethod + def clear_memory() -> None: + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + @staticmethod + def _execute_with_native_hook(call) -> Any: + output: Any = None + + def _hook(_: nn.Module, __: tuple[Any, ...], captured: Any) -> None: + nonlocal output + output = captured + + handle = call.module.register_forward_hook(_hook) + try: + call.module(*call.args, **call.kwargs) + finally: + handle.remove() + assert output is not None + return output + + @staticmethod + def _apply_output_transforms(tensor: torch.Tensor, call) -> torch.Tensor: + if call.negate_output: + return -tensor + return tensor + + @staticmethod + def check_accuracy( + target: torch.Tensor, reference: torch.Tensor, name: str, threshold: float + ) -> None: + full_tensor = getattr(target, "full_tensor", None) + if callable(full_tensor): + target = full_tensor() + t, r = target.detach().cpu().float(), reference.detach().cpu().float() + + logger.info( + "[%s] Shape: SGL=%s, REF=%s | NaNs: SGL=%s, REF=%s", + name, + list(t.shape), + list(r.shape), + torch.isnan(t).sum(), + torch.isnan(r).sum(), + ) + + if t.shape != r.shape: + if t.ndim == 5 and t.shape[2] == 1: + t = t.squeeze(2) + if r.ndim == 5 and r.shape[2] == 1: + r = r.squeeze(2) + if t.shape != r.shape: + raise RuntimeError( + f"Accuracy shape mismatch for {name}: {list(t.shape)} vs {list(r.shape)}" + ) + + cos_sim = torch.nn.functional.cosine_similarity( + t.reshape(-1), r.reshape(-1), dim=0 + ).item() + rank = torch.distributed.get_rank() if torch.distributed.is_initialized() else 0 + logger.info("[%s] Rank %s CosSim=%.6f", name, rank, cos_sim) + assert ( + cos_sim > threshold + ), f"Accuracy failure in {name}: CosSim {cos_sim:.4f} < {threshold}" + + @staticmethod + def transfer_weights( + source: nn.Module, + target: nn.Module, + min_match_ratio: float = MIN_MATCH_RATIO, + target_device: Optional[torch.device] = None, + ) -> None: + device = target_device or get_local_torch_device() + dtype = torch.bfloat16 + materialize_module(target, device, dtype) + + source_state = source.state_dict() + mapping = getattr(target, "param_names_mapping", None) or getattr( + getattr(target, "module", None), "param_names_mapping", None + ) + if mapping: + source_state, _ = hf_to_custom_state_dict( + source_state, get_param_names_mapping(mapping) + ) + + lookup = build_state_lookup(source_state) + reverse_mapping = getattr( + target, "reverse_param_names_mapping", None + ) or getattr( + getattr(target, "module", None), "reverse_param_names_mapping", None + ) + tp_world = ( + get_tensor_model_parallel_world_size() + if model_parallel_is_initialized() + else 1 + ) + rank = ( + get_tensor_model_parallel_rank() if model_parallel_is_initialized() else 0 + ) + shard_contexts = build_parameter_shard_contexts(target) + + matched = 0 + total = 0 + unmatched_details: List[str] = [] + for name, tensor in target.named_parameters(): + total += 1 + src_tensor = None + for cand in generate_name_candidates(name, reverse_mapping): + if cand in lookup: + src_tensor = lookup[cand] + break + if src_tensor is None: + for cand in generate_name_candidates(name, reverse_mapping): + src_tensor = fuse_qkv(lookup, cand) + if src_tensor is not None: + break + if src_tensor is None: + for cand in generate_name_candidates(name, reverse_mapping): + src_tensor = fuse_gate_up_proj(lookup, cand) + if src_tensor is not None: + break + if src_tensor is None: + unmatched_details.append(f"{name}: no matching source tensor") + continue + shard_context = shard_contexts.get(name) + shard_world_size = ( + shard_context.world_size if shard_context is not None else tp_world + ) + shard_rank = shard_context.rank if shard_context is not None else rank + if copy_tensor(tensor, src_tensor, shard_world_size, shard_rank): + matched += 1 + else: + unmatched_details.append( + f"{name}: source {list(src_tensor.shape)} -> target {list(tensor.shape)} unsupported for shard_world_size={shard_world_size}" + ) + + for name, tensor in target.named_buffers(): + src_tensor = None + for cand in generate_name_candidates(name, reverse_mapping): + if cand in lookup: + src_tensor = lookup[cand] + break + if src_tensor is None: + continue + shard_context = shard_contexts.get(name) + shard_world_size = ( + shard_context.world_size if shard_context is not None else tp_world + ) + shard_rank = shard_context.rank if shard_context is not None else rank + copy_tensor(tensor, src_tensor, shard_world_size, shard_rank) + + ratio = matched / max(total, 1) + logger.info( + "Weight transfer: %s/%s matched (%.2f%%).", matched, total, ratio * 100 + ) + if ratio < min_match_ratio: + if rank == 0 and unmatched_details: + logger.error( + "Unmatched parameter details:\n%s", "\n".join(unmatched_details) + ) + raise RuntimeError( + f"Weight transfer matched {matched}/{total} ({ratio:.2%}); below threshold {min_match_ratio:.2%}." + ) + + @staticmethod + def run_component_pair_native( + case: DiffusionTestCase, + component: ComponentType, + sgl_model: nn.Module, + ref_model: nn.Module, + device: str, + ) -> Tuple[torch.Tensor, torch.Tensor]: + if component == ComponentType.TEXT_ENCODER: + raise ValueError("Text encoder path is not migrated to native hooks yet") + profile = resolve_component_native_profile(component) + + inputs = profile.build_inputs(case, sgl_model, device, ref_model) + sgl_call = profile.prepare_sglang_call(sgl_model, inputs) + ref_call = profile.prepare_reference_call(ref_model, inputs) + + with torch.no_grad(): + sgl_raw = AccuracyEngine._execute_with_native_hook(sgl_call) + ref_raw = AccuracyEngine._execute_with_native_hook(ref_call) + + sgl_out = profile.normalize_sglang_output(sgl_raw) + ref_out = profile.normalize_reference_output(ref_raw) + sgl_out = AccuracyEngine._apply_output_transforms(sgl_out, sgl_call) + ref_out = AccuracyEngine._apply_output_transforms(ref_out, ref_call) + return sgl_out, ref_out + + @staticmethod + def load_component_pair( + case: DiffusionTestCase, + component: ComponentType, + library: str, + num_gpus: int, + materialize_sgl_on_device: bool = True, + materialize_ref_on_device: bool = True, + ) -> Tuple[nn.Module, nn.Module, str]: + spec = COMPONENT_SPECS[component] + if library != spec.reference_library: + logger.warning( + "Overriding library '%s' with '%s' for component '%s'.", + library, + spec.reference_library, + component.value, + ) + library = spec.reference_library + hub_id = case.server_args.model_path + component_selection = select_component_source( + hub_id, + case.server_args.extras, + component, + spec.model_index_keys, + ) + sgl_args = build_accuracy_server_args( + component_selection.base_model_id, + component_selection.base_model_root, + case, + component, + num_gpus, + component_selection.component_paths, + ) + initialize_parallel_runtime(sgl_args) + set_global_server_args(sgl_args) + + device = get_local_torch_device() + + sgl_component = _load_sglang_component( + component_selection.source_path, + sgl_args, + component, + library, + text_encoder_cpu_offload=( + False + if component != ComponentType.TEXT_ENCODER or materialize_sgl_on_device + else True + ), + ) + if materialize_sgl_on_device: + sgl_component = sgl_component.to(device=device, dtype=torch.bfloat16) + + ref_component = _load_reference_component( + component_selection.source_path, + component_selection.source_root, + component, + hub_id, + sgl_args.pipeline_config, + component_selection.source_subfolder, + ) + if materialize_ref_on_device: + ref_component = ref_component.to(device=device, dtype=torch.bfloat16) + + if component == ComponentType.TRANSFORMER and "wan" in hub_id.lower(): + fc_mod._forward_context = ForwardContext( + current_timestep=0, attn_metadata=None + ) + + ref_for_transfer = ref_component + if ( + component == ComponentType.TEXT_ENCODER + and getattr(ref_component, "shared", None) is None + ): + ref_for_transfer = resolve_text_encoder_forward_module(ref_component) + AccuracyEngine.transfer_weights( + ref_for_transfer, + sgl_component, + target_device=( + device if materialize_sgl_on_device else torch.device("cpu") + ), + ) + + if component != ComponentType.VAE: + if not hasattr(fc_mod._forward_context, "attn_metadata"): + fc_mod._forward_context = ForwardContext( + current_timestep=int(DEFAULT_TIMESTEP), attn_metadata=None + ) + + return sgl_component.eval(), ref_component.eval(), str(device) diff --git a/python/sglang/multimodal_gen/test/server/perf_baselines.json b/python/sglang/multimodal_gen/test/server/perf_baselines.json index 5758575e03f5..0be6968646ea 100644 --- a/python/sglang/multimodal_gen/test/server/perf_baselines.json +++ b/python/sglang/multimodal_gen/test/server/perf_baselines.json @@ -13,9 +13,9 @@ "denoise_agg": 0.1 }, "pr_test": { - "e2e": 0.15, - "denoise_stage": 0.1, - "non_denoise_stage": 0.6, + "e2e": 0.2, + "denoise_stage": 0.2, + "non_denoise_stage": 0.8, "denoise_step": 0.25, "denoise_agg": 0.15 } @@ -607,7 +607,7 @@ "InputValidationStage": 0.04, "TextEncodingStage": 428.59, "LatentPreparationStage": 0.14, - "TimestepPreparationStage": 47.26, + "TimestepPreparationStage": 422.21, "DenoisingStage": 778.56, "DecodingStage": 10.39 }, @@ -1317,60 +1317,60 @@ }, "wan2_2_i2v_a14b_2gpu": { "stages_ms": { - "InputValidationStage": 18.45, - "TextEncodingStage": 3337.77, - "TimestepPreparationStage": 2.9, - "LatentPreparationStage": 1.25, - "ImageVAEEncodingStage": 1655.89, - "DenoisingStage": 106972.82, - "DecodingStage": 1355.52, + "InputValidationStage": 27.74, + "TextEncodingStage": 1121.93, + "ImageVAEEncodingStage": 1889.26, + "LatentPreparationStage": 0.44, + "TimestepPreparationStage": 6.39, + "DenoisingStage": 137454.52, + "DecodingStage": 2287.25, "per_frame_generation": null }, "denoise_step_ms": { - "0": 1525.6, - "1": 1582.6, - "2": 1597.84, - "3": 1601.34, - "4": 1600.86, - "5": 1598.32, - "6": 1600.93, - "7": 1599.88, - "8": 1600.0, - "9": 1600.55, - "10": 1599.27, - "11": 1600.59, - "12": 1600.17, - "13": 1599.72, - "14": 1599.76, - "15": 24098.85, - "16": 1601.29, - "17": 1598.89, - "18": 1600.12, - "19": 1600.52, - "20": 1599.59, - "21": 1600.37, - "22": 1600.35, - "23": 1599.7, - "24": 1599.92, - "25": 1599.75, - "26": 1600.2, - "27": 1600.06, - "28": 1600.41, - "29": 1599.35, - "30": 1600.69, - "31": 1600.15, - "32": 1599.33, - "33": 1599.86, - "34": 1600.52, - "35": 1599.84, - "36": 1600.38, - "37": 1599.23, - "38": 1600.27, - "39": 1599.78 - }, - "expected_e2e_ms": 123182.9887, - "expected_avg_denoise_ms": 2831.0, - "expected_median_denoise_ms": 1600.09 + "0": 2231.66, + "1": 3489.95, + "2": 3436.66, + "3": 3407.31, + "4": 3422.63, + "5": 3417.48, + "6": 3425.34, + "7": 3423.93, + "8": 3429.36, + "9": 3431.95, + "10": 3435.35, + "11": 3430.29, + "12": 3435.09, + "13": 3436.59, + "14": 3436.94, + "15": 4835.04, + "16": 3416.6, + "17": 3427.03, + "18": 3421.59, + "19": 3427.95, + "20": 3427.21, + "21": 3428.96, + "22": 3430.96, + "23": 3431.29, + "24": 3430.44, + "25": 3430.09, + "26": 3432.23, + "27": 3430.61, + "28": 3430.51, + "29": 3427.92, + "30": 3429.01, + "31": 3430.05, + "32": 3429.63, + "33": 3426.97, + "34": 3426.71, + "35": 3428.44, + "36": 3427.1, + "37": 3425.52, + "38": 3422.81, + "39": 3403.77 + }, + "expected_e2e_ms": 144621.32, + "expected_avg_denoise_ms": 3434.22, + "expected_median_denoise_ms": 3428.99 }, "turbo_wan2_2_i2v_a14b_2gpu": { "stages_ms": { @@ -1983,7 +1983,7 @@ "8": 261.6 }, "expected_e2e_ms": 3541.48, - "expected_avg_denoise_ms": 241.68, + "expected_avg_denoise_ms": 288.82, "expected_median_denoise_ms": 262.05 }, "hunyuan3d_shape_gen": { @@ -2375,220 +2375,71 @@ "expected_avg_denoise_ms": 0.0, "expected_median_denoise_ms": 0.0 }, - "helios_base_t2v": { - "stages_ms": { - "InputValidationStage": 0.04, - "TextEncodingStage": 1102.45, - "LatentPreparationStage": 0.14, - "HeliosChunkedDenoisingStage": 116964.69, - "HeliosDecodingStage": 664.76, - "per_frame_generation": null - }, - "denoise_step_ms": { - "0": 1893.3, - "1": 1900.93, - "2": 1934.08, - "3": 1897.65, - "4": 1907.59, - "5": 1909.1, - "6": 1911.51, - "7": 1909.25, - "8": 1911.69, - "9": 1911.77, - "10": 1913.35, - "11": 1915.44, - "12": 1912.11, - "13": 1910.08, - "14": 1911.77, - "15": 1908.22, - "16": 1908.83, - "17": 1910.11, - "18": 1908.19, - "19": 1911.99, - "20": 1909.96, - "21": 1910.32, - "22": 1911.76, - "23": 1911.87, - "24": 1908.91, - "25": 1912.41, - "26": 1913.15, - "27": 1908.34, - "28": 1913.21, - "29": 1911.98, - "30": 1912.16, - "31": 1914.17, - "32": 1911.45, - "33": 1912.5, - "34": 1914.48, - "35": 1912.64, - "36": 1912.24, - "37": 1914.48, - "38": 1911.06, - "39": 1915.45, - "40": 1914.0, - "41": 1912.99, - "42": 1913.68, - "43": 1914.09, - "44": 1915.83, - "45": 1913.36, - "46": 1914.84, - "47": 1915.31, - "48": 1915.58, - "49": 1912.63 - }, - "expected_e2e_ms": 118821.41, - "expected_avg_denoise_ms": 1911.64, - "expected_median_denoise_ms": 1912.05 - }, - "helios_mid_t2v": { + "flux_2_nvfp4_t2i": { "stages_ms": { "InputValidationStage": 0.09, - "TextEncodingStage": 1102.28, - "LatentPreparationStage": 0.23, - "HeliosChunkedDenoisingStage": 77947.9, - "HeliosDecodingStage": 664.96, - "per_frame_generation": null - }, - "denoise_step_ms": { - "0": 404.46, - "1": 404.88, - "2": 405.35, - "3": 406.01, - "4": 404.97, - "5": 405.07, - "6": 405.06, - "7": 404.98, - "8": 405.39, - "9": 405.52, - "10": 405.76, - "11": 405.53, - "12": 405.16, - "13": 405.46, - "14": 405.75, - "15": 405.69, - "16": 405.26, - "17": 405.23, - "18": 405.42, - "19": 405.99, - "20": 663.39, - "21": 666.6, - "22": 665.73, - "23": 666.37, - "24": 667.43, - "25": 668.28, - "26": 667.96, - "27": 668.93, - "28": 667.78, - "29": 668.15, - "30": 668.91, - "31": 667.22, - "32": 669.31, - "33": 666.57, - "34": 669.78, - "35": 668.38, - "36": 669.95, - "37": 668.76, - "38": 667.82, - "39": 668.98, - "40": 1891.05, - "41": 1893.52, - "42": 1893.48, - "43": 1892.79, - "44": 1892.03, - "45": 1892.87, - "46": 1895.55, - "47": 1892.19, - "48": 1892.89, - "49": 1892.32, - "50": 1890.25, - "51": 1894.1, - "52": 1890.67, - "53": 1892.09, - "54": 1892.64, - "55": 1891.91, - "56": 1894.27, - "57": 1893.62, - "58": 1892.65, - "59": 1891.9 - }, - "expected_e2e_ms": 79824.32, - "expected_avg_denoise_ms": 988.6, - "expected_median_denoise_ms": 668.05 - }, - "helios_distilled_t2v": { - "stages_ms": { - "InputValidationStage": 0.05, - "TextEncodingStage": 552.02, - "LatentPreparationStage": 0.13, - "HeliosChunkedDenoisingStage": 57879.88, - "HeliosDecodingStage": 663.31, - "per_frame_generation": null + "TextEncodingStage": 458.68, + "ImageVAEEncodingStage": 0.01, + "LatentPreparationStage": 0.54, + "TimestepPreparationStage": 20.88, + "DenoisingStage": 7189.58, + "DecodingStage": 13.55 }, "denoise_step_ms": { - "0": 207.03, - "1": 204.36, - "2": 203.87, - "3": 204.51, - "4": 206.21, - "5": 205.54, - "6": 205.06, - "7": 205.45, - "8": 205.96, - "9": 205.95, - "10": 205.22, - "11": 204.43, - "12": 205.14, - "13": 205.06, - "14": 205.11, - "15": 206.09, - "16": 205.1, - "17": 204.99, - "18": 204.55, - "19": 205.14, - "20": 337.47, - "21": 337.06, - "22": 337.68, - "23": 336.58, - "24": 335.98, - "25": 335.84, - "26": 336.01, - "27": 335.61, - "28": 335.79, - "29": 335.62, - "30": 336.69, - "31": 335.98, - "32": 336.15, - "33": 336.55, - "34": 336.98, - "35": 337.33, - "36": 336.34, - "37": 335.94, - "38": 336.69, - "39": 336.14, - "40": 954.88, - "41": 956.2, - "42": 953.9, - "43": 953.49, - "44": 957.1, - "45": 956.95, - "46": 955.02, - "47": 954.98, - "48": 956.0, - "49": 956.63, - "50": 958.66, - "51": 957.26, - "52": 956.73, - "53": 955.06, - "54": 957.04, - "55": 958.07, - "56": 958.28, - "57": 957.99, - "58": 957.61, - "59": 956.98 - }, - "expected_e2e_ms": 59168.9, - "expected_avg_denoise_ms": 499.37, - "expected_median_denoise_ms": 336.25 + "0": 94.51, + "1": 88.17, + "2": 132.17, + "3": 141.26, + "4": 142.49, + "5": 141.41, + "6": 142.62, + "7": 140.92, + "8": 141.08, + "9": 142.68, + "10": 139.88, + "11": 144.98, + "12": 144.37, + "13": 142.0, + "14": 142.49, + "15": 141.24, + "16": 141.05, + "17": 140.69, + "18": 141.48, + "19": 141.92, + "20": 146.34, + "21": 147.32, + "22": 140.68, + "23": 141.09, + "24": 142.51, + "25": 140.83, + "26": 145.73, + "27": 148.47, + "28": 144.86, + "29": 140.83, + "30": 144.76, + "31": 145.36, + "32": 140.58, + "33": 144.49, + "34": 142.65, + "35": 141.86, + "36": 148.19, + "37": 145.5, + "38": 145.68, + "39": 143.64, + "40": 143.7, + "41": 153.8, + "42": 148.57, + "43": 143.2, + "44": 144.15, + "45": 142.11, + "46": 146.38, + "47": 146.97, + "48": 144.62, + "49": 146.38 + }, + "expected_e2e_ms": 8091.46, + "expected_avg_denoise_ms": 141.37, + "expected_median_denoise_ms": 142.63 } } } diff --git a/python/sglang/multimodal_gen/test/server/test_accuracy_1_gpu_a.py b/python/sglang/multimodal_gen/test/server/test_accuracy_1_gpu_a.py new file mode 100644 index 000000000000..409acab1ee7f --- /dev/null +++ b/python/sglang/multimodal_gen/test/server/test_accuracy_1_gpu_a.py @@ -0,0 +1,37 @@ +import pytest + +from sglang.multimodal_gen.test.server.accuracy_config import ( + ComponentType, + get_skip_reason, + should_skip_component, +) +from sglang.multimodal_gen.test.server.accuracy_utils import ( + run_native_component_accuracy_case, + run_text_encoder_accuracy_case, +) +from sglang.multimodal_gen.test.server.component_accuracy import AccuracyEngine +from sglang.multimodal_gen.test.server.testcase_configs import ONE_GPU_CASES_A + + +@pytest.mark.parametrize("case", ONE_GPU_CASES_A, ids=lambda x: x.id) +class TestAccuracy1GPU_A: + """1-GPU Component Accuracy Suite (Set A).""" + + def test_vae_accuracy(self, case): + if should_skip_component(case, ComponentType.VAE): + pytest.skip(get_skip_reason(case, ComponentType.VAE)) + run_native_component_accuracy_case( + AccuracyEngine, case, ComponentType.VAE, "diffusers", 1 + ) + + def test_transformer_accuracy(self, case): + if should_skip_component(case, ComponentType.TRANSFORMER): + pytest.skip(get_skip_reason(case, ComponentType.TRANSFORMER)) + run_native_component_accuracy_case( + AccuracyEngine, case, ComponentType.TRANSFORMER, "diffusers", 1 + ) + + def test_encoder_accuracy(self, case): + if should_skip_component(case, ComponentType.TEXT_ENCODER): + pytest.skip(get_skip_reason(case, ComponentType.TEXT_ENCODER)) + run_text_encoder_accuracy_case(AccuracyEngine, case, 1) diff --git a/python/sglang/multimodal_gen/test/server/test_accuracy_1_gpu_b.py b/python/sglang/multimodal_gen/test/server/test_accuracy_1_gpu_b.py new file mode 100644 index 000000000000..0c0c7cb51f94 --- /dev/null +++ b/python/sglang/multimodal_gen/test/server/test_accuracy_1_gpu_b.py @@ -0,0 +1,37 @@ +import pytest + +from sglang.multimodal_gen.test.server.accuracy_config import ( + ComponentType, + get_skip_reason, + should_skip_component, +) +from sglang.multimodal_gen.test.server.accuracy_utils import ( + run_native_component_accuracy_case, + run_text_encoder_accuracy_case, +) +from sglang.multimodal_gen.test.server.component_accuracy import AccuracyEngine +from sglang.multimodal_gen.test.server.testcase_configs import ONE_GPU_CASES_B + + +@pytest.mark.parametrize("case", ONE_GPU_CASES_B, ids=lambda x: x.id) +class TestAccuracy1GPU_B: + """1-GPU Component Accuracy Suite (Set B).""" + + def test_vae_accuracy(self, case): + if should_skip_component(case, ComponentType.VAE): + pytest.skip(get_skip_reason(case, ComponentType.VAE)) + run_native_component_accuracy_case( + AccuracyEngine, case, ComponentType.VAE, "diffusers", 1 + ) + + def test_transformer_accuracy(self, case): + if should_skip_component(case, ComponentType.TRANSFORMER): + pytest.skip(get_skip_reason(case, ComponentType.TRANSFORMER)) + run_native_component_accuracy_case( + AccuracyEngine, case, ComponentType.TRANSFORMER, "diffusers", 1 + ) + + def test_encoder_accuracy(self, case): + if should_skip_component(case, ComponentType.TEXT_ENCODER): + pytest.skip(get_skip_reason(case, ComponentType.TEXT_ENCODER)) + run_text_encoder_accuracy_case(AccuracyEngine, case, 1) diff --git a/python/sglang/multimodal_gen/test/server/test_accuracy_2_gpu_a.py b/python/sglang/multimodal_gen/test/server/test_accuracy_2_gpu_a.py new file mode 100644 index 000000000000..c40929bf45ee --- /dev/null +++ b/python/sglang/multimodal_gen/test/server/test_accuracy_2_gpu_a.py @@ -0,0 +1,37 @@ +import pytest + +from sglang.multimodal_gen.test.server.accuracy_config import ( + ComponentType, + get_skip_reason, + should_skip_component, +) +from sglang.multimodal_gen.test.server.accuracy_utils import ( + run_native_component_accuracy_case, + run_text_encoder_accuracy_case, +) +from sglang.multimodal_gen.test.server.component_accuracy import AccuracyEngine +from sglang.multimodal_gen.test.server.testcase_configs import TWO_GPU_CASES_A + + +@pytest.mark.parametrize("case", TWO_GPU_CASES_A, ids=lambda x: x.id) +class TestAccuracy2GPU_A: + """2-GPU Component Accuracy Suite (Set A).""" + + def test_vae_accuracy(self, case): + if should_skip_component(case, ComponentType.VAE): + pytest.skip(get_skip_reason(case, ComponentType.VAE)) + run_native_component_accuracy_case( + AccuracyEngine, case, ComponentType.VAE, "diffusers", 2 + ) + + def test_transformer_accuracy(self, case): + if should_skip_component(case, ComponentType.TRANSFORMER): + pytest.skip(get_skip_reason(case, ComponentType.TRANSFORMER)) + run_native_component_accuracy_case( + AccuracyEngine, case, ComponentType.TRANSFORMER, "diffusers", 2 + ) + + def test_encoder_accuracy(self, case): + if should_skip_component(case, ComponentType.TEXT_ENCODER): + pytest.skip(get_skip_reason(case, ComponentType.TEXT_ENCODER)) + run_text_encoder_accuracy_case(AccuracyEngine, case, 2) diff --git a/python/sglang/multimodal_gen/test/server/test_accuracy_2_gpu_b.py b/python/sglang/multimodal_gen/test/server/test_accuracy_2_gpu_b.py new file mode 100644 index 000000000000..9b3cc1190473 --- /dev/null +++ b/python/sglang/multimodal_gen/test/server/test_accuracy_2_gpu_b.py @@ -0,0 +1,37 @@ +import pytest + +from sglang.multimodal_gen.test.server.accuracy_config import ( + ComponentType, + get_skip_reason, + should_skip_component, +) +from sglang.multimodal_gen.test.server.accuracy_utils import ( + run_native_component_accuracy_case, + run_text_encoder_accuracy_case, +) +from sglang.multimodal_gen.test.server.component_accuracy import AccuracyEngine +from sglang.multimodal_gen.test.server.testcase_configs import TWO_GPU_CASES_B + + +@pytest.mark.parametrize("case", TWO_GPU_CASES_B, ids=lambda x: x.id) +class TestAccuracy2GPU_B: + """2-GPU Component Accuracy Suite (Set B).""" + + def test_vae_accuracy(self, case): + if should_skip_component(case, ComponentType.VAE): + pytest.skip(get_skip_reason(case, ComponentType.VAE)) + run_native_component_accuracy_case( + AccuracyEngine, case, ComponentType.VAE, "diffusers", 2 + ) + + def test_transformer_accuracy(self, case): + if should_skip_component(case, ComponentType.TRANSFORMER): + pytest.skip(get_skip_reason(case, ComponentType.TRANSFORMER)) + run_native_component_accuracy_case( + AccuracyEngine, case, ComponentType.TRANSFORMER, "diffusers", 2 + ) + + def test_encoder_accuracy(self, case): + if should_skip_component(case, ComponentType.TEXT_ENCODER): + pytest.skip(get_skip_reason(case, ComponentType.TEXT_ENCODER)) + run_text_encoder_accuracy_case(AccuracyEngine, case, 2) diff --git a/python/sglang/multimodal_gen/test/server/test_server_c.py b/python/sglang/multimodal_gen/test/server/test_server_c.py new file mode 100644 index 000000000000..b5b4ecc81fa5 --- /dev/null +++ b/python/sglang/multimodal_gen/test/server/test_server_c.py @@ -0,0 +1,28 @@ +""" +Config-driven diffusion performance test with pytest parametrization. +""" + +from __future__ import annotations + +import pytest + +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger +from sglang.multimodal_gen.test.server.test_server_common import ( # noqa: F401 + DiffusionServerBase, + diffusion_server, +) +from sglang.multimodal_gen.test.server.testcase_configs import ( + ONE_GPU_CASES_C, + DiffusionTestCase, +) + +logger = init_logger(__name__) + + +class TestDiffusionServerOneGpuB200(DiffusionServerBase): + """B200-targeted smoke tests for 1-GPU diffusion cases.""" + + @pytest.fixture(params=ONE_GPU_CASES_C, ids=lambda c: c.id) + def case(self, request) -> DiffusionTestCase: + """Provide a DiffusionTestCase for each 1-GPU B200 test.""" + return request.param diff --git a/python/sglang/multimodal_gen/test/server/test_server_utils.py b/python/sglang/multimodal_gen/test/server/test_server_utils.py index daf908c128e2..ec8340327ed7 100644 --- a/python/sglang/multimodal_gen/test/server/test_server_utils.py +++ b/python/sglang/multimodal_gen/test/server/test_server_utils.py @@ -55,6 +55,28 @@ MESH_OUTPUT_PATHS: dict[str, str] = {} +def _urlopen_with_retry(url: str, timeout: int = 30, max_retries: int = 3) -> bytes: + """Download content from a URL with retry on transient failures.""" + for attempt in range(max_retries + 1): + try: + with urlopen(url, timeout=timeout) as response: + return response.read() + except (TimeoutError, OSError) as e: + if attempt < max_retries: + wait = 2**attempt + logger.warning( + f"Download attempt {attempt + 1}/{max_retries + 1} failed " + f"for {url}: {e}. Retrying in {wait}s..." + ) + time.sleep(wait) + else: + logger.error( + f"Failed to download from {url} after " + f"{max_retries + 1} attempts: {e}" + ) + raise + + def download_image_from_url(url: str) -> Path: """Download an image from a URL to a temporary file. @@ -76,14 +98,10 @@ def download_image_from_url(url: str) -> Path: Path(tempfile.gettempdir()) / f"diffusion_test_image_{int(time.time())}{ext}" ) - try: - with urlopen(url, timeout=30) as response: - temp_file.write_bytes(response.read()) - logger.info(f"Downloaded image to: {temp_file}") - return temp_file - except Exception as e: - logger.error(f"Failed to download image from {url}: {e}") - raise + data = _urlopen_with_retry(url) + temp_file.write_bytes(data) + logger.info(f"Downloaded image to: {temp_file}") + return temp_file def parse_dimensions(size_string: str | None) -> tuple[int | None, int | None]: @@ -664,8 +682,7 @@ def _download_reference_mesh(url: str) -> Path: return cache_path logger.info(f"Downloading reference mesh from: {url}") - with urlopen(url, timeout=60) as resp: - cache_path.write_bytes(resp.read()) + cache_path.write_bytes(_urlopen_with_retry(url, timeout=60)) logger.info(f"Reference mesh cached at: {cache_path}") return cache_path diff --git a/python/sglang/multimodal_gen/test/server/testcase_configs.py b/python/sglang/multimodal_gen/test/server/testcase_configs.py index 8a63fb68f82b..638fc78f6fb9 100644 --- a/python/sglang/multimodal_gen/test/server/testcase_configs.py +++ b/python/sglang/multimodal_gen/test/server/testcase_configs.py @@ -798,7 +798,7 @@ def from_req_perf_record( ) ) -# TODO: enable on 4090/5090/b200 +# TODO: enable on 4090/5090 ONE_GPU_CASES_C = [ DiffusionTestCase( "flux_2_nvfp4_t2i", diff --git a/python/sglang/multimodal_gen/test/slack_utils.py b/python/sglang/multimodal_gen/test/slack_utils.py index 34ba78371d75..880417265063 100644 --- a/python/sglang/multimodal_gen/test/slack_utils.py +++ b/python/sglang/multimodal_gen/test/slack_utils.py @@ -130,7 +130,7 @@ def upload_file_to_slack( try: suffix = os.path.splitext(urlparse(path).path)[1] or ".tmp" with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tf: - with urlopen(path) as response: + with urlopen(path, timeout=30) as response: tf.write(response.read()) temp_paths.append(tf.name) final_origin_paths.append(tf.name) @@ -155,7 +155,7 @@ def upload_file_to_slack( f"*Case ID:* `{case_id}`\n" f"*Model:* `{model}`\n" f"*Prompt:* {prompt}" ) - client = WebClient(token=token) + client = WebClient(token=token, timeout=60) channel_id = "C0A02NDF7UY" thread_ts = None diff --git a/python/sglang/multimodal_gen/test/unit/test_input_validation.py b/python/sglang/multimodal_gen/test/unit/test_input_validation.py new file mode 100644 index 000000000000..75bd30bf7dbe --- /dev/null +++ b/python/sglang/multimodal_gen/test/unit/test_input_validation.py @@ -0,0 +1,164 @@ +"""Unit tests for InputValidationStage.preprocess_condition_image resolution logic.""" + +import unittest +from unittest.mock import MagicMock, patch + +from PIL import Image + +from sglang.multimodal_gen.configs.pipeline_configs.wan import ( + WanI2V480PConfig, + WanI2V720PConfig, +) +from sglang.multimodal_gen.configs.sample.sampling_params import SamplingParams +from sglang.multimodal_gen.runtime.pipelines_core.schedule_batch import Req +from sglang.multimodal_gen.runtime.pipelines_core.stages.input_validation import ( + InputValidationStage, +) + +# Patch path for get_global_server_args used by Stage.__init__ +_GLOBAL_ARGS_PATCH = ( + "sglang.multimodal_gen.runtime.pipelines_core.stages.base.get_global_server_args" +) + + +def _make_batch(condition_image: Image.Image, width=None, height=None) -> Req: + """Create a minimal Req with a condition image and optional user dimensions.""" + sp = SamplingParams( + seed=42, + num_outputs_per_prompt=1, + width=width, + height=height, + ) + batch = Req(sampling_params=sp, condition_image=condition_image) + return batch + + +def _make_server_args(pipeline_config): + """Create a mock ServerArgs with the given pipeline config.""" + sa = MagicMock() + sa.pipeline_config = pipeline_config + return sa + + +class TestCalculateDimensionsFromArea(unittest.TestCase): + """Tests for InputValidationStage._calculate_dimensions_from_area.""" + + def test_square_aspect_ratio(self): + # area=921600, aspect=1.0, mod=16 → sqrt(921600)=~960 + w, h = InputValidationStage._calculate_dimensions_from_area(921600, 1.0, 16) + self.assertEqual(w % 16, 0) + self.assertEqual(h % 16, 0) + self.assertEqual((w, h), (960, 960)) + + def test_16_9_aspect_ratio(self): + # aspect = 720/1280 = 0.5625 + w, h = InputValidationStage._calculate_dimensions_from_area(921600, 9 / 16, 16) + self.assertEqual(w % 16, 0) + self.assertEqual(h % 16, 0) + self.assertEqual((w, h), (1280, 720)) + + def test_9_16_aspect_ratio(self): + w, h = InputValidationStage._calculate_dimensions_from_area(921600, 16 / 9, 16) + self.assertEqual(w % 16, 0) + self.assertEqual(h % 16, 0) + self.assertEqual((w, h), (720, 1280)) + + def test_mod_alignment(self): + # Ensure dimensions are always multiples of mod_value + w, h = InputValidationStage._calculate_dimensions_from_area(500000, 1.3, 16) + self.assertEqual(w % 16, 0) + self.assertEqual(h % 16, 0) + + +class TestPreprocessConditionImageResolution(unittest.TestCase): + """Tests for the WanI2V480PConfig branch of preprocess_condition_image. + + Verifies that: + - Aspect ratio always comes from the condition image + - User-specified width/height controls target area (scale) + - Output is clamped to max_area when user dimensions exceed it + - Dimensions are always mod-aligned + """ + + def setUp(self): + with patch(_GLOBAL_ARGS_PATCH, return_value=MagicMock()): + self.stage = InputValidationStage() + + def _run(self, config, img_w, img_h, user_w=None, user_h=None): + """Run preprocess_condition_image and return (batch.width, batch.height).""" + img = Image.new("RGB", (img_w, img_h), color="red") + batch = _make_batch(img, width=user_w, height=user_h) + server_args = _make_server_args(config) + self.stage.preprocess_condition_image(batch, server_args, img_w, img_h) + return batch.width, batch.height + + def test_720p_no_user_dims_16_9_image(self): + """16:9 image, no user dims → 1280Ɨ720.""" + w, h = self._run(WanI2V720PConfig(), 1920, 1080) + self.assertEqual((w, h), (1280, 720)) + + def test_720p_no_user_dims_9_16_image(self): + """9:16 image, no user dims → 720Ɨ1280.""" + w, h = self._run(WanI2V720PConfig(), 1080, 1920) + self.assertEqual((w, h), (720, 1280)) + + def test_720p_no_user_dims_square_image(self): + """Square image, no user dims → ~960Ɨ960 (max_area=921600, sqrtā‰ˆ960).""" + w, h = self._run(WanI2V720PConfig(), 1024, 1024) + self.assertEqual((w, h), (960, 960)) + self.assertEqual(w % 16, 0) + + def test_720p_user_dims_equal_max_area_16_9_image(self): + """16:9 image + user 1280Ɨ720 (=max_area) → 1280Ɨ720.""" + w, h = self._run(WanI2V720PConfig(), 1920, 1080, 1280, 720) + self.assertEqual((w, h), (1280, 720)) + + def test_720p_user_dims_equal_max_area_square_image(self): + """Square image + user 1280Ɨ720 → still square (~960Ɨ960) because + aspect ratio comes from image, not from user dimensions.""" + w, h = self._run(WanI2V720PConfig(), 1024, 1024, 1280, 720) + self.assertEqual((w, h), (960, 960)) + + def test_720p_user_dims_smaller_area(self): + """Square image + user 832Ɨ480 → smaller square (target_area=399360).""" + w, h = self._run(WanI2V720PConfig(), 1024, 1024, 832, 480) + self.assertEqual((w, h), (624, 624)) + self.assertEqual(w % 16, 0) + + def test_720p_user_dims_exceed_max_area(self): + """4K request clamped to max_area.""" + w, h = self._run(WanI2V720PConfig(), 1920, 1080, 3840, 2160) + self.assertEqual(w % 16, 0) + self.assertEqual(h % 16, 0) + self.assertEqual((w, h), (1280, 720)) + + def test_480p_no_user_dims_16_9_image(self): + """480p config, 16:9 image → area-based calc from max_area=399360.""" + w, h = self._run(WanI2V480PConfig(), 1920, 1080) + # max_area=480*832=399360, aspect=9/16 → (832, 464) due to rounding + self.assertEqual(w % 16, 0) + self.assertEqual(h % 16, 0) + self.assertEqual((w, h), (832, 464)) + + def test_condition_image_resized_to_output_dims(self): + """Condition image is resized to match output dimensions.""" + img = Image.new("RGB", (1920, 1080), color="blue") + batch = _make_batch(img) + server_args = _make_server_args(WanI2V720PConfig()) + self.stage.preprocess_condition_image(batch, server_args, 1920, 1080) + self.assertEqual(batch.condition_image.size, (batch.width, batch.height)) + + def test_list_condition_image_takes_first(self): + """List of condition images → uses first one.""" + img1 = Image.new("RGB", (1920, 1080), color="red") + img2 = Image.new("RGB", (800, 600), color="green") + batch = _make_batch(img1) + batch.condition_image = [img1, img2] + server_args = _make_server_args(WanI2V720PConfig()) + self.stage.preprocess_condition_image(batch, server_args, 1920, 1080) + self.assertIsInstance(batch.condition_image, Image.Image) + self.assertEqual((batch.width, batch.height), (1280, 720)) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/sglang/multimodal_gen/test/unit/test_resolve_prompts.py b/python/sglang/multimodal_gen/test/unit/test_resolve_prompts.py new file mode 100644 index 000000000000..73e5cdb9eb2e --- /dev/null +++ b/python/sglang/multimodal_gen/test/unit/test_resolve_prompts.py @@ -0,0 +1,99 @@ +import os +import tempfile +import unittest +from types import SimpleNamespace + +from sglang.multimodal_gen.runtime.entrypoints.diffusion_generator import DiffGenerator + + +def _make_generator(prompt_file_path=None): + """Return a DiffGenerator-shaped object with only server_args populated.""" + obj = object.__new__(DiffGenerator) + obj.server_args = SimpleNamespace(prompt_file_path=prompt_file_path) + return obj + + +class TestResolvePrompts(unittest.TestCase): + # ---- inline prompt ---- + def test_none_prompt_returns_space(self): + gen = _make_generator() + self.assertEqual(gen._resolve_prompts(None), [" "]) + + def test_string_prompt(self): + gen = _make_generator() + self.assertEqual(gen._resolve_prompts("hello"), ["hello"]) + + def test_list_prompt(self): + gen = _make_generator() + self.assertEqual(gen._resolve_prompts(["a", "b"]), ["a", "b"]) + + # ---- prompt_path (SamplingParams) ---- + def test_prompt_path_single_line(self): + gen = _make_generator() + with tempfile.NamedTemporaryFile("w", suffix=".txt", delete=False) as f: + f.write("sunset over the ocean\n") + path = f.name + try: + result = gen._resolve_prompts(None, prompt_path=path) + self.assertEqual(result, ["sunset over the ocean"]) + finally: + os.unlink(path) + + def test_prompt_path_multi_line(self): + gen = _make_generator() + with tempfile.NamedTemporaryFile("w", suffix=".txt", delete=False) as f: + f.write("line one\n\nline two\n") + path = f.name + try: + result = gen._resolve_prompts(None, prompt_path=path) + self.assertEqual(result, ["line one", "line two"]) + finally: + os.unlink(path) + + def test_prompt_path_takes_priority_over_server_args(self): + with tempfile.NamedTemporaryFile( + "w", suffix=".txt", delete=False + ) as f1, tempfile.NamedTemporaryFile("w", suffix=".txt", delete=False) as f2: + f1.write("from prompt_path\n") + f2.write("from server_args\n") + path1, path2 = f1.name, f2.name + try: + gen = _make_generator(prompt_file_path=path2) + result = gen._resolve_prompts(None, prompt_path=path1) + self.assertEqual(result, ["from prompt_path"]) + finally: + os.unlink(path1) + os.unlink(path2) + + # ---- prompt_file_path (ServerArgs) ---- + def test_server_args_prompt_file_path(self): + with tempfile.NamedTemporaryFile("w", suffix=".txt", delete=False) as f: + f.write("from server args\n") + path = f.name + try: + gen = _make_generator(prompt_file_path=path) + result = gen._resolve_prompts(None) + self.assertEqual(result, ["from server args"]) + finally: + os.unlink(path) + + # ---- error cases ---- + def test_missing_file_raises(self): + gen = _make_generator() + with self.assertRaises(FileNotFoundError): + gen._resolve_prompts(None, prompt_path="/nonexistent/file.txt") + + def test_empty_file_raises(self): + with tempfile.NamedTemporaryFile("w", suffix=".txt", delete=False) as f: + f.write(" \n\n \n") + path = f.name + try: + gen = _make_generator() + with self.assertRaises(ValueError): + gen._resolve_prompts(None, prompt_path=path) + finally: + os.unlink(path) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/sglang/multimodal_gen/test/unit/test_sampling_params.py b/python/sglang/multimodal_gen/test/unit/test_sampling_params.py index 722d4d12ab26..4ac66ec175e1 100644 --- a/python/sglang/multimodal_gen/test/unit/test_sampling_params.py +++ b/python/sglang/multimodal_gen/test/unit/test_sampling_params.py @@ -7,7 +7,17 @@ ) from sglang.multimodal_gen.configs.sample.flux import FluxSamplingParams from sglang.multimodal_gen.configs.sample.qwenimage import QwenImageSamplingParams -from sglang.multimodal_gen.configs.sample.sampling_params import SamplingParams +from sglang.multimodal_gen.configs.sample.sampling_params import ( + SamplingParams, + _json_safe, +) +from sglang.multimodal_gen.configs.sample.teacache import TeaCacheParams +from sglang.multimodal_gen.configs.sample.wan import ( + WanI2V_14B_480P_SamplingParam, + WanI2V_14B_720P_SamplingParam, + WanT2V_1_3B_SamplingParams, + WanT2V_14B_SamplingParams, +) class TestSamplingParamsValidate(unittest.TestCase): @@ -68,6 +78,59 @@ def test_diffusers_generic_calls_base_post_init(self): with self.assertRaises(AssertionError): DiffusersGenericSamplingParams(num_frames=0) + def test_output_file_name_supports_callable_teacache_params(self): + def coefficients_callback(_: TeaCacheParams) -> list[float]: + return [1.0, 2.0, 3.0, 4.0, 5.0] + + params = SamplingParams( + prompt="callable teacache", + teacache_params=TeaCacheParams( + coefficients_callback=coefficients_callback, + ), + ) + + params._set_output_file_name() + + self.assertTrue(params.output_file_name.endswith(".mp4")) + self.assertIn( + "test_sampling_params.TestSamplingParamsSubclass.test_output_file_name_supports_callable_teacache_params", + _json_safe(coefficients_callback), + ) + + def test_teacache_callback_takes_precedence_over_static_coefficients(self): + def coefficients_callback(_: TeaCacheParams) -> list[float]: + return [9.0, 8.0, 7.0, 6.0, 5.0] + + params = TeaCacheParams( + coefficients=[1.0, 2.0, 3.0, 4.0, 5.0], + coefficients_callback=coefficients_callback, + ) + + self.assertEqual(params.get_coefficients(), [9.0, 8.0, 7.0, 6.0, 5.0]) + + def test_wan_teacache_boundaries_match_legacy_behavior(self): + legacy_equivalent_cases = [ + (WanT2V_1_3B_SamplingParams().teacache_params, False, (5, 50)), + (WanT2V_1_3B_SamplingParams().teacache_params, True, (10, 100)), + (WanT2V_14B_SamplingParams().teacache_params, False, (1, 49)), + (WanT2V_14B_SamplingParams().teacache_params, True, (2, 98)), + (WanI2V_14B_480P_SamplingParam().teacache_params, False, (5, 50)), + (WanI2V_14B_480P_SamplingParam().teacache_params, True, (10, 100)), + (WanI2V_14B_720P_SamplingParam().teacache_params, False, (5, 50)), + (WanI2V_14B_720P_SamplingParam().teacache_params, True, (10, 100)), + ] + + for teacache_params, do_cfg, expected in legacy_equivalent_cases: + with self.subTest( + use_ret_steps=teacache_params.use_ret_steps, + do_cfg=do_cfg, + expected=expected, + ): + self.assertEqual( + teacache_params.get_skip_boundaries(50, do_cfg), + expected, + ) + class TestSamplingParamsCliArgs(unittest.TestCase): def _parse_cli_kwargs(self, argv: list[str]) -> dict: diff --git a/python/sglang/profiler.py b/python/sglang/profiler.py index ebc7a100e24b..8424e7f54bbe 100644 --- a/python/sglang/profiler.py +++ b/python/sglang/profiler.py @@ -42,7 +42,7 @@ def run_profile( # Dump server args. file_path = Path(output_dir) / "server_args.json" if not file_path.exists(): - response = requests.get(url + "/get_server_info") + response = requests.get(url + "/server_info") response.raise_for_status() server_args_data = response.json() with open(file_path, "w") as file: diff --git a/python/sglang/srt/compilation/piecewise_context_manager.py b/python/sglang/srt/compilation/piecewise_context_manager.py index 20a08a9972b9..a49e9ad47a37 100644 --- a/python/sglang/srt/compilation/piecewise_context_manager.py +++ b/python/sglang/srt/compilation/piecewise_context_manager.py @@ -71,6 +71,7 @@ def __init__(self): self.quant_config = None self.moe_layers = None self.moe_fusions = None + self.num_tokens: Optional[int] = None def set_forward_batch(self, forward_batch: ForwardBatch): self.forward_batch = forward_batch @@ -104,6 +105,7 @@ def set_forward_context( quant_config: Any, moe_layers: List[Any], moe_fusions: List[Any], + num_tokens: Optional[int] = None, ): global _forward_context _forward_context = ForwardContext() @@ -112,6 +114,7 @@ def set_forward_context( _forward_context.set_quant_config(quant_config) _forward_context.set_moe_layers(moe_layers) _forward_context.set_moe_fusions(moe_fusions) + _forward_context.num_tokens = num_tokens try: yield finally: diff --git a/python/sglang/srt/configs/load_config.py b/python/sglang/srt/configs/load_config.py index 443ba643d083..2a7a4ea2e33d 100644 --- a/python/sglang/srt/configs/load_config.py +++ b/python/sglang/srt/configs/load_config.py @@ -31,6 +31,7 @@ class LoadFormat(str, enum.Enum): LOCAL_CACHED = "local_cached" FASTSAFETENSORS = "fastsafetensors" PRIVATE = "private" + RUNAI_STREAMER = "runai_streamer" @dataclass diff --git a/python/sglang/srt/configs/model_config.py b/python/sglang/srt/configs/model_config.py index a57cc863f6db..a7f66c8443f5 100644 --- a/python/sglang/srt/configs/model_config.py +++ b/python/sglang/srt/configs/model_config.py @@ -34,6 +34,7 @@ get_hf_text_config, get_sparse_attention_config, ) +from sglang.srt.utils.runai_utils import ObjectStorageModel, is_runai_obj_uri from sglang.utils import is_in_ci logger = logging.getLogger(__name__) @@ -129,6 +130,7 @@ def __init__( self._validate_quantize_and_serve_config() # Get hf config + self._maybe_pull_model_for_runai(self.model_path) self._maybe_pull_model_tokenizer_from_remote() self.model_override_args = json.loads(model_override_args) kwargs = {} @@ -156,7 +158,10 @@ def __init__( "Llama4ForConditionalGeneration", "Step3VLForConditionalGeneration", ] - if self.hf_config.architectures[0] in mm_disabled_models: + if ( + self.hf_config.architectures[0] in mm_disabled_models + and self.model_impl != ModelImpl.TRANSFORMERS + ): enable_multimodal = False logger.info( f"Multimodal is disabled for {self.hf_config.model_type}. To enable it, set --enable-multimodal." @@ -175,14 +180,14 @@ def __init__( self.is_generation = is_generation_model( self.hf_config.architectures, is_embedding ) - self.is_multimodal = enable_multimodal and is_multimodal_model( - self.hf_config.architectures + has_multimodal_subconfig = ( + self.hf_config is not self.hf_text_config + or hasattr(self.hf_config, "vision_config") + or hasattr(self.hf_config, "audio_config") ) - self.is_multimodal_gen = enable_multimodal and is_multimodal_gen_model( - self.hf_config.architectures - ) - self.is_image_gen = enable_multimodal and is_image_gen_model( - self.hf_config.architectures + self.is_multimodal = enable_multimodal and ( + is_multimodal_model(self.hf_config.architectures) + or has_multimodal_subconfig ) self.is_audio_model = enable_multimodal and is_audio_model( self.hf_config.architectures @@ -497,6 +502,11 @@ def _derive_model_shapes(self): self.qk_rope_head_dim = self.hf_config.qk_rope_head_dim self.v_head_dim = self.hf_config.v_head_dim self.qk_nope_head_dim = self.hf_config.qk_nope_head_dim + self.scaling = 1 / math.sqrt(self.qk_nope_head_dim + self.qk_rope_head_dim) + if self.hf_config.rope_scaling: + self.scaling = compute_mla_mscale_scaling( + self.hf_config.rope_scaling, self.scaling + ) elif ( "BailingMoeV2_5ForCausalLM" in self.hf_config.architectures or "BailingMoeForCausalLMNextN" in self.hf_config.architectures @@ -779,6 +789,8 @@ def _parse_quant_hf_config(self): return quant_cfg def _find_quant_modelslim_config(self): + if self.is_draft_model: + return None quant_config_file = Path(self.model_path, "quant_model_description.json") quant_cfg = None if quant_config_file.is_file(): @@ -1147,6 +1159,13 @@ def get_default_sampling_params(self) -> dict[str, Any]: return default_sampling_params + def _maybe_pull_model_for_runai(self, model: str) -> None: + if is_runai_obj_uri(model): + # local path for loading the config + self.model_path = ObjectStorageModel.get_path(model) + # remote path for loading the weights + self.model_weights = model + def _maybe_pull_model_tokenizer_from_remote(self) -> None: """ Pull the model config files to a temporary @@ -1351,14 +1370,6 @@ def is_multimodal_model(model_architectures: List[str]): return False -def is_multimodal_gen_model(model_architectures: List[str]): - return False - - -def is_image_gen_model(model_architectures: List[str]): - return False - - def is_audio_model(model_architectures: List[str]): models = [ "WhisperForConditionalGeneration", diff --git a/python/sglang/srt/constants.py b/python/sglang/srt/constants.py index c9da6b6bb1d5..76d3b1cd3ad5 100644 --- a/python/sglang/srt/constants.py +++ b/python/sglang/srt/constants.py @@ -8,3 +8,5 @@ GPU_MEMORY_TYPE_WEIGHTS, GPU_MEMORY_TYPE_CUDA_GRAPH, ] + +HEALTH_CHECK_RID_PREFIX = "HEALTH_CHECK" diff --git a/python/sglang/srt/disaggregation/base/conn.py b/python/sglang/srt/disaggregation/base/conn.py index 2309e8a83c3c..f7d4092d85dd 100644 --- a/python/sglang/srt/disaggregation/base/conn.py +++ b/python/sglang/srt/disaggregation/base/conn.py @@ -122,13 +122,23 @@ def __init__( @abstractmethod def init( + self, + prefill_dp_rank: int, + ): + """ + Resolve bootstrap metadata and mark the receiver ready for transfer metadata. + """ + ... + + @abstractmethod + def send_metadata( self, kv_indices: npt.NDArray[np.int32], aux_index: Optional[int] = None, state_indices: Optional[List[int]] = None, ): """ - Set req's index metadata locally or notify the prefill server about the kv indices, aux index, and state_indices. + Notify the prefill server about the kv indices, aux index, and state_indices. """ ... diff --git a/python/sglang/srt/disaggregation/common/conn.py b/python/sglang/srt/disaggregation/common/conn.py index 399acde06d3b..26752d52dd54 100644 --- a/python/sglang/srt/disaggregation/common/conn.py +++ b/python/sglang/srt/disaggregation/common/conn.py @@ -148,6 +148,7 @@ def __init__( # These timeout requests should be aborted to release the tree cache. self.bootstrap_timeout = envs.SGLANG_DISAGGREGATION_BOOTSTRAP_TIMEOUT.get() elif self.disaggregation_mode == DisaggregationMode.DECODE: + self.enable_staging: bool = False self.connection_pool: Dict[str, Dict[str, Union[str, int]]] = {} self.connection_lock = threading.Lock() self.required_prefill_response_num_table: Dict[int, int] = {} @@ -216,11 +217,19 @@ def try_ensure_parallel_info(self, bootstrap_addr: str) -> bool: # Sanity checks if info.page_size is not None and info.page_size != self.kv_args.page_size: - raise RuntimeError( - f"Page size mismatch: prefill server has page_size={info.page_size}, " - f"but decode server has page_size={self.kv_args.page_size}. " - f"Both servers must use the same --page-size value." - ) + if self.server_args.enable_hisparse: + # HiSparse: decode host pool page_size=1, prefill device pool page_size >= 1. + # Transfer will use send_kvcache_hisparse with per-token item_lens. + logger.info( + f"HiSparse PD transfer mode: prefill page_size={info.page_size}, " + f"decode host page_size={self.kv_args.page_size}" + ) + else: + raise RuntimeError( + f"Page size mismatch: prefill server has page_size={info.page_size}, " + f"but decode server has page_size={self.kv_args.page_size}. " + f"Both servers must use the same --page-size value." + ) if ( info.kv_cache_dtype is not None @@ -326,7 +335,6 @@ def register_to_bootstrap(self): host = self.bootstrap_host bootstrap_na = NetworkAddress(host, self.bootstrap_port) - bootstrap_server_url = bootstrap_na.to_host_port_str() url = f"{bootstrap_na.to_url()}/route" payload = { "attn_tp_size": self.attn_tp_size, @@ -477,6 +485,14 @@ def poll(self) -> KVPoll: def failure_exception(self): raise Exception("Fake KVReceiver Exception") + def abort(self): + self.kv_mgr.record_failure( + self.bootstrap_room, + "Aborted by AbortReq.", + ) + # Explicitly set the status to failure since this request has been aborted + self.conclude_state = KVPoll.Failed + class CommonKVReceiver(BaseKVReceiver): _ctx = zmq.Context() @@ -489,20 +505,23 @@ def __init__( mgr: CommonKVManager, bootstrap_addr: str, bootstrap_room: Optional[int] = None, - prefill_dp_rank: Optional[int] = None, ): self.bootstrap_room = bootstrap_room self.bootstrap_addr = bootstrap_addr self.kv_mgr = mgr + self.conclude_state: Optional[KVPoll] = None + self.require_staging: bool = False + self.kv_mgr.addr_to_rooms_tracker[self.bootstrap_addr].add(self.bootstrap_room) self.kv_mgr.update_status(self.bootstrap_room, KVPoll.Bootstrapping) + def init(self, prefill_dp_rank: int): if self.bootstrap_addr not in self.kv_mgr.prefill_info_table: self.kv_mgr.record_failure( self.bootstrap_room, f"Prefill server with bootstrap_addr: {self.bootstrap_addr} is healthy before, but now it is down. Request (bootstrap_room: {self.bootstrap_room}) has been marked as failed.", ) + self.conclude_state = KVPoll.Failed self.kv_mgr.update_status(self.bootstrap_room, KVPoll.Failed) - self.bootstrap_infos = None return # Read pre-computed rank mapping from prefill_info (computed in try_ensure_parallel_info) @@ -520,11 +539,15 @@ def __init__( self.required_prefill_response_num ) - assert ( - prefill_dp_rank is not None - ), "prefill_dp_rank must be resolved before creating receiver" + if self.kv_mgr.enable_staging: + self.require_staging = ( + self.prefill_info.attn_tp_size != 0 + and self.prefill_info.attn_tp_size != self.kv_mgr.attn_tp_size + ) + self.prefill_dp_rank = prefill_dp_rank self._setup_bootstrap_infos() + self.kv_mgr.update_status(self.bootstrap_room, KVPoll.WaitingForInput) def _setup_bootstrap_infos(self): all_bootstrap_infos = [] @@ -562,6 +585,7 @@ def _setup_bootstrap_infos(self): self.bootstrap_room, f"Could not fetch bootstrap info for: prefill_dp_rank: {self.prefill_dp_rank} prefill_cp_rank: {target_cp_rank} target_tp_rank: {target_tp_rank} and target_pp_rank {target_pp_rank}", ) + self.conclude_state = KVPoll.Failed self.kv_mgr.update_status( self.bootstrap_room, KVPoll.Failed ) @@ -645,9 +669,25 @@ def _connect_to_bootstrap_server(cls, bootstrap_info: dict): def _register_kv_args(self): pass + def send_metadata( + self, + kv_indices: npt.NDArray[np.int32], + aux_index: Optional[int] = None, + state_indices: Optional[List[int]] = None, + ): + raise NotImplementedError + def failure_exception(self): raise Exception("Fake KVReceiver Exception") + def abort(self): + self.kv_mgr.record_failure( + self.bootstrap_room, + "Aborted by AbortReq.", + ) + # Explicitly set the status to failure since this request has been aborted + self.conclude_state = KVPoll.Failed + class CommonKVBootstrapServer(BaseKVBootstrapServer): def __init__(self, host: str, port: int): diff --git a/python/sglang/srt/disaggregation/common/staging_buffer.py b/python/sglang/srt/disaggregation/common/staging_buffer.py new file mode 100644 index 000000000000..4380c3ce737d --- /dev/null +++ b/python/sglang/srt/disaggregation/common/staging_buffer.py @@ -0,0 +1,768 @@ +""" +GPU Staging Buffer for heterogeneous TP KV cache transfer. + +When prefill attn_tp_size != decode attn_tp_size, the per-token RDMA approach +generates O(tokens * layers) small RDMA requests. This module provides a staging +buffer mechanism that gathers scattered head slices into contiguous GPU memory, +enabling bulk RDMA transfers that reduce request count to O(layers) or O(1). + +Usage: + Activated by setting SGLANG_DISAGG_STAGING_BUFFER=1. +""" + +from __future__ import annotations + +import logging +import os +import threading +from typing import List, Optional, Tuple + +import torch +import triton +import triton.language as tl + +logger = logging.getLogger(__name__) + +# TODO(yangminl): remove torch fallback implementations once the Triton kernels +# have been validated in production across all configurations. +_USE_TRITON_STAGING = not bool(os.environ.get("SGLANG_STAGING_USE_TORCH", "")) + + +@triton.jit +def _fused_gather_to_staging_kernel( + layer_ptrs, + page_indices, + staging, + num_tokens, + stride_pool_token, + head_offset, + per_layer_elems, + ELEMS_PER_TOKEN: tl.constexpr, + PAGE_SIZE: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + layer_id = tl.program_id(0) + block_id = tl.program_id(1) + + layer_ptr = tl.load(layer_ptrs + layer_id).to(staging.dtype) + + offsets = block_id * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = offsets < per_layer_elems + + t_idx = offsets // ELEMS_PER_TOKEN + e_idx = offsets % ELEMS_PER_TOKEN + + page_id = t_idx // PAGE_SIZE + intra_page = t_idx % PAGE_SIZE + page_val = tl.load(page_indices + page_id, mask=mask, other=0) + pool_token = page_val * PAGE_SIZE + intra_page + + src_offsets = ( + pool_token * stride_pool_token.to(tl.int64) + head_offset.to(tl.int64) + e_idx + ) + vals = tl.load(layer_ptr + src_offsets, mask=mask) + + dst_offsets = tl.program_id(0).to(tl.int64) * per_layer_elems.to(tl.int64) + offsets + tl.store(staging + dst_offsets, vals, mask=mask) + + +@triton.jit +def _fused_scatter_from_staging_kernel( + layer_ptrs, + page_indices, + staging, + writer_head_offsets, + num_tokens, + stride_pool_token, + per_layer_elems, + ELEMS_PER_TOKEN: tl.constexpr, + PAGE_SIZE: tl.constexpr, + NUM_LAYERS_X2: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + prog_id = tl.program_id(0) + block_id = tl.program_id(1) + + writer_id = prog_id // NUM_LAYERS_X2 + layer_kv_id = prog_id % NUM_LAYERS_X2 + + layer_ptr = tl.load(layer_ptrs + layer_kv_id).to(staging.dtype) + head_offset = tl.load(writer_head_offsets + writer_id) + + offsets = block_id * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = offsets < per_layer_elems + + t_idx = offsets // ELEMS_PER_TOKEN + e_idx = offsets % ELEMS_PER_TOKEN + + page_id = t_idx // PAGE_SIZE + intra_page = t_idx % PAGE_SIZE + page_val = tl.load(page_indices + page_id, mask=mask, other=0) + pool_token = page_val * PAGE_SIZE + intra_page + + per_rank_elems = per_layer_elems.to(tl.int64) * NUM_LAYERS_X2 + src_offsets = ( + writer_id.to(tl.int64) * per_rank_elems + + layer_kv_id.to(tl.int64) * per_layer_elems.to(tl.int64) + + offsets + ) + vals = tl.load(staging + src_offsets, mask=mask) + + dst_offsets = ( + pool_token * stride_pool_token.to(tl.int64) + head_offset.to(tl.int64) + e_idx + ) + tl.store(layer_ptr + dst_offsets, vals, mask=mask) + + +class StagingBuffer: + """Pre-allocated GPU staging buffer for bulk KV transfer. + + When a custom_mem_pool is provided (e.g., mooncake NVLink allocator), + the buffer is allocated within that pool so it's compatible with + NVLink/MNNVL transport (requires cuMemCreate-backed memory). + """ + + def __init__( + self, + size_bytes: int, + device: str, + gpu_id: int, + custom_mem_pool=None, + ): + self.size_bytes = size_bytes + self.device = device + self.gpu_id = gpu_id + + torch.cuda.set_device(gpu_id) + if custom_mem_pool is not None: + with torch.cuda.use_mem_pool(custom_mem_pool): + self.buffer = torch.empty(size_bytes, dtype=torch.uint8, device=device) + alloc_method = "custom_mem_pool (cuMemCreate)" + else: + self.buffer = torch.empty(size_bytes, dtype=torch.uint8, device=device) + alloc_method = "cudaMalloc (NVLink incompatible!)" + self.data_ptr = self.buffer.data_ptr() + + logger.info( + f"StagingBuffer allocated: {size_bytes / (1024*1024):.1f} MB " + f"on {device}, method={alloc_method}, ptr=0x{self.data_ptr:x}" + ) + + def get_ptr(self) -> int: + return self.data_ptr + + def get_size(self) -> int: + return self.size_bytes + + def fits(self, required_bytes: int) -> bool: + return required_bytes <= self.size_bytes + + +class StagingAllocator: + """Decode-side dynamic staging ring buffer allocator with overcommit. + + One large pre-allocated GPU buffer used as a ring buffer. Each request + gets a (alloc_id, offset, round) triple based on its actual byte + requirement. Allocation (assign) is overcommit — it always succeeds + as long as the request fits in the buffer. Overlap safety is enforced + on the prefill side before RDMA, using a watermark that tracks the + oldest un-freed allocation. + + The watermark (round, tail_offset) is periodically sent to prefill. + Prefill transfer workers wait before writing if their target region + overlaps with not-yet-freed data from a previous round. + """ + + # Permanent alloc failure: chunk exceeds ring buffer total size. + ALLOC_OVERSIZED = -2 + + def __init__( + self, + total_size_bytes: int, + device: str, + gpu_id: int, + custom_mem_pool=None, + ): + self.buffer = StagingBuffer(total_size_bytes, device, gpu_id, custom_mem_pool) + self.total_size = total_size_bytes + self.base_ptr = self.buffer.data_ptr + self.head = 0 + self.round = 0 + self.allocations: dict = {} # alloc_id -> (offset, size, round) + self.alloc_order: List[int] = [] + self.next_alloc_id = 0 + self.watermark_round = 0 + self.watermark_tail = 0 + self.lock = threading.Lock() + + logger.info( + f"StagingAllocator (ring+overcommit): " + f"{total_size_bytes / (1024*1024):.1f} MB " + f"on {device}, ptr=0x{self.base_ptr:x}" + ) + + def assign(self, required_bytes: int) -> Optional[Tuple[int, int, int]]: + """Allocate a region. Returns (alloc_id, offset, round) or None.""" + with self.lock: + if required_bytes > self.total_size: + return None + + space_at_end = self.total_size - self.head + if required_bytes <= space_at_end: + offset = self.head + self.head += required_bytes + else: + self.round += 1 + offset = 0 + self.head = required_bytes + + alloc_id = self.next_alloc_id + self.next_alloc_id += 1 + self.allocations[alloc_id] = (offset, required_bytes, self.round) + self.alloc_order.append(alloc_id) + return (alloc_id, offset, self.round) + + def free(self, alloc_id: int): + """Free an allocation and advance watermark past consecutive freed entries.""" + with self.lock: + if alloc_id not in self.allocations: + return + self.allocations.pop(alloc_id) + + while self.alloc_order and self.alloc_order[0] not in self.allocations: + self.alloc_order.pop(0) + + if not self.allocations: + self.watermark_round = self.round + self.watermark_tail = self.head + elif self.alloc_order: + off, _, rnd = self.allocations[self.alloc_order[0]] + self.watermark_round = rnd + self.watermark_tail = off + + def get_watermark(self) -> Tuple[int, int]: + """Return (round, tail_offset). Everything before this is safe to write.""" + with self.lock: + return (self.watermark_round, self.watermark_tail) + + def get_ptr(self, alloc_id: int) -> int: + offset, _, _ = self.allocations[alloc_id] + return self.base_ptr + offset + + def get_offset(self, alloc_id: int) -> int: + offset, _, _ = self.allocations[alloc_id] + return offset + + def get_round(self, alloc_id: int) -> int: + _, _, rnd = self.allocations[alloc_id] + return rnd + + def get_base_ptr(self) -> int: + return self.base_ptr + + def get_total_size(self) -> int: + return self.total_size + + +def gather_kv_head_slices( + kv_buffer_tensor: torch.Tensor, + gather_idx: torch.Tensor, + head_start: int, + num_heads: int, + staging_tensor: torch.Tensor, +): + """Gather KV head slices from scattered pages into contiguous staging buffer. + + Uses torch.gather(out=) to write directly into staging_tensor without + allocating temporary tensors (avoids CUDA caching allocator stalls). + + Args: + kv_buffer_tensor: [pool_size, head_num, head_dim], one layer. + gather_idx: [num_tokens, num_heads, head_dim] int64, pre-computed + token indices expanded for gather on dim=0. + head_start: Starting head index for the slice. + num_heads: Number of heads to gather. + staging_tensor: Output tensor, shape [num_tokens, num_heads, head_dim]. + """ + src = kv_buffer_tensor[:, head_start : head_start + num_heads, :] + torch.gather(src, 0, gather_idx, out=staging_tensor) + + +def scatter_kv_head_slices( + staging_tensor: torch.Tensor, + kv_buffer_tensor: torch.Tensor, + page_indices: torch.Tensor, + head_start: int, + num_heads: int, + page_size: int = 1, +): + """Scatter KV head slices from contiguous staging buffer to KV cache. + + Args: + staging_tensor: Input tensor from staging buffer (contiguous packed data). + kv_buffer_tensor: The KV buffer for one layer, shape [pool_size, head_num, head_dim]. + page_indices: [num_pages] int32/int64 tensor of page indices. + head_start: Starting head index for the slice. + num_heads: Number of heads to scatter. + page_size: Number of tokens per page. + """ + head_dim = kv_buffer_tensor.shape[-1] + if page_size == 1: + num_tokens = page_indices.shape[0] + data = staging_tensor.reshape(num_tokens, num_heads, head_dim) + kv_buffer_tensor[page_indices, head_start : head_start + num_heads, :] = data + else: + num_tokens = page_indices.shape[0] * page_size + offsets = torch.arange(page_size, device=page_indices.device) + token_indices = (page_indices.unsqueeze(1) * page_size + offsets).reshape(-1) + data = staging_tensor.reshape(num_tokens, num_heads, head_dim) + kv_buffer_tensor[token_indices, head_start : head_start + num_heads, :] = data + + +def _gather_all_layers_torch( + k_buffers: list, + v_buffers: list, + page_indices_np, + staging_buffer: StagingBuffer, + src_head_start: int, + num_heads: int, + page_size: int, + gpu_id: int, +) -> int: + """torch.gather path: zero per-layer allocation, one kernel per layer.""" + import numpy as np + + num_layers = len(k_buffers) + head_dim = k_buffers[0].shape[-1] + dtype_size = k_buffers[0].element_size() + num_tokens = len(page_indices_np) * page_size + per_layer_bytes = num_tokens * num_heads * head_dim * dtype_size + + device = f"cuda:{gpu_id}" + torch.cuda.set_device(gpu_id) + page_idx_tensor = torch.from_numpy(page_indices_np.astype(np.int64)).to(device) + + if page_size == 1: + token_indices = page_idx_tensor + else: + offsets = torch.arange(page_size, device=device) + token_indices = (page_idx_tensor.unsqueeze(1) * page_size + offsets).reshape(-1) + + gather_idx = token_indices.view(-1, 1, 1).expand(num_tokens, num_heads, head_dim) + + if not hasattr(staging_buffer, "_gather_stream"): + staging_buffer._gather_stream = torch.cuda.Stream(device=device) + + staging_buffer._gather_stream.wait_stream( + torch.cuda.default_stream(torch.device(device)) + ) + + staging_view = staging_buffer.buffer + offset = 0 + with torch.cuda.stream(staging_buffer._gather_stream): + for layer_id in range(num_layers): + dst = ( + staging_view[offset : offset + per_layer_bytes] + .view(k_buffers[layer_id].dtype) + .reshape(num_tokens, num_heads, head_dim) + ) + gather_kv_head_slices( + k_buffers[layer_id], + gather_idx, + src_head_start, + num_heads, + dst, + ) + offset += per_layer_bytes + for layer_id in range(num_layers): + dst = ( + staging_view[offset : offset + per_layer_bytes] + .view(v_buffers[layer_id].dtype) + .reshape(num_tokens, num_heads, head_dim) + ) + gather_kv_head_slices( + v_buffers[layer_id], + gather_idx, + src_head_start, + num_heads, + dst, + ) + offset += per_layer_bytes + + staging_buffer._gather_stream.synchronize() + return offset + + +def _gather_all_layers_triton( + k_buffers: list, + v_buffers: list, + page_indices_np, + staging_buffer: StagingBuffer, + src_head_start: int, + num_heads: int, + page_size: int, + gpu_id: int, +) -> int: + """Triton fused kernel path: single kernel launch for all layers.""" + import numpy as np + + num_layers = len(k_buffers) + head_dim = k_buffers[0].shape[-1] + total_heads = k_buffers[0].shape[1] + dtype_size = k_buffers[0].element_size() + num_tokens = len(page_indices_np) * page_size + elems_per_token = num_heads * head_dim + per_layer_elems = num_tokens * elems_per_token + per_layer_bytes = per_layer_elems * dtype_size + total_bytes = per_layer_bytes * num_layers * 2 + + device = f"cuda:{gpu_id}" + torch.cuda.set_device(gpu_id) + page_idx_tensor = torch.from_numpy(page_indices_np.astype(np.int64)).to(device) + + layer_ptrs = torch.tensor( + [buf.data_ptr() for buf in k_buffers] + [buf.data_ptr() for buf in v_buffers], + dtype=torch.int64, + device=device, + ) + # Use integer dtype matching element size for bit-preserving copy + int_dtype_map = {1: torch.int8, 2: torch.int16, 4: torch.int32} + int_dtype = int_dtype_map.get(dtype_size, torch.int16) + staging_typed = staging_buffer.buffer[:total_bytes].view(int_dtype) + + if not hasattr(staging_buffer, "_gather_stream"): + staging_buffer._gather_stream = torch.cuda.Stream(device=device) + + staging_buffer._gather_stream.wait_stream( + torch.cuda.default_stream(torch.device(device)) + ) + + BLOCK_SIZE = 1024 + grid = (2 * num_layers, triton.cdiv(per_layer_elems, BLOCK_SIZE)) + + with torch.cuda.stream(staging_buffer._gather_stream): + _fused_gather_to_staging_kernel[grid]( + layer_ptrs, + page_idx_tensor, + staging_typed, + num_tokens, + total_heads * head_dim, + src_head_start * head_dim, + per_layer_elems, + elems_per_token, + page_size, + BLOCK_SIZE, + ) + + staging_buffer._gather_stream.synchronize() + return total_bytes + + +def gather_all_layers_to_staging( + k_buffers: list, + v_buffers: list, + page_indices_np, + staging_buffer: StagingBuffer, + src_head_start: int, + num_heads: int, + page_size: int, + gpu_id: int, +) -> int: + """Gather all layers' K and V head slices into a staging buffer. + + Returns total bytes written. + Dispatches to Triton fused kernel when available, falls back to torch.gather. + """ + if _USE_TRITON_STAGING: + return _gather_all_layers_triton( + k_buffers, + v_buffers, + page_indices_np, + staging_buffer, + src_head_start, + num_heads, + page_size, + gpu_id, + ) + return _gather_all_layers_torch( + k_buffers, + v_buffers, + page_indices_np, + staging_buffer, + src_head_start, + num_heads, + page_size, + gpu_id, + ) + + +def _scatter_staging_to_kv_torch( + staging_buffer_view: torch.Tensor, + k_buffers: list, + v_buffers: list, + page_idx_tensor: torch.Tensor, + page_size: int, + prefill_attn_tp_size: int, + decode_attn_tp_size: int, + dst_tp_rank: int, + total_kv_heads: int, +) -> None: + """torch path for scatter.""" + num_layers = len(k_buffers) + head_dim = k_buffers[0].shape[-1] + dtype_size = k_buffers[0].element_size() + num_tokens = page_idx_tensor.shape[0] * page_size + + if prefill_attn_tp_size > decode_attn_tp_size: + num_writers = prefill_attn_tp_size // max(1, decode_attn_tp_size) + else: + num_writers = 1 + + for writer_rank in range(num_writers): + _, num_heads, dst_head_start, _ = compute_head_slice_params( + prefill_attn_tp_size, + decode_attn_tp_size, + writer_rank, + dst_tp_rank, + total_kv_heads, + ) + per_layer_bytes = num_tokens * num_heads * head_dim * dtype_size + per_rank_bytes = per_layer_bytes * num_layers * 2 + rank_base = writer_rank * per_rank_bytes + + offset = rank_base + for layer_id in range(num_layers): + layer_data = ( + staging_buffer_view[offset : offset + per_layer_bytes] + .view(k_buffers[layer_id].dtype) + .reshape(num_tokens, num_heads, head_dim) + ) + scatter_kv_head_slices( + layer_data, + k_buffers[layer_id], + page_idx_tensor, + dst_head_start, + num_heads, + page_size, + ) + offset += per_layer_bytes + for layer_id in range(num_layers): + layer_data = ( + staging_buffer_view[offset : offset + per_layer_bytes] + .view(v_buffers[layer_id].dtype) + .reshape(num_tokens, num_heads, head_dim) + ) + scatter_kv_head_slices( + layer_data, + v_buffers[layer_id], + page_idx_tensor, + dst_head_start, + num_heads, + page_size, + ) + offset += per_layer_bytes + + +def _scatter_staging_to_kv_triton( + staging_buffer_view: torch.Tensor, + k_buffers: list, + v_buffers: list, + page_idx_tensor: torch.Tensor, + page_size: int, + prefill_attn_tp_size: int, + decode_attn_tp_size: int, + dst_tp_rank: int, + total_kv_heads: int, +) -> None: + """Triton fused kernel path for scatter.""" + num_layers = len(k_buffers) + head_dim = k_buffers[0].shape[-1] + total_heads = k_buffers[0].shape[1] + dtype_size = k_buffers[0].element_size() + num_tokens = page_idx_tensor.shape[0] * page_size + device = page_idx_tensor.device + + if prefill_attn_tp_size > decode_attn_tp_size: + num_writers = prefill_attn_tp_size // max(1, decode_attn_tp_size) + else: + num_writers = 1 + + # All writers share the same num_heads; only dst_head_start differs + _, num_heads, _, _ = compute_head_slice_params( + prefill_attn_tp_size, + decode_attn_tp_size, + 0, + dst_tp_rank, + total_kv_heads, + ) + elems_per_token = num_heads * head_dim + per_layer_elems = num_tokens * elems_per_token + + layer_ptrs = torch.tensor( + [buf.data_ptr() for buf in k_buffers] + [buf.data_ptr() for buf in v_buffers], + dtype=torch.int64, + device=device, + ) + + writer_head_offsets = torch.tensor( + [ + compute_head_slice_params( + prefill_attn_tp_size, + decode_attn_tp_size, + wr, + dst_tp_rank, + total_kv_heads, + )[2] + * head_dim + for wr in range(num_writers) + ], + dtype=torch.int64, + device=device, + ) + + int_dtype_map = {1: torch.int8, 2: torch.int16, 4: torch.int32} + int_dtype = int_dtype_map.get(dtype_size, torch.int16) + total_staging_bytes = ( + num_tokens * elems_per_token * dtype_size * num_layers * 2 * num_writers + ) + staging_typed = staging_buffer_view[:total_staging_bytes].view(int_dtype) + + BLOCK_SIZE = 1024 + num_layers_x2 = 2 * num_layers + grid = (num_writers * num_layers_x2, triton.cdiv(per_layer_elems, BLOCK_SIZE)) + + _fused_scatter_from_staging_kernel[grid]( + layer_ptrs, + page_idx_tensor, + staging_typed, + writer_head_offsets, + num_tokens, + total_heads * head_dim, + per_layer_elems, + elems_per_token, + page_size, + num_layers_x2, + BLOCK_SIZE, + ) + + +def scatter_staging_to_kv( + staging_buffer_view: torch.Tensor, + k_buffers: list, + v_buffers: list, + page_idx_tensor: torch.Tensor, + page_size: int, + prefill_attn_tp_size: int, + decode_attn_tp_size: int, + dst_tp_rank: int, + total_kv_heads: int, +) -> None: + """Scatter data from a contiguous staging region into KV cache buffers.""" + if _USE_TRITON_STAGING: + return _scatter_staging_to_kv_triton( + staging_buffer_view, + k_buffers, + v_buffers, + page_idx_tensor, + page_size, + prefill_attn_tp_size, + decode_attn_tp_size, + dst_tp_rank, + total_kv_heads, + ) + return _scatter_staging_to_kv_torch( + staging_buffer_view, + k_buffers, + v_buffers, + page_idx_tensor, + page_size, + prefill_attn_tp_size, + decode_attn_tp_size, + dst_tp_rank, + total_kv_heads, + ) + + +def compute_head_slice_params( + src_attn_tp_size: int, + dst_attn_tp_size: int, + src_tp_rank: int, + dst_tp_rank: int, + total_kv_heads: int, +) -> Tuple[int, int, int, int]: + """Compute head slicing parameters for heterogeneous TP transfer. + + Returns: + (src_head_start, num_heads_to_send, dst_head_start, num_heads_to_send) + """ + src_heads_per_rank = max(1, total_kv_heads // src_attn_tp_size) + dst_heads_per_rank = max(1, total_kv_heads // dst_attn_tp_size) + + local_tp_rank = src_tp_rank % src_attn_tp_size + dst_tp_rank_in_group = dst_tp_rank % dst_attn_tp_size + + if src_attn_tp_size > dst_attn_tp_size: + src_head_start = 0 + num_heads_to_send = src_heads_per_rank + src_replication = max(1, src_attn_tp_size // total_kv_heads) + unique_head_idx = local_tp_rank // src_replication + dst_head_start = (unique_head_idx * src_heads_per_rank) % dst_heads_per_rank + else: + src_head_start = ( + dst_tp_rank_in_group * dst_heads_per_rank + ) % src_heads_per_rank + num_heads_to_send = dst_heads_per_rank + dst_head_start = 0 + + return src_head_start, num_heads_to_send, dst_head_start, num_heads_to_send + + +def compute_staging_layout( + src_attn_tp_size: int, + dst_attn_tp_size: int, + dst_tp_rank: int, + total_kv_heads: int, + num_tokens: int, + bytes_per_head_token: int, + num_layers: int, +) -> Tuple[int, List[int], int]: + """Compute per-writer byte layout for a staging region. + + Returns: + (num_writers, writer_bytes_list, total_bytes) + where writer_bytes_list[i] = bytes for writer i covering all layers (K+V). + """ + if src_attn_tp_size > dst_attn_tp_size: + num_writers = src_attn_tp_size // max(1, dst_attn_tp_size) + else: + num_writers = 1 + + writer_bytes = [] + for wr in range(num_writers): + _, nh, _, _ = compute_head_slice_params( + src_attn_tp_size, + dst_attn_tp_size, + wr, + dst_tp_rank, + total_kv_heads, + ) + writer_bytes.append(num_tokens * nh * bytes_per_head_token * num_layers * 2) + return num_writers, writer_bytes, sum(writer_bytes) + + +def resolve_total_kv_heads( + kv_args, + attn_tp_size: int, +) -> int: + """Resolve the global total KV head count from kv_args metadata.""" + total = getattr(kv_args, "total_kv_head_num", 0) + if total > 0: + return total + per_rank = getattr(kv_args, "kv_head_num", 0) + if per_rank > 0: + return per_rank * attn_tp_size + raise ValueError( + "Cannot resolve total_kv_heads: kv_args has neither total_kv_head_num " + "nor kv_head_num. " + "Ensure DecodePreallocQueue._init_kv_manager sets kv_args.kv_head_num." + ) diff --git a/python/sglang/srt/disaggregation/common/staging_handler.py b/python/sglang/srt/disaggregation/common/staging_handler.py new file mode 100644 index 000000000000..b353f4609808 --- /dev/null +++ b/python/sglang/srt/disaggregation/common/staging_handler.py @@ -0,0 +1,732 @@ +""" +Staging handler for heterogeneous TP KV cache transfer. + +Isolates staging scatter lifecycle from decode.py and conn.py. +Generic (backend-agnostic) code is at the top; mooncake-specific +protocol code is at the bottom. +""" + +from __future__ import annotations + +import dataclasses +import logging +import struct +import threading +from typing import TYPE_CHECKING, List, Optional, Tuple + +import torch + +logger = logging.getLogger(__name__) + +if TYPE_CHECKING: + from sglang.srt.disaggregation.decode import DecodeRequest + + +# ====================================================================== +# Generic staging state and handler (backend-agnostic) +# ====================================================================== + + +@dataclasses.dataclass +class DecodeStagingContext: + """Staging-specific context for decode mode.""" + + allocator: object = None + room_bootstrap: dict = dataclasses.field(default_factory=dict) + room_receivers: dict = dataclasses.field(default_factory=dict) + + +@dataclasses.dataclass +class PrefillStagingContext: + """Staging-specific context for prefill mode.""" + + buffers: list = dataclasses.field(default_factory=list) + remote_watermarks: dict = dataclasses.field(default_factory=dict) + watermark_cv: threading.Condition = dataclasses.field( + default_factory=threading.Condition + ) + prefetch_requested: set = dataclasses.field(default_factory=set) + prefetch_sockets: dict = dataclasses.field(default_factory=dict) + + +class DecodeStagingHandler: + """Decode-side staging scatter lifecycle manager. + + Scatter submission can be called from the decode_thread (background) as + soon as all writers/ranks have arrived, while event checking and freeing + always run on the scheduler main thread. + """ + + def __init__( + self, + kv_manager, + staging_allocator, + kv_buffer_info: dict, + decode_tp: int, + total_kv_heads: int, + tp_rank: int, + scheduler, + ): + self.kv_manager = kv_manager + self.staging_allocator = staging_allocator + self.kv_buffer_info = kv_buffer_info + self.decode_tp = decode_tp + self.total_kv_heads = total_kv_heads + self.tp_rank = tp_rank + self.scheduler = scheduler + self._room_to_decode_req: dict = {} + self._wm_subscribers: dict = {} + + def register_wm_subscriber(self, receiver, session_id: str) -> None: + """Register a prefill's bootstrap connection for watermark broadcasts.""" + if receiver is None or not getattr(receiver, "bootstrap_infos", None): + return + key = tuple(str(bi) for bi in receiver.bootstrap_infos) + if key not in self._wm_subscribers: + self._wm_subscribers[key] = (receiver, session_id) + + def num_writers_for(self, decode_req) -> int: + """Compute num_writers for a specific request based on its prefill TP.""" + prefill_tp = decode_req.kv_receiver.prefill_info.attn_tp_size + if prefill_tp > self.decode_tp: + return prefill_tp // max(1, self.decode_tp) + return 1 + + @classmethod + def create(cls, kv_manager, scheduler, tp_rank: int) -> "DecodeStagingHandler": + """Factory: create handler. Raises if staging infra is missing.""" + staging_allocator = kv_manager._staging_ctx.allocator + if staging_allocator is None: + raise RuntimeError( + "Staging is enabled but kv_manager._staging_ctx.allocator is None. " + "Check that the transfer backend correctly initializes the staging allocator." + ) + kv_buffer_info = kv_manager.kv_buffer_tensors + if kv_buffer_info is None: + raise RuntimeError( + "Staging is enabled but kv_manager.kv_buffer_tensors is None. " + "Check that set_kv_buffer_tensors() was called during kv_manager init." + ) + decode_tp = kv_manager.attn_tp_size + + from sglang.srt.disaggregation.common.staging_buffer import ( + resolve_total_kv_heads, + ) + + total_kv_heads = resolve_total_kv_heads(kv_manager.kv_args, decode_tp) + return cls( + kv_manager=kv_manager, + staging_allocator=staging_allocator, + kv_buffer_info=kv_buffer_info, + decode_tp=decode_tp, + total_kv_heads=total_kv_heads, + tp_rank=tp_rank, + scheduler=scheduler, + ) + + # ------------------------------------------------------------------ + # Registration: called from main thread (DecodeTransferQueue) + # ------------------------------------------------------------------ + + def register_decode_req(self, room: int, decode_req: "DecodeRequest") -> None: + self._room_to_decode_req[room] = decode_req + + def unregister_decode_req(self, room: int) -> None: + self._room_to_decode_req.pop(room, None) + + # ------------------------------------------------------------------ + # Scatter submission: called from decode_thread (background) + # ------------------------------------------------------------------ + + def submit_chunk_scatter( + self, room: int, chunk_idx: int, page_start: int, num_pages: int + ) -> bool: + """Submit scatter for an intermediate chunk whose writers all arrived. + + Called from decode_thread. Records a CUDA event on decode_req so + the main thread can later check completion and free the allocation. + """ + decode_req = self._room_to_decode_req.get(room) + if decode_req is None: + logger.warning( + "[STAGING] submit_chunk_scatter: room=%s not registered, " + "chunk_idx=%s. This should not happen if register_decode_req " + "is called at kv_receiver.init() time.", + room, + chunk_idx, + ) + return False + chunk_infos = getattr(decode_req.kv_receiver, "chunk_staging_infos", []) + if chunk_idx >= len(chunk_infos): + return False + alloc_id, staging_offset, _, _, _ = chunk_infos[chunk_idx] + if staging_offset < 0 or alloc_id < 0: + return False + + ok = self._scatter_region(staging_offset, page_start, num_pages, decode_req) + if ok: + event = torch.cuda.Event() + event.record(self.staging_allocator._scatter_stream) + if not hasattr(decode_req, "_chunk_events"): + decode_req._chunk_events = [] + decode_req._chunk_events.append((event, alloc_id)) + chunk_infos[chunk_idx] = (-1, -1, 0, -1, 0) + else: + logger.warning( + "submit_chunk_scatter failed room=%s chunk_idx=%s tp_rank=%s", + room, + chunk_idx, + self.tp_rank, + ) + return ok + + def is_staging_room(self, room: int) -> bool: + """Check if a room is registered for staging scatter.""" + return room in self._room_to_decode_req + + def submit_last_scatter_async(self, room: int) -> bool: + """Submit scatter for the last chunk when all ranks report Success. + + Called from decode_thread. Sets ``_scatter_event`` **before** + ``_staging_last_scatter_submitted`` so the main thread sees the + event when it checks the flag (CPython GIL guarantees ordering). + """ + decode_req = self._room_to_decode_req.get(room) + if decode_req is None: + logger.warning( + "[STAGING] submit_last_scatter_async: room=%s not registered. " + "This should not happen if register_decode_req is called at " + "kv_receiver.init() time.", + room, + ) + return False + alloc_id = self._submit_last_scatter(decode_req) + if alloc_id >= 0: + event = torch.cuda.Event() + event.record(self.staging_allocator._scatter_stream) + decode_req._scatter_event = event + decode_req._scatter_alloc_id = alloc_id + decode_req._staging_last_scatter_submitted = True + else: + decode_req._staging_scatter_done = True + return True + + # ------------------------------------------------------------------ + # Event check + free: called from main thread (pop_transferred) + # ------------------------------------------------------------------ + + def is_done(self, decode_req: "DecodeRequest") -> bool: + """Return True if staging scatter is complete for this request.""" + if not getattr(decode_req, "_staging_scatter_done", False): + return False + return not getattr(decode_req, "_chunk_events", None) + + def advance_scatter(self, decode_req: "DecodeRequest") -> None: + """Check CUDA events and free completed staging allocations. + + Scatter kernels have already been submitted by the decode_thread + (via submit_chunk_scatter / submit_last_scatter_async). This + method only polls the recorded events and releases staging memory. + """ + room = decode_req.req.bootstrap_room + chunk_events = getattr(decode_req, "_chunk_events", None) + if chunk_events: + for i in range(len(chunk_events) - 1, -1, -1): + event, alloc_id = chunk_events[i] + if event.query(): + chunk_events.pop(i) + self._free_and_send_watermark(alloc_id, decode_req) + + if not getattr(decode_req, "_staging_last_scatter_submitted", False): + return + + event = getattr(decode_req, "_scatter_event", None) + if event is not None and event.query(): + self._free_and_send_watermark(decode_req._scatter_alloc_id, decode_req) + decode_req._scatter_event = None + decode_req._scatter_alloc_id = -1 + decode_req._staging_scatter_done = True + + # ------------------------------------------------------------------ + # Internal methods + # ------------------------------------------------------------------ + + def _scatter_region( + self, + staging_offset: int, + page_start: int, + num_pages: int, + decode_req: "DecodeRequest", + ) -> bool: + """Submit scatter kernels for a staging region to scatter_stream. + + May be called from the decode_thread (background). All GPU work + runs on scatter_stream so that the decode_thread never blocks on + the default stream (which carries the main-thread forward pass). + """ + from sglang.srt.disaggregation.common.staging_buffer import ( + scatter_staging_to_kv, + ) + + k_buffers = self.kv_buffer_info["k_buffers"] + v_buffers = self.kv_buffer_info["v_buffers"] + page_size = self.kv_buffer_info["page_size"] + dst_tp_rank = self.kv_manager.kv_args.engine_rank % self.decode_tp + + device = k_buffers[0].device + torch.cuda.set_device(device) + + if not hasattr(self.staging_allocator, "_scatter_stream"): + self.staging_allocator._scatter_stream = torch.cuda.Stream(device=device) + + scatter_stream = self.staging_allocator._scatter_stream + + staging_view = self.staging_allocator.buffer.buffer[staging_offset:] + + req_pool_idx = decode_req.req.req_pool_idx + token_start = page_start * page_size + token_end = token_start + num_pages * page_size + prefill_tp = decode_req.kv_receiver.prefill_info.attn_tp_size + + with torch.cuda.stream(scatter_stream): + kv_indices = self.scheduler.req_to_token_pool.req_to_token[ + req_pool_idx, token_start:token_end + ] + if page_size > 1: + page_idx_tensor = kv_indices[::page_size] // page_size + else: + page_idx_tensor = kv_indices + + scatter_staging_to_kv( + staging_view, + k_buffers, + v_buffers, + page_idx_tensor, + page_size, + prefill_tp, + self.decode_tp, + dst_tp_rank, + self.total_kv_heads, + ) + + return True + + def _submit_last_scatter(self, decode_req: "DecodeRequest") -> int: + """Submit scatter for the last chunk. Returns alloc_id >= 0, or -1.""" + receiver = decode_req.kv_receiver + chunk_infos = getattr(receiver, "chunk_staging_infos", []) + if not chunk_infos: + return -1 + + last_info = chunk_infos[-1] + alloc_id, staging_offset, _, _, last_num_pages = last_info + if staging_offset < 0 or alloc_id < 0: + return -1 + + seq_len = len(decode_req.req.origin_input_ids) + ps = self.scheduler.token_to_kv_pool_allocator.page_size + total_pages = (seq_len + ps - 1) // ps + page_start = total_pages - last_num_pages + + ok = self._scatter_region( + staging_offset, page_start, last_num_pages, decode_req + ) + return alloc_id if ok else -1 + + def _free_and_send_watermark( + self, alloc_id: int, decode_req: "DecodeRequest" + ) -> None: + """Free a staging allocation and broadcast watermark to all prefills.""" + self.staging_allocator.free(alloc_id) + post_wm = self.staging_allocator.get_watermark() + room = decode_req.req.bootstrap_room + wm_round, wm_tail = post_wm + wm_round_b = str(wm_round).encode("ascii") + wm_tail_b = str(wm_tail).encode("ascii") + for _key, (receiver, session_id) in list(self._wm_subscribers.items()): + sid_b = session_id.encode("ascii") + for bootstrap_info in receiver.bootstrap_infos: + try: + sock, lock = receiver._connect_to_bootstrap_server(bootstrap_info) + with lock: + sock.send_multipart( + [b"WATERMARK", wm_round_b, wm_tail_b, sid_b] + ) + except Exception: + pass + + +def is_watermark_ready( + staging_state, session_id: str, alloc_round: int, alloc_end: int +) -> bool: + """Non-blocking check: is the staging region safe to write?""" + if alloc_round <= 0: + return True + prev_round = alloc_round - 1 + wm_round, wm_tail = staging_state.remote_watermarks.get(session_id, (0, 0)) + return prev_round < wm_round or (prev_round == wm_round and alloc_end <= wm_tail) + + +# ====================================================================== +# Mooncake-specific staging protocol and utilities +# ====================================================================== + + +@dataclasses.dataclass +class StagingTransferInfo: + """Per-chunk staging allocation info attached to a TransferInfo.""" + + offsets: List[int] = dataclasses.field(default_factory=lambda: [-1]) + rounds: List[int] = dataclasses.field(default_factory=lambda: [0]) + ends: List[int] = dataclasses.field(default_factory=lambda: [-1]) + + def set_chunk(self, idx: int, offset: int, rnd: int, end: int): + while len(self.offsets) <= idx: + self.offsets.append(-1) + self.rounds.append(0) + self.ends.append(-1) + self.offsets[idx] = offset + self.rounds[idx] = rnd + self.ends[idx] = end + + +@dataclasses.dataclass +class StagingRegisterInfo: + """Staging buffer registration info attached to a KVArgsRegisterInfo.""" + + base_ptr: int = 0 + total_size: int = 0 + + @classmethod + def from_zmq_fields( + cls, msg: list, msg_start_offset: int + ) -> Optional["StagingRegisterInfo"]: + i = msg_start_offset + base_ptr = ( + struct.unpack("Q", msg[i])[0] if len(msg) > i and len(msg[i]) == 8 else 0 + ) + total_size = ( + int(msg[i + 1].decode("ascii")) + if len(msg) > i + 1 and len(msg[i + 1]) > 0 + else 0 + ) + if base_ptr == 0 and total_size == 0: + return None + return cls(base_ptr=base_ptr, total_size=total_size) + + +class PrefillStagingStrategy: + """Prefill-side staging transfer: readiness check + gather-RDMA execution. + + Encapsulates the decision logic (chunk index calculation, staging offset + lookup, watermark readiness) and delegates actual RDMA to the kv_manager. + """ + + def __init__(self, kv_manager, staging_buffer): + self.kv_manager = kv_manager + self.staging_buffer = staging_buffer + page_size = kv_manager.kv_buffer_tensors["page_size"] + cps = kv_manager.server_args.chunked_prefill_size or 8192 + self.full_chunk_pages = max(1, cps // page_size) + + def check_ready( + self, + req, + kv_chunk_index_start: int, + num_chunk_pages: int, + ) -> Tuple[bool, int, int, int, int]: + """Check if staging offset and watermark are ready for this chunk. + + Returns (ready, chunk_idx, offset, round, end). + offset == ALLOC_OVERSIZED means permanent failure (fall back to slice). + offset == -1 means allocation pending (re-enqueue). + """ + from sglang.srt.disaggregation.common.staging_buffer import StagingAllocator + + chunk_idx = ( + kv_chunk_index_start // self.full_chunk_pages + if self.full_chunk_pages > 0 + else 0 + ) + + stg = req.staging + if stg is None or chunk_idx >= len(stg.offsets): + return (False, chunk_idx, -1, 0, -1) + + c_offset = stg.offsets[chunk_idx] + if c_offset == StagingAllocator.ALLOC_OVERSIZED: + return (False, chunk_idx, StagingAllocator.ALLOC_OVERSIZED, 0, -1) + if c_offset < 0: + return (False, chunk_idx, -1, 0, -1) + + c_round = stg.rounds[chunk_idx] + c_end = stg.ends[chunk_idx] + + if not self.kv_manager._is_watermark_ready( + req.mooncake_session_id, c_round, c_end + ): + return (False, chunk_idx, c_offset, c_round, c_end) + + return (True, chunk_idx, c_offset, c_round, c_end) + + def transfer( + self, + session_id: str, + prefill_kv_indices, + dst_staging_ptr: int, + dst_staging_size: int, + target_info, + ) -> int: + """Execute staged transfer (gather + RDMA). + + Returns 0 on success, -1 to signal fallback to slice path. + """ + try: + return self.kv_manager.send_kvcache_staged( + session_id, + prefill_kv_indices, + dst_staging_ptr, + dst_staging_size, + target_info.dst_tp_rank, + target_info.dst_attn_tp_size, + target_info.dst_kv_item_len, + staging_buffer=self.staging_buffer, + ) + except Exception as e: + raise RuntimeError( + f"[Staging] KV transfer via staging buffer failed: {e}. " + f"session={session_id}" + ) from e + + +def init_staging_buffers(engine, kv_args, count: int) -> list: + """Create prefill-side staging buffers and register them with the engine. + + Returns list of StagingBuffer instances. + """ + from sglang.srt.disaggregation.common.staging_buffer import StagingBuffer + from sglang.srt.disaggregation.mooncake.utils import ( + init_mooncake_custom_mem_pool, + ) + from sglang.srt.environ import envs + + size_mb = envs.SGLANG_DISAGG_STAGING_BUFFER_SIZE_MB.get() + size_bytes = size_mb * 1024 * 1024 + gpu_id = kv_args.gpu_id + device = f"cuda:{gpu_id}" + + _, custom_mem_pool, pool_type = init_mooncake_custom_mem_pool(device) + if custom_mem_pool is None: + logger.warning( + "No mooncake custom mem pool available for staging buffer. " + "NVLink transport will NOT work. Set SGLANG_MOONCAKE_CUSTOM_MEM_POOL." + ) + + buffers = [] + for _ in range(count): + buf = StagingBuffer(size_bytes, device, gpu_id, custom_mem_pool=custom_mem_pool) + engine.batch_register([buf.get_ptr()], [buf.get_size()]) + buffers.append(buf) + return buffers + + +def init_staging_allocator(engine, kv_args): + """Create decode-side staging ring-buffer allocator and register with engine. + + Returns a StagingAllocator instance. + """ + from sglang.srt.disaggregation.common.staging_buffer import StagingAllocator + from sglang.srt.disaggregation.mooncake.utils import ( + init_mooncake_custom_mem_pool, + ) + from sglang.srt.environ import envs + + pool_size_mb = envs.SGLANG_DISAGG_STAGING_POOL_SIZE_MB.get() + pool_size_bytes = pool_size_mb * 1024 * 1024 + gpu_id = kv_args.gpu_id + device = f"cuda:{gpu_id}" + + _, custom_mem_pool, _ = init_mooncake_custom_mem_pool(device) + allocator = StagingAllocator(pool_size_bytes, device, gpu_id, custom_mem_pool) + engine.batch_register([allocator.get_base_ptr()], [allocator.get_total_size()]) + return allocator + + +def handle_staging_req( + msg, + staging_allocator, + kv_args, + attn_tp_size: int, + prefill_attn_tp_size: int, + kv_buffer_tensors, + room_receivers: dict, + room_bootstrap: dict, +): + """Allocate staging for a chunk on-demand and send STAGING_RSP to prefill. + + Deduplicates: multiple prefill TP ranks requesting the same (room, chunk_idx) + only allocate once. Sends ALLOC_OVERSIZED on permanent failure. + """ + from sglang.srt.disaggregation.common.staging_buffer import StagingAllocator + + room = int(msg[1].decode("ascii")) + chunk_idx = int(msg[2].decode("ascii")) + chunk_num_pages = int(msg[3].decode("ascii")) + session_id = msg[4].decode("ascii") + + if staging_allocator is None: + logger.warning( + "STAGING_REQ ignored: allocator is None room=%s chunk=%s", + room, + chunk_idx, + ) + return + + receiver = room_receivers.get(room) + if receiver is None: + logger.warning( + "STAGING_REQ dropped: no receiver for room=%s chunk=%s session=%s", + room, + chunk_idx, + session_id, + ) + return + infos = getattr(receiver, "chunk_staging_infos", []) + + if chunk_idx < len(infos) and infos[chunk_idx][0] >= 0: + _, offset, rnd, end, _ = infos[chunk_idx] + elif ( + chunk_idx < len(infos) + and infos[chunk_idx][1] == StagingAllocator.ALLOC_OVERSIZED + ): + offset, rnd, end = StagingAllocator.ALLOC_OVERSIZED, 0, -1 + else: + from sglang.srt.disaggregation.common.staging_buffer import ( + compute_staging_layout, + resolve_total_kv_heads, + ) + + page_size = kv_args.page_size + kv_item_lens = kv_args.kv_item_lens + num_kv_layers = len(kv_item_lens) // 2 + decode_bytes_per_token = kv_item_lens[0] // page_size + total_kv_heads = resolve_total_kv_heads(kv_args, attn_tp_size) + dst_heads_per_rank = max(1, total_kv_heads // max(1, attn_tp_size)) + bytes_per_head_per_token = decode_bytes_per_token // dst_heads_per_rank + dst_tp_rank = kv_args.engine_rank % max(1, attn_tp_size) + + chunk_tokens = chunk_num_pages * page_size + _, _, required = compute_staging_layout( + prefill_attn_tp_size, + attn_tp_size, + dst_tp_rank, + total_kv_heads, + chunk_tokens, + bytes_per_head_per_token, + num_kv_layers, + ) + result = staging_allocator.assign(required) + if result is None: + logger.error( + "[STAGING_REQ] alloc failed room=%s chunk=%d (need %d bytes, " + "buffer total=%d bytes). Increase SGLANG_DISAGG_STAGING_POOL_SIZE_MB.", + room, + chunk_idx, + required, + staging_allocator.total_size, + ) + offset, rnd, end = StagingAllocator.ALLOC_OVERSIZED, 0, -1 + while len(infos) <= chunk_idx: + infos.append((-1, -1, 0, -1, 0)) + infos[chunk_idx] = ( + -1, + StagingAllocator.ALLOC_OVERSIZED, + 0, + -1, + chunk_num_pages, + ) + else: + alloc_id, offset, rnd = result + end = offset + required + while len(infos) <= chunk_idx: + infos.append((-1, -1, 0, -1, 0)) + infos[chunk_idx] = (alloc_id, offset, rnd, end, chunk_num_pages) + + bootstrap_infos = room_bootstrap.get(room) + if bootstrap_infos: + for bi in bootstrap_infos: + try: + sock, lock = receiver._connect_to_bootstrap_server(bi) + with lock: + sock.send_multipart( + [ + b"STAGING_RSP", + str(room).encode("ascii"), + str(chunk_idx).encode("ascii"), + str(offset).encode("ascii"), + str(rnd).encode("ascii"), + str(end).encode("ascii"), + session_id.encode("ascii"), + ] + ) + except Exception: + pass + + +def prefetch_staging_reqs( + room: int, + transfer_infos: dict, + kv_buffer_tensors: dict, + chunked_prefill_size: int, + staging_requested: set, + prefetch_sockets: dict, +) -> None: + """Send STAGING_REQ for all chunks before the prefill forward starts. + + Called from the scheduler right after batch formation, so that decode + allocates staging during the GPU forward pass. + """ + import zmq + + from sglang.srt.utils.network import NetworkAddress + + page_size = kv_buffer_tensors["page_size"] + cps = chunked_prefill_size or 8192 + full_chunk_pages = max(1, cps // page_size) + + for session_id, tinfo in transfer_infos[room].items(): + if tinfo.is_dummy: + continue + total_pages = len(tinfo.dst_kv_indices) + if total_pages == 0: + continue + num_chunks = (total_pages + full_chunk_pages - 1) // full_chunk_pages + + for chunk_idx in range(num_chunks): + stg_key = (room, chunk_idx, session_id) + if stg_key in staging_requested: + continue + staging_requested.add(stg_key) + + remaining = total_pages - chunk_idx * full_chunk_pages + chunk_pages = min(full_chunk_pages, remaining) + try: + na = NetworkAddress(tinfo.endpoint, tinfo.dst_port) + ep = na.to_tcp() + if ep not in prefetch_sockets: + sock = zmq.Context().socket(zmq.PUSH) + if na.is_ipv6: + sock.setsockopt(zmq.IPV6, 1) + sock.connect(ep) + prefetch_sockets[ep] = sock + prefetch_sockets[ep].send_multipart( + [ + b"STAGING_REQ", + str(room).encode("ascii"), + str(chunk_idx).encode("ascii"), + str(chunk_pages).encode("ascii"), + session_id.encode("ascii"), + ] + ) + except Exception: + staging_requested.discard(stg_key) diff --git a/python/sglang/srt/disaggregation/decode.py b/python/sglang/srt/disaggregation/decode.py index ce7ae1557447..f54c882cc2ec 100644 --- a/python/sglang/srt/disaggregation/decode.py +++ b/python/sglang/srt/disaggregation/decode.py @@ -27,6 +27,7 @@ from http import HTTPStatus from typing import TYPE_CHECKING, Dict, List, Optional, Tuple +import numpy as np import torch from torch.distributed import ProcessGroup @@ -45,6 +46,7 @@ is_mla_backend, kv_to_page_indices, poll_and_all_reduce, + poll_and_all_reduce_with_staging, prepare_abort, ) from sglang.srt.environ import envs @@ -66,6 +68,7 @@ set_schedule_time_batch, set_time_batch, ) +from sglang.srt.utils.network import NetworkAddress from sglang.srt.utils.torch_memory_saver_adapter import TorchMemorySaverAdapter logger = logging.getLogger(__name__) @@ -87,7 +90,7 @@ def _is_fake_transfer(req: Req, server_args: ServerArgs) -> bool: def _bootstrap_addr(req: Req) -> str: # FIXME: make a property of a req - return f"{req.bootstrap_host}:{req.bootstrap_port}" + return NetworkAddress(req.bootstrap_host, req.bootstrap_port).to_host_port_str() class DecodeReqToTokenPool: @@ -174,11 +177,13 @@ def __init__( device: str, enable_memory_saver: bool, cache_params: "Mamba2CacheParams", + mamba_layer_ids: List[int], speculative_num_draft_tokens: int, enable_mamba_extra_buffer: bool, pre_alloc_size: int, enable_overlap_schedule: bool, mamba_size: int = None, + start_layer: int = None, ): DecodeReqToTokenPool.__init__( self, @@ -192,16 +197,25 @@ def __init__( self.mamba_ping_pong_track_buffer_size = 2 if enable_overlap_schedule else 1 self.enable_mamba_extra_buffer = enable_mamba_extra_buffer self.enable_memory_saver = enable_memory_saver - effective_mamba_size = ( - mamba_size if mamba_size is not None else size - ) + pre_alloc_size - # TODO: Support PP - self.start_layer = 0 + if mamba_size is not None: + effective_mamba_size = min(mamba_size, size + pre_alloc_size) + if mamba_size > size + pre_alloc_size: + logger.warning( + "mamba_size (%d) exceeds size + pre_alloc_size (%d), " + "capping effective_mamba_size to %d", + mamba_size, + size + pre_alloc_size, + effective_mamba_size, + ) + else: + effective_mamba_size = size + pre_alloc_size + self.start_layer = start_layer if start_layer is not None else 0 self.layer_transfer_counter = None self._init_mamba_pool( size=effective_mamba_size, mamba_spec_state_size=size + pre_alloc_size, cache_params=cache_params, + mamba_layer_ids=mamba_layer_ids, device=device, enable_mamba_extra_buffer=self.enable_mamba_extra_buffer, speculative_num_draft_tokens=speculative_num_draft_tokens, @@ -273,12 +287,15 @@ def __init__( # Queue for requests pending pre-allocation self.queue: List[DecodeRequest] = [] self.retracted_queue: List[Req] = [] - self.pending_reqs: List[Req] = [] + self.pending_reqs: List[DecodeRequest] = [] self._ensure_retry_count: Dict[str, int] = {} - self._max_ensure_retries: int = 20 # scheduling cycles + self._max_ensure_retries: int = 15 # scheduling cycles self._ensure_last_attempt_time: Dict[str, float] = {} self._ensure_retry_interval: float = 1.0 # seconds + self.enable_staging = envs.SGLANG_DISAGG_STAGING_BUFFER.get() self.kv_manager = self._init_kv_manager() + if self.enable_staging: + self.transfer_queue._init_staging_handler(self.kv_manager) if self.scheduler.tp_worker.is_hybrid_swa: # FIXME: current SWA allocation allocate full kv cache size in prefill @@ -296,9 +313,16 @@ def _init_kv_manager(self) -> CommonKVManager: kv_args.pp_rank = self.pp_rank kv_args.system_dp_rank = self.scheduler.dp_rank - kv_data_ptrs, kv_data_lens, kv_item_lens = ( - self.token_to_kv_pool.get_contiguous_buf_infos() - ) + if self.scheduler.enable_hisparse: + # Direct-to-host: register host pool pointers so P writes to D's host memory + host_pool = self.scheduler.hisparse_coordinator.mem_pool_host + kv_data_ptrs, kv_data_lens, kv_item_lens = ( + host_pool.get_contiguous_buf_infos() + ) + else: + kv_data_ptrs, kv_data_lens, kv_item_lens = ( + self.token_to_kv_pool.get_contiguous_buf_infos() + ) if self.draft_token_to_kv_pool is not None: # We should also transfer draft model kv cache. The indices are # always shared with a target model. @@ -312,7 +336,10 @@ def _init_kv_manager(self) -> CommonKVManager: kv_args.kv_data_ptrs = kv_data_ptrs kv_args.kv_data_lens = kv_data_lens kv_args.kv_item_lens = kv_item_lens - kv_args.page_size = self.token_to_kv_pool.page_size + # HiSparse Host pool has page_size=1; use it when hisparse is enabled + kv_args.page_size = ( + 1 if self.scheduler.enable_hisparse else self.token_to_kv_pool.page_size + ) kv_args.aux_data_ptrs, kv_args.aux_data_lens, kv_args.aux_item_lens = ( self.metadata_buffers.get_buf_infos() @@ -354,6 +381,21 @@ def _init_kv_manager(self) -> CommonKVManager: self.scheduler.server_args, self.is_mla_backend, ) + # Staging buffer setup (only when heterogeneous TP staging is enabled) + if self.enable_staging and not self.is_mla_backend: + kv_pool_for_heads = self.token_to_kv_pool + if hasattr(kv_pool_for_heads, "full_kv_pool"): + kv_pool_for_heads = kv_pool_for_heads.full_kv_pool + per_rank_kv_heads = getattr(kv_pool_for_heads, "head_num", 0) + if per_rank_kv_heads > 0: + kv_args.kv_head_num = per_rank_kv_heads + kv_args.total_kv_head_num = per_rank_kv_heads * attn_tp_size + if hasattr(kv_manager, "set_kv_buffer_tensors"): + kv_pool = kv_pool_for_heads + if hasattr(kv_pool, "k_buffer") and hasattr(kv_pool, "v_buffer"): + kv_manager.set_kv_buffer_tensors( + kv_pool.k_buffer, kv_pool.v_buffer, kv_pool.page_size + ) return kv_manager def add(self, req: Req, is_retracted: bool = False) -> None: @@ -365,17 +407,20 @@ def add(self, req: Req, is_retracted: bool = False) -> None: req.retraction_mb_id = None self.retracted_queue.append(req) else: + decode_req = self._create_receiver_and_enqueue(req) + # NOTE: fake transfer does not need to resolve prefill dp rank in the pending queue if _is_fake_transfer(req, self.scheduler.server_args): - self._create_receiver_and_enqueue(req, 0) + decode_req.kv_receiver.init(0) return # Fast path: cache-only lookup, no network calls prefill_dp_rank = self._resolve_prefill_dp_rank(req) if prefill_dp_rank is not None: - self._create_receiver_and_enqueue(req, prefill_dp_rank) - else: - self.pending_reqs.append(req) + decode_req.kv_receiver.init(prefill_dp_rank) + return + + self.pending_reqs.append(decode_req) def _resolve_prefill_dp_rank(self, req: Req) -> Optional[int]: if req.disagg_prefill_dp_rank is not None: @@ -393,7 +438,7 @@ def _resolve_prefill_dp_rank(self, req: Req) -> Optional[int]: return None - def _create_receiver_and_enqueue(self, req: Req, prefill_dp_rank: int) -> None: + def _create_receiver_and_enqueue(self, req: Req) -> DecodeRequest: backend = ( TransferBackend.FAKE if _is_fake_transfer(req, self.scheduler.server_args) @@ -405,12 +450,11 @@ def _create_receiver_and_enqueue(self, req: Req, prefill_dp_rank: int) -> None: mgr=self.kv_manager, bootstrap_addr=_bootstrap_addr(req), bootstrap_room=req.bootstrap_room, - prefill_dp_rank=prefill_dp_rank, ) - self.queue.append( - DecodeRequest(req=req, kv_receiver=kv_receiver, waiting_for_input=False) - ) + decode_req = DecodeRequest(req=req, kv_receiver=kv_receiver) + self.queue.append(decode_req) + return decode_req def _check_if_req_exceed_kv_capacity(self, req: Req) -> bool: if len(req.origin_input_ids) > self.max_total_num_tokens: @@ -508,12 +552,12 @@ def _update_handshake_waiters( raise ValueError(f"Unexpected poll case: {poll}") def _ensure_prefill_info( - self, addr_to_reqs: Dict[str, List[Req]] - ) -> Tuple[Dict[str, List[Req]], List[Req]]: + self, addr_to_reqs: Dict[str, List[DecodeRequest]] + ) -> Tuple[Dict[str, List[DecodeRequest]], List[DecodeRequest]]: """Non-blocking ensure parallel info for each addr. Returns (ready_addrs, remaining_reqs).""" - ready: Dict[str, List[Req]] = {} - remaining: List[Req] = [] + ready: Dict[str, List[DecodeRequest]] = {} + remaining: List[DecodeRequest] = [] now = time.monotonic() for bootstrap_addr, reqs in addr_to_reqs.items(): @@ -540,13 +584,8 @@ def _ensure_prefill_info( if count >= self._max_ensure_retries: error_msg = f"Could not fetch prefill parallel info from {bootstrap_addr} after {count} attempts" logger.error(error_msg) - for req in reqs: - prepare_abort( - req, error_msg, status_code=HTTPStatus.INTERNAL_SERVER_ERROR - ) - if self.scheduler.enable_metrics: - self.scheduler.metrics_collector.increment_bootstrap_failed_reqs() - self.scheduler.stream_output([req], req.return_logprob) + for decode_req in reqs: + decode_req.kv_receiver.abort() del self._ensure_retry_count[bootstrap_addr] del self._ensure_last_attempt_time[bootstrap_addr] else: @@ -555,46 +594,48 @@ def _ensure_prefill_info( return ready, remaining def _resolve_pending_reqs(self) -> None: - """Batch-resolve prefill_dp_ranks for pending requests and create receivers.""" + """Batch-resolve prefill_dp_ranks for pending requests and initialize receivers.""" if not self.pending_reqs: return # Group pending requests by bootstrap_addr - addr_to_reqs: Dict[str, List[Req]] = {} - for req in self.pending_reqs: - addr = _bootstrap_addr(req) - addr_to_reqs.setdefault(addr, []).append(req) + addr_to_reqs: Dict[str, List[DecodeRequest]] = {} + for decode_req in self.pending_reqs: + addr = _bootstrap_addr(decode_req.req) + addr_to_reqs.setdefault(addr, []).append(decode_req) # Pass 1: ensure parallel info for each addr ready_addrs, remaining = self._ensure_prefill_info(addr_to_reqs) - # Pass 2: resolve dp rank for addrs whose info is available - resolved = [] - for bootstrap_addr, reqs in ready_addrs.items(): - need_query: List[Req] = [] - for req in reqs: - prefill_dp_rank = self._resolve_prefill_dp_rank(req) + resolved: List[Tuple[DecodeRequest, int]] = [] + for bootstrap_addr, decode_reqs in ready_addrs.items(): + need_query: List[DecodeRequest] = [] + for decode_req in decode_reqs: + prefill_dp_rank = self._resolve_prefill_dp_rank(decode_req.req) if prefill_dp_rank is not None: - resolved.append((req, prefill_dp_rank)) + resolved.append((decode_req, prefill_dp_rank)) else: - need_query.append(req) + need_query.append(decode_req) + # Pass 2: resolve dp rank for addrs whose info is available if need_query: - rooms = [req.bootstrap_room for req in need_query] + rooms = [decode_req.req.bootstrap_room for decode_req in need_query] room_to_rank = CommonKVReceiver.query_prefill_dp_ranks( bootstrap_addr, rooms ) - for req in need_query: - prefill_dp_rank = room_to_rank.get(str(req.bootstrap_room)) + for decode_req in need_query: + prefill_dp_rank = room_to_rank.get( + str(decode_req.req.bootstrap_room) + ) if prefill_dp_rank is not None: - resolved.append((req, int(prefill_dp_rank))) + resolved.append((decode_req, int(prefill_dp_rank))) else: - remaining.append(req) + remaining.append(decode_req) self.pending_reqs = remaining - for req, prefill_dp_rank in resolved: - self._create_receiver_and_enqueue(req, prefill_dp_rank) + for decode_req, prefill_dp_rank in resolved: + decode_req.kv_receiver.init(prefill_dp_rank) def pop_preallocated( self, rids_to_check: Optional[List[str]] = None @@ -668,16 +709,21 @@ def pop_preallocated( break allocatable_tokens -= required_tokens_for_request - self._pre_alloc(decode_req.req) + dst_kv_indices = self._pre_alloc(decode_req.req) - kv_indices = ( - self.req_to_token_pool.req_to_token[decode_req.req.req_pool_idx][ - : len(decode_req.req.origin_input_ids) - ] - .cpu() - .numpy() - ) - page_size = self.token_to_kv_pool_allocator.page_size + origin_input_len = len(decode_req.req.origin_input_ids) + if self.scheduler.enable_hisparse: + # Must cast to int32 for ZMQ serialization — from_zmq reads np.int32. + kv_indices = ( + dst_kv_indices[:origin_input_len].cpu().numpy().astype(np.int32) + ) + page_size = 1 # host pool page_size + else: + kv_indices_full = self.req_to_token_pool.req_to_token[ + decode_req.req.req_pool_idx + ][:origin_input_len] + kv_indices = kv_indices_full.cpu().numpy() + page_size = self.token_to_kv_pool_allocator.page_size # Prepare extra pool indices for hybrid models if isinstance(self.token_to_kv_pool, HybridLinearKVPool): @@ -714,7 +760,9 @@ def pop_preallocated( decode_req.req.req_pool_idx, :seq_len ] state_indices = kv_indices_full.cpu().numpy() - state_indices = kv_to_page_indices(state_indices, page_size) + # Indexer lives on device pool; always use device page_size + device_page_size = self.token_to_kv_pool.page_size + state_indices = kv_to_page_indices(state_indices, device_page_size) else: state_indices = None @@ -723,9 +771,16 @@ def pop_preallocated( ) assert decode_req.metadata_buffer_index is not None page_indices = kv_to_page_indices(kv_indices, page_size) - decode_req.kv_receiver.init( + decode_req.kv_receiver.send_metadata( page_indices, decode_req.metadata_buffer_index, state_indices ) + if ( + self.transfer_queue.enable_staging + and decode_req.kv_receiver.require_staging + ): + self.transfer_queue.staging_handler.register_decode_req( + decode_req.req.bootstrap_room, decode_req + ) preallocated_reqs.append(decode_req) indices_to_remove.add(i) decode_req.req.time_stats.set_decode_transfer_queue_entry_time() @@ -804,7 +859,30 @@ def _pre_alloc(self, req: Req) -> torch.Tensor: fill_len = len(req.origin_input_ids) + max(len(req.output_ids) - 1, 0) req.kv_allocated_len = fill_len req.kv_committed_len = fill_len - if self.token_to_kv_pool_allocator.page_size == 1: + + if self.scheduler.enable_hisparse: + # Direct-to-host path: only allocate logical indices (no hisparse + # device indices) and allocate host indices for RDMA destination. + coordinator = self.scheduler.hisparse_coordinator + device = self.token_to_kv_pool_allocator.device + kv_loc = self.token_to_kv_pool_allocator.alloc_logical_only( + prefix_lens=torch.tensor([0], dtype=torch.int64, device=device), + prefix_lens_cpu=torch.tensor([0], dtype=torch.int64), + seq_lens=torch.tensor([fill_len], dtype=torch.int64, device=device), + seq_lens_cpu=torch.tensor([fill_len], dtype=torch.int64), + last_loc=torch.tensor([-1], dtype=torch.int64, device=device), + extend_num_tokens=fill_len, + ) + # Allocate host indices for the RDMA transfer target + host_indices = coordinator.mem_pool_host.alloc(fill_len) + if host_indices is None: + raise RuntimeError( + f"HiSparse host mem pool alloc failed for {fill_len} tokens " + f"in _pre_alloc (req {req.rid})" + ) + host_indices = host_indices.to(device=coordinator.device) + coordinator.req_to_host_pool[req.req_pool_idx, :fill_len] = host_indices + elif self.token_to_kv_pool_allocator.page_size == 1: kv_loc = self.token_to_kv_pool_allocator.alloc(fill_len) else: device = self.token_to_kv_pool_allocator.device @@ -827,6 +905,9 @@ def _pre_alloc(self, req: Req) -> torch.Tensor: req.fill_ids = req.origin_input_ids + req.output_ids req.set_extend_input_len(len(req.fill_ids)) + # Return the transfer destination indices: + if self.scheduler.enable_hisparse: + return host_indices return kv_loc @@ -852,12 +933,18 @@ def __init__( self.scheduler = scheduler self.tree_cache = tree_cache self.spec_algorithm = scheduler.spec_algorithm + self.enable_staging = envs.SGLANG_DISAGG_STAGING_BUFFER.get() + self.staging_handler = None def add(self, decode_req: DecodeRequest) -> None: self.queue.append(decode_req) def extend(self, decode_reqs: List[DecodeRequest]) -> None: self.queue.extend(decode_reqs) + if self.enable_staging: + for dr in decode_reqs: + if dr.kv_receiver.require_staging: + self.staging_handler.register_decode_req(dr.req.bootstrap_room, dr) def _commit_transfer_to_req(self, decode_req: DecodeRequest) -> bool: """ @@ -916,6 +1003,9 @@ def _commit_transfer_to_req(self, decode_req: DecodeRequest) -> bool: # Case 3: Success - commit the transfer decode_req.req.output_ids.append(output_id[0].item()) decode_req.req.cached_tokens = cached_tokens[0].item() + decode_req.req.cached_tokens_device = cached_tokens[1].item() + decode_req.req.cached_tokens_host = cached_tokens[2].item() + decode_req.req.cached_tokens_storage = cached_tokens[3].item() if not self.spec_algorithm.is_none(): decode_req.req.output_topk_p = output_topk_p decode_req.req.output_topk_index = output_topk_index @@ -940,18 +1030,39 @@ def _commit_transfer_to_req(self, decode_req: DecodeRequest) -> bool: decode_req.req.time_stats.set_wait_queue_entry_time() return True + def _poll_with_staging(self) -> list: + return poll_and_all_reduce_with_staging( + self.queue, self.staging_handler, self.gloo_group + ) + + def _init_staging_handler(self, kv_manager): + """Create staging handler from kv_manager. Must be called exactly once.""" + from sglang.srt.disaggregation.common.staging_handler import ( + DecodeStagingHandler, + ) + + self.staging_handler = DecodeStagingHandler.create( + kv_manager, self.scheduler, self.tp_rank + ) + kv_manager._staging_handler = self.staging_handler + def pop_transferred(self, rids_to_check: Optional[List[str]] = None) -> List[Req]: if not self.queue: return [] - polls = poll_and_all_reduce( - [decode_req.kv_receiver for decode_req in self.queue], self.gloo_group - ) + + if self.enable_staging: + polls = self._poll_with_staging() + else: + polls = poll_and_all_reduce( + [dr.kv_receiver for dr in self.queue], self.gloo_group + ) transferred_reqs = [] indices_to_remove = set() for i, (decode_req, poll) in enumerate(zip(self.queue, polls)): if rids_to_check is not None and decode_req.req.rid not in rids_to_check: continue + if poll == KVPoll.Failed: error_message = f"Decode transfer failed for request rank={self.tp_rank} {decode_req.req.rid=} {decode_req.req.bootstrap_room=}" try: @@ -967,6 +1078,8 @@ def pop_transferred(self, rids_to_check: Optional[List[str]] = None) -> List[Req self.scheduler.stream_output( [decode_req.req], decode_req.req.return_logprob ) + if self.scheduler.enable_hisparse: + self.scheduler.hisparse_coordinator.request_finished(decode_req.req) # release pre-allocated kv cache, but don't insert into the tree since it's failed release_kv_cache(decode_req.req, self.tree_cache, is_insert=False) indices_to_remove.add(i) @@ -982,6 +1095,10 @@ def pop_transferred(self, rids_to_check: Optional[List[str]] = None) -> List[Req self.scheduler.stream_output( [decode_req.req], decode_req.req.return_logprob ) + if self.scheduler.enable_hisparse: + self.scheduler.hisparse_coordinator.request_finished( + decode_req.req + ) release_kv_cache( decode_req.req, self.tree_cache, is_insert=False ) @@ -999,6 +1116,12 @@ def pop_transferred(self, rids_to_check: Optional[List[str]] = None) -> List[Req raise ValueError(f"Unexpected poll case: {poll}") for i in indices_to_remove: + if self.enable_staging and self.staging_handler.is_staging_room( + self.queue[i].req.bootstrap_room + ): + self.staging_handler.unregister_decode_req( + self.queue[i].req.bootstrap_room + ) idx = self.queue[i].metadata_buffer_index assert idx != -1 self.req_to_metadata_buffer_idx_allocator.free(idx) @@ -1099,6 +1222,10 @@ def get_next_disagg_decode_batch_to_run( if not new_prebuilt_batch.is_empty(): if self.running_batch.is_empty(): self.running_batch = new_prebuilt_batch + if self.enable_hisparse: + self.running_batch.hisparse_coordinator = ( + self.hisparse_coordinator + ) else: self.running_batch.merge_batch(new_prebuilt_batch) @@ -1191,4 +1318,10 @@ def process_decode_queue(self: Scheduler): transferred_reqs = ( self.disagg_decode_transfer_queue.pop_transferred() ) # the requests which kv has arrived - self.waiting_queue.extend(transferred_reqs) + if self.enable_hisparse: + for req in transferred_reqs: + # Direct-to-host: KV data already in host pool, skip staging + self.hisparse_coordinator.admit_request_direct(req) + self.waiting_queue.extend(transferred_reqs) + else: + self.waiting_queue.extend(transferred_reqs) diff --git a/python/sglang/srt/disaggregation/encode_grpc_server.py b/python/sglang/srt/disaggregation/encode_grpc_server.py index b7d2e5f03919..033520093b58 100644 --- a/python/sglang/srt/disaggregation/encode_grpc_server.py +++ b/python/sglang/srt/disaggregation/encode_grpc_server.py @@ -258,7 +258,7 @@ async def serve_grpc_encoder(server_args: ServerArgs): ) reflection.enable_server_reflection(SERVICE_NAMES, server) - listen_addr = f"{server_args.host}:{server_args.port}" + listen_addr = NetworkAddress(server_args.host, server_args.port).to_host_port_str() server.add_insecure_port(listen_addr) await server.start() diff --git a/python/sglang/srt/disaggregation/encode_receiver.py b/python/sglang/srt/disaggregation/encode_receiver.py index 6b70e15ef53d..391bab6eff5d 100644 --- a/python/sglang/srt/disaggregation/encode_receiver.py +++ b/python/sglang/srt/disaggregation/encode_receiver.py @@ -29,7 +29,11 @@ from sglang.srt.server_args import ServerArgs from sglang.srt.utils import ImageData from sglang.srt.utils.hf_transformers_utils import get_processor -from sglang.srt.utils.network import get_local_ip_auto, get_zmq_socket_on_host +from sglang.srt.utils.network import ( + NetworkAddress, + get_local_ip_auto, + get_zmq_socket_on_host, +) logger = logging.getLogger(__name__) @@ -447,7 +451,9 @@ async def send_embedding_port(req_id, receive_count, host_name, embedding_port): payload = { "req_id": part_req_id, # use part_req_id to match encode request "receive_count": receive_count, - "receive_url": f"{host_name}:{embedding_port}", + "receive_url": NetworkAddress( + host_name, embedding_port + ).to_host_port_str(), "modality": modality.name, } logger.info( @@ -666,6 +672,11 @@ def __init__( server_args, _processor, transport_mode, + model_config=( + getattr(self.scheduler, "model_config", None) + if self.scheduler is not None + else None + ), skip_mm_pool=not enable_adaptive_dispatch_to_encoder, ) diff --git a/python/sglang/srt/disaggregation/encode_server.py b/python/sglang/srt/disaggregation/encode_server.py index 4da24e522c98..02d1054b48e4 100644 --- a/python/sglang/srt/disaggregation/encode_server.py +++ b/python/sglang/srt/disaggregation/encode_server.py @@ -435,8 +435,13 @@ def _load_single_item( return data try: if modality == Modality.IMAGE: - img, _ = load_image(data) - if discard_alpha_channel and img.mode != "RGB": + img, _ = load_image(data, False) + if ( + discard_alpha_channel + and not isinstance(img, torch.Tensor) + and img.mode != "RGB" + ): + # Needed only when `img` is a PIL image img = img.convert("RGB") return img elif modality == Modality.VIDEO: diff --git a/python/sglang/srt/disaggregation/fake/conn.py b/python/sglang/srt/disaggregation/fake/conn.py index 4a3841e68208..03b79af189f7 100644 --- a/python/sglang/srt/disaggregation/fake/conn.py +++ b/python/sglang/srt/disaggregation/fake/conn.py @@ -82,28 +82,33 @@ def __init__( mgr: BaseKVManager, bootstrap_addr: str, bootstrap_room: Optional[int] = None, - prefill_dp_rank: Optional[int] = None, ): - self.has_init = False + self.bootstrap_done = False + self.has_sent_metadata = False def poll(self) -> KVPoll: - if self.has_init is False: - # Assume handshake completed instantly + if not self.bootstrap_done: + return KVPoll.Bootstrapping + if not self.has_sent_metadata: return KVPoll.WaitingForInput - else: - # Assume transfer completed instantly - logger.debug("FakeKVReceiver poll success") - return KVPoll.Success + logger.debug("FakeKVReceiver poll success") + return KVPoll.Success def init( + self, + prefill_dp_rank: int, + ): + self.bootstrap_done = True + + def send_metadata( self, kv_indices: list[int], aux_index: Optional[int] = None, state_indices: Optional[List[int]] = None, ): - self.has_init = True + self.has_sent_metadata = True logger.debug( - f"FakeKVReceiver init with kv_indices: {kv_indices}, aux_index: {aux_index}, state_indices: {state_indices}" + f"FakeKVReceiver send_metadata with kv_indices: {kv_indices}, aux_index: {aux_index}, state_indices: {state_indices}" ) def failure_exception(self): diff --git a/python/sglang/srt/disaggregation/mooncake/conn.py b/python/sglang/srt/disaggregation/mooncake/conn.py index 15e815e696c2..64d97f5c6966 100644 --- a/python/sglang/srt/disaggregation/mooncake/conn.py +++ b/python/sglang/srt/disaggregation/mooncake/conn.py @@ -61,6 +61,14 @@ class TransferKVChunk: state_indices: Optional[List[int]] +from sglang.srt.disaggregation.common.staging_handler import ( + DecodeStagingContext, + PrefillStagingContext, + StagingRegisterInfo, + StagingTransferInfo, +) + + # decode @dataclasses.dataclass class TransferInfo: @@ -73,6 +81,8 @@ class TransferInfo: dst_state_indices: List[int] required_dst_info_num: int is_dummy: bool + # Note: always put the optional staging field at the final (it will be set through 'STAGING_RSP' pkg when needed) + staging: Optional[StagingTransferInfo] = None @classmethod def from_zmq(cls, msg: List[bytes]): @@ -118,6 +128,10 @@ class KVArgsRegisterInfo: # for mamba state different tp slice transfer dst_state_item_lens: list[int] dst_state_dim_per_tensor: list[int] + # HiSparse: decode host pool stores KV at token granularity + enable_hisparse: bool = False + # Note: always put the staging field at the final (since the staging field is optional and contains multiple inputs) + staging: Optional[StagingRegisterInfo] = None @classmethod def from_zmq(cls, msg: List[bytes]): @@ -142,6 +156,11 @@ def from_zmq(cls, msg: List[bytes]): if len(msg) > 11 and len(msg[11]) > 0 else [] ), + enable_hisparse=( + msg[12].decode("ascii") == "1" if len(msg) > 12 else False + ), + # Note: always put the staging field at the final + staging=StagingRegisterInfo.from_zmq_fields(msg, 13), ) @@ -178,6 +197,7 @@ def __init__( super().__init__(args, disaggregation_mode, server_args, is_mla_backend) self.init_engine() self.register_buffer_to_engine() + self.enable_staging = envs.SGLANG_DISAGG_STAGING_BUFFER.get() if self.disaggregation_mode == DisaggregationMode.PREFILL: self.start_prefill_thread() self.session_failures = defaultdict(int) @@ -204,14 +224,34 @@ def __init__( ) for _ in range(transfer_queue_size) ] - for queue, executor in zip(self.transfer_queues, self.executors): - threading.Thread( - target=self.transfer_worker, args=(queue, executor), daemon=True - ).start() self.enable_custom_mem_pool, self.custom_mem_pool_type = ( check_mooncake_custom_mem_pool_enabled() ) + self._staging_ctx = PrefillStagingContext() if self.enable_staging else None + if self.enable_staging: + self._init_staging_buffers(len(self.transfer_queues)) + for i, (queue, executor) in enumerate( + zip(self.transfer_queues, self.executors) + ): + threading.Thread( + target=self.transfer_worker, + args=( + queue, + executor, + ( + self._staging_ctx.buffers[i] + if self.enable_staging and self._staging_ctx.buffers + else None + ), + ), + daemon=True, + ).start() elif self.disaggregation_mode == DisaggregationMode.DECODE: + self._staging_ctx = DecodeStagingContext() if self.enable_staging else None + if self.enable_staging: + self._init_staging_allocator() + self._staging_handler = None + self._chunk_writer_counts: dict = defaultdict(lambda: defaultdict(list)) self.start_decode_thread() def init_engine(self): @@ -236,6 +276,297 @@ def register_buffer_to_engine(self): self.kv_args.state_data_ptrs, self.kv_args.state_data_lens ) + # ------------------------------------------------------------------ + # Staging buffer methods (all delegate to staging_handler.py) + # ------------------------------------------------------------------ + + def register_staging_room_bootstrap(self, room, bootstrap_infos, receiver): + self._staging_ctx.room_bootstrap[room] = bootstrap_infos + self._staging_ctx.room_receivers[room] = receiver + + def set_kv_buffer_tensors(self, k_buffers: list, v_buffers: list, page_size: int): + self.kv_buffer_tensors = { + "k_buffers": k_buffers, + "v_buffers": v_buffers, + "page_size": page_size, + } + + def _init_staging_buffers(self, count: int): + from sglang.srt.disaggregation.common.staging_handler import ( + init_staging_buffers, + ) + + self._staging_ctx.buffers = init_staging_buffers( + self.engine, self.kv_args, count + ) + self.kv_buffer_tensors = None + + def _init_staging_allocator(self): + from sglang.srt.disaggregation.common.staging_handler import ( + init_staging_allocator, + ) + + self._staging_ctx.allocator = init_staging_allocator(self.engine, self.kv_args) + self.kv_buffer_tensors = None + + def _handle_staging_req(self, msg): + from sglang.srt.disaggregation.common.staging_handler import ( + handle_staging_req, + ) + + room = int(msg[1].decode("ascii")) + session_id = msg[4].decode("ascii") + handler = self._staging_handler + assert ( + handler is not None + ), "STAGING_REQ received before staging handler initialized" + decode_req = handler._room_to_decode_req.get(room) + if decode_req is None: + logger.warning( + "STAGING_REQ received for unregistered room=%s, skipping", + room, + ) + return + prefill_tp = decode_req.kv_receiver.prefill_info.attn_tp_size + handle_staging_req( + msg, + self._staging_ctx.allocator, + self.kv_args, + self.attn_tp_size, + prefill_tp, + getattr(self, "kv_buffer_tensors", None), + self._staging_ctx.room_receivers, + self._staging_ctx.room_bootstrap, + ) + + receiver = self._staging_ctx.room_receivers.get(room) + if receiver is not None: + handler.register_wm_subscriber(receiver, session_id) + + def _is_watermark_ready( + self, session_id: str, alloc_round: int, alloc_end: int + ) -> bool: + from sglang.srt.disaggregation.common.staging_handler import ( + is_watermark_ready, + ) + + return is_watermark_ready(self._staging_ctx, session_id, alloc_round, alloc_end) + + def _try_create_staging_strategy(self, staging_buffer): + if not self.enable_staging or self.kv_buffer_tensors is None: + return None + from sglang.srt.disaggregation.common.staging_handler import ( + PrefillStagingStrategy, + ) + + return PrefillStagingStrategy(self, staging_buffer) + + def _send_chunk_ready(self, req, chunk_idx, kv_chunk, prefill_unique_rank): + """Notify decode that a non-last staging chunk RDMA is complete.""" + try: + na = NetworkAddress(req.endpoint, req.dst_port) + self._connect( + na.to_tcp(), + is_ipv6=na.is_ipv6, + ).send_multipart( + [ + b"CHUNK_READY", + str(req.room).encode("ascii"), + str(chunk_idx).encode("ascii"), + str(kv_chunk.index_slice.start).encode("ascii"), + str(len(kv_chunk.prefill_kv_indices)).encode("ascii"), + req.mooncake_session_id.encode("ascii"), + str(prefill_unique_rank).encode("ascii"), + ] + ) + except Exception: + pass + + def _do_staging_transfer( + self, + staging_strategy, + kv_chunk, + req, + target_info, + chunked_dst_kv_indice, + executor, + queue, + prefill_unique_rank, + ): + """Execute staging transfer for one chunk. Returns (ret, deferred). + + Handles readiness check, transfer, fallback, and CHUNK_READY notification. + deferred=True means caller should re-enqueue and break. + """ + _tp = self.attn_tp_rank + ready, chunk_idx, c_offset, _, _ = staging_strategy.check_ready( + req, + kv_chunk.index_slice.start, + len(kv_chunk.prefill_kv_indices), + ) + if not ready: + from sglang.srt.disaggregation.common.staging_buffer import StagingAllocator + + if c_offset == StagingAllocator.ALLOC_OVERSIZED: + raise RuntimeError( + f"[Staging] Chunk staging allocation permanently failed: " + f"chunk exceeds ring buffer total size (room={kv_chunk.room}). " + f"Increase SGLANG_DISAGG_STAGING_POOL_SIZE_MB." + ) + queue.put(kv_chunk) + return (-1, True) + + ret = staging_strategy.transfer( + req.mooncake_session_id, + kv_chunk.prefill_kv_indices, + target_info.staging.base_ptr + c_offset, + target_info.staging.total_size - c_offset, + target_info, + ) + if ret == -1: + logger.warning( + f"[Staging][tp{_tp}] Falling back to per-token slice path " + f"(room={kv_chunk.room})" + ) + ret = self.send_kvcache_slice( + req.mooncake_session_id, + kv_chunk.prefill_kv_indices, + target_info.dst_kv_ptrs, + chunked_dst_kv_indice, + target_info.dst_tp_rank, + target_info.dst_attn_tp_size, + target_info.dst_kv_item_len, + executor, + ) + elif ret == 0 and not kv_chunk.is_last_chunk: + self._send_chunk_ready(req, chunk_idx, kv_chunk, prefill_unique_rank) + return (ret, False) + + def _prefetch_staging_reqs(self, room: int): + if not self.enable_staging or self.kv_buffer_tensors is None: + return + + room_infos = self.transfer_infos.get(room, {}) + needs_staging = any( + not tinfo.is_dummy + and self.decode_kv_args_table.get(tinfo.mooncake_session_id) is not None + and self.decode_kv_args_table[tinfo.mooncake_session_id].dst_attn_tp_size + != self.attn_tp_size + for tinfo in room_infos.values() + ) + if not needs_staging: + return + + from sglang.srt.disaggregation.common.staging_handler import ( + prefetch_staging_reqs, + ) + + prefetch_staging_reqs( + room, + self.transfer_infos, + self.kv_buffer_tensors, + self.server_args.chunked_prefill_size, + self._staging_ctx.prefetch_requested, + self._staging_ctx.prefetch_sockets, + ) + + def send_kvcache_staged( + self, + mooncake_session_id: str, + prefill_kv_indices: npt.NDArray[np.int32], + dst_staging_ptr: int, + dst_staging_size: int, + dst_tp_rank: int, + dst_attn_tp_size: int, + dst_kv_item_len: int, + staging_buffer=None, + ) -> int: + """Transfer KV cache via staging buffers (gather -> bulk RDMA -> scatter on decode).""" + from sglang.srt.disaggregation.common.staging_buffer import ( + compute_head_slice_params, + compute_staging_layout, + resolve_total_kv_heads, + ) + + if self.kv_buffer_tensors is None or staging_buffer is None: + return -1 + + k_buffers = self.kv_buffer_tensors["k_buffers"] + v_buffers = self.kv_buffer_tensors["v_buffers"] + page_size = self.kv_buffer_tensors["page_size"] + num_layers = len(k_buffers) + head_dim = k_buffers[0].shape[-1] + dtype_size = k_buffers[0].element_size() + + total_kv_heads = resolve_total_kv_heads(self.kv_args, self.attn_tp_size) + + local_tp_rank = self.kv_args.engine_rank % self.attn_tp_size + src_head_start, num_heads_to_send, _, _ = compute_head_slice_params( + self.attn_tp_size, + dst_attn_tp_size, + local_tp_rank, + dst_tp_rank, + total_kv_heads, + ) + + num_tokens = len(prefill_kv_indices) * page_size + per_layer_bytes = num_tokens * num_heads_to_send * head_dim * dtype_size + per_rank_bytes = per_layer_bytes * num_layers * 2 + + num_writers, writer_rank_bytes, total_staging_needed = compute_staging_layout( + self.attn_tp_size, + dst_attn_tp_size, + dst_tp_rank, + total_kv_heads, + num_tokens, + head_dim * dtype_size, + num_layers, + ) + writer_idx = local_tp_rank % num_writers if num_writers > 1 else 0 + rank_offset = sum(writer_rank_bytes[:writer_idx]) + + if not staging_buffer.fits(per_rank_bytes): + logger.warning( + f"Prefill staging too small for {per_rank_bytes} bytes, falling back" + ) + return -1 + if dst_staging_size < total_staging_needed: + logger.warning( + f"Decode staging too small: need {total_staging_needed} bytes " + f"({num_writers if self.attn_tp_size > dst_attn_tp_size else 1} writers " + f"x {per_rank_bytes} bytes/rank), have {dst_staging_size}, falling back" + ) + return -1 + + from sglang.srt.disaggregation.common.staging_buffer import ( + gather_all_layers_to_staging, + ) + + gather_all_layers_to_staging( + k_buffers, + v_buffers, + prefill_kv_indices, + staging_buffer, + src_head_start, + num_heads_to_send, + page_size, + self.kv_args.gpu_id, + ) + + dst_write_ptr = dst_staging_ptr + rank_offset + ret = self._transfer_data( + mooncake_session_id, + [(staging_buffer.get_ptr(), dst_write_ptr, per_rank_bytes)], + ) + if ret != 0: + raise RuntimeError( + f"[Staging] Bulk RDMA transfer failed with ret={ret}. " + f"src_ptr=0x{staging_buffer.get_ptr():x}, " + f"dst_ptr=0x{dst_write_ptr:x}, size={per_rank_bytes}. " + f"The decode staging buffer may not be properly registered." + ) + return ret + def _transfer_data(self, mooncake_session_id, transfer_blocks): if not transfer_blocks: return 0 @@ -371,6 +702,49 @@ def send_kvcache( executor=executor, ) + def send_kvcache_hisparse( + self, + mooncake_session_id: str, + prefill_kv_indices: npt.NDArray[np.int32], + dst_kv_ptrs: list[int], + dst_kv_indices: npt.NDArray[np.int32], + page_index_slice: slice, + executor: concurrent.futures.ThreadPoolExecutor, + ): + """HiSparse transfer: prefill page_size > decode host page_size=1. + + Receives page-level prefill_kv_indices and the full token-level + dst_kv_indices. Expands both to token granularity before transfer. + """ + page_size = self.kv_args.page_size + per_token_item_lens = [il // page_size for il in self.kv_args.kv_item_lens] + + # Expand page-level src indices to token-level + base = np.repeat(prefill_kv_indices * page_size, page_size) + offsets = np.tile(np.arange(page_size, dtype=np.int32), len(prefill_kv_indices)) + expanded_src = base + offsets + + # Expand page-level index_slice to token-level for dst + token_start = page_index_slice.start * page_size + token_end = min(page_index_slice.stop * page_size, len(dst_kv_indices)) + expanded_dst = dst_kv_indices[token_start:token_end] + + # Clip src to match dst length (last page may be partial) + expanded_src = expanded_src[: len(expanded_dst)] + + logger.debug( + f"Send KVCache for hisparse: {expanded_src.shape} -> {expanded_dst.shape}" + ) + return self._send_kvcache_generic( + mooncake_session_id=mooncake_session_id, + src_data_ptrs=self.kv_args.kv_data_ptrs, + dst_data_ptrs=dst_kv_ptrs, + item_lens=per_token_item_lens, + prefill_data_indices=expanded_src, + dst_data_indices=expanded_dst, + executor=executor, + ) + def send_kvcache_slice( self, mooncake_session_id: str, @@ -770,11 +1144,22 @@ def sync_status_to_decode_endpoint( ) def transfer_worker( - self, queue: FastQueue, executor: concurrent.futures.ThreadPoolExecutor + self, + queue: FastQueue, + executor: concurrent.futures.ThreadPoolExecutor, + staging_buffer=None, ): + staging_strategy = None + while True: try: kv_chunk: TransferKVChunk = queue.get() + if ( + self.enable_staging + and staging_strategy is None + and staging_buffer is not None + ): + staging_strategy = self._try_create_staging_strategy(staging_buffer) reqs_to_be_processed = ( self.transfer_infos[kv_chunk.room].values() if kv_chunk.room in self.transfer_infos @@ -788,6 +1173,9 @@ def transfer_worker( + self.pp_rank * self.attn_cp_size + self.attn_cp_rank ) + # When staging transfer is not yet ready (watermark/allocation pending), + # the chunk is re-enqueued and we break out of the req loop to retry later. + staging_deferred = False for req in reqs_to_be_processed: if not req.is_dummy: # Early exit if the request has failed @@ -828,13 +1216,42 @@ def transfer_worker( self.attn_tp_size == target_rank_registration_info.dst_attn_tp_size ): - ret = self.send_kvcache( - req.mooncake_session_id, - kv_chunk.prefill_kv_indices, - target_rank_registration_info.dst_kv_ptrs, + if target_rank_registration_info.enable_hisparse: + ret = self.send_kvcache_hisparse( + req.mooncake_session_id, + kv_chunk.prefill_kv_indices, + target_rank_registration_info.dst_kv_ptrs, + req.dst_kv_indices, + kv_chunk.index_slice, + executor, + ) + else: + ret = self.send_kvcache( + req.mooncake_session_id, + kv_chunk.prefill_kv_indices, + target_rank_registration_info.dst_kv_ptrs, + chunked_dst_kv_indice, + executor, + ) + elif ( + self.enable_staging + and staging_strategy is not None + and target_rank_registration_info.staging is not None + ): + ret, deferred = self._do_staging_transfer( + staging_strategy, + kv_chunk, + req, + target_rank_registration_info, chunked_dst_kv_indice, executor, + queue, + prefill_unique_rank, ) + if deferred: + staging_deferred = True + # Chunk re-enqueued; stop processing remaining reqs for this chunk + break else: ret = self.send_kvcache_slice( req.mooncake_session_id, @@ -857,7 +1274,8 @@ def transfer_worker( ) self.record_failure( kv_chunk.room, - f"Failed to send kv chunk of {kv_chunk.room} to {req.endpoint}:{req.dst_port}", + f"Failed to send kv chunk of {kv_chunk.room} to " + f"{NetworkAddress(req.endpoint, req.dst_port).to_host_port_str()}", ) self.update_status(kv_chunk.room, KVPoll.Failed) self.sync_status_to_decode_endpoint( @@ -908,6 +1326,9 @@ def transfer_worker( if kv_chunk.is_last_chunk and req.room in self.request_status: self.update_status(req.room, KVPoll.Success) + if staging_deferred: + continue + if ( kv_chunk.room not in self.request_status or self.check_status(kv_chunk.room) == KVPoll.Success @@ -928,6 +1349,50 @@ def bootstrap_thread(): while True: waiting_req_bytes = self.server_socket.recv_multipart() room = waiting_req_bytes[0].decode("ascii") + # Staging: decode reports consumption watermark back to prefill + if room == "WATERMARK": + wm_round = int(waiting_req_bytes[1].decode("ascii")) + wm_tail = int(waiting_req_bytes[2].decode("ascii")) + wm_session = ( + waiting_req_bytes[3].decode("ascii") + if len(waiting_req_bytes) > 3 + else "" + ) + with self._staging_ctx.watermark_cv: + prev = self._staging_ctx.remote_watermarks.get( + wm_session, (0, 0) + ) + if (wm_round, wm_tail) > prev: + self._staging_ctx.remote_watermarks[wm_session] = ( + wm_round, + wm_tail, + ) + self._staging_ctx.watermark_cv.notify_all() + continue + # Staging: decode replies with allocated staging offset + if room == "STAGING_RSP": + stg_room = int(waiting_req_bytes[1].decode("ascii")) + stg_chunk_idx = int(waiting_req_bytes[2].decode("ascii")) + stg_offset = int(waiting_req_bytes[3].decode("ascii")) + stg_round = int(waiting_req_bytes[4].decode("ascii")) + stg_end = int(waiting_req_bytes[5].decode("ascii")) + stg_session = waiting_req_bytes[6].decode("ascii") + room_infos = self.transfer_infos.get(stg_room, {}) + tinfo = room_infos.get(stg_session) + if tinfo is not None: + if tinfo.staging is None: + tinfo.staging = StagingTransferInfo() + tinfo.staging.set_chunk( + stg_chunk_idx, stg_offset, stg_round, stg_end + ) + else: + logger.warning( + "STAGING_RSP RECV but tinfo=None room=%s chunk=%d session=%s", + stg_room, + stg_chunk_idx, + stg_session, + ) + continue mooncake_session_id = waiting_req_bytes[3].decode("ascii") if room == "None": self.decode_kv_args_table[mooncake_session_id] = ( @@ -965,6 +1430,42 @@ def decode_thread(): self._handle_aux_data(msg) continue + # Staging: prefill notifies a chunk written to staging buffer + if msg[0] == b"CHUNK_READY": + room = int(msg[1].decode("ascii")) + chunk_idx = int(msg[2].decode("ascii")) + page_start = int(msg[3].decode("ascii")) + num_pages = int(msg[4].decode("ascii")) + session_id = msg[5].decode("ascii") + self._chunk_writer_counts[room][chunk_idx].append( + (page_start, num_pages, session_id) + ) + handler = self._staging_handler + assert ( + handler is not None + ), "CHUNK_READY received before staging handler initialized" + writers_arrived = len(self._chunk_writer_counts[room][chunk_idx]) + decode_req = handler._room_to_decode_req.get(room) + if decode_req is None: + logger.warning( + "CHUNK_READY received for unregistered room=%s chunk=%d, skipping", + room, + chunk_idx, + ) + continue + num_writers = handler.num_writers_for(decode_req) + if writers_arrived >= num_writers: + handler.submit_chunk_scatter( + room, chunk_idx, page_start, num_pages + ) + del self._chunk_writer_counts[room][chunk_idx] + continue + + # Staging: prefill pre-requests staging allocation before forward + if msg[0] == b"STAGING_REQ": + self._handle_staging_req(msg) + continue + bootstrap_room, status, prefill_rank = msg status = int(status.decode("ascii")) bootstrap_room = int(bootstrap_room.decode("ascii")) @@ -980,6 +1481,11 @@ def decode_thread(): self.prefill_response_tracker[bootstrap_room] ) if arrived_response_num == expected_response_num: + if self.enable_staging: + handler = self._staging_handler + if handler.is_staging_room(bootstrap_room): + handler.submit_last_scatter_async(bootstrap_room) + self._chunk_writer_counts.pop(bootstrap_room, None) self.update_status(bootstrap_room, KVPoll.Success) elif status == KVPoll.Failed: self.record_failure( @@ -1237,15 +1743,10 @@ def __init__( mgr: MooncakeKVManager, bootstrap_addr: str, bootstrap_room: Optional[int] = None, - prefill_dp_rank: Optional[int] = None, ): self.session_id = mgr.get_session_id() - self.conclude_state = None self.init_time = None - super().__init__(mgr, bootstrap_addr, bootstrap_room, prefill_dp_rank) - - self.kv_mgr.addr_to_rooms_tracker[self.bootstrap_addr].add(self.bootstrap_room) - self.kv_mgr.update_status(self.bootstrap_room, KVPoll.WaitingForInput) + super().__init__(mgr, bootstrap_addr, bootstrap_room) def _register_kv_args(self): for bootstrap_info in self.bootstrap_infos: @@ -1275,6 +1776,18 @@ def _register_kv_args(self): dst_tp_rank = str(tp_rank).encode("ascii") dst_attn_tp_size = str(self.kv_mgr.attn_tp_size).encode("ascii") dst_kv_item_len = str(kv_item_len).encode("ascii") + enable_hisparse = b"1" if self.kv_mgr.server_args.enable_hisparse else b"0" + + if ( + self.kv_mgr.enable_staging + and self.kv_mgr._staging_ctx.allocator is not None + ): + _alloc = self.kv_mgr._staging_ctx.allocator + packed_staging_base_ptr = struct.pack("Q", _alloc.get_base_ptr()) + staging_total_size_str = str(_alloc.get_total_size()).encode("ascii") + else: + packed_staging_base_ptr = b"" + staging_total_size_str = b"" sock, lock = self._connect_to_bootstrap_server(bootstrap_info) with lock: @@ -1292,10 +1805,19 @@ def _register_kv_args(self): dst_kv_item_len, packed_state_item_lens, packed_state_dim_per_tensor, + enable_hisparse, + packed_staging_base_ptr, + staging_total_size_str, ] ) def init( + self, + prefill_dp_rank: int, + ): + super().init(prefill_dp_rank) + + def send_metadata( self, kv_indices: npt.NDArray[np.int32], aux_index: Optional[int] = None, @@ -1309,6 +1831,15 @@ def init( self.kv_mgr.update_status(self.bootstrap_room, KVPoll.Failed) return + if ( + self.kv_mgr.enable_staging + and self.kv_mgr._staging_ctx.allocator is not None + ): + self.chunk_staging_infos = [] + self.kv_mgr.register_staging_room_bootstrap( + self.bootstrap_room, self.bootstrap_infos, self + ) + for bootstrap_info in self.bootstrap_infos: sock, lock = self._connect_to_bootstrap_server(bootstrap_info) is_dummy = bootstrap_info["is_dummy"] diff --git a/python/sglang/srt/disaggregation/mooncake/utils.py b/python/sglang/srt/disaggregation/mooncake/utils.py index 63e7bd29de4e..279cf194de87 100644 --- a/python/sglang/srt/disaggregation/mooncake/utils.py +++ b/python/sglang/srt/disaggregation/mooncake/utils.py @@ -23,7 +23,7 @@ logger = logging.getLogger(__name__) # Global constants for custom memory pool types -SUPPORTED_MOONCAKE_CUSTOM_MEM_POOL_TYPES = ["NVLINK", "BAREX", "INTRA_NVLINK"] +SUPPORTED_MOONCAKE_CUSTOM_MEM_POOL_TYPES = ["NVLINK", "BAREX", "INTRA_NODE_NVLINK"] def init_mooncake_custom_mem_pool( diff --git a/python/sglang/srt/disaggregation/mori/conn.py b/python/sglang/srt/disaggregation/mori/conn.py index 89b11f03e94e..70154f9e981c 100644 --- a/python/sglang/srt/disaggregation/mori/conn.py +++ b/python/sglang/srt/disaggregation/mori/conn.py @@ -972,10 +972,8 @@ def failure_exception(self): raise RuntimeError(failure_reason) def abort(self): - reason = "Aborted by AbortReq." - self.kv_mgr.record_failure(self.bootstrap_room, reason) - self._notify_decode(KVPoll.Failed, reason) - self.conclude_state = KVPoll.Failed + super().abort() + self._notify_decode(KVPoll.Failed, "Aborted by AbortReq.") class MoriKVReceiver(CommonKVReceiver): @@ -985,17 +983,18 @@ def __init__( mgr: MoriKVManager, bootstrap_addr: str, bootstrap_room: Optional[int] = None, - prefill_dp_rank: Optional[int] = None, ): - super().__init__(mgr, bootstrap_addr, bootstrap_room, prefill_dp_rank) - self.conclude_state: Optional[KVPoll] = None + super().__init__(mgr, bootstrap_addr, bootstrap_room) self.init_time: Optional[float] = None - if self.bootstrap_room is None or self.bootstrap_infos is None: + + def init( + self, + prefill_dp_rank: int, + ): + super().init(prefill_dp_rank) + if self.bootstrap_room is None: return - self.kv_mgr.addr_to_rooms_tracker[self.bootstrap_addr].add(self.bootstrap_room) self.kv_mgr.room_to_bootstrap_addr[self.bootstrap_room] = self.bootstrap_addr - self.kv_mgr.update_status(self.bootstrap_room, KVPoll.WaitingForInput) - self._register_kv_args() def _register_kv_args(self): if self.bootstrap_infos is None: @@ -1029,7 +1028,7 @@ def _register_kv_args(self): ] ) - def init( + def send_metadata( self, kv_indices: npt.NDArray[np.int32], aux_index: Optional[int] = None, @@ -1105,10 +1104,8 @@ def failure_exception(self): def abort(self): if self.bootstrap_room is None: return - reason = "Aborted by AbortReq." - self.kv_mgr.record_failure(self.bootstrap_room, reason) + super().abort() self.kv_mgr.update_status(self.bootstrap_room, KVPoll.Failed) - self.conclude_state = KVPoll.Failed self.clear() diff --git a/python/sglang/srt/disaggregation/nixl/conn.py b/python/sglang/srt/disaggregation/nixl/conn.py index 764fd9e42689..38a4d15cf048 100644 --- a/python/sglang/srt/disaggregation/nixl/conn.py +++ b/python/sglang/srt/disaggregation/nixl/conn.py @@ -957,20 +957,18 @@ def __init__( mgr: NixlKVManager, bootstrap_addr: str, bootstrap_room: Optional[int] = None, - prefill_dp_rank: Optional[int] = None, ): self.started_transfer = False - self.conclude_state = None - super().__init__(mgr, bootstrap_addr, bootstrap_room, prefill_dp_rank) - - # Track this room with its bootstrap address for heartbeat monitoring - if hasattr(self.kv_mgr, "addr_to_rooms_tracker"): - self.kv_mgr.addr_to_rooms_tracker[self.bootstrap_addr].add( - self.bootstrap_room - ) + super().__init__(mgr, bootstrap_addr, bootstrap_room) self.init_time = None def init( + self, + prefill_dp_rank: int, + ): + super().init(prefill_dp_rank) + + def send_metadata( self, kv_indices: npt.NDArray[np.int32], aux_index: Optional[int] = None, @@ -1026,7 +1024,7 @@ def poll(self) -> KVPoll: self.conclude_state = status return status if not self.started_transfer: - return KVPoll.WaitingForInput # type: ignore + return status now = time.time() elapsed = now - self.init_time diff --git a/python/sglang/srt/disaggregation/prefill.py b/python/sglang/srt/disaggregation/prefill.py index 7d97f123bd1c..8eadf8195421 100644 --- a/python/sglang/srt/disaggregation/prefill.py +++ b/python/sglang/srt/disaggregation/prefill.py @@ -42,6 +42,7 @@ poll_and_all_reduce_attn_cp_tp_group, prepare_abort, ) +from sglang.srt.environ import envs from sglang.srt.managers.schedule_batch import ( FINISH_ABORT, FINISH_LENGTH, @@ -201,6 +202,21 @@ def _init_kv_manager(self) -> CommonKVManager: self.scheduler.server_args, self.is_mla_backend, ) + # Pass KV pool tensor refs to the manager for GPU gather (staging mode) + if ( + envs.SGLANG_DISAGG_STAGING_BUFFER.get() + and hasattr(kv_manager, "set_kv_buffer_tensors") + and not self.is_mla_backend + ): + kv_pool = self.token_to_kv_pool + if hasattr(kv_pool, "full_kv_pool"): + kv_pool = kv_pool.full_kv_pool + if hasattr(kv_pool, "k_buffer") and hasattr(kv_pool, "v_buffer"): + kv_manager.set_kv_buffer_tensors( + kv_pool.k_buffer, + kv_pool.v_buffer, + kv_pool.page_size, + ) return kv_manager def add(self, req: Req, num_kv_heads: int) -> None: @@ -336,6 +352,17 @@ class SchedulerDisaggregationPrefillMixin: Mixin for Scheduler to handle disaggregation prefill """ + def maybe_prefetch_staging_for_batch(self: Scheduler, batch: ScheduleBatch) -> None: + """Pre-send STAGING_REQ so decode allocates staging during GPU forward.""" + kv_mgr = self.disagg_prefill_bootstrap_queue.kv_manager + prefetch = getattr(kv_mgr, "_prefetch_staging_reqs", None) + if prefetch is None: + return + for req in batch.reqs: + room = getattr(req, "bootstrap_room", None) + if room is not None and room in kv_mgr.transfer_infos: + prefetch(room) + def get_next_disagg_prefill_batch_to_run( self: Scheduler, ) -> Optional[ScheduleBatch]: @@ -356,6 +383,7 @@ def get_next_disagg_prefill_batch_to_run( @torch.no_grad() def event_loop_normal_disagg_prefill(self: Scheduler) -> None: """A normal scheduler loop for prefill worker in disaggregation mode.""" + self.enable_staging = envs.SGLANG_DISAGG_STAGING_BUFFER.get() while True: # Receive requests @@ -371,6 +399,8 @@ def event_loop_normal_disagg_prefill(self: Scheduler) -> None: # Launch the current batch if batch: + if self.enable_staging: + self.maybe_prefetch_staging_for_batch(batch) result = self.run_batch(batch) self.process_batch_result(batch, result) else: @@ -384,6 +414,7 @@ def event_loop_normal_disagg_prefill(self: Scheduler) -> None: @torch.no_grad() def event_loop_overlap_disagg_prefill(self: Scheduler) -> None: self.result_queue = deque() + self.enable_staging = envs.SGLANG_DISAGG_STAGING_BUFFER.get() while True: # Receive requests @@ -399,6 +430,8 @@ def event_loop_overlap_disagg_prefill(self: Scheduler) -> None: # Launch the current batch if batch: + if self.enable_staging: + self.maybe_prefetch_staging_for_batch(batch) batch_result = self.run_batch(batch) self.result_queue.append((batch.copy(), batch_result)) else: diff --git a/python/sglang/srt/disaggregation/utils.py b/python/sglang/srt/disaggregation/utils.py index b7b3b0238861..d7956a60487b 100644 --- a/python/sglang/srt/disaggregation/utils.py +++ b/python/sglang/srt/disaggregation/utils.py @@ -80,6 +80,30 @@ def poll_and_all_reduce_attn_cp_tp_group( return tensor_to_reduce.tolist() +def poll_and_all_reduce_with_staging( + decode_reqs, staging_handler, gloo_group: dist.ProcessGroup +): + """Staging-aware polling: advance scatter, demote incomplete transfers, all_reduce.""" + from sglang.srt.disaggregation.base import KVPoll + + for decode_req in decode_reqs: + if decode_req.kv_receiver.require_staging and not staging_handler.is_done( + decode_req + ): + staging_handler.advance_scatter(decode_req) + + raw_polls = [int(dr.kv_receiver.poll()) for dr in decode_reqs] + for i, decode_req in enumerate(decode_reqs): + if raw_polls[i] == int(KVPoll.Success): + if decode_req.kv_receiver.require_staging and not staging_handler.is_done( + decode_req + ): + raw_polls[i] = int(KVPoll.Transferring) + poll_tensor = torch.tensor(raw_polls, dtype=torch.uint8, device="cpu") + dist.all_reduce(poll_tensor, op=dist.ReduceOp.MIN, group=gloo_group) + return poll_tensor.tolist() + + ######################### # Metadata Buffers ######################### @@ -227,6 +251,9 @@ def set_buf(self, req: Req): self.output_ids[req.metadata_buffer_index][0] = req.output_ids[0] self.cached_tokens[req.metadata_buffer_index][0] = req.cached_tokens + self.cached_tokens[req.metadata_buffer_index][1] = req.cached_tokens_device + self.cached_tokens[req.metadata_buffer_index][2] = req.cached_tokens_host + self.cached_tokens[req.metadata_buffer_index][3] = req.cached_tokens_storage if req.return_logprob: if req.output_token_logprobs_val: # not none or empty list self.output_token_logprobs_val[req.metadata_buffer_index][0] = ( diff --git a/python/sglang/srt/distributed/device_communicators/custom_all_reduce.py b/python/sglang/srt/distributed/device_communicators/custom_all_reduce.py index d7e0a314672f..c9e055bb26dd 100644 --- a/python/sglang/srt/distributed/device_communicators/custom_all_reduce.py +++ b/python/sglang/srt/distributed/device_communicators/custom_all_reduce.py @@ -63,6 +63,7 @@ def __init__( self._IS_CAPTURING = False self.disabled = True # This can be modified in-place by context manager in piecewise cuda graph runner self.original_disabled = True # To store the original state + self.use_amd_deterministic_impl = _use_amd_deterministic_impl() if not ops.IS_CUSTOM_AR_AVAILABLE: # disable because of missing custom allreduce library @@ -269,65 +270,36 @@ def should_custom_ar(self, inp: torch.Tensor): return False if _is_hip: + if self.use_amd_deterministic_impl: + return True if self.full_nvlink: return inp_size <= self.max_size return False return False - # all reduce, assuming inp tensor is IPC registered with register_buffer, - # or, in the context of cuda graphs, register_graph_buffers - def all_reduce_reg(self, inp: torch.Tensor, out: torch.Tensor = None): - if out is None: - out = torch.empty_like(inp) - ops.all_reduce_reg(self._ptr, inp, out) - return out - - # all reduce, assuming inp tensor is NOT IPC registered - def all_reduce_unreg(self, inp: torch.Tensor, out: torch.Tensor = None): - if out is None: - out = torch.empty_like(inp) - ops.all_reduce_unreg(self._ptr, inp, self.buffer, out) - return out - - def all_reduce( - self, - inp: torch.Tensor, - *, - out: torch.Tensor = None, - registered: bool = False, - ): - """Performs an out-of-place all reduce. - - If registered is True, this assumes inp's pointer is already - IPC-registered. Otherwise, inp is first copied into a pre-registered - buffer. - """ - if out is None: - out = torch.empty_like(inp) - if registered: - ops.all_reduce(self._ptr, inp, out, 0, 0) - else: - ops.all_reduce( - self._ptr, inp, out, self.buffer_ptrs[self.rank], self.max_size - ) - return out - - def deterministic_all_reduce( - self, - inp: torch.Tensor, - *, - out: torch.Tensor = None, - registered: bool = False, - ): - """Deterministic all-reduce using 1-stage kernel with fixed ordering (AMD only).""" - if out is None: - out = torch.empty_like(inp) - if registered: - ops.deterministic_all_reduce_reg(self._ptr, inp, out) - else: - reg_buffer = self.buffer.view(inp.dtype)[: inp.numel()] - ops.deterministic_all_reduce_unreg(self._ptr, inp, reg_buffer, out) + def _all_reduce_impl(self, inp: torch.Tensor, registered: bool): + out = torch.empty_like(inp) + if not _is_hip: # CUDA-like + if registered: + ops.all_reduce(self._ptr, inp, out, 0, 0) + else: + ops.all_reduce( + self._ptr, inp, out, self.buffer_ptrs[self.rank], self.max_size + ) + elif self.use_amd_deterministic_impl: + inp_size = inp.numel() * inp.element_size() + if inp_size < self.max_size: + reg_buffer = self.buffer.view(inp.dtype)[: inp.numel()] + ops.deterministic_all_reduce_unreg(self._ptr, inp, reg_buffer, out) + else: + self.register_buffer(inp) + ops.deterministic_all_reduce_reg(self._ptr, inp, out) + else: # normal AMD ROCm path + if registered: + ops.all_reduce_reg(self._ptr, inp, out) + else: + ops.all_reduce_unreg(self._ptr, inp, self.buffer, out) return out def custom_all_reduce(self, input: torch.Tensor) -> Optional[torch.Tensor]: @@ -337,35 +309,20 @@ def custom_all_reduce(self, input: torch.Tensor) -> Optional[torch.Tensor]: return None if self._IS_CAPTURING: if torch.cuda.is_current_stream_capturing(): - if _is_hip: - if self.tms_cudagraph: - return self.all_reduce_unreg(input) - return self.all_reduce_reg(input) - else: - return self.all_reduce(input, registered=not self.tms_cudagraph) + return self._all_reduce_impl(input, registered=not self.tms_cudagraph) else: # Could be warmup OR piecewise cuda graph split op execution. # In piecewise cuda graph, split ops run eagerly outside the graph # but _IS_CAPTURING is still True. We need to do real all-reduce. if is_in_piecewise_cuda_graph(): # Split op execution - do real all-reduce - if _is_hip: - return self.all_reduce_unreg(input) - else: - return self.all_reduce(input, registered=False) + return self._all_reduce_impl(input, registered=False) else: # True warmup - mimic the allocation pattern since custom # allreduce is out-of-place. return torch.zeros_like(input) else: - if _is_hip: - # note: outside of cuda graph context, - # custom allreduce incurs a cost of cudaMemcpy, which should - # be small(<=1% of overall latency) compared to the performance - # gains of using custom kernels - return self.all_reduce_unreg(input) - else: - return self.all_reduce(input, registered=False) + return self._all_reduce_impl(input, registered=False) def close(self): if not self.disabled and self._ptr: @@ -382,7 +339,7 @@ def __del__(self): def dispatch_custom_allreduce(): """Return the CustomAllreduce class to use (aiter on ROCm if enabled). - On AMD with 1-stage AR enabled, use sglang's CustomAllreduce (has deterministic_all_reduce method). + On AMD with 1-stage AR enabled, use sglang's CustomAllreduce. Otherwise use AiterCustomAllreduce if available. Set SGLANG_USE_JIT_ALL_REDUCE=1 to use the JIT-compiled v2 implementation. @@ -414,15 +371,9 @@ def dispatch_custom_allreduce(): else: logger.debug("[AR] All-reduce: default") - # Check if 1-stage AR should be used - if envs.SGLANG_USE_1STAGE_ALLREDUCE.is_set(): - use_1stage = envs.SGLANG_USE_1STAGE_ALLREDUCE.get() - else: - use_1stage = envs.SGLANG_ENABLE_DETERMINISTIC_INFERENCE.get() - # On AMD with 1-stage AR, use sglang's CustomAllreduce # (AiterCustomAllreduce doesn't have deterministic_all_reduce method) - if use_1stage: + if _use_amd_deterministic_impl(): return CustomAllreduce if get_bool_env_var("SGLANG_USE_AITER_AR", default="true"): @@ -446,3 +397,12 @@ def dispatch_custom_allreduce(): return CustomAllreduce return CustomAllreduce + + +def _use_amd_deterministic_impl() -> bool: + if not _is_hip: # CUDA is always deterministic + return False + if envs.SGLANG_USE_1STAGE_ALLREDUCE.is_set(): + return envs.SGLANG_USE_1STAGE_ALLREDUCE.get() + else: + return envs.SGLANG_ENABLE_DETERMINISTIC_INFERENCE.get() diff --git a/python/sglang/srt/distributed/device_communicators/pynccl.py b/python/sglang/srt/distributed/device_communicators/pynccl.py index 660582ad3730..eccbc872e11e 100644 --- a/python/sglang/srt/distributed/device_communicators/pynccl.py +++ b/python/sglang/srt/distributed/device_communicators/pynccl.py @@ -31,7 +31,6 @@ def __init__( group: Union[ProcessGroup, StatelessProcessGroup], device: Union[int, str, torch.device], library_path: Optional[str] = None, - use_current_stream: bool = False, ): """ Args: @@ -62,7 +61,6 @@ def __init__( if self.world_size == 1: self.available = False self.disabled = True - self.stream = None return try: self.nccl = NCCLLibrary(library_path) @@ -71,12 +69,10 @@ def __init__( # e.g. in a non-GPU environment self.available = False self.disabled = True - self.stream = None return self.available = True self.disabled = False - self.use_current_stream = use_current_stream self.nccl_version = self.nccl.ncclGetRawVersion() if self.rank == 0: @@ -113,12 +109,13 @@ def __init__( self.comm: ncclComm_t = self.nccl.ncclCommInitRank( self.world_size, self.unique_id, self.rank ) - self.stream = torch.cuda.Stream() + warmup_stream = torch.cuda.Stream() # A small all_reduce for warmup. - data = torch.zeros(1, device=device) - self.all_reduce(data) - self.stream.synchronize() + with torch.cuda.stream(warmup_stream): + data = torch.zeros(1, device=device) + self.all_reduce(data) + warmup_stream.synchronize() del data # by default it is disabled, e.g. in profiling models and prefill phase. @@ -126,24 +123,11 @@ def __init__( # when we are using CUDA graph. self.disabled = True - def _resolve_stream(self, stream: Optional[torch.cuda.Stream]): - """Return the stream to use for NCCL calls. + def _resolve_stream(self) -> torch.cuda.Stream: + """Return the current device stream used for NCCL calls.""" + return get_current_device_stream_fast() - Behavior mirrors the previous inline logic: - - if an explicit stream is provided, return it - - if stream is None and self.use_current_stream is True, return - torch.cuda.current_stream() - - otherwise return the communicator's default stream (self.stream) - """ - if stream is not None: - return stream - if self.use_current_stream: - return get_current_device_stream_fast() - return self.stream - - def all_reduce( - self, tensor: torch.Tensor, op: ReduceOp = ReduceOp.SUM, stream=None - ): + def all_reduce(self, tensor: torch.Tensor, op: ReduceOp = ReduceOp.SUM): if self.disabled: return # nccl communicator created on a specific device @@ -153,7 +137,7 @@ def all_reduce( f"this nccl communicator is created to work on {self.device}, " f"but the input tensor is on {tensor.device}" ) - stream = self._resolve_stream(stream) + stream = self._resolve_stream() self.nccl.ncclAllReduce( buffer_type(tensor.data_ptr()), buffer_type(tensor.data_ptr()), @@ -169,7 +153,6 @@ def outplace_all_reduce( in_tensor: torch.Tensor, out_tensor: Optional[torch.Tensor] = None, op: ReduceOp = ReduceOp.SUM, - stream=None, ) -> Optional[torch.Tensor]: if self.disabled: return None @@ -181,7 +164,7 @@ def outplace_all_reduce( if out_tensor is None: out_tensor = torch.empty_like(in_tensor) - stream = self._resolve_stream(stream) + stream = self._resolve_stream() self.nccl.ncclAllReduce( buffer_type(in_tensor.data_ptr()), # sendbuff buffer_type(out_tensor.data_ptr()), # recvbuff - DIFFERENT pointer @@ -197,7 +180,6 @@ def all_gather( self, output_tensor: torch.Tensor, input_tensor: torch.Tensor, - stream=None, sizes: Optional[list[int]] = None, ): if self.disabled: @@ -209,7 +191,7 @@ def all_gather( f"this nccl communicator is created to work on {self.device}, " f"but the input tensor is on {input_tensor.device}" ) - stream = self._resolve_stream(stream) + stream = self._resolve_stream() if sizes is not None: split_offset = 0 @@ -242,7 +224,7 @@ def cp_all_gather_into_tensor( self, output_tensor: torch.Tensor, input_tensor: torch.Tensor, - stream=None, + stream: torch.cuda.Stream, sizes: Optional[list[int]] = None, ): """ @@ -256,7 +238,6 @@ def cp_all_gather_into_tensor( f"this nccl communicator is created to work on {self.device}, " f"but the input tensor is on {input_tensor.device}" ) - stream = self._resolve_stream(stream) self.nccl.ncclAllGather( buffer_type(input_tensor.data_ptr()), buffer_type(output_tensor.data_ptr()), @@ -271,7 +252,6 @@ def reduce_scatter( output_tensor: torch.Tensor, input_tensor: torch.Tensor, op: ReduceOp = ReduceOp.SUM, - stream=None, sizes: Optional[list[int]] = None, ): if self.disabled: @@ -283,7 +263,7 @@ def reduce_scatter( f"this nccl communicator is created to work on {self.device}, " f"but the input tensor is on {input_tensor.device}" ) - stream = self._resolve_stream(stream) + stream = self._resolve_stream() if sizes is not None: split_offset = 0 @@ -314,14 +294,14 @@ def reduce_scatter( cudaStream_t(stream.cuda_stream), ) - def send(self, tensor: torch.Tensor, dst: int, stream=None): + def send(self, tensor: torch.Tensor, dst: int): if self.disabled: return assert tensor.device == self.device, ( f"this nccl communicator is created to work on {self.device}, " f"but the input tensor is on {tensor.device}" ) - stream = self._resolve_stream(stream) + stream = self._resolve_stream() self.nccl.ncclSend( buffer_type(tensor.data_ptr()), tensor.numel(), @@ -331,14 +311,14 @@ def send(self, tensor: torch.Tensor, dst: int, stream=None): cudaStream_t(stream.cuda_stream), ) - def recv(self, tensor: torch.Tensor, src: int, stream=None): + def recv(self, tensor: torch.Tensor, src: int): if self.disabled: return assert tensor.device == self.device, ( f"this nccl communicator is created to work on {self.device}, " f"but the input tensor is on {tensor.device}" ) - stream = self._resolve_stream(stream) + stream = self._resolve_stream() self.nccl.ncclRecv( buffer_type(tensor.data_ptr()), tensor.numel(), @@ -348,14 +328,14 @@ def recv(self, tensor: torch.Tensor, src: int, stream=None): cudaStream_t(stream.cuda_stream), ) - def broadcast(self, tensor: torch.Tensor, src: int, stream=None): + def broadcast(self, tensor: torch.Tensor, src: int): if self.disabled: return assert tensor.device == self.device, ( f"this nccl communicator is created to work on {self.device}, " f"but the input tensor is on {tensor.device}" ) - stream = self._resolve_stream(stream) + stream = self._resolve_stream() if src == self.rank: sendbuff = buffer_type(tensor.data_ptr()) @@ -387,25 +367,17 @@ def group_end(self): self.nccl.ncclGroupEnd() @contextmanager - def change_state( - self, enable: Optional[bool] = None, stream: Optional[torch.cuda.Stream] = None - ): + def change_state(self, enable: Optional[bool] = None): """ - A context manager to change the state of the communicator. + A context manager to change the enabled state of the communicator. """ if enable is None: # guess a default value when not specified enable = self.available - if stream is None: - stream = self.stream - old_disable = self.disabled - old_stream = self.stream - - self.stream = stream self.disabled = not enable - yield - - self.disabled = old_disable - self.stream = old_stream + try: + yield + finally: + self.disabled = old_disable diff --git a/python/sglang/srt/distributed/parallel_state.py b/python/sglang/srt/distributed/parallel_state.py index 72311f9d3ffe..a80c6da5dae6 100644 --- a/python/sglang/srt/distributed/parallel_state.py +++ b/python/sglang/srt/distributed/parallel_state.py @@ -246,7 +246,6 @@ def __init__( use_npu_communicator: bool, use_message_queue_broadcaster: bool = False, group_name: Optional[str] = None, - pynccl_use_current_stream: bool = False, gloo_timeout: timedelta = timedelta(seconds=120 * 60), ): # Set group info @@ -316,7 +315,6 @@ def __init__( # Import communicators self.use_pynccl = use_pynccl - self.pynccl_use_current_stream = pynccl_use_current_stream self.use_pymscclpp = use_pymscclpp self.use_custom_allreduce = use_custom_allreduce self.use_torch_symm_mem_all_reduce = use_torch_symm_mem_all_reduce @@ -358,7 +356,6 @@ def __init__( self.pynccl_comm = PyNcclCommunicator( group=self.cpu_group, device=self.device, - use_current_stream=pynccl_use_current_stream, ) self.pymscclpp_comm: Optional[PyMscclppCommunicator] = None @@ -533,9 +530,7 @@ def graph_capture( if not pynccl_comm: maybe_pynccl_context = nullcontext() else: - maybe_pynccl_context = pynccl_comm.change_state( - enable=True, stream=get_current_device_stream_fast() - ) + maybe_pynccl_context = pynccl_comm.change_state(enable=True) pymscclpp_comm = self.pymscclpp_comm maybe_pymscclpp_context: Any @@ -565,26 +560,6 @@ def all_reduce(self, input_: torch.Tensor) -> torch.Tensor: if self.world_size == 1: return input_ - # On AMD, use the deterministic 1-stage kernel when: - # - SGLANG_USE_1STAGE_ALLREDUCE=1 (explicitly enabled), OR - # - SGLANG_USE_1STAGE_ALLREDUCE not set AND --enable-deterministic-inference is on - if envs.SGLANG_USE_1STAGE_ALLREDUCE.is_set(): - use_1stage_ar = envs.SGLANG_USE_1STAGE_ALLREDUCE.get() - else: - use_1stage_ar = envs.SGLANG_ENABLE_DETERMINISTIC_INFERENCE.get() - use_deterministic_ar = is_hip() and use_1stage_ar - if use_deterministic_ar: - if not input_.is_cpu and self.ca_comm is not None: - inp_size = input_.numel() * input_.element_size() - # Try unregistered mode first (faster for smaller tensors) - if inp_size < self.ca_comm.max_size: - return self.ca_comm.deterministic_all_reduce( - input_, registered=False - ) - # Use registered mode for larger tensors - self.ca_comm.register_buffer(input_) - return self.ca_comm.deterministic_all_reduce(input_, registered=True) - if input_.is_cpu: if is_shm_available(input_.dtype, self.world_size, self.local_size): torch.ops.sgl_kernel.shm_allreduce(input_, REDUCE_OP_SUM) @@ -602,9 +577,7 @@ def all_reduce(self, input_: torch.Tensor) -> torch.Tensor: return self.npu_communicator.all_reduce(input_) if self.pynccl_comm is not None and self.is_symmetric_memory_enabled(): - with self.pynccl_comm.change_state( - enable=True, stream=get_current_device_stream_fast() - ): + with self.pynccl_comm.change_state(enable=True): self.pynccl_comm.all_reduce(input_) return input_ @@ -686,6 +659,7 @@ def fused_allreduce_rmsnorm( 512, 1024, 2048, + 2880, 4096, } @@ -720,9 +694,7 @@ def _all_reduce_out_place( assert not pymscclpp_comm.disabled out = pymscclpp_comm.all_reduce(input_) elif outplace_all_reduce_method == "pynccl": - with pynccl_comm.change_state( - enable=True, stream=get_current_device_stream_fast() - ): + with pynccl_comm.change_state(enable=True): out = pynccl_comm.outplace_all_reduce(input_) assert out is not None return out @@ -746,9 +718,7 @@ def _reduce_scatter_tensor( if pynccl_comm is not None and ( not pynccl_comm.disabled or self.is_symmetric_memory_enabled() ): - with pynccl_comm.change_state( - enable=True, stream=get_current_device_stream_fast() - ): + with pynccl_comm.change_state(enable=True): pynccl_comm.reduce_scatter(output, input) else: torch.distributed.reduce_scatter_tensor( @@ -780,9 +750,7 @@ def reduce_scatterv( world_size = self.world_size pynccl_comm = self.pynccl_comm - with pynccl_comm.change_state( - enable=True, stream=get_current_device_stream_fast() - ): + with pynccl_comm.change_state(enable=True): assert ( pynccl_comm is not None and not pynccl_comm.disabled ), "pynccl is required for reduce_scatterv" @@ -811,9 +779,7 @@ def _all_gather_into_tensor(self, output: torch.Tensor, input: torch.Tensor): if pynccl_comm is not None and ( not pynccl_comm.disabled or self.is_symmetric_memory_enabled() ): - with pynccl_comm.change_state( - enable=True, stream=get_current_device_stream_fast() - ): + with pynccl_comm.change_state(enable=True): pynccl_comm.all_gather(output, input) else: torch.distributed.all_gather_into_tensor( @@ -827,7 +793,7 @@ def all_gather_into_tensor(self, output: torch.Tensor, input: torch.Tensor): reg_all_gather_into_tensor(output, input, group_name=self.unique_name) def cp_all_gather_into_tensor_async( - self, output: torch.Tensor, input: torch.Tensor, stream=None + self, output: torch.Tensor, input: torch.Tensor, stream: torch.cuda.Stream ): """ Implement an asynchronous `allgather` operation on a specified stream. @@ -835,9 +801,6 @@ def cp_all_gather_into_tensor_async( eliminating the CPU-side launch-kernel blocking issue caused by synchronization problems. The specific implementation uses the interface provided by pynccl to remove the synchronization logic of events. """ - assert ( - stream is not None - ), f"Invalid params stream ({stream}, Please specify the stream to use when calling cp_all_gather_into_tensor_async.)" pynccl_comm = self.pynccl_comm if pynccl_comm is None or pynccl_comm.disabled: self.all_gather_into_tensor(output, input) @@ -930,9 +893,7 @@ def all_gatherv( world_size = self.world_size pynccl_comm = self.pynccl_comm - with pynccl_comm.change_state( - enable=True, stream=get_current_device_stream_fast() - ): + with pynccl_comm.change_state(enable=True): assert ( pynccl_comm is not None and not pynccl_comm.disabled ), "pynccl is required for all_gatherv" @@ -1439,7 +1400,6 @@ def init_model_parallel_group( use_message_queue_broadcaster: bool = False, group_name: Optional[str] = None, use_mscclpp_allreduce: Optional[bool] = None, - pynccl_use_current_stream: bool = True, use_torch_symm_mem_allreduce: Optional[bool] = None, ) -> GroupCoordinator: if use_custom_allreduce is None: @@ -1465,7 +1425,6 @@ def init_model_parallel_group( use_npu_communicator=True, use_message_queue_broadcaster=use_message_queue_broadcaster, group_name=group_name, - pynccl_use_current_stream=pynccl_use_current_stream, ) @@ -1835,7 +1794,6 @@ def initialize_model_parallel( "SGLANG_USE_MESSAGE_QUEUE_BROADCASTER", "true" ), group_name="tp", - pynccl_use_current_stream=duplicate_tp_group, ) if duplicate_tp_group: @@ -1851,7 +1809,6 @@ def initialize_model_parallel( "SGLANG_USE_MESSAGE_QUEUE_BROADCASTER", "true" ), group_name="pdmux_prefill_tp", - pynccl_use_current_stream=True, ) if _TP.pynccl_comm: _TP.pynccl_comm.disabled = False diff --git a/python/sglang/srt/entrypoints/engine.py b/python/sglang/srt/entrypoints/engine.py index 35ebbf1bcc71..30480ed01e87 100644 --- a/python/sglang/srt/entrypoints/engine.py +++ b/python/sglang/srt/entrypoints/engine.py @@ -17,6 +17,8 @@ This file implements python APIs for the inference engine. """ +from __future__ import annotations + import asyncio import atexit import dataclasses @@ -47,6 +49,9 @@ import zmq from sglang.srt.elastic_ep.expert_backup_manager import run_expert_backup_manager +from sglang.srt.entrypoints.engine_info_bootstrap_server import ( + EngineInfoBootstrapServer, +) from sglang.srt.entrypoints.EngineBase import EngineBase from sglang.srt.managers.data_parallel_controller import ( run_data_parallel_controller_process, @@ -78,9 +83,6 @@ from sglang.srt.managers.template_manager import TemplateManager from sglang.srt.managers.tokenizer_manager import TokenizerManager from sglang.srt.managers.tokenizer_manager_multiitem_mixin import ScoreResult -from sglang.srt.model_loader.remote_instance_weight_loader_utils import ( - parse_remote_instance_transfer_engine_info_from_scheduler_infos, -) from sglang.srt.observability.trace import process_tracing_init, trace_set_thread_info from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.utils import ( @@ -96,8 +98,9 @@ set_prometheus_multiproc_dir, set_ulimit, ) -from sglang.srt.utils.network import get_zmq_socket +from sglang.srt.utils.network import get_zmq_socket, is_port_available from sglang.srt.utils.torch_memory_saver_adapter import TorchMemorySaverAdapter +from sglang.srt.utils.watchdog import SubprocessWatchdog from sglang.version import __version__ logger = logging.getLogger(__name__) @@ -113,6 +116,7 @@ class SchedulerInitResult: scheduler_infos: List[Dict[str, Any]] wait_for_ready: Callable[[], None] = lambda: None wait_for_completion: Callable[[], None] = lambda: None + engine_info_bootstrap_server: Optional[Any] = None def init_tokenizer_manager( @@ -176,6 +180,10 @@ def __init__(self, **kwargs): self.server_args = server_args logger.info(f"{server_args=}") + # Pre-initialize tokenizer_manager so the atexit handler in + # shutdown() won't hit AttributeError. + self.tokenizer_manager = None + # Shutdown the subprocesses automatically when the program exits atexit.register(self.shutdown) @@ -185,6 +193,7 @@ def __init__(self, **kwargs): template_manager, port_args, scheduler_init_result, + subprocess_watchdog, ) = self._launch_subprocesses( server_args=server_args, init_tokenizer_manager_func=self.init_tokenizer_manager_func, @@ -194,12 +203,14 @@ def __init__(self, **kwargs): self.tokenizer_manager = tokenizer_manager self.template_manager = template_manager self._scheduler_init_result = scheduler_init_result + if tokenizer_manager is not None: + tokenizer_manager._subprocess_watchdog = subprocess_watchdog self.port_args = port_args - self.remote_instance_transfer_engine_info = ( - parse_remote_instance_transfer_engine_info_from_scheduler_infos( - scheduler_init_result.scheduler_infos + # Access transfer engine info if bootstrap server is started. + if scheduler_init_result.engine_info_bootstrap_server is not None: + self.remote_instance_transfer_engine_info = ( + scheduler_init_result.engine_info_bootstrap_server.transfer_engine_info ) - ) # Initialize ZMQ sockets context = zmq.Context(2) @@ -505,9 +516,13 @@ def _launch_scheduler_processes( server_args: ServerArgs, port_args: PortArgs, run_scheduler_process_func: Callable, - ) -> SchedulerInitResult: + ) -> Tuple[SchedulerInitResult, Optional[List]]: """Launch scheduler processes using multiprocessing. Override in subclasses for different backends (e.g. Ray). + + Returns: + Tuple of (SchedulerInitResult, scheduler_procs). + scheduler_procs is None for RayEngine (uses Ray actors instead). """ scheduler_procs = [] @@ -592,10 +607,13 @@ def wait_for_completion(): f"terminated with {proc.exitcode}" ) - return SchedulerInitResult( - scheduler_infos=scheduler_infos, - wait_for_ready=wait_for_ready, - wait_for_completion=wait_for_completion, + return ( + SchedulerInitResult( + scheduler_infos=scheduler_infos, + wait_for_ready=wait_for_ready, + wait_for_completion=wait_for_completion, + ), + scheduler_procs, ) @classmethod @@ -606,26 +624,53 @@ def _launch_subprocesses( run_scheduler_process_func: Callable, run_detokenizer_process_func: Callable, port_args: Optional[PortArgs] = None, - ) -> Tuple[TokenizerManager, TemplateManager, PortArgs, SchedulerInitResult]: + ) -> Tuple[ + TokenizerManager, + TemplateManager, + PortArgs, + SchedulerInitResult, + Optional[SubprocessWatchdog], + ]: """Launch the TokenizerManager in the main process, the Scheduler in a subprocess, and the DetokenizerManager in another subprocess. Returns: - Tuple of (tokenizer_manager, template_manager, port_args, scheduler_init_result). + Tuple of (tokenizer_manager, template_manager, port_args, scheduler_init_result, subprocess_watchdog). """ # Configure global environment configure_logger(server_args) _set_envs_and_config(server_args) server_args.check_server_args() + _set_gc(server_args) # Allocate ports for inter-process communications if port_args is None: port_args = PortArgs.init_new(server_args) logger.info(f"{server_args=}") + # Start the engine info bootstrap server if per-rank info is needed. + engine_info_bootstrap_server = None + if ( + server_args.remote_instance_weight_loader_start_seed_via_transfer_engine + and server_args.node_rank == 0 + ): + bootstrap_port = server_args.engine_info_bootstrap_port + if not is_port_available(bootstrap_port): + raise RuntimeError( + f"engine_info_bootstrap_port {bootstrap_port} is already in use. " + f"When running multiple instances on the same node, each instance must use a " + f"different --engine-info-bootstrap-port." + ) + engine_info_bootstrap_server = EngineInfoBootstrapServer( + host=server_args.host, port=bootstrap_port + ) + # Launch scheduler processes - scheduler_init_result = cls._launch_scheduler_processes( + scheduler_init_result, scheduler_procs = cls._launch_scheduler_processes( server_args, port_args, run_scheduler_process_func ) + scheduler_init_result.engine_info_bootstrap_server = ( + engine_info_bootstrap_server + ) if ( server_args.enable_elastic_expert_backup @@ -645,6 +690,7 @@ def _launch_subprocesses( None, port_args, scheduler_init_result, + None, ) launch_dummy_health_check_server( @@ -657,6 +703,7 @@ def _launch_subprocesses( None, port_args, scheduler_init_result, + None, ) # Launch detokenizer process @@ -687,15 +734,32 @@ def _launch_subprocesses( "max_req_input_len" ] + # Set up subprocess liveness watchdog to detect crashes + # Note: RayEngine returns scheduler_procs=None as it uses Ray actors instead of mp.Process + processes = list(scheduler_procs or []) + names = [f"scheduler_{i}" for i in range(len(processes))] + processes.append(detoken_proc) + names.append("detokenizer") + subprocess_watchdog = SubprocessWatchdog( + processes=processes, process_names=names + ) + subprocess_watchdog.start() + return ( tokenizer_manager, template_manager, port_args, scheduler_init_result, + subprocess_watchdog, ) def shutdown(self): """Shutdown the engine""" + if ( + self.tokenizer_manager is not None + and self.tokenizer_manager._subprocess_watchdog is not None + ): + self.tokenizer_manager._subprocess_watchdog.stop() kill_process_tree(os.getpid(), include_parent=False) def __enter__(self): @@ -1135,7 +1199,7 @@ def _set_envs_and_config(server_args: ServerArgs): if server_args.attention_backend == "flashinfer": assert_pkg_version( "flashinfer_python", - "0.6.6", + "0.6.7", "Please uninstall the old version and " "reinstall the latest version by following the instructions " "at https://docs.flashinfer.ai/installation.html.", @@ -1179,28 +1243,53 @@ def launch_phase_sigquit_handler(signum, frame): mp.set_start_method("spawn", force=True) +def _set_gc(server_args: ServerArgs): + if gc_threshold := server_args.gc_threshold: + import gc + + gc.set_threshold(*gc_threshold) + + +def _scheduler_died_error(rank: int, proc) -> RuntimeError: + """Build a descriptive error for a scheduler process that died during init.""" + proc.join(timeout=10) + return RuntimeError( + f"Rank {rank} scheduler died during initialization " + f"(exit code: {proc.exitcode}). " + f"If exit code is -9 (SIGKILL), a common cause is the OS OOM killer. " + f"Run `dmesg -T | grep -i oom` to check." + ) + + def _wait_for_scheduler_ready( scheduler_pipe_readers: List, scheduler_procs: List, ) -> List[Dict]: - """Wait for the model to finish loading and return scheduler infos.""" + """Wait for the model to finish loading and return scheduler infos. + + Uses poll() with timeout instead of blocking recv(), so that child process + death (e.g. OOM SIGKILL) is detected promptly instead of hanging forever. + """ scheduler_infos = [] for i in range(len(scheduler_pipe_readers)): - try: - data = scheduler_pipe_readers[i].recv() - except EOFError: - logger.error( - f"Rank {i} scheduler is dead. Please check if there are relevant logs." - ) - scheduler_procs[i].join() - logger.error(f"Exit code: {scheduler_procs[i].exitcode}") - raise + while True: + if scheduler_pipe_readers[i].poll(timeout=5.0): + try: + data = scheduler_pipe_readers[i].recv() + except EOFError: + raise _scheduler_died_error(i, scheduler_procs[i]) + if data["status"] != "ready": + raise RuntimeError( + "Initialization failed. Please see the error messages above." + ) + scheduler_infos.append(data) + break + + # Poll timed out — check all processes for early death + for j in range(len(scheduler_procs)): + if not scheduler_procs[j].is_alive(): + raise _scheduler_died_error(j, scheduler_procs[j]) - if data["status"] != "ready": - raise RuntimeError( - "Initialization failed. Please see the error messages above." - ) - scheduler_infos.append(data) return scheduler_infos diff --git a/python/sglang/srt/entrypoints/engine_info_bootstrap_server.py b/python/sglang/srt/entrypoints/engine_info_bootstrap_server.py new file mode 100644 index 000000000000..77de7fc7d030 --- /dev/null +++ b/python/sglang/srt/entrypoints/engine_info_bootstrap_server.py @@ -0,0 +1,105 @@ +# Copyright 2023-2024 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import logging +import threading +from typing import Dict, Optional, Tuple + +import uvicorn +from fastapi import FastAPI, HTTPException +from fastapi.responses import PlainTextResponse + +logger = logging.getLogger(__name__) + + +class EngineInfoBootstrapServer: + """Lightweight HTTP server for per-rank model info registration. + + Runs in a daemon thread on node_rank==0. Each ModelRunner registers its + info via HTTP PUT after model initialization. The Engine + accesses the collected info directly in-process; external consumers can + query via HTTP GET. + + Currently supports transfer engine memory registration info. + """ + + def __init__(self, host: str, port: int): + self.host = host + self.port = port + + # Storage: {tp_rank: (session_id, weights_info_dict)} + self.transfer_engine_info: Dict[int, Tuple] = {} + self.lock = threading.Lock() + + app = FastAPI() + + @app.get("/health") + def health(): + return PlainTextResponse("OK") + + @app.put("/register_transfer_engine_info") + def register_transfer_engine_info(data: dict): + try: + tp_rank = data["tp_rank"] + info = data["transfer_engine_info"] + session_id = info["session_id"] + weights_info_dict = info["weights_info_dict"] + + with self.lock: + self.transfer_engine_info[tp_rank] = ( + session_id, + weights_info_dict, + ) + + logger.info( + f"Registered transfer engine info for tp_rank={tp_rank}, " + f"session_id={session_id}" + ) + return PlainTextResponse("OK") + except Exception as e: + logger.error(f"Failed to register engine info: {e}") + raise HTTPException(status_code=400, detail=str(e)) + + @app.get("/get_transfer_engine_info") + def get_transfer_engine_info(rank: int): + if rank < 0: + raise HTTPException(status_code=400, detail="Invalid rank parameter") + + with self.lock: + info = self.transfer_engine_info.get(rank) + + if info is None: + raise HTTPException( + status_code=404, + detail=f"No transfer engine info for rank {rank}", + ) + + return {"rank": rank, "remote_instance_transfer_engine_info": list(info)} + + config = uvicorn.Config(app, host=host, port=port, log_level="warning") + self._server = uvicorn.Server(config) + self._thread = threading.Thread( + target=self._server.run, + daemon=True, + ) + self._thread.start() + logger.info(f"EngineInfoBootstrapServer started on {host}:{port}") + + def close(self): + self._server.should_exit = True + self._thread.join(timeout=5) + + def get_transfer_engine_info(self, rank: int) -> Optional[Tuple]: + """Direct in-process access for co-located HTTP server (no HTTP round-trip).""" + return self.transfer_engine_info.get(rank) diff --git a/python/sglang/srt/entrypoints/grpc_server.py b/python/sglang/srt/entrypoints/grpc_server.py index 674431bf497a..b0188a1371ab 100644 --- a/python/sglang/srt/entrypoints/grpc_server.py +++ b/python/sglang/srt/entrypoints/grpc_server.py @@ -7,9 +7,11 @@ async def serve_grpc(server_args, model_info=None): """Start the standalone gRPC server with integrated scheduler.""" try: from smg_grpc_servicer.sglang.server import serve_grpc as _serve_grpc - except ImportError: + except ImportError as e: raise ImportError( "gRPC mode requires the smg-grpc-servicer package. " - "Install it with: pip install smg-grpc-servicer[sglang]" - ) from None + "If not installed, run: pip install smg-grpc-servicer[sglang]. " + "If already installed, there may be a broken import due to a " + "version mismatch — see the chained exception above for details." + ) from e await _serve_grpc(server_args, model_info) diff --git a/python/sglang/srt/entrypoints/http_server.py b/python/sglang/srt/entrypoints/http_server.py index eeefe3f9ba53..0f1aee9294c0 100644 --- a/python/sglang/srt/entrypoints/http_server.py +++ b/python/sglang/srt/entrypoints/http_server.py @@ -59,6 +59,7 @@ from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import ORJSONResponse, Response, StreamingResponse +from sglang.srt.constants import HEALTH_CHECK_RID_PREFIX from sglang.srt.disaggregation.utils import FAKE_BOOTSTRAP_HOST, DisaggregationMode from sglang.srt.entrypoints.anthropic.protocol import ( AnthropicCountTokensRequest, @@ -126,7 +127,6 @@ OpenSessionReqInput, ParseFunctionCallReq, PauseGenerationReqInput, - PinPrefixReqInput, ProfileReqInput, ReleaseMemoryOccupationReqInput, ResumeMemoryOccupationReqInput, @@ -152,9 +152,6 @@ ) from sglang.srt.managers.template_manager import TemplateManager from sglang.srt.managers.tokenizer_manager import ServerStatus, TokenizerManager -from sglang.srt.model_loader.remote_instance_weight_loader_utils import ( - parse_remote_instance_transfer_engine_info_from_scheduler_infos, -) from sglang.srt.observability.func_timer import enable_func_timer from sglang.srt.observability.trace import ( process_tracing_init, @@ -177,6 +174,7 @@ dumps_json, orjson_response, ) +from sglang.srt.utils.watchdog import SubprocessWatchdog from sglang.utils import get_exception_traceback from sglang.version import __version__ @@ -194,15 +192,6 @@ class _GlobalState: tokenizer_manager: Union[TokenizerManager, MultiTokenizerRouter, TokenizerWorker] template_manager: TemplateManager scheduler_info: Dict - # Dict{ - # rank: Tuple( - # session_id, - # Dict{ - # name: Tuple (d_ptr, numel, element_size) - # } - # ) - # } - remote_instance_transfer_engine_info: Optional[Dict] = None _global_state: Optional[_GlobalState] = None @@ -509,13 +498,9 @@ async def health_generate(request: Request) -> Response: return Response(status_code=200) sampling_params = {"max_new_tokens": 1, "temperature": 0.0} - rid = f"HEALTH_CHECK_{time.time()}" + rid = f"{HEALTH_CHECK_RID_PREFIX}_{time.time()}" - if _global_state.tokenizer_manager.is_image_gen: - gri = _global_state.tokenizer_manager.get_image_gen_health_check_request( - rid, sampling_params - ) - elif _global_state.tokenizer_manager.is_generation: + if _global_state.tokenizer_manager.is_generation: gri = GenerateReqInput( rid=rid, input_ids=[0], @@ -859,25 +844,6 @@ async def hicache_storage_backend_status(): } -@app.api_route("/hicache/pin_prefix", methods=["POST"]) -@auth_level(AuthLevel.ADMIN_OPTIONAL) -async def pin_prefix(obj: PinPrefixReqInput): - """Pin a prefix by token_ids to resist eviction.""" - if not _global_state.tokenizer_manager.server_args.admin_api_key: - return _admin_api_key_missing_response() - ret = await _global_state.tokenizer_manager.pin_prefix( - obj.token_ids, obj.ttl_seconds - ) - return ORJSONResponse( - content={ - "status": "ok" if ret.success else "error", - "nodes_pinned": ret.nodes_pinned, - "message": ret.message, - }, - status_code=200 if ret.success else HTTPStatus.BAD_REQUEST, - ) - - @app.api_route("/start_profile", methods=["GET", "POST"]) @auth_level(AuthLevel.ADMIN_OPTIONAL) async def start_profile_async(obj: Optional[ProfileReqInput] = None): @@ -1032,26 +998,39 @@ async def send_weights_to_remote_instance( @app.get("/get_remote_instance_transfer_engine_info") @auth_level(AuthLevel.ADMIN_OPTIONAL) async def get_remote_instance_transfer_engine_info(rank: int = None): - if rank is None or rank < 0: - return Response(status_code=HTTPStatus.BAD_REQUEST) + """Get the server information (deprecated - use /remote_instance_transfer_engine_info instead).""" + logger.warning( + "Endpoint '/get_remote_instance_transfer_engine_info' is deprecated and will be removed in a future version. " + "Please use '/remote_instance_transfer_engine_info' instead." + ) + return await remote_instance_transfer_engine_info(rank=rank) - if ( - _global_state.remote_instance_transfer_engine_info is None - or len(_global_state.remote_instance_transfer_engine_info) == 0 - ): - return Response(status_code=HTTPStatus.BAD_REQUEST) +@app.get("/remote_instance_transfer_engine_info") +@auth_level(AuthLevel.ADMIN_OPTIONAL) +async def remote_instance_transfer_engine_info(rank: int = None): + if rank is None or rank < 0: + return ORJSONResponse( + {"error": {"message": "Missing or invalid rank parameter"}}, + status_code=HTTPStatus.BAD_REQUEST, + ) + + server_args = _global_state.tokenizer_manager.server_args try: - result = { - "rank": rank, - "remote_instance_transfer_engine_info": _global_state.remote_instance_transfer_engine_info[ - rank - ], - } - return result - except Exception as e: - logger.error(f"Exception: {e}") - return Response(status_code=HTTPStatus.BAD_REQUEST) + resp = requests.get( + f"{server_args.engine_info_bootstrap_url}/get_transfer_engine_info", + params={"rank": rank}, + timeout=5, + ) + if resp.status_code == 200: + return resp.json() + except (requests.exceptions.RequestException, ValueError) as e: + logger.warning(f"Failed to get transfer engine info for rank {rank}: {e}") + + return ORJSONResponse( + {"error": {"message": f"Failed to get transfer engine info for rank {rank}"}}, + status_code=HTTPStatus.BAD_REQUEST, + ) @app.post("/init_weights_update_group") @@ -1486,11 +1465,18 @@ async def openai_v1_audio_transcriptions( response_format: str = Form(default="json"), temperature: float = Form(default=0.0), stream: bool = Form(default=False), + timestamp_granularities: Optional[List[str]] = Form( + default=None, alias="timestamp_granularities[]" + ), ): """OpenAI-compatible audio transcription endpoint.""" - if response_format not in ["json", "text"]: + if response_format not in ["json", "text", "verbose_json"]: return ORJSONResponse( - content={"error": {"message": "Only 'json' and 'text' formats supported"}}, + content={ + "error": { + "message": "Only 'json', 'text', and 'verbose_json' formats supported" + } + }, status_code=400, ) @@ -1504,6 +1490,7 @@ async def openai_v1_audio_transcriptions( response_format=response_format, temperature=temperature, stream=stream, + timestamp_granularities=timestamp_granularities, raw_request=raw_request, ) ) @@ -1979,6 +1966,7 @@ def _setup_and_run_http_server( template_manager, port_args: PortArgs, scheduler_infos: List[Dict], + subprocess_watchdog: Optional[SubprocessWatchdog], execute_warmup_func: Callable = _execute_server_warmup, launch_callback: Optional[Callable[[], None]] = None, ): @@ -1986,21 +1974,19 @@ def _setup_and_run_http_server( Called by launch_server after subprocesses have been launched. """ - # Parse info got from the schedulers - remote_instance_transfer_engine_info = ( - parse_remote_instance_transfer_engine_info_from_scheduler_infos(scheduler_infos) - ) - # Set global states set_global_state( _GlobalState( tokenizer_manager=tokenizer_manager, template_manager=template_manager, scheduler_info=scheduler_infos[0], - remote_instance_transfer_engine_info=remote_instance_transfer_engine_info, ) ) + # Store watchdog on tokenizer_manager (single source of truth for SIGQUIT handler) + if tokenizer_manager is not None: + tokenizer_manager._subprocess_watchdog = subprocess_watchdog + if server_args.enable_metrics: add_prometheus_track_response_middleware(app) @@ -2171,13 +2157,17 @@ def launch_server( 2. Inter-process communication is done through IPC (each process uses a different port) via the ZMQ library. """ # Launch subprocesses - tokenizer_manager, template_manager, port_args, scheduler_init_result = ( - Engine._launch_subprocesses( - server_args=server_args, - init_tokenizer_manager_func=init_tokenizer_manager_func, - run_scheduler_process_func=run_scheduler_process_func, - run_detokenizer_process_func=run_detokenizer_process_func, - ) + ( + tokenizer_manager, + template_manager, + port_args, + scheduler_init_result, + subprocess_watchdog, + ) = Engine._launch_subprocesses( + server_args=server_args, + init_tokenizer_manager_func=init_tokenizer_manager_func, + run_scheduler_process_func=run_scheduler_process_func, + run_detokenizer_process_func=run_detokenizer_process_func, ) _setup_and_run_http_server( @@ -2186,6 +2176,7 @@ def launch_server( template_manager, port_args, scheduler_init_result.scheduler_infos, + subprocess_watchdog, execute_warmup_func=execute_warmup_func, launch_callback=launch_callback, ) diff --git a/python/sglang/srt/entrypoints/openai/protocol.py b/python/sglang/srt/entrypoints/openai/protocol.py index 16fa9cbc807c..c40bd37d9f58 100644 --- a/python/sglang/srt/entrypoints/openai/protocol.py +++ b/python/sglang/srt/entrypoints/openai/protocol.py @@ -585,6 +585,7 @@ class ChatCompletionRequest(BaseModel): tool_choice: Union[ToolChoice, Literal["auto", "required", "none"]] = Field( default="auto", examples=["none"] ) # noqa + parallel_tool_calls: bool = True return_hidden_states: bool = False return_routed_experts: bool = False return_cached_tokens_details: bool = False @@ -1443,6 +1444,7 @@ class TranscriptionRequest(BaseModel): language: Optional[str] = None response_format: str = "json" temperature: float = 0.0 + timestamp_granularities: Optional[List[str]] = None stream: bool = False # Internal fields (not from API) audio_data: Optional[bytes] = None @@ -1463,6 +1465,26 @@ class TranscriptionResponse(BaseModel): usage: Optional[TranscriptionUsage] = None +class TranscriptionSegment(BaseModel): + """A segment with timestamp information.""" + + id: int + start: float + end: float + text: str + + +class TranscriptionVerboseResponse(BaseModel): + """Verbose transcription response with timestamps (OpenAI-compatible).""" + + task: str = "transcribe" + language: Optional[str] = None + duration: Optional[float] = None + text: str + segments: List[TranscriptionSegment] = [] + usage: Optional[TranscriptionUsage] = None + + class TranscriptionStreamChoice(BaseModel): """Delta content for streaming transcription.""" diff --git a/python/sglang/srt/entrypoints/openai/serving_chat.py b/python/sglang/srt/entrypoints/openai/serving_chat.py index 67cfef0a94d4..b4cee0bd0d5a 100644 --- a/python/sglang/srt/entrypoints/openai/serving_chat.py +++ b/python/sglang/srt/entrypoints/openai/serving_chat.py @@ -352,14 +352,17 @@ def _process_messages( if self.tool_call_parser: parser = FunctionCallParser(request.tools, self.tool_call_parser) tool_call_constraint = parser.get_structure_constraint( - request.tool_choice + request.tool_choice, + parallel_tool_calls=request.parallel_tool_calls, ) # Handle JSON schema constraint directly for required or named tool choice if request.tool_choice == "required" or isinstance( request.tool_choice, ToolChoice ): json_schema = get_json_schema_constraint( - request.tools, request.tool_choice + request.tools, + request.tool_choice, + parallel_tool_calls=request.parallel_tool_calls, ) tool_call_constraint = ("json_schema", json_schema) @@ -602,10 +605,25 @@ async def _handle_streaming_request( adapted_request: GenerateReqInput, request: ChatCompletionRequest, raw_request: Request, - ) -> StreamingResponse: + ) -> Union[StreamingResponse, ErrorResponse]: """Handle streaming chat completion request""" + generator = self._generate_chat_stream(adapted_request, request, raw_request) + + # Kick-start the generator to trigger validation before HTTP 200 is sent. + # If validation fails (e.g., context length exceeded), we can still return + # a proper HTTP 400 error response instead of streaming it as SSE payload. + try: + first_chunk = await generator.__anext__() + except ValueError as e: + return self.create_error_response(str(e)) + + async def prepend_first_chunk(): + yield first_chunk + async for chunk in generator: + yield chunk + return StreamingResponse( - self._generate_chat_stream(adapted_request, request, raw_request), + prepend_first_chunk(), media_type="text/event-stream", background=self.tokenizer_manager.create_abort_task(adapted_request), ) @@ -635,6 +653,7 @@ async def _generate_chat_stream( hidden_states = {} routed_experts = {} + stream_started = False try: async for content in self.tokenizer_manager.generate_request( adapted_request, raw_request @@ -699,6 +718,7 @@ async def _generate_chat_stream( model=request.model, ) yield f"data: {chunk.model_dump_json()}\n\n" + stream_started = True stream_buffer = stream_buffers.get(index, "") delta = content["text"][len(stream_buffer) :] @@ -879,6 +899,8 @@ async def _generate_chat_stream( yield f"data: {usage_chunk.model_dump_json()}\n\n" except ValueError as e: + if not stream_started: + raise error = self.create_streaming_error_response(str(e)) yield f"data: {error}\n\n" @@ -1268,6 +1290,12 @@ def _get_reasoning_from_request(self, request: ChatCompletionRequest) -> bool: not request.chat_template_kwargs or request.chat_template_kwargs.get("enable_thinking") is not False ) + if self.reasoning_parser in ["mimo"]: + # Models that require explicit enable thinking (enable_thinking=True) + return ( + request.chat_template_kwargs is not None + and request.chat_template_kwargs.get("enable_thinking") is True + ) if self.reasoning_parser in ["mistral"]: # Mistral models only reason when reasoning_effort is explicitly # set to a value other than None/"none" (typically "high"). diff --git a/python/sglang/srt/entrypoints/openai/serving_completions.py b/python/sglang/srt/entrypoints/openai/serving_completions.py index 0bda21907bdc..8c4f79c1b81b 100644 --- a/python/sglang/srt/entrypoints/openai/serving_completions.py +++ b/python/sglang/srt/entrypoints/openai/serving_completions.py @@ -180,10 +180,25 @@ async def _handle_streaming_request( adapted_request: GenerateReqInput, request: CompletionRequest, raw_request: Request, - ) -> StreamingResponse: + ) -> Union[StreamingResponse, ErrorResponse]: """Handle streaming completion request""" + generator = self._generate_completion_stream( + adapted_request, request, raw_request + ) + + # Kick-start the generator to trigger validation before HTTP 200 is sent. + try: + first_chunk = await generator.__anext__() + except ValueError as e: + return self.create_error_response(str(e)) + + async def prepend_first_chunk(): + yield first_chunk + async for chunk in generator: + yield chunk + return StreamingResponse( - self._generate_completion_stream(adapted_request, request, raw_request), + prepend_first_chunk(), media_type="text/event-stream", background=self.tokenizer_manager.create_abort_task(adapted_request), ) @@ -208,6 +223,7 @@ async def _generate_completion_stream( hidden_states = {} routed_experts = {} + stream_started = False try: async for content in self.tokenizer_manager.generate_request( adapted_request, raw_request @@ -312,6 +328,7 @@ async def _generate_completion_stream( ) yield f"data: {chunk.model_dump_json()}\n\n" + stream_started = True if request.return_hidden_states and hidden_states: for index, choice_hidden_states in hidden_states.items(): @@ -373,6 +390,8 @@ async def _generate_completion_stream( yield f"data: {final_usage_data}\n\n" except Exception as e: + if not stream_started: + raise error = self.create_streaming_error_response(str(e)) yield f"data: {error}\n\n" diff --git a/python/sglang/srt/entrypoints/openai/serving_transcription.py b/python/sglang/srt/entrypoints/openai/serving_transcription.py index 2b5661f4967a..bfbad1e0d321 100644 --- a/python/sglang/srt/entrypoints/openai/serving_transcription.py +++ b/python/sglang/srt/entrypoints/openai/serving_transcription.py @@ -22,7 +22,7 @@ import math import time import uuid -from typing import TYPE_CHECKING, AsyncGenerator, Optional, Union +from typing import TYPE_CHECKING, AsyncGenerator, List, Optional, Union from fastapi import Request from fastapi.responses import ORJSONResponse, Response, StreamingResponse @@ -32,9 +32,11 @@ ErrorResponse, TranscriptionRequest, TranscriptionResponse, + TranscriptionSegment, TranscriptionStreamChoice, TranscriptionStreamResponse, TranscriptionUsage, + TranscriptionVerboseResponse, ) from sglang.srt.entrypoints.openai.serving_base import OpenAIServingBase from sglang.srt.managers.io_struct import GenerateReqInput @@ -44,6 +46,10 @@ logger = logging.getLogger(__name__) +# Whisper timestamp token constants +TIMESTAMP_BASE_TOKEN_ID = 50365 # <|0.00|> +TIMESTAMP_BASE_OFFSET = 0.02 # Each token step = 0.02 seconds + class OpenAIServingTranscription(OpenAIServingBase): """Handler for /v1/audio/transcriptions requests""" @@ -72,6 +78,9 @@ def _convert_to_internal_request( "language": request.language, # Pass to WhisperProcessor for language-specific decoding } + if request.timestamp_granularities: + sampling_params["timestamp_granularities"] = request.timestamp_granularities + # For Whisper, we pass audio_data and let the processor handle it adapted_request = GenerateReqInput( text="", # Empty text - Whisper processor will set proper decoder tokens @@ -89,13 +98,83 @@ def _get_audio_duration(self, audio_data: bytes) -> float: try: import soundfile as sf - audio_array, sr = sf.read(io.BytesIO(audio_data)) - duration = len(audio_array) / sr - return duration + info = sf.info(io.BytesIO(audio_data)) + return info.duration except Exception as e: logger.warning(f"Could not calculate audio duration: {e}") return 0.0 + def _parse_segments( + self, output_ids: List[int], tokenizer + ) -> tuple[str, List[TranscriptionSegment]]: + """Parse timestamp tokens from output_ids into segments. + + The decoder prompt ends with <|0.00|>, so the first segment starts at + t=0. The model then outputs: + text_tokens <|end_ts|> [<|start_ts|> text_tokens <|end_ts|> ...] + Each timestamp token marks the end of the current segment; its value + also becomes the start of the next segment. + """ + # Token IDs for special tokens we want to strip from segment text + eos_token_id = getattr(tokenizer, "eos_token_id", 50257) + + segments = [] + full_text_parts = [] + current_text_tokens = [] + current_start = 0.0 # First segment starts at 0.0 (from prompt <|0.00|>) + seg_id = 0 + + for token_id in output_ids: + if token_id >= TIMESTAMP_BASE_TOKEN_ID: + # This is a timestamp token — marks the end of current segment + timestamp = (token_id - TIMESTAMP_BASE_TOKEN_ID) * TIMESTAMP_BASE_OFFSET + + if current_text_tokens: + text = tokenizer.decode( + current_text_tokens, skip_special_tokens=True + ).strip() + if text: + segments.append( + TranscriptionSegment( + id=seg_id, + start=round(current_start, 2), + end=round(timestamp, 2), + text=text, + ) + ) + full_text_parts.append(text) + seg_id += 1 + current_text_tokens = [] + + # Next segment starts at this timestamp + current_start = timestamp + + elif token_id == eos_token_id: + # Skip end-of-text token + continue + else: + # Regular text token + current_text_tokens.append(token_id) + + # Handle any trailing text tokens without a closing timestamp + if current_text_tokens: + text = tokenizer.decode( + current_text_tokens, skip_special_tokens=True + ).strip() + if text: + segments.append( + TranscriptionSegment( + id=seg_id, + start=round(current_start, 2), + end=round(current_start, 2), + text=text, + ) + ) + full_text_parts.append(text) + + full_text = " ".join(full_text_parts) + return full_text, segments + async def create_transcription( self, audio_data: bytes, @@ -105,7 +184,14 @@ async def create_transcription( temperature: float, stream: bool, raw_request: Request, - ) -> Union[TranscriptionResponse, StreamingResponse, Response, ORJSONResponse]: + timestamp_granularities: Optional[List[str]] = None, + ) -> Union[ + TranscriptionResponse, + TranscriptionVerboseResponse, + StreamingResponse, + Response, + ORJSONResponse, + ]: """Main entry point for transcription requests.""" # Calculate audio duration for usage reporting audio_duration_s = self._get_audio_duration(audio_data) @@ -117,6 +203,7 @@ async def create_transcription( language=language, response_format=response_format, temperature=temperature, + timestamp_granularities=timestamp_granularities, stream=stream, audio_duration_s=audio_duration_s, ) @@ -129,7 +216,13 @@ async def _handle_non_streaming_request( adapted_request: GenerateReqInput, request: TranscriptionRequest, raw_request: Request, - ) -> Union[TranscriptionResponse, ErrorResponse, ORJSONResponse, Response]: + ) -> Union[ + TranscriptionResponse, + TranscriptionVerboseResponse, + ErrorResponse, + ORJSONResponse, + Response, + ]: """Handle non-streaming transcription request.""" try: ret = await self.tokenizer_manager.generate_request( @@ -139,14 +232,26 @@ async def _handle_non_streaming_request( return self.create_error_response(str(e)) text = ret.get("text", "") + usage = TranscriptionUsage(seconds=int(math.ceil(request.audio_duration_s))) # Build response based on format if request.response_format == "text": return Response(content=text, media_type="text/plain") - # JSON format - usage = TranscriptionUsage(seconds=int(math.ceil(request.audio_duration_s))) + if request.response_format == "verbose_json": + output_ids = ret.get("output_ids", []) + tokenizer = self.tokenizer_manager.tokenizer + parsed_text, segments = self._parse_segments(output_ids, tokenizer) + + return TranscriptionVerboseResponse( + language=request.language or "en", + duration=round(request.audio_duration_s, 2), + text=parsed_text or text, + segments=segments, + usage=usage, + ) + # Default JSON format return TranscriptionResponse(text=text, usage=usage) async def _handle_streaming_request( diff --git a/python/sglang/srt/environ.py b/python/sglang/srt/environ.py index 4dcf0613bd91..4d0964efd44f 100644 --- a/python/sglang/srt/environ.py +++ b/python/sglang/srt/environ.py @@ -243,6 +243,10 @@ class Envs: SGLANG_DISAGGREGATION_WAITING_TIMEOUT = EnvInt(300) SGLANG_DISAGGREGATION_NIXL_BACKEND = EnvStr("UCX") SGLANG_DISAGGREGATION_ALL_CP_RANKS_TRANSFER = EnvBool(False) + # Extra slots in req_to_token_pool for decode workers (only effective when + # max_num_reqs > 32). Increases pool capacity so more KV cache transfers + # can overlap with decode execution without raising max_running_requests. + SGLANG_DISAGGREGATION_NUM_PRE_ALLOCATE_REQS = EnvInt(0) # Scheduler: others: SGLANG_EMPTY_CACHE_INTERVAL = EnvFloat(-1) # in seconds. Set if you observe high memory accumulation over a long serving period. @@ -280,9 +284,13 @@ class Envs: SGLANG_HICACHE_DECODE_OFFLOAD_STRIDE = EnvInt(None) SGLANG_HICACHE_FILE_BACKEND_STORAGE_DIR = EnvStr(None) SGLANG_HICACHE_NIXL_BACKEND_STORAGE_DIR = EnvStr(None) - # Max fraction of cache (by token count) that can be pinned; 0 = disable pinning. - SGLANG_HICACHE_MAX_PINNED_RATIO = EnvFloat(0.0) - + # Staging buffer for heterogeneous TP KV transfer + SGLANG_DISAGG_STAGING_BUFFER = EnvBool(False) + SGLANG_DISAGG_STAGING_BUFFER_SIZE_MB = EnvInt(64) + SGLANG_DISAGG_STAGING_POOL_SIZE_MB = EnvInt(4096) + # TODO(yangminl): remove SGLANG_STAGING_USE_TORCH and the torch fallback in + # staging_buffer.py once Triton kernels are fully validated in production. + SGLANG_STAGING_USE_TORCH = EnvBool(False) # Mooncake KV Transfer SGLANG_MOONCAKE_CUSTOM_MEM_POOL = EnvStr(None) ENABLE_ASCEND_TRANSFER_WITH_MOONCAKE = EnvBool(False) @@ -328,6 +336,7 @@ class Envs: SGLANG_CPU_QUANTIZATION = EnvBool(False) SGLANG_USE_DYNAMIC_MXFP4_LINEAR = EnvBool(False) SGLANG_FORCE_FP8_MARLIN = EnvBool(False) + SGLANG_FORCE_NVFP4_MARLIN = EnvBool(False) SGLANG_MOE_NVFP4_DISPATCH = EnvBool(False) SGLANG_NVFP4_CKPT_FP8_GEMM_IN_ATTN = EnvBool(False) SGLANG_PER_TOKEN_GROUP_QUANT_8BIT_V2 = EnvBool(False) @@ -337,10 +346,12 @@ class Envs: # Flashinfer SGLANG_IS_FLASHINFER_AVAILABLE = EnvBool(True) - SGLANG_ENABLE_FLASHINFER_FP8_GEMM = EnvBool(False) # Default to the pick from flashinfer - SGLANG_FLASHINFER_FP4_GEMM_BACKEND = EnvStr("") SGLANG_FLASHINFER_WORKSPACE_SIZE = EnvInt(384 * 1024 * 1024) + # Skip-softmax threshold scale factor for TRT-LLM attention (prefill and decode separately). + # None = standard attention. See https://arxiv.org/abs/2512.12087 + SGLANG_SKIP_SOFTMAX_PREFILL_THRESHOLD_SCALE_FACTOR = EnvFloat(None) + SGLANG_SKIP_SOFTMAX_DECODE_THRESHOLD_SCALE_FACTOR = EnvFloat(None) # TODO(mmangkad): Remove this once the FlashInfer unified allreduce-fusion # transport issue on GB200/GB300 platforms is fixed and verified resolved. SGLANG_FLASHINFER_FORCE_POSIX_FD_TRANSPORT = EnvBool(None) @@ -408,7 +419,6 @@ class Envs: DISABLE_OPENAPI_DOC = EnvBool(False) SGLANG_ENABLE_TORCH_INFERENCE_MODE = EnvBool(False) SGLANG_IS_FIRST_RANK_ON_NODE = EnvBool(True) - SGLANG_SUPPORT_CUTLASS_BLOCK_FP8 = EnvBool(False) SGLANG_SYNC_TOKEN_IDS_ACROSS_TP = EnvBool(False) SGLANG_ENABLE_COLOCATED_BATCH_GEN = EnvBool(False) @@ -450,12 +460,10 @@ class Envs: # VLM Item CUDA IPC Transport SGLANG_USE_CUDA_IPC_TRANSPORT = EnvBool(False) + SGLANG_USE_IPC_POOL_HANDLE_CACHE = EnvBool(False) SGLANG_MM_FEATURE_CACHE_MB = EnvInt(4 * 1024) SGLANG_MM_ITEM_MEM_POOL_RECYCLE_INTERVAL_SEC = EnvFloat(0.05) - # MM splitting behavior control - SGLANG_ENABLE_MM_SPLITTING = EnvBool(False) - # Mamba SGLANG_MAMBA_CONV_DTYPE = EnvStr("bfloat16") SGLANG_MAMBA_SSM_DTYPE = EnvStr(None) @@ -548,9 +556,6 @@ def _warn_deprecated_env_to_cli_flag(env_name: str, suggestion: str): def _convert_SGL_to_SGLANG(): _print_deprecated_env("SGLANG_LOG_GC", "SGLANG_GC_LOG") - _print_deprecated_env( - "SGLANG_ENABLE_FLASHINFER_FP8_GEMM", "SGLANG_ENABLE_FLASHINFER_GEMM" - ) _print_deprecated_env( "SGLANG_MOE_NVFP4_DISPATCH", "SGLANG_CUTEDSL_MOE_NVFP4_DISPATCH" ) @@ -581,23 +586,6 @@ def _convert_SGL_to_SGLANG(): _convert_SGL_to_SGLANG() - -_warn_deprecated_env_to_cli_flag( - "SGLANG_ENABLE_FLASHINFER_FP8_GEMM", - "It will be completely removed in 0.5.7. Please use '--fp8-gemm-backend=flashinfer_trtllm' instead.", -) -_warn_deprecated_env_to_cli_flag( - "SGLANG_ENABLE_FLASHINFER_GEMM", - "It will be completely removed in 0.5.7. Please use '--fp8-gemm-backend=flashinfer_trtllm' instead.", -) -_warn_deprecated_env_to_cli_flag( - "SGLANG_SUPPORT_CUTLASS_BLOCK_FP8", - "It will be completely removed in 0.5.7. Please use '--fp8-gemm-backend=cutlass' instead.", -) -_warn_deprecated_env_to_cli_flag( - "SGLANG_FLASHINFER_FP4_GEMM_BACKEND", - "It will be completely removed in 0.5.9. Please use '--fp4-gemm-backend' instead.", -) _warn_deprecated_env_to_cli_flag( "SGLANG_SCHEDULER_DECREASE_PREFILL_IDLE", "Please use '--enable-prefill-delayer' instead.", diff --git a/python/sglang/srt/function_call/base_format_detector.py b/python/sglang/srt/function_call/base_format_detector.py index 8022dbe076e2..3163867bcd05 100644 --- a/python/sglang/srt/function_call/base_format_detector.py +++ b/python/sglang/srt/function_call/base_format_detector.py @@ -171,12 +171,13 @@ def parse_streaming_increment( # parallel tool calls because the bot_token (e.g., '[') can also # appear inside array parameters of the current tool, and we must not # mistakenly identify that as the start of a new tool. + used_separator_branch = False if self.current_tool_id > 0 and current_text.startswith( self.tool_call_separator ): start_idx = len(self.tool_call_separator) + used_separator_branch = True else: - # Only search for bot_token if not processing subsequent tool tool_call_pos = current_text.find(self.bot_token) if tool_call_pos != -1: start_idx = tool_call_pos + len(self.bot_token) @@ -186,7 +187,23 @@ def parse_streaming_increment( if start_idx >= len(current_text): return StreamingParseResult() - obj, end_idx = _partial_json_loads(current_text[start_idx:], flags) + try: + obj, end_idx = _partial_json_loads(current_text[start_idx:], flags) + except (MalformedJSON, json.JSONDecodeError): + # Separator landed on non-JSON markup; fall back to + # bot_token which skips past all inter-object markup. + # e.g. Qwen25: separator "," matches between eot/bot tags. + if used_separator_branch and self.bot_token in current_text: + start_idx = current_text.find(self.bot_token) + len( + self.bot_token + ) + if start_idx >= len(current_text): + return StreamingParseResult() + obj, end_idx = _partial_json_loads( + current_text[start_idx:], flags + ) + else: + raise is_current_complete = _is_complete_json( current_text[start_idx : start_idx + end_idx] @@ -212,7 +229,7 @@ def parse_streaming_increment( current_tool_call = obj - except MalformedJSON: + except (MalformedJSON, json.JSONDecodeError): return StreamingParseResult() if not current_tool_call: diff --git a/python/sglang/srt/function_call/function_call_parser.py b/python/sglang/srt/function_call/function_call_parser.py index 2f562192e219..ca066e196d0f 100644 --- a/python/sglang/srt/function_call/function_call_parser.py +++ b/python/sglang/srt/function_call/function_call_parser.py @@ -184,7 +184,9 @@ def get_structure_tag(self) -> LegacyStructuralTagResponseFormat: ) def get_structure_constraint( - self, tool_choice: Union[ToolChoice, Literal["auto", "required"]] + self, + tool_choice: Union[ToolChoice, Literal["auto", "required"]], + parallel_tool_calls: bool = True, ) -> Optional[ToolCallConstraint]: """ Returns the appropriate structure constraint for tool calls based on the tool_choice. @@ -210,5 +212,7 @@ def get_structure_constraint( tag = self.get_structure_tag() return ("structural_tag", tag) elif tool_choice == "required" or isinstance(tool_choice, ToolChoice): - json_schema = get_json_schema_constraint(self.tools, tool_choice) + json_schema = get_json_schema_constraint( + self.tools, tool_choice, parallel_tool_calls=parallel_tool_calls + ) return ("json_schema", json_schema) diff --git a/python/sglang/srt/function_call/utils.py b/python/sglang/srt/function_call/utils.py index 567ca583bb8f..1ef93e05197d 100644 --- a/python/sglang/srt/function_call/utils.py +++ b/python/sglang/srt/function_call/utils.py @@ -205,13 +205,16 @@ def infer_type_from_json_schema(schema: Dict[str, Any]) -> Optional[str]: def get_json_schema_constraint( - tools: List[Tool], tool_choice: Union[ToolChoice, Literal["required"]] + tools: List[Tool], + tool_choice: Union[ToolChoice, Literal["required"]], + parallel_tool_calls: bool = True, ) -> Optional[dict]: """ Get the JSON schema constraint for the specified tool choice. Args: tool_choice: The tool choice specification + parallel_tool_calls: If False, constrain to exactly one tool call (maxItems=1) Returns: JSON schema dict, or None if no valid tools found @@ -222,12 +225,14 @@ def get_json_schema_constraint( fn_name = tool_choice.function.name for tool in tools: if tool.function.name == fn_name: - return { + schema = { "type": "array", "minItems": 1, - "maxItems": 1, "items": _get_tool_schema(tool), } + if not parallel_tool_calls: + schema["maxItems"] = 1 + return schema return None elif tool_choice == "required": json_schema = { @@ -238,6 +243,8 @@ def get_json_schema_constraint( "anyOf": [_get_tool_schema(tool) for tool in tools], }, } + if not parallel_tool_calls: + json_schema["maxItems"] = 1 json_schema_defs = _get_tool_schema_defs(tools) if json_schema_defs: json_schema["$defs"] = json_schema_defs diff --git a/python/sglang/srt/hardware_backend/npu/attention/ascend_backend.py b/python/sglang/srt/hardware_backend/npu/attention/ascend_backend.py index 1ac339d0bf22..d93cc5f62f1c 100644 --- a/python/sglang/srt/hardware_backend/npu/attention/ascend_backend.py +++ b/python/sglang/srt/hardware_backend/npu/attention/ascend_backend.py @@ -467,6 +467,8 @@ def init_forward_metadata_capture_cuda_graph( self.q_head_num_padding is not None and self.q_head_num_padding > self.tp_q_head_num ): + # In the MLA architecture, the FIA kernel requires the head count to be a power of 2. + # Therefore, we pad the head dimension accordingly and initialize an empty tensor for padding. metadata.nope_padding = torch.empty( [ bs, @@ -1050,6 +1052,9 @@ def forward_extend( -1, layer.tp_q_head_num * layer.v_head_dim ) elif sum(forward_batch.extend_prefix_lens_cpu) > 0: + # This branch adds support for prefix cache for GLM-4.7-Flash. + # When using the MLA architecture, if qk head dim equals v head dim and the head count is not a power of 2, + # we use the FIA kernel for computation. if layer.qk_head_dim == layer.v_head_dim: q = q.reshape(-1, layer.tp_q_head_num, layer.qk_head_dim) @@ -1694,6 +1699,8 @@ def forward_decode_graph( self.q_head_num_padding is not None and self.q_head_num_padding > layer.tp_q_head_num ): + # The FIA kernel only supports head counts that are powers of 2. + # Therefore, we pad the head dimension when it is not a power of 2. q_nope = torch.cat( [q_nope, self.forward_metadata.nope_padding], dim=2 ).contiguous() diff --git a/python/sglang/srt/hardware_backend/npu/modules/deepseek_v2_attention_mla_npu.py b/python/sglang/srt/hardware_backend/npu/modules/deepseek_v2_attention_mla_npu.py index 16dc169ab0b9..689c8c95f111 100644 --- a/python/sglang/srt/hardware_backend/npu/modules/deepseek_v2_attention_mla_npu.py +++ b/python/sglang/srt/hardware_backend/npu/modules/deepseek_v2_attention_mla_npu.py @@ -3,6 +3,7 @@ import torch import torch_npu +from sgl_kernel_npu.norm.fused_split_qk_norm import fused_split_qk_norm from sglang.srt.environ import envs from sglang.srt.hardware_backend.npu.attention.mla_preprocess import ( @@ -323,39 +324,63 @@ def forward_dsa_prepare_npu( ) else: fused_qkv_a_proj_out = m.fused_qkv_a_proj_with_mqa(hidden_states)[0] - q, latent_cache = fused_qkv_a_proj_out.split( - [m.q_lora_rank, m.kv_lora_rank + m.qk_rope_head_dim], dim=-1 - ) - - # overlap qk norm - q = m.q_a_layernorm(q) - if ( - _use_ag_after_qlora - and layer_scatter_modes.layer_input_mode == ScatterMode.SCATTERED - and layer_scatter_modes.attn_mode == ScatterMode.TP_ATTN_FULL - ): - q = scattered_to_tp_attn_full(q, forward_batch) - latent_cache = scattered_to_tp_attn_full(latent_cache, forward_batch) - q_lora = q.clone() # required for topk_indices - - q_event = None - if m.alt_stream is not None: - m.alt_stream.wait_stream(torch.npu.current_stream()) - with torch.npu.stream(m.alt_stream): + if m.rotary_emb.is_neox_style: + q, latent_cache = fused_qkv_a_proj_out.split( + [m.q_lora_rank, m.kv_lora_rank + m.qk_rope_head_dim], dim=-1 + ) + # overlap qk norm + q = m.q_a_layernorm(q) + if ( + _use_ag_after_qlora + and layer_scatter_modes.layer_input_mode == ScatterMode.SCATTERED + and layer_scatter_modes.attn_mode == ScatterMode.TP_ATTN_FULL + ): + q = scattered_to_tp_attn_full(q, forward_batch) + latent_cache = scattered_to_tp_attn_full(latent_cache, forward_batch) + q_lora = q.clone() # required for topk_indices + + q_event = None + if m.alt_stream is not None: + m.alt_stream.wait_stream(torch.npu.current_stream()) + with torch.npu.stream(m.alt_stream): + q = m.q_b_proj(q_lora)[0].view(-1, m.num_local_heads, m.qk_head_dim) + # record q to ensure memory space will not be released + q.record_stream(m.alt_stream) + q_event = m.alt_stream.record_event() + else: q = m.q_b_proj(q_lora)[0].view(-1, m.num_local_heads, m.qk_head_dim) - # record q to ensure memory space will not be released - q.record_stream(m.alt_stream) - q_event = m.alt_stream.record_event() + + k_nope, k_pe = latent_cache.unsqueeze(1).split( + [m.kv_lora_rank, m.qk_rope_head_dim], dim=-1 + ) + k_nope = m.kv_a_layernorm(k_nope) + # main stream waits for the completion of the event on the alt stream to ensure data dependency is complete + if q_event is not None: + torch.npu.current_stream().wait_event(q_event) else: - q = m.q_b_proj(q_lora)[0].view(-1, m.num_local_heads, m.qk_head_dim) + if fused_qkv_a_proj_out.shape[0] < 65535: + q_lora, k_nope, k_pe = fused_split_qk_norm( + fused_qkv_a_proj_out, + m.q_a_layernorm, + m.kv_a_layernorm, + m.q_lora_rank, + m.kv_lora_rank, + m.qk_rope_head_dim, + eps=m.q_a_layernorm.variance_epsilon, + ) + else: + q, latent_cache = fused_qkv_a_proj_out.split( + [m.q_lora_rank, m.kv_lora_rank + m.qk_rope_head_dim], dim=-1 + ) + # overlap qk norm + q = m.q_a_layernorm(q) - k_nope, k_pe = latent_cache.unsqueeze(1).split( - [m.kv_lora_rank, m.qk_rope_head_dim], dim=-1 - ) - k_nope = m.kv_a_layernorm(k_nope) - # main stream waits for the completion of the event on the alt stream to ensure data dependency is complete - if q_event is not None: - torch.npu.current_stream().wait_event(q_event) + q_lora = q.clone() # required for topk_indices + k_nope, k_pe = latent_cache.unsqueeze(1).split( + [m.kv_lora_rank, m.qk_rope_head_dim], dim=-1 + ) + k_nope = m.kv_a_layernorm(k_nope) + q = m.q_b_proj(q_lora)[0].view(-1, m.num_local_heads, m.qk_head_dim) q_nope, q_pe = q.split([m.qk_nope_head_dim, m.qk_rope_head_dim], dim=-1) @@ -363,6 +388,11 @@ def forward_dsa_prepare_npu( q_nope_out = q_nope_out.transpose(0, 1) + if m.layer_id == 0: + m.rotary_emb.sin_cos_cache = m.rotary_emb.cos_sin_cache.index_select( + 0, positions + ) + q_pe, k_pe = m.rotary_emb(positions, q_pe, k_pe) if nsa_use_prefill_cp(forward_batch): diff --git a/python/sglang/srt/hardware_backend/npu/utils.py b/python/sglang/srt/hardware_backend/npu/utils.py index cdc5ad2fda5f..a0515f4f80e5 100644 --- a/python/sglang/srt/hardware_backend/npu/utils.py +++ b/python/sglang/srt/hardware_backend/npu/utils.py @@ -107,6 +107,28 @@ def init_npu_backend(): torch_npu.npu.set_compile_mode(jit_compile=False) +def _is_nz_aligned(tensor: torch.Tensor) -> bool: + """Check whether the last two dims satisfy FRACTAL_NZ alignment rules. + + Ascend FRACTAL_NZ requires: + BF16 / FP16 : both dims divisible by 16 + INT8 : k % 16 == 0 and n % 32 == 0 + INT4 : k % 16 == 0 and n % 64 == 0 + FP4 : both dims divisible by 64 + """ + if tensor.dim() < 2: + return False + k, n = tensor.shape[-2], tensor.shape[-1] + if tensor.dtype in (torch.bfloat16, torch.float16): + return k % 16 == 0 and n % 16 == 0 + if tensor.dtype == torch.int8: + return k % 16 == 0 and n % 32 == 0 + if tensor.dtype in (torch.uint8, torch.int32): + # INT4 is typically packed into uint8/int32; be conservative + return k % 16 == 0 and n % 64 == 0 + return True + + def npu_format_cast( tensor: torch.Tensor, acl_format: NPUACLFormat = NPUACLFormat.ACL_FORMAT_FRACTAL_NZ, @@ -135,8 +157,20 @@ def npu_format_cast( "significantly reduced." ) return tensor - else: - return torch.ops.npu.npu_format_cast(tensor, acl_format.value) + + if acl_format == NPUACLFormat.ACL_FORMAT_FRACTAL_NZ and not _is_nz_aligned(tensor): + k, n = tensor.shape[-2], tensor.shape[-1] + logger.warning_once( + "Skipping FRACTAL_NZ format cast: tensor shape (%d, %d) dtype %s " + "is not aligned to NZ requirements. Falling back to 'ND' format, " + "which may reduce NPU performance.", + k, + n, + tensor.dtype, + ) + return tensor + + return torch.ops.npu.npu_format_cast(tensor, acl_format.value) def get_indexer_weight_stream(): @@ -144,3 +178,67 @@ def get_indexer_weight_stream(): if indexer_weight_stream is None: indexer_weight_stream = torch.npu.Stream() return indexer_weight_stream + + +share_stream = None +routed_stream = None + + +def get_share_stream(): + global share_stream + return share_stream + + +def set_share_stream(stream): + global share_stream + share_stream = stream + # TODO LKL: set stream limit has impact on precision + # torch.npu.set_stream_limit(share_stream, 8, 16) + + +def get_routed_stream(): + global routed_stream + return routed_stream + + +def set_routed_stream(stream): + global routed_stream + routed_stream = stream + # TODO LKL: set stream limit has impact on precision + # torch.npu.set_stream_limit(routed_stream, 16, 32) + + +def wait_share_stream(): + stream = get_share_stream() + if stream is not None: + cur_stream = torch.get_device_module().current_stream() + cur_stream.wait_stream(stream) + + +def wait_routed_stream(): + stream = get_routed_stream() + if stream is not None: + cur_stream = torch.get_device_module().current_stream() + cur_stream.wait_stream(stream) + + +def process_shared_expert(hidden_states, forward_func): + stream = get_share_stream() + if stream is None: + stream = torch.get_device_module().Stream() + set_share_stream(stream) + stream.wait_stream(torch.get_device_module().current_stream()) + with torch.get_device_module().stream(stream): + shared_output = forward_func(hidden_states) + return shared_output + + +def process_routed_expert(hidden_states, topk_output, forward_func): + stream = get_routed_stream() + if stream is None: + stream = torch.get_device_module().Stream() + set_routed_stream(stream) + stream.wait_stream(torch.get_device_module().current_stream()) + with torch.get_device_module().stream(stream): + shared_output = forward_func(hidden_states, topk_output) + return shared_output diff --git a/python/sglang/srt/layers/attention/aiter_backend.py b/python/sglang/srt/layers/attention/aiter_backend.py index 23875a653912..7833c2494cb1 100755 --- a/python/sglang/srt/layers/attention/aiter_backend.py +++ b/python/sglang/srt/layers/attention/aiter_backend.py @@ -2012,6 +2012,7 @@ def forward_extend( self.use_triton_unified_attention and self.use_sliding_window_kv_pool ): + token_to_kv_pool = forward_batch.token_to_kv_pool k_cache, v_cache = forward_batch.token_to_kv_pool.get_kv_buffer( layer.layer_id @@ -2036,7 +2037,6 @@ def forward_extend( k_scale=k_descale, v_scale=v_descale, ) - elif self.use_mla: forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v) else: @@ -2419,6 +2419,7 @@ def forward_decode( # use standard set_kv_buffer, as they lack SWA-specific attributes # like full_to_swa_index_mapping. if self.use_triton_unified_attention and self.use_sliding_window_kv_pool: + token_to_kv_pool = forward_batch.token_to_kv_pool k_cache, v_cache = forward_batch.token_to_kv_pool.get_kv_buffer( layer.layer_id @@ -2502,7 +2503,7 @@ def forward_decode( o = torch.empty_like(q, dtype=self.input_dtype) - max_kv_len = page_table.shape[1] + max_kv_len = page_table.shape[1] * self.page_size unified_attention( q=q.view(-1, layer.tp_q_head_num, layer.qk_head_dim), diff --git a/python/sglang/srt/layers/attention/fla/chunk.py b/python/sglang/srt/layers/attention/fla/chunk.py index 28fc166f44e8..f715afdc8418 100644 --- a/python/sglang/srt/layers/attention/fla/chunk.py +++ b/python/sglang/srt/layers/attention/fla/chunk.py @@ -8,19 +8,20 @@ from einops import rearrange from sglang.srt.layers.attention.fla.chunk_delta_h import chunk_gated_delta_rule_fwd_h +from sglang.srt.layers.attention.fla.chunk_fwd import chunk_gated_delta_rule_fwd_intra from sglang.srt.layers.attention.fla.chunk_o import chunk_fwd_o -from sglang.srt.layers.attention.fla.chunk_scaled_dot_kkt import ( - chunk_scaled_dot_kkt_fwd, -) from sglang.srt.layers.attention.fla.cumsum import chunk_local_cumsum +from sglang.srt.layers.attention.fla.index import ( + prepare_chunk_indices, +) from sglang.srt.layers.attention.fla.l2norm import l2norm_fwd -from sglang.srt.layers.attention.fla.solve_tril import solve_tril from sglang.srt.layers.attention.fla.utils import ( SUPPRESS_LEVEL, autocast_custom_fwd, input_guard, ) -from sglang.srt.layers.attention.fla.wy_fast import recompute_w_u_fwd + +CHUNK_SIZE = 64 def chunk_gated_delta_rule_fwd( @@ -33,21 +34,20 @@ def chunk_gated_delta_rule_fwd( initial_state: torch.Tensor, initial_state_indices: torch.Tensor, cu_seqlens: Optional[torch.LongTensor] = None, + chunk_indices: torch.LongTensor | None = None, ): - g = chunk_local_cumsum(g, chunk_size=64, cu_seqlens=cu_seqlens) - # obtain WY representation. u is actually the new v. - A = chunk_scaled_dot_kkt_fwd( - k=k, beta=beta, g_cumsum=g, cu_seqlens=cu_seqlens, output_dtype=torch.float32 - ) - A = solve_tril(A=A, cu_seqlens=cu_seqlens, output_dtype=k.dtype) - w, u = recompute_w_u_fwd( + g = chunk_local_cumsum(g, chunk_size=CHUNK_SIZE, cu_seqlens=cu_seqlens) + + # fused kkt + solve_tril + recompute_w_u + w, u, A = chunk_gated_delta_rule_fwd_intra( k=k, v=v, + g=g, beta=beta, - A=A, - g_cumsum=g, cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, ) + h, v_new = chunk_gated_delta_rule_fwd_h( k=k, w=w, @@ -97,6 +97,11 @@ def forward( q = l2norm_fwd(q) k = l2norm_fwd(k) + chunk_indices = ( + prepare_chunk_indices(cu_seqlens, CHUNK_SIZE) + if cu_seqlens is not None + else None + ) g, o, A, w, h, v_new = chunk_gated_delta_rule_fwd( q=q, k=k, @@ -107,6 +112,7 @@ def forward( initial_state=initial_state, initial_state_indices=initial_state_indices, cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, ) return o.to(q.dtype), h diff --git a/python/sglang/srt/layers/attention/fla/chunk_fwd.py b/python/sglang/srt/layers/attention/fla/chunk_fwd.py new file mode 100644 index 000000000000..432a274cd5e5 --- /dev/null +++ b/python/sglang/srt/layers/attention/fla/chunk_fwd.py @@ -0,0 +1,416 @@ +# Adapted from https://github.com/fla-org/flash-linear-attention/blob/main/fla/ops/gated_delta_rule/chunk_fwd.py +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +import torch +import triton +import triton.language as tl + +from sglang.srt.layers.attention.fla.index import prepare_chunk_indices +from sglang.srt.layers.attention.fla.op import safe_exp +from sglang.srt.layers.attention.fla.utils import ( + autotune_cache_kwargs, + is_tf32_supported, +) +from sglang.srt.layers.attention.fla.wy_fast import recompute_w_u_fwd + +# TF32 for the block-merge dot products (16x16 matmuls) is safe and ~2x faster on SM90. +# The numerically sensitive forward-substitution uses scalar ops, not tl.dot. +if is_tf32_supported: + _MERGE_DOT_PRECISION = tl.constexpr("tf32") +else: + _MERGE_DOT_PRECISION = tl.constexpr("ieee") + + +@triton.heuristics( + { + "USE_G": lambda args: args["g"] is not None, + "IS_VARLEN": lambda args: args["cu_seqlens"] is not None, + } +) +@triton.autotune( + configs=[ + triton.Config({"BK": BK}, num_warps=num_warps) + for BK in [32, 64] + for num_warps in [1, 2, 4] + ], + key=["H", "Hg", "K", "BC"], + **autotune_cache_kwargs, +) +@triton.jit(do_not_specialize=["T"]) +def chunk_gated_delta_rule_fwd_kkt_solve_kernel( + k, + g, + beta, + A, + cu_seqlens, + chunk_indices, + T, + H: tl.constexpr, + Hg: tl.constexpr, + K: tl.constexpr, + BT: tl.constexpr, + BC: tl.constexpr, + BK: tl.constexpr, + USE_G: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + """ + Fused kernel: compute beta * K @ K^T (lower triangular) + solve_tril (I+A)^{-1} in one pass. + + This kernel fuses chunk_scaled_dot_kkt_fwd and solve_tril into a single kernel, + avoiding the HBM round-trip for the intermediate A matrix. + + Steps: + 1. Compute all 10 lower-triangular [BC, BC] blocks of beta * K @ K^T in registers + 2. Apply gate and beta scaling + 3. Forward substitution on diagonal blocks + 4. Block merge to get full (I+A)^{-1} + 5. Write result to A (output) + """ + i_t, i_bh = tl.program_id(0), tl.program_id(1) + i_b, i_h = i_bh // H, i_bh % H + + if IS_VARLEN: + i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load( + chunk_indices + i_t * 2 + 1 + ).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load( + cu_seqlens + i_n + 1 + ).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + + if i_t * BT >= T: + return + + i_tc0 = i_t * BT + i_tc1 = i_t * BT + BC + i_tc2 = i_t * BT + 2 * BC + i_tc3 = i_t * BT + 3 * BC + + k += (bos * Hg + i_h // (H // Hg)) * K + A += (bos * H + i_h) * BT + + o_i = tl.arange(0, BC) + m_tc0 = (i_tc0 + o_i) < T + m_tc1 = (i_tc1 + o_i) < T + m_tc2 = (i_tc2 + o_i) < T + m_tc3 = (i_tc3 + o_i) < T + + # load beta for each sub-chunk + p_b0 = tl.make_block_ptr(beta + bos * H + i_h, (T,), (H,), (i_tc0,), (BC,), (0,)) + p_b1 = tl.make_block_ptr(beta + bos * H + i_h, (T,), (H,), (i_tc1,), (BC,), (0,)) + p_b2 = tl.make_block_ptr(beta + bos * H + i_h, (T,), (H,), (i_tc2,), (BC,), (0,)) + p_b3 = tl.make_block_ptr(beta + bos * H + i_h, (T,), (H,), (i_tc3,), (BC,), (0,)) + b_b0 = tl.load(p_b0, boundary_check=(0,)).to(tl.float32) + b_b1 = tl.load(p_b1, boundary_check=(0,)).to(tl.float32) + b_b2 = tl.load(p_b2, boundary_check=(0,)).to(tl.float32) + b_b3 = tl.load(p_b3, boundary_check=(0,)).to(tl.float32) + + # load gate if used + if USE_G: + p_g0 = tl.make_block_ptr(g + bos * H + i_h, (T,), (H,), (i_tc0,), (BC,), (0,)) + p_g1 = tl.make_block_ptr(g + bos * H + i_h, (T,), (H,), (i_tc1,), (BC,), (0,)) + p_g2 = tl.make_block_ptr(g + bos * H + i_h, (T,), (H,), (i_tc2,), (BC,), (0,)) + p_g3 = tl.make_block_ptr(g + bos * H + i_h, (T,), (H,), (i_tc3,), (BC,), (0,)) + + b_g0 = tl.load(p_g0, boundary_check=(0,)).to(tl.float32) + b_g1 = tl.load(p_g1, boundary_check=(0,)).to(tl.float32) + b_g2 = tl.load(p_g2, boundary_check=(0,)).to(tl.float32) + b_g3 = tl.load(p_g3, boundary_check=(0,)).to(tl.float32) + + ############################################################################ + # Step 1: compute all 10 lower-triangular [BC, BC] blocks of K @ K^T + ############################################################################ + + # 4 diagonal blocks + b_A00 = tl.zeros([BC, BC], dtype=tl.float32) + b_A11 = tl.zeros([BC, BC], dtype=tl.float32) + b_A22 = tl.zeros([BC, BC], dtype=tl.float32) + b_A33 = tl.zeros([BC, BC], dtype=tl.float32) + + # 6 off-diagonal blocks + b_A10 = tl.zeros([BC, BC], dtype=tl.float32) + b_A20 = tl.zeros([BC, BC], dtype=tl.float32) + b_A21 = tl.zeros([BC, BC], dtype=tl.float32) + b_A30 = tl.zeros([BC, BC], dtype=tl.float32) + b_A31 = tl.zeros([BC, BC], dtype=tl.float32) + b_A32 = tl.zeros([BC, BC], dtype=tl.float32) + + for i_k in range(tl.cdiv(K, BK)): + p_k0 = tl.make_block_ptr( + k, (T, K), (Hg * K, 1), (i_tc0, i_k * BK), (BC, BK), (1, 0) + ) + b_k0 = tl.load(p_k0, boundary_check=(0, 1)) + # diagonal block 0 + b_A00 += tl.dot(b_k0, tl.trans(b_k0)) + + if i_tc1 < T: + p_k1 = tl.make_block_ptr( + k, (T, K), (Hg * K, 1), (i_tc1, i_k * BK), (BC, BK), (1, 0) + ) + b_k1 = tl.load(p_k1, boundary_check=(0, 1)) + # diagonal block 1 + b_A11 += tl.dot(b_k1, tl.trans(b_k1)) + # off-diagonal (1,0) + b_A10 += tl.dot(b_k1, tl.trans(b_k0)) + + if i_tc2 < T: + p_k2 = tl.make_block_ptr( + k, (T, K), (Hg * K, 1), (i_tc2, i_k * BK), (BC, BK), (1, 0) + ) + b_k2 = tl.load(p_k2, boundary_check=(0, 1)) + # diagonal block 2 + b_A22 += tl.dot(b_k2, tl.trans(b_k2)) + # off-diagonal (2,0), (2,1) + b_A20 += tl.dot(b_k2, tl.trans(b_k0)) + b_A21 += tl.dot(b_k2, tl.trans(b_k1)) + + if i_tc3 < T: + p_k3 = tl.make_block_ptr( + k, (T, K), (Hg * K, 1), (i_tc3, i_k * BK), (BC, BK), (1, 0) + ) + b_k3 = tl.load(p_k3, boundary_check=(0, 1)) + # diagonal block 3 + b_A33 += tl.dot(b_k3, tl.trans(b_k3)) + # off-diagonal (3,0), (3,1), (3,2) + b_A30 += tl.dot(b_k3, tl.trans(b_k0)) + b_A31 += tl.dot(b_k3, tl.trans(b_k1)) + b_A32 += tl.dot(b_k3, tl.trans(b_k2)) + + ############################################################################ + # Step 2: apply gate and beta scaling + ############################################################################ + + if USE_G: + # diagonal blocks: g_diff = g_i - g_j within sub-chunk + b_A00 *= safe_exp(b_g0[:, None] - b_g0[None, :]) + b_A11 *= safe_exp(b_g1[:, None] - b_g1[None, :]) + b_A22 *= safe_exp(b_g2[:, None] - b_g2[None, :]) + b_A33 *= safe_exp(b_g3[:, None] - b_g3[None, :]) + + # off-diagonal blocks: g_diff = g_row - g_col (cross sub-chunk) + b_A10 *= safe_exp(b_g1[:, None] - b_g0[None, :]) + b_A20 *= safe_exp(b_g2[:, None] - b_g0[None, :]) + b_A21 *= safe_exp(b_g2[:, None] - b_g1[None, :]) + b_A30 *= safe_exp(b_g3[:, None] - b_g0[None, :]) + b_A31 *= safe_exp(b_g3[:, None] - b_g1[None, :]) + b_A32 *= safe_exp(b_g3[:, None] - b_g2[None, :]) + + # apply beta to row dimension and mask + m_d = o_i[:, None] > o_i[None, :] + m_I = o_i[:, None] == o_i[None, :] + + # diagonal blocks: strictly lower triangular within sub-chunk, scaled by beta + b_A00 = ( + tl.where(m_d & (m_tc0[:, None] & m_tc0[None, :]), b_A00, 0.0) * b_b0[:, None] + ) + b_A11 = ( + tl.where(m_d & (m_tc1[:, None] & m_tc1[None, :]), b_A11, 0.0) * b_b1[:, None] + ) + b_A22 = ( + tl.where(m_d & (m_tc2[:, None] & m_tc2[None, :]), b_A22, 0.0) * b_b2[:, None] + ) + b_A33 = ( + tl.where(m_d & (m_tc3[:, None] & m_tc3[None, :]), b_A33, 0.0) * b_b3[:, None] + ) + + # off-diagonal blocks: full block, scaled by beta + b_A10 = b_A10 * b_b1[:, None] + b_A20 = b_A20 * b_b2[:, None] + b_A21 = b_A21 * b_b2[:, None] + b_A30 = b_A30 * b_b3[:, None] + b_A31 = b_A31 * b_b3[:, None] + b_A32 = b_A32 * b_b3[:, None] + + ############################################################################ + # Step 3: forward substitution on diagonal blocks -> (I + A_diag)^{-1} + # + # Same algorithm as solve_tril, but rows are extracted from in-register + # [BC, BC] tensor via tl.sum(tl.where(mask, tensor, 0), 0) instead of + # tl.load from HBM. + ############################################################################ + + b_Ai00 = -b_A00 + b_Ai11 = -b_A11 + b_Ai22 = -b_A22 + b_Ai33 = -b_A33 + + for i in range(2, min(BC, T - i_tc0)): + b_a00 = tl.sum(tl.where((o_i == i)[:, None], -b_A00, 0.0), 0) + b_a00 = tl.where(o_i < i, b_a00, 0.0) + b_a00 = b_a00 + tl.sum(b_a00[:, None] * b_Ai00, 0) + b_Ai00 = tl.where((o_i == i)[:, None], b_a00, b_Ai00) + for i in range(2, min(BC, T - i_tc1)): + b_a11 = tl.sum(tl.where((o_i == i)[:, None], -b_A11, 0.0), 0) + b_a11 = tl.where(o_i < i, b_a11, 0.0) + b_a11 = b_a11 + tl.sum(b_a11[:, None] * b_Ai11, 0) + b_Ai11 = tl.where((o_i == i)[:, None], b_a11, b_Ai11) + for i in range(2, min(BC, T - i_tc2)): + b_a22 = tl.sum(tl.where((o_i == i)[:, None], -b_A22, 0.0), 0) + b_a22 = tl.where(o_i < i, b_a22, 0.0) + b_a22 = b_a22 + tl.sum(b_a22[:, None] * b_Ai22, 0) + b_Ai22 = tl.where((o_i == i)[:, None], b_a22, b_Ai22) + for i in range(2, min(BC, T - i_tc3)): + b_a33 = tl.sum(tl.where((o_i == i)[:, None], -b_A33, 0.0), 0) + b_a33 = tl.where(o_i < i, b_a33, 0.0) + b_a33 = b_a33 + tl.sum(b_a33[:, None] * b_Ai33, 0) + b_Ai33 = tl.where((o_i == i)[:, None], b_a33, b_Ai33) + + b_Ai00 += m_I + b_Ai11 += m_I + b_Ai22 += m_I + b_Ai33 += m_I + + ############################################################################ + # Step 4: block merge -> full (I + A)^{-1} + ############################################################################ + + b_Ai10 = -tl.dot( + tl.dot(b_Ai11, b_A10, input_precision=_MERGE_DOT_PRECISION), + b_Ai00, + input_precision=_MERGE_DOT_PRECISION, + ) + b_Ai21 = -tl.dot( + tl.dot(b_Ai22, b_A21, input_precision=_MERGE_DOT_PRECISION), + b_Ai11, + input_precision=_MERGE_DOT_PRECISION, + ) + b_Ai32 = -tl.dot( + tl.dot(b_Ai33, b_A32, input_precision=_MERGE_DOT_PRECISION), + b_Ai22, + input_precision=_MERGE_DOT_PRECISION, + ) + + b_Ai20 = -tl.dot( + b_Ai22, + tl.dot(b_A20, b_Ai00, input_precision=_MERGE_DOT_PRECISION) + + tl.dot(b_A21, b_Ai10, input_precision=_MERGE_DOT_PRECISION), + input_precision=_MERGE_DOT_PRECISION, + ) + b_Ai31 = -tl.dot( + b_Ai33, + tl.dot(b_A31, b_Ai11, input_precision=_MERGE_DOT_PRECISION) + + tl.dot(b_A32, b_Ai21, input_precision=_MERGE_DOT_PRECISION), + input_precision=_MERGE_DOT_PRECISION, + ) + b_Ai30 = -tl.dot( + b_Ai33, + tl.dot(b_A30, b_Ai00, input_precision=_MERGE_DOT_PRECISION) + + tl.dot(b_A31, b_Ai10, input_precision=_MERGE_DOT_PRECISION) + + tl.dot(b_A32, b_Ai20, input_precision=_MERGE_DOT_PRECISION), + input_precision=_MERGE_DOT_PRECISION, + ) + + ############################################################################ + # Step 5: store full (I + A)^{-1} to output A + ############################################################################ + + p_A00 = tl.make_block_ptr(A, (T, BT), (H * BT, 1), (i_tc0, 0), (BC, BC), (1, 0)) + p_A10 = tl.make_block_ptr(A, (T, BT), (H * BT, 1), (i_tc1, 0), (BC, BC), (1, 0)) + p_A11 = tl.make_block_ptr(A, (T, BT), (H * BT, 1), (i_tc1, BC), (BC, BC), (1, 0)) + p_A20 = tl.make_block_ptr(A, (T, BT), (H * BT, 1), (i_tc2, 0), (BC, BC), (1, 0)) + p_A21 = tl.make_block_ptr(A, (T, BT), (H * BT, 1), (i_tc2, BC), (BC, BC), (1, 0)) + p_A22 = tl.make_block_ptr( + A, (T, BT), (H * BT, 1), (i_tc2, 2 * BC), (BC, BC), (1, 0) + ) + p_A30 = tl.make_block_ptr(A, (T, BT), (H * BT, 1), (i_tc3, 0), (BC, BC), (1, 0)) + p_A31 = tl.make_block_ptr(A, (T, BT), (H * BT, 1), (i_tc3, BC), (BC, BC), (1, 0)) + p_A32 = tl.make_block_ptr( + A, (T, BT), (H * BT, 1), (i_tc3, 2 * BC), (BC, BC), (1, 0) + ) + p_A33 = tl.make_block_ptr( + A, (T, BT), (H * BT, 1), (i_tc3, 3 * BC), (BC, BC), (1, 0) + ) + + tl.store(p_A00, b_Ai00.to(A.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_A10, b_Ai10.to(A.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_A11, b_Ai11.to(A.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_A20, b_Ai20.to(A.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_A21, b_Ai21.to(A.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_A22, b_Ai22.to(A.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_A30, b_Ai30.to(A.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_A31, b_Ai31.to(A.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_A32, b_Ai32.to(A.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_A33, b_Ai33.to(A.dtype.element_ty), boundary_check=(0, 1)) + + +def chunk_gated_delta_rule_fwd_intra( + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor | None = None, + beta: torch.Tensor | None = None, + cu_seqlens: torch.LongTensor | None = None, + chunk_size: int = 64, + chunk_indices: torch.LongTensor | None = None, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + r""" + GDN intra-chunk forward: fused kkt + solve_tril + recompute_w_u. + + Equivalent to: + A = chunk_scaled_dot_kkt_fwd(k, g, beta, ...) # kernel 1 + A = solve_tril(A, ...) # kernel 2 + w, u = recompute_w_u_fwd(k, v, beta, A, g, ...) # kernel 3 + + Fuses kernels 1+2 into a single kernel, reducing from 3 to 2 kernel launches + and eliminating the HBM round-trip for the intermediate A matrix. + + Args: + k (torch.Tensor): + The key tensor of shape `[B, T, H, K]`. + v (torch.Tensor): + The value tensor of shape `[B, T, H, V]`. + g (torch.Tensor): + The cumulative sum of the gate tensor of shape `[B, T, H]`. Default: `None`. + beta (torch.Tensor): + The beta tensor of shape `[B, T, H]`. + cu_seqlens (torch.LongTensor): + The cumulative sequence lengths. Default: `None`. + chunk_size (int): + The chunk size. Default: 64. + chunk_indices (torch.LongTensor): + Precomputed chunk indices. Default: `None`. + + Returns: + w (torch.Tensor): shape `[B, T, H, K]` + u (torch.Tensor): shape `[B, T, H, V]` + A (torch.Tensor): shape `[B, T, H, BT]`, the solved (I+A)^{-1} matrix + """ + B, T, Hg, K = k.shape + H = beta.shape[-1] + BT = chunk_size + BC = 16 + + if chunk_indices is None and cu_seqlens is not None: + chunk_indices = prepare_chunk_indices(cu_seqlens, BT) + NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) + + # Step 1: fused kkt + solve_tril + A = torch.zeros(B, T, H, BT, device=k.device, dtype=k.dtype) + chunk_gated_delta_rule_fwd_kkt_solve_kernel[(NT, B * H)]( + k=k, + g=g, + beta=beta, + A=A, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + T=T, + H=H, + Hg=Hg, + K=K, + BT=BT, + BC=BC, + ) + + # Step 2: recompute_w_u + w, u = recompute_w_u_fwd( + k=k, + v=v, + beta=beta, + A=A, + g_cumsum=g, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + ) + return w, u, A diff --git a/python/sglang/srt/layers/attention/fla/chunk_intra.py b/python/sglang/srt/layers/attention/fla/chunk_intra.py new file mode 100644 index 000000000000..344de6117ba4 --- /dev/null +++ b/python/sglang/srt/layers/attention/fla/chunk_intra.py @@ -0,0 +1,661 @@ +# Adapted from flash-linear-attention project. +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +import torch +import triton +import triton.language as tl + +from sglang.srt.layers.attention.fla.chunk_intra_token_parallel import ( + chunk_kda_fwd_intra_token_parallel, +) +from sglang.srt.layers.attention.fla.index import ( + prepare_chunk_indices, +) +from sglang.srt.layers.attention.fla.op import exp2, gather +from sglang.srt.layers.attention.fla.utils import ( + autotune_cache_kwargs, + is_gather_supported, + is_tf32_supported, +) + +if is_tf32_supported: + SOLVE_TRIL_DOT_PRECISION = tl.constexpr("tf32") +else: + SOLVE_TRIL_DOT_PRECISION = tl.constexpr("ieee") + + +################################################################################ +# Fused inter + solve_tril kernel: compute off-diagonal Akk and solve in one pass +################################################################################ + + +@triton.heuristics( + { + "IS_VARLEN": lambda args: args["cu_seqlens"] is not None, + } +) +@triton.autotune( + configs=[ + triton.Config({"BK": BK}, num_warps=num_warps) + for BK in [32, 64] + for num_warps in [1, 2, 4] + ], + key=["H", "K", "BC"], + **autotune_cache_kwargs, +) +@triton.jit(do_not_specialize=["T"]) +def chunk_kda_fwd_kernel_inter_solve_fused( + q, + k, + g, + beta, + Aqk, + Akkd, + Akk, + scale, + cu_seqlens, + chunk_indices, + T, + H: tl.constexpr, + K: tl.constexpr, + BT: tl.constexpr, + BC: tl.constexpr, + BK: tl.constexpr, + IS_VARLEN: tl.constexpr, + USE_SAFE_GATE: tl.constexpr, +): + """ + Fused kernel: compute inter-subchunk Akk + solve_tril in one pass. + Prerequisite: token_parallel has already computed diagonal Akk blocks in Akkd. + + This kernel: + 1. Computes off-diagonal Aqk blocks -> writes to global + 2. Computes off-diagonal Akk blocks -> keeps in registers + 3. Loads diagonal Akk blocks from Akkd (fp32) + 4. Does forward substitution on diagonals + 5. Computes merged Akk_inv + 6. Writes Akk_inv to Akk + """ + i_t, i_bh = tl.program_id(0), tl.program_id(1) + i_b, i_h = i_bh // H, i_bh % H + + if IS_VARLEN: + i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load( + chunk_indices + i_t * 2 + 1 + ).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load( + cu_seqlens + i_n + 1 + ).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + + if i_t * BT >= T: + return + + i_tc0 = i_t * BT + i_tc1 = i_t * BT + BC + i_tc2 = i_t * BT + 2 * BC + i_tc3 = i_t * BT + 3 * BC + + q += (bos * H + i_h) * K + k += (bos * H + i_h) * K + g += (bos * H + i_h) * K + Aqk += (bos * H + i_h) * BT + Akk += (bos * H + i_h) * BT + Akkd += (bos * H + i_h) * BC + + o_i = tl.arange(0, BC) + m_tc1 = (i_tc1 + o_i) < T + m_tc2 = (i_tc2 + o_i) < T + m_tc3 = (i_tc3 + o_i) < T + + b_Aqk10 = tl.zeros([BC, BC], dtype=tl.float32) + b_Akk10 = tl.zeros([BC, BC], dtype=tl.float32) + + b_Aqk20 = tl.zeros([BC, BC], dtype=tl.float32) + b_Akk20 = tl.zeros([BC, BC], dtype=tl.float32) + b_Aqk21 = tl.zeros([BC, BC], dtype=tl.float32) + b_Akk21 = tl.zeros([BC, BC], dtype=tl.float32) + + b_Aqk30 = tl.zeros([BC, BC], dtype=tl.float32) + b_Akk30 = tl.zeros([BC, BC], dtype=tl.float32) + b_Aqk31 = tl.zeros([BC, BC], dtype=tl.float32) + b_Akk31 = tl.zeros([BC, BC], dtype=tl.float32) + b_Aqk32 = tl.zeros([BC, BC], dtype=tl.float32) + b_Akk32 = tl.zeros([BC, BC], dtype=tl.float32) + + ################################################################################ + # off-diagonal blocks + ################################################################################ + for i_k in range(tl.cdiv(K, BK)): + o_k = i_k * BK + tl.arange(0, BK) + m_k = o_k < K + + p_k0 = tl.make_block_ptr( + k, (T, K), (H * K, 1), (i_tc0, i_k * BK), (BC, BK), (1, 0) + ) + p_g0 = tl.make_block_ptr( + g, (T, K), (H * K, 1), (i_tc0, i_k * BK), (BC, BK), (1, 0) + ) + b_k0 = tl.load(p_k0, boundary_check=(0, 1)).to(tl.float32) + b_g0 = tl.load(p_g0, boundary_check=(0, 1)).to(tl.float32) + + if i_tc1 < T: + p_q1 = tl.make_block_ptr( + q, (T, K), (H * K, 1), (i_tc1, i_k * BK), (BC, BK), (1, 0) + ) + p_k1 = tl.make_block_ptr( + k, (T, K), (H * K, 1), (i_tc1, i_k * BK), (BC, BK), (1, 0) + ) + p_g1 = tl.make_block_ptr( + g, (T, K), (H * K, 1), (i_tc1, i_k * BK), (BC, BK), (1, 0) + ) + # [BC, BK] + b_q1 = tl.load(p_q1, boundary_check=(0, 1)).to(tl.float32) + b_k1 = tl.load(p_k1, boundary_check=(0, 1)).to(tl.float32) + b_g1 = tl.load(p_g1, boundary_check=(0, 1)).to(tl.float32) + # [BK] + b_gn1 = tl.load(g + i_tc1 * H * K + o_k, mask=m_k, other=0).to(tl.float32) + # [BC, BK] + b_gqn = tl.where(m_tc1[:, None], exp2(b_g1 - b_gn1[None, :]), 0) + # [BK, BC] + b_kgt = tl.trans(b_k0 * exp2(b_gn1[None, :] - b_g0)) + # [BC, BC] + b_Aqk10 += tl.dot(b_q1 * b_gqn, b_kgt) + b_Akk10 += tl.dot(b_k1 * b_gqn, b_kgt) + + if i_tc2 < T: + p_q2 = tl.make_block_ptr( + q, (T, K), (H * K, 1), (i_tc2, i_k * BK), (BC, BK), (1, 0) + ) + p_k2 = tl.make_block_ptr( + k, (T, K), (H * K, 1), (i_tc2, i_k * BK), (BC, BK), (1, 0) + ) + p_g2 = tl.make_block_ptr( + g, (T, K), (H * K, 1), (i_tc2, i_k * BK), (BC, BK), (1, 0) + ) + # [BC, BK] + b_q2 = tl.load(p_q2, boundary_check=(0, 1)).to(tl.float32) + b_k2 = tl.load(p_k2, boundary_check=(0, 1)).to(tl.float32) + b_g2 = tl.load(p_g2, boundary_check=(0, 1)).to(tl.float32) + # [BK] + b_gn2 = tl.load(g + i_tc2 * H * K + o_k, mask=m_k, other=0).to( + tl.float32 + ) + # [BC, BK] + b_gqn2 = tl.where(m_tc2[:, None], exp2(b_g2 - b_gn2[None, :]), 0) + b_qg2 = b_q2 * b_gqn2 + b_kg2 = b_k2 * b_gqn2 + # [BK, BC] + b_kgt = tl.trans(b_k0 * exp2(b_gn2[None, :] - b_g0)) + b_Aqk20 += tl.dot(b_qg2, b_kgt) + b_Akk20 += tl.dot(b_kg2, b_kgt) + # [BC, BC] + b_kgt = tl.trans(b_k1 * exp2(b_gn2[None, :] - b_g1)) + # [BC, BC] + b_Aqk21 += tl.dot(b_qg2, b_kgt) + b_Akk21 += tl.dot(b_kg2, b_kgt) + + if i_tc3 < T: + p_q3 = tl.make_block_ptr( + q, (T, K), (H * K, 1), (i_tc3, i_k * BK), (BC, BK), (1, 0) + ) + p_k3 = tl.make_block_ptr( + k, (T, K), (H * K, 1), (i_tc3, i_k * BK), (BC, BK), (1, 0) + ) + p_g3 = tl.make_block_ptr( + g, (T, K), (H * K, 1), (i_tc3, i_k * BK), (BC, BK), (1, 0) + ) + # [BC, BK] + b_q3 = tl.load(p_q3, boundary_check=(0, 1)).to(tl.float32) + b_k3 = tl.load(p_k3, boundary_check=(0, 1)).to(tl.float32) + b_g3 = tl.load(p_g3, boundary_check=(0, 1)).to(tl.float32) + # [BK] + b_gn3 = tl.load(g + i_tc3 * H * K + o_k, mask=m_k, other=0).to( + tl.float32 + ) + # [BC, BK] + b_gqn3 = tl.where(m_tc3[:, None], exp2(b_g3 - b_gn3[None, :]), 0) + b_qg3 = b_q3 * b_gqn3 + b_kg3 = b_k3 * b_gqn3 + # [BK, BC] + b_kgt = tl.trans(b_k0 * exp2(b_gn3[None, :] - b_g0)) + # [BC, BC] + b_Aqk30 += tl.dot(b_qg3, b_kgt) + b_Akk30 += tl.dot(b_kg3, b_kgt) + # [BK, BC] + b_kgt = tl.trans(b_k1 * exp2(b_gn3[None, :] - b_g1)) + # [BC, BC] + b_Aqk31 += tl.dot(b_qg3, b_kgt) + b_Akk31 += tl.dot(b_kg3, b_kgt) + # [BK, BC] + b_kgt = tl.trans(b_k2 * exp2(b_gn3[None, :] - b_g2)) + # [BC, BC] + b_Aqk32 += tl.dot(b_qg3, b_kgt) + b_Akk32 += tl.dot(b_kg3, b_kgt) + + ################################################################################ + # save off-diagonal Aqk blocks and prepare Akk + ################################################################################ + if i_tc1 < T: + p_Aqk10 = tl.make_block_ptr( + Aqk, (T, BT), (H * BT, 1), (i_tc1, 0), (BC, BC), (1, 0) + ) + tl.store( + p_Aqk10, (b_Aqk10 * scale).to(Aqk.dtype.element_ty), boundary_check=(0, 1) + ) + + p_b1 = tl.make_block_ptr( + beta + bos * H + i_h, (T,), (H,), (i_tc1,), (BC,), (0,) + ) + b_b1 = tl.load(p_b1, boundary_check=(0,)).to(tl.float32) + b_Akk10 = b_Akk10 * b_b1[:, None] + if i_tc2 < T: + p_Aqk20 = tl.make_block_ptr( + Aqk, (T, BT), (H * BT, 1), (i_tc2, 0), (BC, BC), (1, 0) + ) + p_Aqk21 = tl.make_block_ptr( + Aqk, (T, BT), (H * BT, 1), (i_tc2, BC), (BC, BC), (1, 0) + ) + tl.store( + p_Aqk20, (b_Aqk20 * scale).to(Aqk.dtype.element_ty), boundary_check=(0, 1) + ) + tl.store( + p_Aqk21, (b_Aqk21 * scale).to(Aqk.dtype.element_ty), boundary_check=(0, 1) + ) + + p_b2 = tl.make_block_ptr( + beta + bos * H + i_h, (T,), (H,), (i_tc2,), (BC,), (0,) + ) + b_b2 = tl.load(p_b2, boundary_check=(0,)).to(tl.float32) + b_Akk20 = b_Akk20 * b_b2[:, None] + b_Akk21 = b_Akk21 * b_b2[:, None] + if i_tc3 < T: + p_Aqk30 = tl.make_block_ptr( + Aqk, (T, BT), (H * BT, 1), (i_tc3, 0), (BC, BC), (1, 0) + ) + p_Aqk31 = tl.make_block_ptr( + Aqk, (T, BT), (H * BT, 1), (i_tc3, BC), (BC, BC), (1, 0) + ) + p_Aqk32 = tl.make_block_ptr( + Aqk, (T, BT), (H * BT, 1), (i_tc3, 2 * BC), (BC, BC), (1, 0) + ) + tl.store( + p_Aqk30, (b_Aqk30 * scale).to(Aqk.dtype.element_ty), boundary_check=(0, 1) + ) + tl.store( + p_Aqk31, (b_Aqk31 * scale).to(Aqk.dtype.element_ty), boundary_check=(0, 1) + ) + tl.store( + p_Aqk32, (b_Aqk32 * scale).to(Aqk.dtype.element_ty), boundary_check=(0, 1) + ) + + p_b3 = tl.make_block_ptr( + beta + bos * H + i_h, (T,), (H,), (i_tc3,), (BC,), (0,) + ) + b_b3 = tl.load(p_b3, boundary_check=(0,)).to(tl.float32) + b_Akk30 = b_Akk30 * b_b3[:, None] + b_Akk31 = b_Akk31 * b_b3[:, None] + b_Akk32 = b_Akk32 * b_b3[:, None] + + p_Akk00 = tl.make_block_ptr( + Akkd, (T, BC), (H * BC, 1), (i_tc0, 0), (BC, BC), (1, 0) + ) + p_Akk11 = tl.make_block_ptr( + Akkd, (T, BC), (H * BC, 1), (i_tc1, 0), (BC, BC), (1, 0) + ) + p_Akk22 = tl.make_block_ptr( + Akkd, (T, BC), (H * BC, 1), (i_tc2, 0), (BC, BC), (1, 0) + ) + p_Akk33 = tl.make_block_ptr( + Akkd, (T, BC), (H * BC, 1), (i_tc3, 0), (BC, BC), (1, 0) + ) + b_Ai00 = tl.load(p_Akk00, boundary_check=(0, 1)).to(tl.float32) + b_Ai11 = tl.load(p_Akk11, boundary_check=(0, 1)).to(tl.float32) + b_Ai22 = tl.load(p_Akk22, boundary_check=(0, 1)).to(tl.float32) + b_Ai33 = tl.load(p_Akk33, boundary_check=(0, 1)).to(tl.float32) + + ################################################################################ + # forward substitution on diagonals + ################################################################################ + + if not USE_SAFE_GATE: + m_A = o_i[:, None] > o_i[None, :] + m_I = o_i[:, None] == o_i[None, :] + + b_Ai00 = -tl.where(m_A, b_Ai00, 0) + b_Ai11 = -tl.where(m_A, b_Ai11, 0) + b_Ai22 = -tl.where(m_A, b_Ai22, 0) + b_Ai33 = -tl.where(m_A, b_Ai33, 0) + + for i in range(2, min(BC, T - i_tc0)): + b_a00 = -tl.load(Akkd + (i_tc0 + i) * H * BC + o_i) + b_a00 = tl.where(o_i < i, b_a00, 0.0) + b_a00 += tl.sum(b_a00[:, None] * b_Ai00, 0) + b_Ai00 = tl.where((o_i == i)[:, None], b_a00, b_Ai00) + for i in range(BC + 2, min(2 * BC, T - i_tc0)): + b_a11 = -tl.load(Akkd + (i_tc0 + i) * H * BC + o_i) + b_a11 = tl.where(o_i < i - BC, b_a11, 0.0) + b_a11 += tl.sum(b_a11[:, None] * b_Ai11, 0) + b_Ai11 = tl.where((o_i == i - BC)[:, None], b_a11, b_Ai11) + for i in range(2 * BC + 2, min(3 * BC, T - i_tc0)): + b_a22 = -tl.load(Akkd + (i_tc0 + i) * H * BC + o_i) + b_a22 = tl.where(o_i < i - 2 * BC, b_a22, 0.0) + b_a22 += tl.sum(b_a22[:, None] * b_Ai22, 0) + b_Ai22 = tl.where((o_i == i - 2 * BC)[:, None], b_a22, b_Ai22) + for i in range(3 * BC + 2, min(4 * BC, T - i_tc0)): + b_a33 = -tl.load(Akkd + (i_tc0 + i) * H * BC + o_i) + b_a33 = tl.where(o_i < i - 3 * BC, b_a33, 0.0) + b_a33 += tl.sum(b_a33[:, None] * b_Ai33, 0) + b_Ai33 = tl.where((o_i == i - 3 * BC)[:, None], b_a33, b_Ai33) + + b_Ai00 += m_I + b_Ai11 += m_I + b_Ai22 += m_I + b_Ai33 += m_I + + ################################################################################ + # compute merged inverse using off-diagonals + ################################################################################ + + # we used tf32 to maintain matrix inverse's precision whenever possible. + b_Ai10 = -tl.dot( + tl.dot(b_Ai11, b_Akk10, input_precision=SOLVE_TRIL_DOT_PRECISION), + b_Ai00, + input_precision=SOLVE_TRIL_DOT_PRECISION, + ) + b_Ai21 = -tl.dot( + tl.dot(b_Ai22, b_Akk21, input_precision=SOLVE_TRIL_DOT_PRECISION), + b_Ai11, + input_precision=SOLVE_TRIL_DOT_PRECISION, + ) + b_Ai32 = -tl.dot( + tl.dot(b_Ai33, b_Akk32, input_precision=SOLVE_TRIL_DOT_PRECISION), + b_Ai22, + input_precision=SOLVE_TRIL_DOT_PRECISION, + ) + + b_Ai20 = -tl.dot( + b_Ai22, + tl.dot(b_Akk20, b_Ai00, input_precision=SOLVE_TRIL_DOT_PRECISION) + + tl.dot(b_Akk21, b_Ai10, input_precision=SOLVE_TRIL_DOT_PRECISION), + input_precision=SOLVE_TRIL_DOT_PRECISION, + ) + b_Ai31 = -tl.dot( + b_Ai33, + tl.dot(b_Akk31, b_Ai11, input_precision=SOLVE_TRIL_DOT_PRECISION) + + tl.dot(b_Akk32, b_Ai21, input_precision=SOLVE_TRIL_DOT_PRECISION), + input_precision=SOLVE_TRIL_DOT_PRECISION, + ) + b_Ai30 = -tl.dot( + b_Ai33, + tl.dot(b_Akk30, b_Ai00, input_precision=SOLVE_TRIL_DOT_PRECISION) + + tl.dot(b_Akk31, b_Ai10, input_precision=SOLVE_TRIL_DOT_PRECISION) + + tl.dot(b_Akk32, b_Ai20, input_precision=SOLVE_TRIL_DOT_PRECISION), + input_precision=SOLVE_TRIL_DOT_PRECISION, + ) + + ################################################################################ + # store full Akk_inv to Akk + ################################################################################ + + p_Akk00 = tl.make_block_ptr(Akk, (T, BT), (H * BT, 1), (i_tc0, 0), (BC, BC), (1, 0)) + p_Akk10 = tl.make_block_ptr(Akk, (T, BT), (H * BT, 1), (i_tc1, 0), (BC, BC), (1, 0)) + p_Akk11 = tl.make_block_ptr( + Akk, (T, BT), (H * BT, 1), (i_tc1, BC), (BC, BC), (1, 0) + ) + p_Akk20 = tl.make_block_ptr(Akk, (T, BT), (H * BT, 1), (i_tc2, 0), (BC, BC), (1, 0)) + p_Akk21 = tl.make_block_ptr( + Akk, (T, BT), (H * BT, 1), (i_tc2, BC), (BC, BC), (1, 0) + ) + p_Akk22 = tl.make_block_ptr( + Akk, (T, BT), (H * BT, 1), (i_tc2, 2 * BC), (BC, BC), (1, 0) + ) + p_Akk30 = tl.make_block_ptr(Akk, (T, BT), (H * BT, 1), (i_tc3, 0), (BC, BC), (1, 0)) + p_Akk31 = tl.make_block_ptr( + Akk, (T, BT), (H * BT, 1), (i_tc3, BC), (BC, BC), (1, 0) + ) + p_Akk32 = tl.make_block_ptr( + Akk, (T, BT), (H * BT, 1), (i_tc3, 2 * BC), (BC, BC), (1, 0) + ) + p_Akk33 = tl.make_block_ptr( + Akk, (T, BT), (H * BT, 1), (i_tc3, 3 * BC), (BC, BC), (1, 0) + ) + + tl.store(p_Akk00, b_Ai00.to(Akk.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_Akk10, b_Ai10.to(Akk.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_Akk11, b_Ai11.to(Akk.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_Akk20, b_Ai20.to(Akk.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_Akk21, b_Ai21.to(Akk.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_Akk22, b_Ai22.to(Akk.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_Akk30, b_Ai30.to(Akk.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_Akk31, b_Ai31.to(Akk.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_Akk32, b_Ai32.to(Akk.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_Akk33, b_Ai33.to(Akk.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.heuristics( + { + "IS_VARLEN": lambda args: args["cu_seqlens"] is not None, + } +) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps, num_stages=num_stages) + for num_warps in [1, 2, 4, 8] + for num_stages in [2, 3, 4] + ], + key=["BT", "BC"], + **autotune_cache_kwargs, +) +@triton.jit(do_not_specialize=["T"]) +def chunk_kda_fwd_kernel_intra_sub_chunk( + q, + k, + g, + beta, + Aqk, + Akk, + scale, + cu_seqlens, + chunk_indices, + T, + H: tl.constexpr, + K: tl.constexpr, + BT: tl.constexpr, + BC: tl.constexpr, + BK: tl.constexpr, + IS_VARLEN: tl.constexpr, + USE_GATHER: tl.constexpr, +): + i_t, i_i, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_b, i_h = i_bh // H, i_bh % H + + if IS_VARLEN: + i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load( + chunk_indices + i_t * 2 + 1 + ).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load( + cu_seqlens + i_n + 1 + ).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + + i_ti = i_t * BT + i_i * BC + if i_ti >= T: + return + + o_c = i_ti + tl.arange(0, BC) + m_c = o_c < T + + q = q + (bos * H + i_h) * K + k = k + (bos * H + i_h) * K + g = g + (bos * H + i_h) * K + beta = beta + bos * H + i_h + Aqk = Aqk + (bos * H + i_h) * BT + Akk = Akk + (bos * H + i_h) * BC + + p_q = tl.make_block_ptr(q, (T, K), (H * K, 1), (i_ti, 0), (BC, BK), (1, 0)) + p_k = tl.make_block_ptr(k, (T, K), (H * K, 1), (i_ti, 0), (BC, BK), (1, 0)) + p_g = tl.make_block_ptr(g, (T, K), (H * K, 1), (i_ti, 0), (BC, BK), (1, 0)) + + p_beta = tl.make_block_ptr(beta, (T,), (H,), (i_ti,), (BC,), (0,)) + + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_g = tl.load(p_g, boundary_check=(0, 1)) + b_beta = tl.load(p_beta, boundary_check=(0,)) + + if USE_GATHER: + b_gn = gather( + b_g, tl.full([1, BK], min(BC // 2, T - i_ti - 1), dtype=tl.int16), axis=0 + ) + else: + # calculate offset + p_gn = g + (i_ti + min(BC // 2, T - i_ti - 1)) * H * K + tl.arange(0, BK) + b_gn = tl.load(p_gn, mask=tl.arange(0, BK) < K, other=0.0) + b_gn = b_gn[None, :] + + # current block, keep numerical stability by subtracting the left boundary + # less than 85 to avoid overflow in exp2 + b_gm = (b_g - b_gn).to(tl.float32) + + b_gq = tl.where(m_c[:, None], exp2(b_gm), 0.0) + b_gk = tl.where(m_c[:, None], exp2(-b_gm), 0.0) + + b_kgt = tl.trans(b_k * b_gk) + + b_Aqk = tl.dot(b_q * b_gq, b_kgt) * scale + b_Akk = tl.dot(b_k * b_gq, b_kgt) * b_beta[:, None] + + o_i = tl.arange(0, BC) + m_Aqk = o_i[:, None] >= o_i[None, :] + m_Akk = o_i[:, None] > o_i[None, :] + m_I = o_i[:, None] == o_i[None, :] + + b_Aqk = tl.where(m_Aqk, b_Aqk, 0.0) + b_Akk = tl.where(m_Akk, b_Akk, 0.0) + + p_Aqk = tl.make_block_ptr( + Aqk, (T, BT), (H * BT, 1), (i_ti, i_i * BC), (BC, BC), (1, 0) + ) + p_Akk = tl.make_block_ptr(Akk, (T, BC), (H * BC, 1), (i_ti, 0), (BC, BC), (1, 0)) + tl.store(p_Aqk, b_Aqk.to(Aqk.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_Akk, b_Akk.to(Akk.dtype.element_ty), boundary_check=(0, 1)) + + tl.debug_barrier() + + ################################################################################ + # forward substitution + ################################################################################ + + b_Ai = -b_Akk + for i in range(2, min(BC, T - i_ti)): + b_a = -tl.load(Akk + (i_ti + i) * H * BC + o_i) + b_a = tl.where(o_i < i, b_a, 0.0) + b_a += tl.sum(b_a[:, None] * b_Ai, 0) + b_Ai = tl.where((o_i == i)[:, None], b_a, b_Ai) + b_Ai += m_I + tl.store(p_Akk, b_Ai.to(Akk.dtype.element_ty), boundary_check=(0, 1)) + + +def chunk_kda_fwd_intra( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + gk: torch.Tensor | None = None, + beta: torch.Tensor | None = None, + scale: float | None = None, + cu_seqlens: torch.LongTensor | None = None, + chunk_size: int = 64, + chunk_indices: torch.LongTensor | None = None, + safe_gate: bool = False, + disable_recompute: bool = False, +): + B, T, H, K = k.shape + BT = chunk_size + BC = 16 + if chunk_indices is None and cu_seqlens is not None: + chunk_indices = prepare_chunk_indices(cu_seqlens, BT) + NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) + NC = triton.cdiv(BT, BC) + + Aqk = torch.zeros(B, T, H, BT, device=k.device, dtype=k.dtype) + # Akk must be zero-initialized - kernel only writes lower triangular + Akk = torch.zeros(B, T, H, BT, device=k.device, dtype=k.dtype) + # Separate fp32 buffer for diagonal 16x16 blocks (for precision in solve_tril) + Akkd = torch.zeros(B, T, H, BC, device=k.device, dtype=torch.float32) + + # Step 1: Run token_parallel first to compute diagonal blocks into Akkd (fp32) + # Step 1: compute diagonal blocks into Akk_diag (fp32) + if safe_gate: + grid = (NT, NC, B * H) + BK = triton.next_power_of_2(K) + chunk_kda_fwd_kernel_intra_sub_chunk[grid]( + q=q, + k=k, + g=gk, + beta=beta, + Aqk=Aqk, + Akk=Akkd, + scale=scale, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + T=T, + H=H, + K=K, + BT=BT, + BC=BC, + BK=BK, + USE_GATHER=is_gather_supported, + ) + else: + Aqk, Akkd = chunk_kda_fwd_intra_token_parallel( + q=q, + k=k, + gk=gk, + beta=beta, + Aqk=Aqk, + Akk=Akkd, + scale=scale, + cu_seqlens=cu_seqlens, + chunk_size=BT, + sub_chunk_size=BC, + ) + + # Step 2: Fused inter + solve_tril (works for both fixed-len and varlen) + grid = (NT, B * H) + chunk_kda_fwd_kernel_inter_solve_fused[grid]( + q=q, + k=k, + g=gk, + beta=beta, + Aqk=Aqk, + Akkd=Akkd, + Akk=Akk, + scale=scale, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + T=T, + H=H, + K=K, + BT=BT, + BC=BC, + USE_SAFE_GATE=safe_gate, + ) + from sglang.srt.layers.attention.fla.kda import ( + recompute_w_u_fwd as kda_recompute_w_u_fwd, + ) + + w, u, qg, kg = kda_recompute_w_u_fwd( + k=k, + v=v, + beta=beta, + A=Akk, + q=q if disable_recompute else None, + gk=gk, + cu_seqlens=cu_seqlens, + ) + return w, u, qg, kg, Aqk, Akk diff --git a/python/sglang/srt/layers/attention/fla/chunk_intra_token_parallel.py b/python/sglang/srt/layers/attention/fla/chunk_intra_token_parallel.py new file mode 100644 index 000000000000..ec8bc848c839 --- /dev/null +++ b/python/sglang/srt/layers/attention/fla/chunk_intra_token_parallel.py @@ -0,0 +1,197 @@ +# Adapted from flash-linear-attention project. +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang +# Token-parallel implementation of KDA intra chunk kernel + +import torch +import triton +import triton.language as tl + +from sglang.srt.layers.attention.fla.op import exp2 +from sglang.srt.layers.attention.fla.utils import autotune_cache_kwargs + + +@triton.heuristics( + { + "IS_VARLEN": lambda args: args["cu_seqlens"] is not None, + } +) +@triton.autotune( + configs=[ + triton.Config({"BH": BH}, num_warps=num_warps) + for BH in [1, 2, 4, 8] + for num_warps in [1, 2, 4, 8] + ], + key=["K", "H"], + **autotune_cache_kwargs, +) +@triton.jit(do_not_specialize=["T", "N"]) +def chunk_kda_fwd_kernel_intra_token_parallel( + q, + k, + g, + beta, + Aqk, + Akk, + scale, + cu_seqlens, + N, + T, + H: tl.constexpr, + K: tl.constexpr, + BT: tl.constexpr, + BC: tl.constexpr, + BK: tl.constexpr, + BH: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + i_tg, i_hg = tl.program_id(0), tl.program_id(1) + + if IS_VARLEN: + i_n = 0 + left, right = 0, N + + # Unrolled binary search (max B=2^32) + # We can limit iterations based on expected max batch size if needed + # 20 iterations covers B=1M, usually enough + for _ in range(20): + if left < right: + mid = (left + right) // 2 + if i_tg < tl.load(cu_seqlens + mid + 1).to(tl.int32): + right = mid + else: + left = mid + 1 + i_n = left + + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load( + cu_seqlens + i_n + 1 + ).to(tl.int32) + T = eos - bos + i_t = i_tg - bos + else: + bos = (i_tg // T) * T + i_t = i_tg % T + + if i_t >= T: + return + + i_c = i_t // BT + i_s = (i_t % BT) // BC + i_tc = i_c * BT + i_ts = i_tc + i_s * BC + + q += bos * H * K + k += bos * H * K + g += bos * H * K + Aqk += bos * H * BT + Akk += bos * H * BC + beta += bos * H + + o_h = tl.arange(0, BH) + o_k = tl.arange(0, BK) + m_h = (i_hg * BH + o_h) < H + m_k = o_k < K + + p_q = tl.make_block_ptr( + q + i_t * H * K, (H, K), (K, 1), (i_hg * BH, 0), (BH, BK), (1, 0) + ) + p_k = tl.make_block_ptr( + k + i_t * H * K, (H, K), (K, 1), (i_hg * BH, 0), (BH, BK), (1, 0) + ) + p_g = tl.make_block_ptr( + g + i_t * H * K, (H, K), (K, 1), (i_hg * BH, 0), (BH, BK), (1, 0) + ) + p_beta = tl.make_block_ptr(beta + i_t * H, (H,), (1,), (i_hg * BH,), (BH,), (0,)) + # [BH, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)).to(tl.float32) + b_k = tl.load(p_k, boundary_check=(0, 1)).to(tl.float32) + b_g = tl.load(p_g, boundary_check=(0, 1)).to(tl.float32) + b_k = b_k * tl.load(p_beta, boundary_check=(0,)).to(tl.float32)[:, None] + + for j in range(i_ts, min(i_t + 1, min(T, i_ts + BC))): + p_kj = tl.make_block_ptr( + k + j * H * K, (H, K), (K, 1), (i_hg * BH, 0), (BH, BK), (1, 0) + ) + p_gj = tl.make_block_ptr( + g + j * H * K, (H, K), (K, 1), (i_hg * BH, 0), (BH, BK), (1, 0) + ) + # [BH, BK] + b_kj = tl.load(p_kj, boundary_check=(0, 1)).to(tl.float32) + b_gj = tl.load(p_gj, boundary_check=(0, 1)).to(tl.float32) + + b_kgj = b_kj * exp2(b_g - b_gj) + + b_kgj = tl.where(m_k[None, :], b_kgj, 0.0) + # [BH] + b_Aqk = tl.sum(b_q * b_kgj, axis=1) * scale + b_Akk = tl.sum(b_k * b_kgj, axis=1) * tl.where(j < i_t, 1.0, 0.0) + + tl.store( + Aqk + i_t * H * BT + (i_hg * BH + o_h) * BT + j % BT, + b_Aqk.to(Aqk.dtype.element_ty), + mask=m_h, + ) + tl.store( + Akk + i_t * H * BC + (i_hg * BH + o_h) * BC + j - i_ts, + b_Akk.to(Akk.dtype.element_ty), + mask=m_h, + ) + + +def chunk_kda_fwd_intra_token_parallel( + q: torch.Tensor, + k: torch.Tensor, + gk: torch.Tensor, + beta: torch.Tensor, + Aqk: torch.Tensor, + Akk: torch.Tensor, + scale: float, + cu_seqlens: torch.LongTensor | None = None, + chunk_size: int = 64, + sub_chunk_size: int = 16, +) -> None: + """ + Token-parallel implementation: each token gets its own thread block. + Supports both fixed-length and variable-length sequences. + Reduces wasted computation on padding. + + Writes directly to Aqk and Akk tensors (in-place). + + Args: + q: [B, T, H, K] + k: [B, T, H, K] + gk: [B, T, H, K] cumsum of gates + beta: [B, T, H] + Aqk: [B, T, H, BT] output tensor to write to + Akk: [B, T, H, BC] output tensor for diagonal blocks (fp32) + scale: attention scale + chunk_size: BT (default 64) + sub_chunk_size: BC (default 16) + """ + B, T, H, K = q.shape + N = len(cu_seqlens) - 1 if cu_seqlens is not None else B + BT = chunk_size + BC = sub_chunk_size + + def grid(meta): + return (B * T, triton.cdiv(H, meta["BH"])) + + BK = triton.next_power_of_2(K) + + chunk_kda_fwd_kernel_intra_token_parallel[grid]( + q=q, + k=k, + g=gk, + beta=beta, + Aqk=Aqk, + Akk=Akk, + scale=scale, + cu_seqlens=cu_seqlens, + N=N, + T=T, + H=H, + K=K, + BT=BT, + BC=BC, + BK=BK, + ) + return Aqk, Akk diff --git a/python/sglang/srt/layers/attention/fla/fused_recurrent.py b/python/sglang/srt/layers/attention/fla/fused_recurrent.py index 44e42e2d60ed..f110770c0223 100644 --- a/python/sglang/srt/layers/attention/fla/fused_recurrent.py +++ b/python/sglang/srt/layers/attention/fla/fused_recurrent.py @@ -940,3 +940,6 @@ def fused_recurrent_gated_delta_rule_update( retrieve_parent_token, ) return o + + +fused_recurrent_gdn = fused_recurrent_gated_delta_rule diff --git a/python/sglang/srt/layers/attention/fla/kda.py b/python/sglang/srt/layers/attention/fla/kda.py index 3f17b21cce54..a8d5cb405ea9 100644 --- a/python/sglang/srt/layers/attention/fla/kda.py +++ b/python/sglang/srt/layers/attention/fla/kda.py @@ -9,6 +9,7 @@ import triton.language as tl from sglang.srt.layers.attention.fla.chunk_delta_h import chunk_gated_delta_rule_fwd_h +from sglang.srt.layers.attention.fla.chunk_intra import chunk_kda_fwd_intra from sglang.srt.layers.attention.fla.cumsum import chunk_local_cumsum from sglang.srt.layers.attention.fla.fused_norm_gate import layer_norm_gated_fwd from sglang.srt.layers.attention.fla.fused_recurrent import ( @@ -17,7 +18,6 @@ from sglang.srt.layers.attention.fla.index import prepare_chunk_indices from sglang.srt.layers.attention.fla.l2norm import l2norm_fwd from sglang.srt.layers.attention.fla.op import exp, log -from sglang.srt.layers.attention.fla.solve_tril import solve_tril from sglang.srt.layers.attention.fla.utils import is_amd BT_LIST_AUTOTUNE = [32, 64, 128] @@ -863,27 +863,19 @@ def chunk_kda_fwd( ): chunk_size = 64 g = chunk_local_cumsum(g, chunk_size=chunk_size, cu_seqlens=cu_seqlens) - # the intra Aqk is kept in fp32 - # the computation has very marginal effect on the entire throughput - A, Aqk = chunk_kda_scaled_dot_kkt_fwd( + + # Fused: scaled_dot_kkt + solve_tril + recompute_w_u + w, u, _, kg, Aqk, _ = chunk_kda_fwd_intra( q=q, k=k, + v=v, gk=g, beta=beta, scale=scale, cu_seqlens=cu_seqlens, - output_dtype=torch.float32, - ) - A = solve_tril(A=A, cu_seqlens=cu_seqlens, output_dtype=k.dtype) - w, u, _, kg = recompute_w_u_fwd( - k=k, - v=v, - beta=beta, - A=A, - gk=g, - cu_seqlens=cu_seqlens, + chunk_size=chunk_size, ) - del A + h, v_new = chunk_gated_delta_rule_fwd_h( k=kg, w=w, diff --git a/python/sglang/srt/layers/attention/fla/utils.py b/python/sglang/srt/layers/attention/fla/utils.py index af6ca3d6e572..4154a3c52352 100644 --- a/python/sglang/srt/layers/attention/fla/utils.py +++ b/python/sglang/srt/layers/attention/fla/utils.py @@ -3,6 +3,7 @@ import contextlib import functools +import inspect import logging import os import sys @@ -20,6 +21,16 @@ COMPILER_MODE = os.getenv("FLA_COMPILER_MODE") == "1" FLA_CI_ENV = os.getenv("FLA_CI_ENV") == "1" +FLA_CACHE_RESULTS = os.getenv("FLA_CACHE_RESULTS", "1") == "1" + + +SUPPORTS_AUTOTUNE_CACHE = ( + "cache_results" in inspect.signature(triton.autotune).parameters +) + +autotune_cache_kwargs = ( + {"cache_results": FLA_CACHE_RESULTS} if SUPPORTS_AUTOTUNE_CACHE else {} +) @lru_cache(maxsize=1) @@ -323,3 +334,6 @@ def custom_device_ctx(index: int): def custom_device_ctx(index: int): return torch.cuda.device(index) + + +device_platform = get_available_device() diff --git a/python/sglang/srt/layers/attention/fla/wy_fast.py b/python/sglang/srt/layers/attention/fla/wy_fast.py index 757e5621087b..980a475ccc40 100644 --- a/python/sglang/srt/layers/attention/fla/wy_fast.py +++ b/python/sglang/srt/layers/attention/fla/wy_fast.py @@ -115,14 +115,14 @@ def recompute_w_u_fwd( g_cumsum: torch.Tensor, A: torch.Tensor, cu_seqlens: Optional[torch.LongTensor], + chunk_indices: torch.LongTensor | None = None, ) -> Tuple[torch.Tensor, torch.Tensor]: B, T, Hg, K, V = *k.shape, v.shape[-1] H = v.shape[-2] BT = A.shape[-1] - chunk_indices = ( - prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None - ) + if chunk_indices is None and cu_seqlens is not None: + chunk_indices = prepare_chunk_indices(cu_seqlens, BT) NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) BK = 64 BV = 64 diff --git a/python/sglang/srt/layers/attention/flashinfer_backend.py b/python/sglang/srt/layers/attention/flashinfer_backend.py index 147863803602..4fe8aec31301 100644 --- a/python/sglang/srt/layers/attention/flashinfer_backend.py +++ b/python/sglang/srt/layers/attention/flashinfer_backend.py @@ -17,7 +17,10 @@ import torch from sglang.kernel_api_logging import debug_kernel_api -from sglang.srt.compilation.piecewise_context_manager import is_in_piecewise_cuda_graph +from sglang.srt.compilation.piecewise_context_manager import ( + get_forward_context, + is_in_piecewise_cuda_graph, +) from sglang.srt.dllm.config import DllmConfig from sglang.srt.environ import envs from sglang.srt.layers.attention.base_attn_backend import AttentionBackend @@ -147,6 +150,7 @@ def __init__( self.max_context_len = model_runner.model_config.context_len self.skip_prefill = skip_prefill self.is_multimodal = model_runner.model_config.is_multimodal + self.page_size = model_runner.page_size assert not ( model_runner.sliding_window_size is not None @@ -1048,16 +1052,19 @@ def update_cross_attention( fixed_split_size: Optional[int] = None, disable_split_kv: Optional[bool] = None, ): + # Cache encoder_lens on CPU to avoid GPU→CPU transfer per call + encoder_lens_cpu = encoder_lens.cpu() if encoder_lens is not None else None for wrapper_id in range(2): if wrapper_id == 0: - # Normal attention paged_kernel_lens = seq_lens kv_start_idx = encoder_lens + kv_lens_cpu = seq_lens_cpu else: - # Cross attention + # Cross-attention: attend to encoder tokens only paged_kernel_lens = encoder_lens kv_start_idx = torch.zeros_like(encoder_lens) seq_lens_sum = encoder_lens.sum().item() + kv_lens_cpu = encoder_lens_cpu self.call_begin_forward( decode_wrappers[wrapper_id], @@ -1067,7 +1074,7 @@ def update_cross_attention( self.kv_indptr[wrapper_id], kv_start_idx, spec_info, - seq_lens_cpu=seq_lens_cpu, + seq_lens_cpu=kv_lens_cpu, ) def call_begin_forward( @@ -1189,6 +1196,7 @@ def __init__(self, model_runner: ModelRunner, attn_backend: FlashInferAttnBacken self.q_data_type = model_runner.dtype self.sliding_window_size = model_runner.sliding_window_size self.attn_backend = attn_backend + self.page_size = attn_backend.page_size # Buffers and wrappers self.kv_indptr = attn_backend.kv_indptr @@ -1378,8 +1386,13 @@ def call_begin_forward( # Normal extend kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0) kv_indptr = kv_indptr[: bs + 1] + # Reserve extra space in kv_indices for a potential piecewise CUDA graph + # dummy request (see below). Worst case: static_num_tokens extra pages. + fwd_ctx = get_forward_context() + pcg_num_tokens = fwd_ctx.num_tokens if fwd_ctx is not None else None + extra_kv = pcg_num_tokens if pcg_num_tokens is not None else 0 kv_indices = torch.empty( - paged_kernel_lens_sum + 256, + paged_kernel_lens_sum + extra_kv + 256, dtype=torch.int32, device=req_pool_indices.device, ) @@ -1394,6 +1407,40 @@ def call_begin_forward( ) qo_indptr[1 : bs + 1] = torch.cumsum(seq_lens - prefix_lens, dim=0) qo_indptr = qo_indptr[: bs + 1] + + # Piecewise CUDA graph padding: input_ids are padded to static_num_tokens, + # so q.shape[0] == static_num_tokens but qo_indptr[-1] == actual tokens. + # Append a dummy request for the padding tokens so that + # qo_indptr[-1] == static_num_tokens, satisfying flashinfer's shape check + # without corrupting the causal masks of real requests. + # The dummy request's KV indices all point to slot 0 (a scratch location); + # its attention output is discarded via the [:raw_num_tokens] slice in replay. + bs_eff = bs + # extend_num_tokens is a Python int (== sum of seq_lens - prefix_lens), + # and paged_kernel_lens_sum is also a Python int (== kv_indptr[-1]), + # so this block requires no CPU-GPU synchronisation. + actual_qo_tokens = ( + fwd_ctx.forward_batch.extend_num_tokens if fwd_ctx is not None else None + ) + if ( + pcg_num_tokens is not None + and actual_qo_tokens is not None + and pcg_num_tokens > actual_qo_tokens + ): + pad_tokens = pcg_num_tokens - actual_qo_tokens + num_dummy_pages = (pad_tokens + self.page_size - 1) // self.page_size + kv_start = ( + paged_kernel_lens_sum # equals kv_indptr[-1], no .item() needed + ) + kv_indices[kv_start : kv_start + num_dummy_pages] = 0 + qo_indptr = torch.cat( + [qo_indptr, qo_indptr.new_tensor([pcg_num_tokens])] + ) + kv_indptr = torch.cat( + [kv_indptr, kv_indptr.new_tensor([kv_start + num_dummy_pages])] + ) + bs_eff = bs + 1 + custom_mask = None else: assert isinstance(spec_info, SpecInput) @@ -1405,6 +1452,7 @@ def call_begin_forward( self.req_to_token, ) ) + bs_eff = bs # extend part if use_ragged: @@ -1446,7 +1494,7 @@ def call_begin_forward( qo_indptr, kv_indptr, kv_indices, - self.kv_last_page_len[:bs], + self.kv_last_page_len[:bs_eff], self.num_qo_heads, self.num_kv_heads, self.head_dim, diff --git a/python/sglang/srt/layers/attention/mamba/mamba.py b/python/sglang/srt/layers/attention/mamba/mamba.py index 46d0d5b3f951..d19be6f5b4dd 100644 --- a/python/sglang/srt/layers/attention/mamba/mamba.py +++ b/python/sglang/srt/layers/attention/mamba/mamba.py @@ -60,7 +60,7 @@ def mamba_v2_sharded_weight_loader( ) -> LoaderFunction: """Create a weight loader for mamba v2. This ensures that the projections are correctly sharded so that they can be split into x, B, C. It also - ensures the the all the groups corresponding to a head shard is placed + ensures that all the groups corresponding to a head shard is placed together with it. """ diff --git a/python/sglang/srt/layers/attention/nsa/nsa_indexer.py b/python/sglang/srt/layers/attention/nsa/nsa_indexer.py index 7d1511963191..02ef4e2440cd 100644 --- a/python/sglang/srt/layers/attention/nsa/nsa_indexer.py +++ b/python/sglang/srt/layers/attention/nsa/nsa_indexer.py @@ -1254,22 +1254,48 @@ def forward_npu( and not forward_batch.forward_mode.is_draft_extend() ) - cos_sin = self.rotary_emb.cos_sin_cache[positions] - cos, sin = cos_sin.chunk(2, dim=-1) - cos = cos.repeat(1, 2).view(-1, 1, 1, self.rope_head_dim) - sin = sin.repeat(1, 2).view(-1, 1, 1, self.rope_head_dim) - bs = q_lora.shape[0] - if self.alt_stream is not None: - self.alt_stream.wait_stream(torch.npu.current_stream()) - with torch.npu.stream(self.alt_stream): + + if self.rotary_emb.is_neox_style: + if not hasattr(forward_batch, "npu_indexer_sin_cos_cache"): + cos_sin = self.rotary_emb.cos_sin_cache[positions] + cos, sin = cos_sin.chunk(2, dim=-1) + cos = cos.repeat(1, 2).view(-1, 1, 1, self.rope_head_dim) + sin = sin.repeat(1, 2).view(-1, 1, 1, self.rope_head_dim) + forward_batch.npu_indexer_sin_cos_cache = (sin, cos) + else: + sin, cos = forward_batch.npu_indexer_sin_cos_cache + + if self.alt_stream is not None: + self.alt_stream.wait_stream(torch.npu.current_stream()) + with torch.npu.stream(self.alt_stream): + q_lora = ( + (q_lora, dynamic_scale) if dynamic_scale is not None else q_lora + ) + q = self.wq_b(q_lora)[ + 0 + ] # [bs, 1536] @ [1536, 64 * 128] = [bs, 64 * 128] + wq_b_event = self.alt_stream.record_event() + q = q.view(bs, self.n_heads, self.head_dim) # [bs, 64, 128] + q_pe, q_nope = torch.split( + q, + [self.rope_head_dim, self.head_dim - self.rope_head_dim], + dim=-1, + ) # [bs, 64, 64 + 64] + q_pe = q_pe.view(bs, self.n_heads, 1, self.rope_head_dim) + q_pe = torch_npu.npu_rotary_mul(q_pe, cos, sin).view( + bs, self.n_heads, self.rope_head_dim + ) # [bs, n, d] + q = torch.cat([q_pe, q_nope], dim=-1) + q.record_stream(self.alt_stream) + q_rope_event = self.alt_stream.record_event() + else: q_lora = ( (q_lora, dynamic_scale) if dynamic_scale is not None else q_lora ) q = self.wq_b(q_lora)[ 0 ] # [bs, 1536] @ [1536, 64 * 128] = [bs, 64 * 128] - wq_b_event = self.alt_stream.record_event() q = q.view(bs, self.n_heads, self.head_dim) # [bs, 64, 128] q_pe, q_nope = torch.split( q, @@ -1281,9 +1307,52 @@ def forward_npu( bs, self.n_heads, self.rope_head_dim ) # [bs, n, d] q = torch.cat([q_pe, q_nope], dim=-1) - q.record_stream(self.alt_stream) - q_rope_event = self.alt_stream.record_event() + + if envs.SGLANG_NPU_USE_MULTI_STREAM.get(): + indexer_weight_stream = get_indexer_weight_stream() + indexer_weight_stream.wait_stream(torch.npu.current_stream()) + with torch.npu.stream(indexer_weight_stream): + x = x.view(-1, self.hidden_size) + weights = self.weights_proj(x.float())[0].to(torch.bfloat16) + weights.record_stream(indexer_weight_stream) + weights_event = indexer_weight_stream.record_event() + else: + x = x.view(-1, self.hidden_size) + weights = self.weights_proj(x.float())[0].to(torch.bfloat16) + + k_proj = self.wk(x)[0] # [b, s, 7168] @ [7168, 128] = [b, s, 128] + k = self.k_norm(k_proj) + if ( + _use_ag_after_qlora + and layer_scatter_modes.layer_input_mode == ScatterMode.SCATTERED + and layer_scatter_modes.attn_mode == ScatterMode.TP_ATTN_FULL + ): + k = scattered_to_tp_attn_full(k, forward_batch) + k_pe, k_nope = torch.split( + k, + [self.rope_head_dim, self.head_dim - self.rope_head_dim], + dim=-1, + ) # [bs, 64 + 64] + + k_pe = k_pe.view(-1, 1, 1, self.rope_head_dim) + k_pe = torch.ops.npu.npu_rotary_mul(k_pe, cos, sin).view( + bs, 1, self.rope_head_dim + ) # [bs, 1, d] + k = torch.cat([k_pe, k_nope.unsqueeze(1)], dim=-1) # [bs, 1, 128] + else: + if envs.SGLANG_NPU_USE_MULTI_STREAM.get(): + indexer_weight_stream = get_indexer_weight_stream() + indexer_weight_stream.wait_stream(torch.npu.current_stream()) + with torch.npu.stream(indexer_weight_stream): + x = x.view(-1, self.hidden_size) + weights = self.weights_proj(x.float())[0].to(torch.bfloat16) + weights.record_stream(indexer_weight_stream) + weights_event = indexer_weight_stream.record_event() + else: + x = x.view(-1, self.hidden_size) + weights = self.weights_proj(x.float())[0].to(torch.bfloat16) + q_lora = (q_lora, dynamic_scale) if dynamic_scale is not None else q_lora q = self.wq_b(q_lora)[0] # [bs, 1536] @ [1536, 64 * 128] = [bs, 64 * 128] q = q.view(bs, self.n_heads, self.head_dim) # [bs, 64, 128] @@ -1292,43 +1361,26 @@ def forward_npu( [self.rope_head_dim, self.head_dim - self.rope_head_dim], dim=-1, ) # [bs, 64, 64 + 64] - q_pe = q_pe.view(bs, self.n_heads, 1, self.rope_head_dim) - q_pe = torch_npu.npu_rotary_mul(q_pe, cos, sin).view( - bs, self.n_heads, self.rope_head_dim - ) # [bs, n, d] - q = torch.cat([q_pe, q_nope], dim=-1) - if envs.SGLANG_NPU_USE_MULTI_STREAM.get(): - indexer_weight_stream = get_indexer_weight_stream() - indexer_weight_stream.wait_stream(torch.npu.current_stream()) - with torch.npu.stream(indexer_weight_stream): - x = x.view(-1, self.hidden_size) - weights = self.weights_proj(x.float())[0].to(torch.bfloat16) - weights.record_stream(indexer_weight_stream) - weights_event = indexer_weight_stream.record_event() - else: - x = x.view(-1, self.hidden_size) - weights = self.weights_proj(x.float())[0].to(torch.bfloat16) + k_proj = self.wk(x)[0] # [b, s, 7168] @ [7168, 128] = [b, s, 128] + k = self.k_norm(k_proj) + k_pe, k_nope = torch.split( + k, + [self.rope_head_dim, self.head_dim - self.rope_head_dim], + dim=-1, + ) # [bs, 64 + 64] - k_proj = self.wk(x)[0] # [b, s, 7168] @ [7168, 128] = [b, s, 128] - k = self.k_norm(k_proj) - if ( - _use_ag_after_qlora - and layer_scatter_modes.layer_input_mode == ScatterMode.SCATTERED - and layer_scatter_modes.attn_mode == ScatterMode.TP_ATTN_FULL - ): - k = scattered_to_tp_attn_full(k, forward_batch) - k_pe, k_nope = torch.split( - k, - [self.rope_head_dim, self.head_dim - self.rope_head_dim], - dim=-1, - ) # [bs, 64 + 64] - - k_pe = k_pe.view(-1, 1, 1, self.rope_head_dim) - k_pe = torch.ops.npu.npu_rotary_mul(k_pe, cos, sin).view( - bs, 1, self.rope_head_dim - ) # [bs, 1, d] - k = torch.cat([k_pe, k_nope.unsqueeze(1)], dim=-1) # [bs, 1, 128] + k_pe = k_pe.unsqueeze(1) + + if layer_id == 0: + self.rotary_emb.sin_cos_cache = ( + self.rotary_emb.cos_sin_cache.index_select(0, positions) + ) + + q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe) + k_pe = k_pe.squeeze(1) + q = torch.cat([q_pe, q_nope], dim=-1) + k = torch.cat([k_pe, k_nope], dim=-1) if ( is_prefill @@ -1394,7 +1446,7 @@ def forward_npu( past_key_states = forward_batch.token_to_kv_pool.get_index_k_buffer(layer_id) - if self.alt_stream is not None: + if self.rotary_emb.is_neox_style and self.alt_stream is not None: torch.npu.current_stream().wait_event(q_rope_event) if envs.SGLANG_NPU_USE_MULTI_STREAM.get(): torch.npu.current_stream().wait_event(weights_event) diff --git a/python/sglang/srt/layers/attention/nsa/tilelang_kernel.py b/python/sglang/srt/layers/attention/nsa/tilelang_kernel.py index 35fe6997d79f..bfc62d7f0b19 100644 --- a/python/sglang/srt/layers/attention/nsa/tilelang_kernel.py +++ b/python/sglang/srt/layers/attention/nsa/tilelang_kernel.py @@ -1,3 +1,4 @@ +from functools import lru_cache from typing import Optional, Tuple import tilelang @@ -44,6 +45,23 @@ def fast_round_scale(amax, fp8_max_inv): return fast_pow2(fast_log2_ceil(amax * fp8_max_inv)) +@lru_cache(maxsize=8) +def _pick_inner_iter(seq: int, ni: int, cu: int, block_per_cu: int) -> int: + """ + Pick the largest valid inner_iter (power-of-two divisor of ni) that keeps + enough work per CU (seq * ni / inner_iter / cu >= block_per_cu), so we avoid + under-utilization while minimizing the number of partial groups. + """ + + max_it = int(seq * ni / (cu * block_per_cu)) + it = ni + while it >= 2: + if it <= max_it and ni % it == 0: + return it + it //= 2 + return 1 + + @tilelang.jit(pass_configs=pass_configs) def act_quant_kernel( N, in_dtype=BF16, out_dtype=FP8, scale_dtype=FP32, round_scale=False @@ -1037,6 +1055,255 @@ def main( return main +@tilelang.jit(out_idx=[-2, -1], pass_configs=pass_configs) +def sparse_mla_fwd_decode_partial_fp8( + num_heads: int, + d_v: int, + d_tail: int, + topk: int, + *, + sm_scale=None, + block_I=64, + inner_iter=1, + threads=256, +): + assert d_v == 512, f"only support d_v=512" + assert ( + topk % block_I == 0 + ), "otherwise will load some index=0 thus causing wrong kv to be loaded" + + # Softmax scores are in [0, 1]. We scale by fp8_max_val before FP8 cast + # to better utilize FP8 dynamic range, then apply the inverse scale after GEMM. + # This is numerically safe because softmax output is bounded by 1. + fp8_dtype = "float8_e4m3fnuz" if _is_fp8_fnuz else "float8_e4m3fn" + fp8_max_val = 240.0 if _is_fp8_fnuz else 448.0 + s_inv_scale_const = fp8_max_val + s_scale_const = 1.0 / fp8_max_val + + BI = block_I + group_size = 128 + dim_quant_fp8 = d_v + d_tail + rope_offset_fp8 = d_v + n_groups = topk // (BI * inner_iter) + + if sm_scale is None: + sm_scale = (1.0 / (d_v + d_tail)) ** 0.5 * 1.44269504 + else: + sm_scale = sm_scale * 1.44269504 + + h_per_block = 16 + # Match bf16 partial behavior: keep fixed 16-head tiles and use + # sliced T.copy on H0:H1 for tail handling. + assert ( + num_heads <= h_per_block or num_heads % h_per_block == 0 + ), "num_heads must be <=16 or divisible by 16" + head_blocks_per_seq = (num_heads + h_per_block - 1) // h_per_block + + batch = 1 + kv_group = 1 + seq_len = T.symbolic("seq_len") + num_pages = T.symbolic("num_pages") + + q_fp8_shape = [batch, seq_len, num_heads, d_v + d_tail] + kv_fp8_shape = [batch, num_pages, kv_group, dim_quant_fp8] + idx_shape = [batch, seq_len, kv_group, topk] + partial_o_shape = [batch, seq_len, n_groups, num_heads, d_v] + partial_lse_shape = [batch, seq_len, n_groups, num_heads] + + accum_dtype = T.float32 + dtype_bf16 = T.bfloat16 + + @T.prim_func + def main( + q_fp8: T.Tensor(q_fp8_shape, fp8_dtype), + kv_fp8: T.Tensor(kv_fp8_shape, fp8_dtype), + indices: T.Tensor(idx_shape, T.int32), + partial_o: T.Tensor(partial_o_shape, dtype_bf16), + partial_lse: T.Tensor(partial_lse_shape, accum_dtype), + ): + with T.Kernel(seq_len * head_blocks_per_seq, n_groups, threads=threads) as ( + bx, + by, + ): + b_i, g_i = 0, 0 + s_i = bx // head_blocks_per_seq + group_i = by + H0 = (bx % head_blocks_per_seq) * h_per_block + H1 = H0 + h_per_block + + # We intentionally split the K=512 GEMM into 4x128 tiles. + # Although this adds extra intermediate memory traffic, + # it shortens the MFMA accumulation dependency chain and improves performance. + q_tile0 = T.alloc_shared([h_per_block, group_size], fp8_dtype) + q_tile1 = T.alloc_shared([h_per_block, group_size], fp8_dtype) + q_tile2 = T.alloc_shared([h_per_block, group_size], fp8_dtype) + q_tile3 = T.alloc_shared([h_per_block, group_size], fp8_dtype) + kv_tile0 = T.alloc_shared([BI, group_size], fp8_dtype) + kv_tile1 = T.alloc_shared([BI, group_size], fp8_dtype) + kv_tile2 = T.alloc_shared([BI, group_size], fp8_dtype) + kv_tile3 = T.alloc_shared([BI, group_size], fp8_dtype) + q_tail_buf = T.alloc_shared([h_per_block, d_tail], fp8_dtype) + k_tail_shared = T.alloc_shared([BI, d_tail], fp8_dtype) + s_fp8_shared = T.alloc_shared([h_per_block, BI], fp8_dtype) + page_idx_shared = T.alloc_shared([BI], T.int32) + + mask = T.alloc_fragment([BI], T.bool) + acc_s = T.alloc_fragment([h_per_block, BI], accum_dtype) + acc_tile = T.alloc_fragment([h_per_block, BI], accum_dtype) + sv_tile = T.alloc_fragment([h_per_block, group_size], accum_dtype) + sumexp = T.alloc_fragment([h_per_block], accum_dtype) + sumexp_i = T.alloc_fragment([h_per_block], accum_dtype) + alpha = T.alloc_fragment([h_per_block], accum_dtype) + m_i = T.alloc_fragment([h_per_block], accum_dtype) + m_i_prev = T.alloc_fragment([h_per_block], accum_dtype) + inv_denom = T.alloc_fragment([h_per_block], accum_dtype) + + acc_o_tile0 = T.alloc_fragment([h_per_block, group_size], accum_dtype) + acc_o_tile1 = T.alloc_fragment([h_per_block, group_size], accum_dtype) + acc_o_tile2 = T.alloc_fragment([h_per_block, group_size], accum_dtype) + acc_o_tile3 = T.alloc_fragment([h_per_block, group_size], accum_dtype) + + T.fill(acc_o_tile0, 0) + T.fill(acc_o_tile1, 0) + T.fill(acc_o_tile2, 0) + T.fill(acc_o_tile3, 0) + T.fill(sumexp, 0) + T.fill(m_i, -(2**30)) + + T.copy(q_fp8[b_i, s_i, H0:H1, d_v:], q_tail_buf) + T.copy(q_fp8[b_i, s_i, H0:H1, 0 * group_size : 1 * group_size], q_tile0) + T.copy(q_fp8[b_i, s_i, H0:H1, 1 * group_size : 2 * group_size], q_tile1) + T.copy(q_fp8[b_i, s_i, H0:H1, 2 * group_size : 3 * group_size], q_tile2) + T.copy(q_fp8[b_i, s_i, H0:H1, 3 * group_size : 4 * group_size], q_tile3) + + for k_i in T.serial(inner_iter): + topk_block_i = group_i * inner_iter + k_i + + for bi_i in T.Parallel(BI): + idx = indices[b_i, s_i, g_i, topk_block_i * BI + bi_i] + valid = idx >= 0 + page_idx_shared[bi_i] = T.if_then_else(valid, idx, 0) + mask[bi_i] = valid + + for bi_i, j in T.Parallel(BI, group_size): + page = page_idx_shared[bi_i] + kv_tile0[bi_i, j] = kv_fp8[b_i, page, g_i, 0 * group_size + j] + kv_tile1[bi_i, j] = kv_fp8[b_i, page, g_i, 1 * group_size + j] + kv_tile2[bi_i, j] = kv_fp8[b_i, page, g_i, 2 * group_size + j] + kv_tile3[bi_i, j] = kv_fp8[b_i, page, g_i, 3 * group_size + j] + + for bi_i, j in T.Parallel(BI, d_tail): + page = page_idx_shared[bi_i] + k_tail_shared[bi_i, j] = kv_fp8[b_i, page, g_i, rope_offset_fp8 + j] + + for h_i, bi_i in T.Parallel(h_per_block, BI): + acc_s[h_i, bi_i] = T.if_then_else( + mask[bi_i], 0, -T.infinity(acc_s.dtype) + ) + + T.gemm(q_tile0, kv_tile0, acc_s, transpose_B=True, clear_accum=False) + T.gemm(q_tile1, kv_tile1, acc_tile, transpose_B=True, clear_accum=True) + for h_i, bi_i in T.Parallel(h_per_block, BI): + acc_s[h_i, bi_i] += acc_tile[h_i, bi_i] + T.gemm(q_tile2, kv_tile2, acc_tile, transpose_B=True, clear_accum=True) + for h_i, bi_i in T.Parallel(h_per_block, BI): + acc_s[h_i, bi_i] += acc_tile[h_i, bi_i] + T.gemm(q_tile3, kv_tile3, acc_tile, transpose_B=True, clear_accum=True) + for h_i, bi_i in T.Parallel(h_per_block, BI): + acc_s[h_i, bi_i] += acc_tile[h_i, bi_i] + T.gemm( + q_tail_buf, + k_tail_shared, + acc_s, + transpose_B=True, + policy=T.GemmWarpPolicy.FullCol, + ) + + T.copy(m_i, m_i_prev) + T.reduce_max(acc_s, m_i, dim=1, clear=False) + for h_i in T.Parallel(h_per_block): + alpha[h_i] = T.exp2((m_i_prev[h_i] - m_i[h_i]) * sm_scale) + for h_i, bi_i in T.Parallel(h_per_block, BI): + acc_s[h_i, bi_i] = T.exp2( + acc_s[h_i, bi_i] * sm_scale - m_i[h_i] * sm_scale + ) + T.reduce_sum(acc_s, sumexp_i, dim=1) + for h_i in T.Parallel(h_per_block): + sumexp[h_i] = sumexp[h_i] * alpha[h_i] + sumexp_i[h_i] + for h_i, j in T.Parallel(h_per_block, group_size): + acc_o_tile0[h_i, j] = acc_o_tile0[h_i, j] * alpha[h_i] + acc_o_tile1[h_i, j] = acc_o_tile1[h_i, j] * alpha[h_i] + acc_o_tile2[h_i, j] = acc_o_tile2[h_i, j] * alpha[h_i] + acc_o_tile3[h_i, j] = acc_o_tile3[h_i, j] * alpha[h_i] + + for h_i, bi_i in T.Parallel(h_per_block, BI): + s_fp8_shared[h_i, bi_i] = T.clamp( + acc_s[h_i, bi_i] * s_inv_scale_const, + -fp8_max_val, + fp8_max_val, + ) + T.gemm(s_fp8_shared, kv_tile0, sv_tile, clear_accum=True) + for h_i, j in T.Parallel(h_per_block, group_size): + acc_o_tile0[h_i, j] = ( + acc_o_tile0[h_i, j] + sv_tile[h_i, j] * s_scale_const + ) + + T.gemm(s_fp8_shared, kv_tile1, sv_tile, clear_accum=True) + for h_i, j in T.Parallel(h_per_block, group_size): + acc_o_tile1[h_i, j] = ( + acc_o_tile1[h_i, j] + sv_tile[h_i, j] * s_scale_const + ) + + T.gemm(s_fp8_shared, kv_tile2, sv_tile, clear_accum=True) + for h_i, j in T.Parallel(h_per_block, group_size): + acc_o_tile2[h_i, j] = ( + acc_o_tile2[h_i, j] + sv_tile[h_i, j] * s_scale_const + ) + + T.gemm(s_fp8_shared, kv_tile3, sv_tile, clear_accum=True) + for h_i, j in T.Parallel(h_per_block, group_size): + acc_o_tile3[h_i, j] = ( + acc_o_tile3[h_i, j] + sv_tile[h_i, j] * s_scale_const + ) + + for h_i in T.Parallel(h_per_block): + denom = T.if_then_else(sumexp[h_i] == 0.0, 1.0, sumexp[h_i]) + inv_denom[h_i] = 1.0 / denom + for h_i, j in T.Parallel(h_per_block, group_size): + acc_o_tile0[h_i, j] = acc_o_tile0[h_i, j] * inv_denom[h_i] + acc_o_tile1[h_i, j] = acc_o_tile1[h_i, j] * inv_denom[h_i] + acc_o_tile2[h_i, j] = acc_o_tile2[h_i, j] * inv_denom[h_i] + acc_o_tile3[h_i, j] = acc_o_tile3[h_i, j] * inv_denom[h_i] + + for h_i in T.Parallel(h_per_block): + sumexp[h_i] = T.if_then_else( + sumexp[h_i] == 0.0, + -(2**30), + T.log2(sumexp[h_i]) + m_i[h_i] * sm_scale, + ) + + T.copy( + acc_o_tile0, + partial_o[b_i, s_i, group_i, H0:H1, 0 * group_size : 1 * group_size], + ) + T.copy( + acc_o_tile1, + partial_o[b_i, s_i, group_i, H0:H1, 1 * group_size : 2 * group_size], + ) + T.copy( + acc_o_tile2, + partial_o[b_i, s_i, group_i, H0:H1, 2 * group_size : 3 * group_size], + ) + T.copy( + acc_o_tile3, + partial_o[b_i, s_i, group_i, H0:H1, 3 * group_size : 4 * group_size], + ) + + T.copy(sumexp, partial_lse[b_i, s_i, group_i, H0:H1]) + + return main + + def tilelang_sparse_fwd( q: torch.Tensor, kv: torch.Tensor, @@ -1052,46 +1319,47 @@ def tilelang_sparse_fwd( assert topk == 2048 if _is_hip: - # sparse_mla_fwd_decode_partial splits topk KV blocks into N_GROUPS - # independent tiles per query, then sparse_mla_fwd_decode_combine - # reduces them via online softmax. - - if _is_gfx95_supported: - # gfx950 - block_I, threads = 64, 256 - block_per_cu = 2 + is_fp8_kv = kv.dtype in (torch.float8_e4m3fn, torch.float8_e4m3fnuz) + if is_fp8_kv: + if q.dtype != kv.dtype: + q = q.to(kv.dtype) + if _is_gfx95_supported: + block_I, threads, block_per_cu, cu = 64, 256, 2, 256 + else: + block_I, threads, block_per_cu, cu = 64, 256, 1, 304 + ni = topk // block_I + inner_iter = _pick_inner_iter(q.shape[0], ni, cu, block_per_cu) + kernel_partial = sparse_mla_fwd_decode_partial_fp8( + num_heads, + d_v, + tail_dim, + topk, + sm_scale=sm_scale, + block_I=block_I, + inner_iter=inner_iter, + threads=threads, + ) else: - # gfx942 - block_I, threads = 32, 128 - block_per_cu = 1 - - NI = topk // block_I - CU = 304 - - def _inner_iter(seq: int) -> int: - """Largest inner_iter ≤ NI that keeps grid/CU ≄ block_per_cu.""" - max_it = int(seq * NI / (CU * block_per_cu)) - it = NI - while it >= 2: - if it <= max_it and NI % it == 0: - return it - it //= 2 - return 1 - - inner_iter = _inner_iter(q.shape[0]) - n_groups = NI // inner_iter - - kernel_partial = sparse_mla_fwd_decode_partial( - num_heads, - d_v, - tail_dim, - topk, - sm_scale=sm_scale, - block_I=block_I, - inner_iter=inner_iter, - num_stages=1, - threads=threads, + if _is_gfx95_supported: + block_I, threads, block_per_cu, cu = 64, 256, 2, 256 + else: + block_I, threads, block_per_cu, cu = 32, 128, 1, 304 + ni = topk // block_I + inner_iter = _pick_inner_iter(q.shape[0], ni, cu, block_per_cu) + kernel_partial = sparse_mla_fwd_decode_partial( + num_heads, + d_v, + tail_dim, + topk, + sm_scale=sm_scale, + block_I=block_I, + inner_iter=inner_iter, + threads=threads, + ) + partial_o_batched, partial_lse_batched = kernel_partial( + q.unsqueeze(0), kv.unsqueeze(0), indices.unsqueeze(0) ) + n_groups = ni // inner_iter kernel_combine = sparse_mla_fwd_decode_combine( num_heads, d_v, @@ -1100,10 +1368,7 @@ def _inner_iter(seq: int) -> int: block_I=block_I, threads=threads, ) - partial_o, partial_lse = kernel_partial( - q.unsqueeze(0), kv.unsqueeze(0), indices.unsqueeze(0) - ) - out = kernel_combine(partial_o, partial_lse) + out = kernel_combine(partial_o_batched, partial_lse_batched) else: kernel = sparse_attention_fwd_kernel_v2( num_heads, d_v, tail_dim, topk, sm_scale=sm_scale diff --git a/python/sglang/srt/layers/attention/nsa_backend.py b/python/sglang/srt/layers/attention/nsa_backend.py index dea5c5348d18..862488e5f918 100644 --- a/python/sglang/srt/layers/attention/nsa_backend.py +++ b/python/sglang/srt/layers/attention/nsa_backend.py @@ -1297,6 +1297,7 @@ def forward_extend( cos_sin_cache, is_neox, llama_4_scaling, + is_prefill=True, ) if k is not None: @@ -1790,6 +1791,7 @@ def _forward_standard_mha( enable_pdl=False, is_causal=causal, return_lse=False, + skip_softmax_threshold_scale_factor=envs.SGLANG_SKIP_SOFTMAX_PREFILL_THRESHOLD_SCALE_FACTOR.get(), ) # Use FA3 for SM90 (Hopper/H200) @@ -1928,6 +1930,7 @@ def _forward_trtllm( cos_sin_cache: Optional[torch.Tensor] = None, is_neox: Optional[bool] = False, llama_4_scaling: Optional[torch.Tensor] = None, + is_prefill: bool = False, ) -> torch.Tensor: """Forward using TRT-LLM sparse MLA kernel.""" import flashinfer.decode @@ -1989,6 +1992,13 @@ def _forward_trtllm( if envs.SGLANG_NSA_FUSE_TOPK.get(): page_table_1 = topk_indices + elif is_prefill: + page_table_1 = transform_index_page_table_prefill( + page_table=metadata.page_table_1, + topk_indices=topk_indices, + extend_lens_cpu=metadata.nsa_extend_seq_lens_list, + page_size=1, + ) else: page_table_1 = transform_index_page_table_decode( page_table=metadata.page_table_1, @@ -2025,6 +2035,7 @@ def _forward_trtllm( sparse_mla_top_k=self.nsa_index_topk, bmm1_scale=bmm1_scale, backend="trtllm-gen", + skip_softmax_threshold_scale_factor=envs.SGLANG_SKIP_SOFTMAX_DECODE_THRESHOLD_SCALE_FACTOR.get(), ) # Output: [batch, q_len=1, heads, v_dim] -> [batch, heads, v_dim] return out.squeeze(1) diff --git a/python/sglang/srt/layers/attention/trtllm_mha_backend.py b/python/sglang/srt/layers/attention/trtllm_mha_backend.py index 09f3f409a1fe..205b5fadf8ce 100644 --- a/python/sglang/srt/layers/attention/trtllm_mha_backend.py +++ b/python/sglang/srt/layers/attention/trtllm_mha_backend.py @@ -773,6 +773,7 @@ def forward_decode( bmm2_scale=bmm2_scale, window_left=layer.sliding_window_size, sinks=attention_sink, + skip_softmax_threshold_scale_factor=envs.SGLANG_SKIP_SOFTMAX_DECODE_THRESHOLD_SCALE_FACTOR.get(), out_dtype=self.q_data_type, # model_runner.dtype ) @@ -855,6 +856,7 @@ def forward_extend( bmm2_scale=bmm2_scale, window_left=layer.sliding_window_size, sinks=attention_sink, + skip_softmax_threshold_scale_factor=envs.SGLANG_SKIP_SOFTMAX_DECODE_THRESHOLD_SCALE_FACTOR.get(), out_dtype=self.q_data_type, # model_runner.dtype q_len_per_req=self.forward_metadata.max_seq_len_q, ) @@ -874,6 +876,7 @@ def forward_extend( cum_seq_lens_kv=self.forward_metadata.cu_seqlens_k, window_left=layer.sliding_window_size, sinks=attention_sink, + skip_softmax_threshold_scale_factor=envs.SGLANG_SKIP_SOFTMAX_PREFILL_THRESHOLD_SCALE_FACTOR.get(), out_dtype=self.q_data_type, # model_runner.dtype ) diff --git a/python/sglang/srt/layers/attention/trtllm_mla_backend.py b/python/sglang/srt/layers/attention/trtllm_mla_backend.py index b74124da25ea..54eb273ed33d 100755 --- a/python/sglang/srt/layers/attention/trtllm_mla_backend.py +++ b/python/sglang/srt/layers/attention/trtllm_mla_backend.py @@ -14,6 +14,7 @@ import triton.language as tl from sglang.srt.compilation.piecewise_context_manager import is_in_piecewise_cuda_graph +from sglang.srt.environ import envs from sglang.srt.layers.attention.flashinfer_mla_backend import ( FlashInferMLAAttnBackend, FlashInferMLAMultiStepDraftBackend, @@ -875,6 +876,7 @@ def forward_decode( seq_lens=forward_batch.seq_lens.to(torch.int32), max_seq_len=metadata.max_seq_len_k, bmm1_scale=bmm1_scale, + skip_softmax_threshold_scale_factor=envs.SGLANG_SKIP_SOFTMAX_DECODE_THRESHOLD_SCALE_FACTOR.get(), ) # Reshape output directly without slicing @@ -1062,6 +1064,7 @@ def forward_extend( seq_lens=metadata.seq_lens_k, max_seq_len=max_seq_len, bmm1_scale=bmm1_scale, + skip_softmax_threshold_scale_factor=envs.SGLANG_SKIP_SOFTMAX_DECODE_THRESHOLD_SCALE_FACTOR.get(), ) if needs_unpad: @@ -1099,6 +1102,7 @@ def forward_extend( "bmm1_scale": q_scale * k_scale * layer.scaling, "bmm2_scale": v_scale, "cum_seq_lens_q": self.forward_prefill_metadata.cum_seq_lens, + "skip_softmax_threshold_scale_factor": envs.SGLANG_SKIP_SOFTMAX_PREFILL_THRESHOLD_SCALE_FACTOR.get(), } # When chunked prefix cache is enabled, dispatch to different path for ragged attention. diff --git a/python/sglang/srt/layers/attention/utils.py b/python/sglang/srt/layers/attention/utils.py index 7cd278c82eab..e0774c9a407b 100644 --- a/python/sglang/srt/layers/attention/utils.py +++ b/python/sglang/srt/layers/attention/utils.py @@ -644,7 +644,7 @@ def launch_reshape_and_cache_flash( key_cache, value_cache, slot_mapping, - swa_slot_mapping if swa_slot_mapping is not None else key, + swa_slot_mapping, k_scale if k_scale is not None else key, v_scale if v_scale is not None else key, key_cache.stride(0), @@ -658,3 +658,736 @@ def launch_reshape_and_cache_flash( HAS_SWA=(swa_slot_mapping is not None), USE_SCALE=(k_scale is not None), ) + + +@triton.jit +def _get_gptj_rotated_x( + x, + x_rotated_mask, + BLOCK_D: tl.constexpr, + BLOCK_D_HALF: tl.constexpr, +): + # GPT-J rotary layout: + # Pair adjacent dimensions and apply: + # [x0, x1, x2, x3] -> [-x1, x0, -x3, x2] + + # Apply sign inversion on odd positions. + x_rotated = tl.where(x_rotated_mask, x, -x) + # Reshape into (D/2, 2) pairs. + x_rotated = tl.reshape(x_rotated, (BLOCK_D_HALF, 2)) + # Swap each pair. + x_rotated = tl.flip(x_rotated, 1) + # Flatten back to original shape. + x_rotated = tl.reshape(x_rotated, (BLOCK_D,)) + return x_rotated + + +@triton.jit +def _get_neox_rotated_x( + x, + x_rotated_mask, + BLOCK_D: tl.constexpr, + BLOCK_D_HALF: tl.constexpr, +): + # GPT-NeoX rotary layout: + # Split head dimension into two halves: + # [x0, x1, x2, x3] -> [-x2, -x3, x0, x1] + + # Keep first half positive, second half negative. + x_rotated = tl.where(x_rotated_mask, x, -x) + # Reshape into (2, D/2). + x_rotated = tl.reshape(x_rotated, (2, BLOCK_D_HALF)) + # Reverse each half. + x_rotated = tl.flip(x_rotated, 1) + # Flatten and reverse full vector. + x_rotated = tl.reshape(x_rotated, (BLOCK_D,)) + x_rotated = tl.flip(x_rotated, 0) + return x_rotated + + +@triton.jit +def _unit_rope( + x_ptrs, + cos, + sin, + d_pe_offs, + IS_NEOX: tl.constexpr, + BLOCK_D_pe: tl.constexpr, + BLOCK_D_HALF_pe: tl.constexpr, +): + # Load one full attention head vector. + x_pe = tl.load(x_ptrs) + + # Stage 1: Build rotated vector according to rotary layout. + if IS_NEOX: + x_rotated_mask = d_pe_offs < BLOCK_D_HALF_pe + x_pe_rotated = _get_neox_rotated_x( + x_pe, x_rotated_mask, BLOCK_D_pe, BLOCK_D_HALF_pe + ) + else: + x_rotated_mask = d_pe_offs % 2 == 0 + x_pe_rotated = _get_gptj_rotated_x( + x_pe, x_rotated_mask, BLOCK_D_pe, BLOCK_D_HALF_pe + ) + + # Stage 2: Apply RoPE transform: + # x' = x*cos + rotate(x)*sin + x_pe = x_pe * cos + x_pe_rotated * sin + + return x_pe + + +@triton.jit +def _load_cos_sin( + cos_sin_ptr, + pos, + d_cos_offs, + stride_t, + stride_d, + freq_dim, +): + base = pos * stride_t + cos = tl.load(cos_sin_ptr + base + d_cos_offs * stride_d) + sin = tl.load(cos_sin_ptr + base + (d_cos_offs + freq_dim) * stride_d) + return cos, sin + + +@triton.jit +def _fused_qk_rope_reshape_and_cache_kernel( + q_ptr, + k_ptr, + v_ptr, + pos_ptr, + cos_sin_ptr, + offs_ptr, + key_cache_ptr, + value_cache_ptr, + slot_mapping_ptr, + swa_slot_mapping_ptr, + q_out_ptr, + k_out_ptr, + zeros_out_ptr, + T, + T_slot, + q_stride_t, + q_stride_h, + q_stride_d, + k_stride_t, + k_stride_h, + k_stride_d, + v_stride_t, + v_stride_h, + v_stride_d, + cos_sin_stride_t, + cos_sin_stride_d, + q_out_stride_t, + q_out_stride_h, + q_out_stride_d, + k_out_stride_t, + k_out_stride_h, + k_out_stride_d, + key_cache_stride_t, + key_cache_stride_h, + key_cache_stride_d, + key_cache_stride_b, + key_cache_stride_x, + value_cache_stride_t, + value_cache_stride_h, + value_cache_stride_d, + value_cache_stride_b, + value_cache_stride_slot_chunk, + value_cache_stride_x, + zeros_out_stride_t, + zeros_out_stride_h, + zeros_out_stride_d, + k_scale_ptr, + v_scale_ptr, + QH_PER_KH: tl.constexpr, + QH: tl.constexpr, + KH: tl.constexpr, + REUSE_FREQS_FRONT_PART: tl.constexpr, + IS_NEOX: tl.constexpr, + BLOCK_D_pe: tl.constexpr, + BLOCK_D_HALF_pe: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + X_SIZE: tl.constexpr, + FLASH_LAYOUT: tl.constexpr, + VALUE_SHUFFLE_LAYOUT: tl.constexpr = False, + HAVE_POS: tl.constexpr = False, + HAVE_K_SCALE: tl.constexpr = False, + HAVE_V_SCALE: tl.constexpr = False, + HAVE_ZEROS: tl.constexpr = False, + HAS_SWA: tl.constexpr = False, +): + # ============================================================ + # Stage 0: Static stride assumptions for Triton compiler + # + # These assumptions help Triton optimize pointer arithmetic and + # simplify generated address calculations. + # ============================================================ + + tl.assume(q_stride_t >= 0) + tl.assume(q_stride_h >= 0) + tl.assume(q_stride_d >= 0) + tl.assume(k_stride_t >= 0) + tl.assume(k_stride_h >= 0) + tl.assume(k_stride_d >= 0) + tl.assume(v_stride_t >= 0) + tl.assume(v_stride_h >= 0) + tl.assume(v_stride_d >= 0) + tl.assume(cos_sin_stride_t >= 0) + tl.assume(cos_sin_stride_d >= 0) + tl.assume(q_out_stride_t >= 0) + tl.assume(q_out_stride_h >= 0) + tl.assume(q_out_stride_d >= 0) + tl.assume(k_out_stride_t >= 0) + tl.assume(k_out_stride_h >= 0) + tl.assume(k_out_stride_d >= 0) + tl.assume(key_cache_stride_t >= 0) + tl.assume(key_cache_stride_h >= 0) + tl.assume(key_cache_stride_d >= 0) + tl.assume(key_cache_stride_b >= 0) + tl.assume(key_cache_stride_x >= 0) + tl.assume(value_cache_stride_t >= 0) + tl.assume(value_cache_stride_h >= 0) + tl.assume(value_cache_stride_d >= 0) + tl.assume(value_cache_stride_b >= 0) + tl.assume(value_cache_stride_slot_chunk >= 0) + tl.assume(value_cache_stride_x >= 0) + tl.assume(zeros_out_stride_t >= 0) + tl.assume(zeros_out_stride_h >= 0) + tl.assume(zeros_out_stride_d >= 0) + + # ============================================================ + # Stage 1: Program instance mapping + # + # Each program handles: + # - one (token, q_head) for Q path + # - selected KV ownership for cache write path + # + # pid layout: + # [0, T*QH) -> decode Q path + # [T*QH, extra KV) -> KV-only path + # ============================================================ + + pid = tl.program_id(0) + tl.assume(pid >= 0) + + d_pe_offs = tl.arange(0, BLOCK_D_pe).to(tl.int64) + + # ============================================================ + # Stage 2: Main decode path (Q always active) + # ============================================================ + + if pid < T * QH: + pid_t = pid // QH + pid_hq = pid % QH + + # -------------------------------------------------------- + # Stage 2.1: Compute rotary frequency offsets + # + # RoPE frequencies may be stored as: + # D/2 frequencies (shared front-half) + # D frequencies (full explicit) + # -------------------------------------------------------- + + if REUSE_FREQS_FRONT_PART: + if IS_NEOX: + d_cos_offs = d_pe_offs + d_cos_offs = tl.where( + (d_cos_offs >= BLOCK_D_HALF_pe) & (d_cos_offs < BLOCK_D_pe), + d_cos_offs - BLOCK_D_HALF_pe, + d_cos_offs, + ).to(d_cos_offs.dtype) + # d_cos_mask = d_cos_offs < BLOCK_D_pe + else: + d_cos_offs = d_pe_offs // 2 + # d_cos_mask = d_cos_offs < BLOCK_D_HALF_pe + else: + d_cos_offs = d_pe_offs + # d_cos_mask = d_cos_offs < BLOCK_D_pe + + # -------------------------------------------------------- + # Stage 2.2: Load token position and optional offset + # + # offs_ptr is used by chunked prefill / sliding-window decode. + # -------------------------------------------------------- + pos = tl.load(pos_ptr + pid_t) + if HAVE_POS: + offset = tl.load(offs_ptr + pid_t) + pos = pos + offset + + # -------------------------------------------------------- + # Stage 2.3: Load cosine / sine table + # -------------------------------------------------------- + # cos_offs = pos * cos_stride_t + d_cos_offs * cos_stride_d + # cos = tl.load(cos_ptr + cos_offs) + # sin = tl.load(sin_ptr + cos_offs) + + freq_dim = BLOCK_D_HALF_pe if REUSE_FREQS_FRONT_PART else BLOCK_D_pe + + cos, sin = _load_cos_sin( + cos_sin_ptr, + pos, + d_cos_offs, + cos_sin_stride_t, + cos_sin_stride_d, + freq_dim, + ) + + # -------------------------------------------------------- + # Stage 2.4: Apply RoPE to Q + # -------------------------------------------------------- + q_ptrs = ( + q_ptr + pid_t * q_stride_t + pid_hq * q_stride_h + d_pe_offs * q_stride_d + ) + q_pe = _unit_rope( + q_ptrs, + cos, + sin, + d_pe_offs, + IS_NEOX, + BLOCK_D_pe, + BLOCK_D_HALF_pe, + ) + + # Store rotated Q output. + q_out_ptrs = ( + q_out_ptr + + pid_t * q_out_stride_t + + pid_hq * q_out_stride_h + + d_pe_offs * q_out_stride_d + ) + tl.store(q_out_ptrs, q_pe.to(q_out_ptr.dtype.element_ty)) + + if HAVE_ZEROS: + z = tl.zeros((BLOCK_D_pe,), dtype=zeros_out_ptr.dtype.element_ty) + zeros_out_ptrs = ( + zeros_out_ptr + + pid_t * zeros_out_stride_t + + pid_hq * zeros_out_stride_h + + d_pe_offs * zeros_out_stride_d + ) + tl.store(zeros_out_ptrs, z) + + # ======================================================== + # Stage 3: KV ownership path + # + # Only one Q group leader writes KV: + # pid_hq % QH_PER_KH == 0 + # + # This prevents duplicated KV cache writes. + # ======================================================== + + if pid_hq % QH_PER_KH == 0: + # ---------------------------------------------------- + # Stage 3.1: Resolve cache slot + # ---------------------------------------------------- + pid_slot = tl.load(slot_mapping_ptr + pid_t).to(tl.int64) + if HAS_SWA: + pid_slot = tl.load(swa_slot_mapping_ptr + pid_slot) + + # ------------------------------------------------ + # Stage 3.2: Apply RoPE to K + # ------------------------------------------------ + if pid_slot >= 0: + pid_t_slot = pid_slot // BLOCK_SIZE + pid_b = pid_slot % BLOCK_SIZE + pid_hk = pid_hq // QH_PER_KH + if HAVE_K_SCALE: + k_scale = tl.load(k_scale_ptr) + else: + k_scale = 1 + k_ptrs = ( + k_ptr + + pid_t * k_stride_t + + pid_hk * k_stride_h + + d_pe_offs * k_stride_d + ) + k_pe = _unit_rope( + k_ptrs, + cos, + sin, + d_pe_offs, + IS_NEOX, + BLOCK_D_pe, + BLOCK_D_HALF_pe, + ) + + k_out_ptrs = ( + k_out_ptr + + pid_t * k_out_stride_t + + pid_hk * k_out_stride_h + + d_pe_offs * k_out_stride_d + ) + tl.store(k_out_ptrs, k_pe.to(k_out_ptr.dtype.element_ty)) + + # ------------------------------------------------ + # Stage 3.3: Optional fp8 scaling before cache + # ------------------------------------------------ + + k_scale_rcprl = 1 / k_scale + k_pe = k_pe * k_scale_rcprl + + # ------------------------------------------------ + # Stage 3.4: Write K cache + # + # Two layouts supported: + # FLASH_LAYOUT + # paged KV layout + # ------------------------------------------------ + + if FLASH_LAYOUT: + k_out_ptrs = ( + key_cache_ptr + + pid_t_slot * key_cache_stride_t + + pid_b * key_cache_stride_b + + pid_hk * key_cache_stride_h + + d_pe_offs * key_cache_stride_d + ) + else: + k_pe = tl.reshape(k_pe, (BLOCK_D_pe // X_SIZE, X_SIZE)) + dx_offs = tl.arange(0, BLOCK_D_pe // X_SIZE).to(tl.int64) + x_offs = tl.arange(0, X_SIZE).to(tl.int64) + k_out_ptrs = ( + key_cache_ptr + + pid_t_slot * key_cache_stride_t + + pid_hk * key_cache_stride_h + + dx_offs[:, None] * key_cache_stride_d + + pid_b * key_cache_stride_b + + x_offs[None, :] * key_cache_stride_x + ) + + tl.store(k_out_ptrs, k_pe.to(key_cache_ptr.dtype.element_ty)) + + # ------------------------------------------------ + # Stage 3.5: Write V cache + # + # Supports: + # normal layout + # shuffle layout + # ------------------------------------------------ + + v_ptrs = ( + v_ptr + + pid_t * v_stride_t + + pid_hk * v_stride_h + + d_pe_offs * v_stride_d + ) + if HAVE_V_SCALE: + v_scale = tl.load(v_scale_ptr) + else: + v_scale = 1 + v_scale_rcprl = 1 / v_scale + v = tl.load(v_ptrs) * v_scale_rcprl + if VALUE_SHUFFLE_LAYOUT: + slot_chunk = pid_b // X_SIZE + x_off = pid_b % X_SIZE + v_out_ptrs = ( + value_cache_ptr + + pid_t_slot * value_cache_stride_t + + pid_hk * value_cache_stride_h + + slot_chunk * value_cache_stride_slot_chunk + + d_pe_offs.to(tl.int64) * value_cache_stride_d + + x_off * value_cache_stride_x + ) + else: + v_out_ptrs = ( + value_cache_ptr + + pid_t_slot * value_cache_stride_t + + pid_hk * value_cache_stride_h + + d_pe_offs.to(tl.int64) * value_cache_stride_d + + pid_b * value_cache_stride_b + ) + tl.store(v_out_ptrs, v.to(value_cache_ptr.dtype.element_ty)) + # ============================================================ + # Stage 4: Extra KV-only path + # + # Handles tokens that only require cache update: + # T_slot > T + # + # No Q / no RoPE on Q branch. + # ============================================================ + else: + pid = pid - T * QH + T * KH + if pid < T_slot * KH: + pid_t = pid // KH + pid_hk = pid % KH + pid_slot = tl.load(slot_mapping_ptr + pid_t).to(tl.int64) + if HAS_SWA: + pid_slot = tl.load(swa_slot_mapping_ptr + pid_slot) + + if pid_slot >= 0: + pid_t_slot = pid_slot // BLOCK_SIZE + pid_b = pid_slot % BLOCK_SIZE + if HAVE_K_SCALE: + k_scale = tl.load(k_scale_ptr) + else: + k_scale = 1 + k_ptrs = ( + k_ptr + + pid_t * k_stride_t + + pid_hk * k_stride_h + + d_pe_offs * k_stride_d + ) + + k_pe = tl.load(k_ptrs) + + k_out_ptrs = ( + k_out_ptr + + pid_t * k_out_stride_t + + pid_hk * k_out_stride_h + + d_pe_offs * k_out_stride_d + ) + tl.store(k_out_ptrs, k_pe.to(k_out_ptr.dtype.element_ty)) + + k_scale_rcprl = 1 / k_scale + k_pe = k_pe * k_scale_rcprl + + if FLASH_LAYOUT: + k_out_ptrs = ( + key_cache_ptr + + pid_t_slot * key_cache_stride_t + + d_pe_offs * key_cache_stride_d + + pid_b * key_cache_stride_b + + pid_hk * key_cache_stride_h + ) + else: + k_pe = tl.reshape(k_pe, (BLOCK_D_pe // X_SIZE, X_SIZE)) + dx_offs = tl.arange(0, BLOCK_D_pe // X_SIZE).to(tl.int64) + x_offs = tl.arange(0, X_SIZE).to(tl.int64) + k_out_ptrs = ( + key_cache_ptr + + pid_t_slot * key_cache_stride_t + + pid_hk * key_cache_stride_h + + dx_offs[:, None] * key_cache_stride_d + + pid_b * key_cache_stride_b + + x_offs[None, :] * key_cache_stride_x + ) + tl.store(k_out_ptrs, k_pe.to(key_cache_ptr.dtype.element_ty)) + + v_ptrs = ( + v_ptr + + pid_t * v_stride_t + + pid_hk * v_stride_h + + d_pe_offs * v_stride_d + ) + if HAVE_V_SCALE: + v_scale = tl.load(v_scale_ptr) + else: + v_scale = 1 + v_scale_rcprl = 1 / v_scale + v = tl.load(v_ptrs) * v_scale_rcprl + if VALUE_SHUFFLE_LAYOUT: + slot_chunk = pid_b // X_SIZE + x_off = pid_b % X_SIZE + v_out_ptrs = ( + value_cache_ptr + + pid_t_slot * value_cache_stride_t + + pid_hk * value_cache_stride_h + + slot_chunk * value_cache_stride_slot_chunk + + d_pe_offs * value_cache_stride_d + + x_off * value_cache_stride_x + ) + else: + v_out_ptrs = ( + value_cache_ptr + + pid_t_slot * value_cache_stride_t + + pid_hk * value_cache_stride_h + + d_pe_offs * value_cache_stride_d + + pid_b * value_cache_stride_b + ) + tl.store(v_out_ptrs, v.to(value_cache_ptr.dtype.element_ty)) + + +def fused_qk_rope_reshape_and_cache( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slot_mapping: torch.Tensor, + pos: torch.Tensor, + cos_sin: torch.Tensor, + k_scale: torch.Tensor, + v_scale: torch.Tensor, + is_neox: bool, + flash_layout: bool, + apply_scale: bool = True, + offs: torch.Tensor = None, + q_out: torch.Tensor = None, + k_out: torch.Tensor = None, + output_zeros: bool = True, + zeros_out: torch.Tensor = None, + swa_slot_mapping=None, +): + """ + Perform RoPE on q and k and along the last dimension and copy k and v in to key_cache and value_cache inplace + + Key parameters: + - q: shape (T, QH, D). + - k: shape (T_slot, KH, D). + - v: shape (T_slot, KH, D). + - if flash_layout: + - key_cache: shape (T_cache, block_size, KH, D). + - value_cache: shape (T_cache, block_size, KH, D). + - else: + - key_cache: shape (T_cache, KH, D // x, block_size, x). + - value_cache: shape (T_cache, KH, D, block_size). + - slot_mapping: shape (T_slot, ). + + T is the number of decode tokens, T_cahce * block_size is the max number of tokens of kv_cache + QH must be multiple of KH + + Returns: + - q_out: same shape as input q. + - k_out: same shape as input k. + - key_cache: same shape as input key_cache (inplace). + - value_cache: same shape as input value_cache (inplace). + - zeros_out: same shape as input q. + """ + + t, qh, d = q.shape + tk, kh, dk = k.shape + tv, vh, dv = v.shape + if flash_layout: + t_cache, block_size, kh_cache, dk_cache = key_cache.shape + t_cache_v, block_size_v, vh_cache, dv_cache = value_cache.shape + value_shuffle_layout = False + else: + t_cache, kh_cache, dkx_cache, block_size, x_cache = key_cache.shape + if value_cache.ndim == 5: + # value_cache shuffle: (num_blocks, num_kv_heads, block_size // x, head_size, x) + t_cache_v, vh_cache, slot_chunk_v, dv_cache, x_v = value_cache.shape + value_shuffle_layout = True + block_size_v = slot_chunk_v * x_v + assert block_size_v == block_size and x_v == x_cache, ( + f"value_cache shuffle (T,KH,block_size//x,D,x) must match key: " + f"{block_size_v=} {block_size=} {x_v=} {x_cache=}" + ) + else: + t_cache_v, vh_cache, dv_cache, block_size_v = value_cache.shape + value_shuffle_layout = False + (t_slot,) = slot_mapping.shape + + assert ( + t == tk == tv and t_slot <= tk + ), f"Number of tokens should be identical for q, kand v. The number of tokens of slot_mapping should no more than that of q, k and v, {t=} {tk=} {tv=} {t_slot=}" + assert ( + block_size == block_size_v + ), f"block size should be identical for key_cache, and value_cache {block_size} {block_size_v}" + assert ( + kh == vh == kh_cache == vh_cache + ), "KV head should be identical for k, v, key_cache, and value_cache" + assert ( + t_cache == t_cache_v + ), "Number of tokens should be identical for key_cache, and value_cache" + if flash_layout: + assert ( + d == dk == dv == dk_cache == dv_cache + ), "D dimension should be identical for q, k, and v" + else: + assert ( + d == dk == dv == dkx_cache * x_cache == dv_cache + ), "D dimension should be identical for q, k, and v" + assert x_cache == triton.next_power_of_2(x_cache), "x_size should be power of 2" + + assert d == triton.next_power_of_2(d), "D dimension should be power of 2" + assert block_size == triton.next_power_of_2( + block_size + ), "block_size should be power of 2" + assert qh % kh == 0, "Q heads must be multiple of H heads" + d_freq = cos_sin.shape[-1] // 2 + assert (d_freq == d // 2) or ( + d_freq == d + ), "cos/sin last dim should be the same or half of the qk last dim" + reuse_freqs_front_part = d_freq == d // 2 + + if q_out is None: + q_out = torch.empty((t, qh, d), dtype=q.dtype, device=q.device) + + if k_out is None: + k_out = torch.empty((tk, kh, dk), dtype=k.dtype, device=q.device) + + if zeros_out is not None: + tz, qhz, dz = zeros_out.shape + assert ( + t == tz and qh == qhz and d == dz + ), f"q and zeros shape mismatch {q.shape=} {zeros_out.shape=}" + output_zeros = True + elif output_zeros: + zeros_out = torch.empty((t, qh, d), dtype=q.dtype, device=q.device) + else: + zeros_out = None + + n_pid = t * qh + (t_slot - t) * kh if t_slot >= t else t * qh + grid = (n_pid, 1, 1) + _fused_qk_rope_reshape_and_cache_kernel[grid]( + q, + k, + v, + pos, + cos_sin, + offs, + key_cache, + value_cache, + slot_mapping, + swa_slot_mapping, + q_out, + k_out, + zeros_out, + t, + t_slot, + *q.stride(), + *k.stride(), + *v.stride(), + cos_sin.stride(0), + cos_sin.stride(-1), + *q_out.stride(), + *k_out.stride(), + key_cache.stride(0) if not flash_layout else key_cache.stride(0), + key_cache.stride(1) if not flash_layout else key_cache.stride(2), + key_cache.stride(2) if not flash_layout else key_cache.stride(3), + key_cache.stride(3) if not flash_layout else key_cache.stride(1), + key_cache.stride(4) if not flash_layout else 0, + value_cache.stride(0) if not flash_layout else value_cache.stride(0), + value_cache.stride(1) if not flash_layout else value_cache.stride(2), + ( + value_cache.stride(3) + if (not flash_layout and value_shuffle_layout) + else (value_cache.stride(2) if not flash_layout else value_cache.stride(3)) + ), + ( + 0 + if (not flash_layout and value_shuffle_layout) + else (value_cache.stride(3) if not flash_layout else value_cache.stride(1)) + ), + value_cache.stride(2) if (not flash_layout and value_shuffle_layout) else 0, + value_cache.stride(4) if (not flash_layout and value_shuffle_layout) else 0, + zeros_out.stride(0) if zeros_out is not None else 0, + zeros_out.stride(1) if zeros_out is not None else 0, + zeros_out.stride(2) if zeros_out is not None else 0, + k_scale_ptr=k_scale, + v_scale_ptr=v_scale, + QH_PER_KH=qh // kh, + QH=qh, + KH=kh, + REUSE_FREQS_FRONT_PART=reuse_freqs_front_part, + IS_NEOX=is_neox, + BLOCK_D_pe=d, + BLOCK_D_HALF_pe=d // 2, + BLOCK_SIZE=block_size, + X_SIZE=x_cache if not flash_layout else 0, + FLASH_LAYOUT=flash_layout, + VALUE_SHUFFLE_LAYOUT=value_shuffle_layout, + HAVE_POS=(offs is not None), + HAVE_K_SCALE=(k_scale is not None and apply_scale), + HAVE_V_SCALE=(v_scale is not None and apply_scale), + HAVE_ZEROS=output_zeros, + HAS_SWA=(swa_slot_mapping is not None), + num_warps=1, + ) + + if zeros_out is not None: + return q_out.view(-1, qh * d), k_out, key_cache, value_cache, zeros_out + return q_out.view(-1, qh * d), k_out, key_cache, value_cache diff --git a/python/sglang/srt/layers/attention/vision.py b/python/sglang/srt/layers/attention/vision.py index 087e76baf934..4c3f9e2f2721 100644 --- a/python/sglang/srt/layers/attention/vision.py +++ b/python/sglang/srt/layers/attention/vision.py @@ -888,7 +888,9 @@ def _determine_attention_backend(self, passed_backend: Optional[str]) -> str: Priority: server args override > constructor arg > platform default. Platform defaults: - - CUDA: "triton_attn" + - CUDA (Hopper SM90): "fa3" + - CUDA (Blackwell SM100): "fa4" + - CUDA (other): "triton_attn" - Non-CUDA: "sdpa" """ override_backend = get_global_server_args().mm_attention_backend @@ -900,6 +902,8 @@ def _determine_attention_backend(self, passed_backend: Optional[str]) -> str: major, minor = get_device_capability() if major == 9: backend = "fa3" + elif major == 10: + backend = "fa4" else: backend = "triton_attn" elif _is_hip: diff --git a/python/sglang/srt/layers/fused_sampling.py b/python/sglang/srt/layers/fused_sampling.py new file mode 100644 index 000000000000..55852c2f0f34 --- /dev/null +++ b/python/sglang/srt/layers/fused_sampling.py @@ -0,0 +1,371 @@ +"""Fused Triton kernels for the sampling pipeline. + +Fuses temperature scaling + softmax into a single kernel to reduce +kernel launch overhead and global memory traffic during decode. + +Two kernel variants: + - Single-pass: vocab fits in one tile (1 read + 1 write). Used when + next_power_of_2(vocab) <= 32768. + - Multi-pass: 2-pass online softmax with autotune (2 reads + 1 write). + Used for large vocabs (e.g. 128K+). +""" + +import logging + +import torch +import triton +import triton.language as tl + +logger = logging.getLogger(__name__) + +_MAX_SINGLE_PASS_BLOCK = 32768 + +# --------------------------------------------------------------------------- +# Single-pass kernel: entire vocab fits in one BLOCK_SIZE tile. +# Data stays in registers — only 1 global memory read + 1 write. +# --------------------------------------------------------------------------- + + +@triton.jit +def _single_pass_temperature_softmax_kernel( + logits_ptr, + temperatures_ptr, + output_ptr, + vocab_size, + logits_stride, + output_stride, + BLOCK_SIZE: tl.constexpr, +): + row_idx = tl.program_id(0) + temp = tl.load(temperatures_ptr + row_idx) + + offsets = tl.arange(0, BLOCK_SIZE) + mask = offsets < vocab_size + + x = tl.load( + logits_ptr + row_idx * logits_stride + offsets, + mask=mask, + other=float("-inf"), + ) + x = (x / temp).to(tl.float32) + + x_max = tl.max(x, axis=0) + exp_x = tl.exp(x - x_max) + prob = exp_x / tl.sum(exp_x, axis=0) + + tl.store(output_ptr + row_idx * output_stride + offsets, prob, mask=mask) + + +@triton.jit +def _single_pass_temperature_softmax_inplace_kernel( + logits_ptr, + temperatures_ptr, + vocab_size, + stride, + BLOCK_SIZE: tl.constexpr, +): + row_idx = tl.program_id(0) + temp = tl.load(temperatures_ptr + row_idx) + + row_start = logits_ptr + row_idx * stride + offsets = tl.arange(0, BLOCK_SIZE) + mask = offsets < vocab_size + + x = tl.load(row_start + offsets, mask=mask, other=float("-inf")) + x = (x / temp).to(tl.float32) + + x_max = tl.max(x, axis=0) + exp_x = tl.exp(x - x_max) + prob = exp_x / tl.sum(exp_x, axis=0) + + tl.store(row_start + offsets, prob, mask=mask) + + +# --------------------------------------------------------------------------- +# Multi-pass kernel: vocab too large for one tile. +# 2-pass online softmax with autotune over (BLOCK_SIZE, num_warps). +# --------------------------------------------------------------------------- + +_MULTI_PASS_AUTOTUNE_CONFIGS = [ + triton.Config({"BLOCK_SIZE": 2048}, num_warps=8), + triton.Config({"BLOCK_SIZE": 2048}, num_warps=16), + triton.Config({"BLOCK_SIZE": 4096}, num_warps=8), + triton.Config({"BLOCK_SIZE": 4096}, num_warps=16), + triton.Config({"BLOCK_SIZE": 4096}, num_warps=16, num_stages=4), + triton.Config({"BLOCK_SIZE": 8192}, num_warps=16), + triton.Config({"BLOCK_SIZE": 8192}, num_warps=32), + triton.Config({"BLOCK_SIZE": 8192}, num_warps=32, num_stages=4), + triton.Config({"BLOCK_SIZE": 16384}, num_warps=16), + triton.Config({"BLOCK_SIZE": 16384}, num_warps=32), + triton.Config({"BLOCK_SIZE": 16384}, num_warps=32, num_stages=4), + triton.Config({"BLOCK_SIZE": 32768}, num_warps=32), + triton.Config({"BLOCK_SIZE": 32768}, num_warps=32, num_stages=4), +] + + +@triton.autotune(configs=_MULTI_PASS_AUTOTUNE_CONFIGS, key=["vocab_size"]) +@triton.jit +def _multi_pass_temperature_softmax_kernel( + logits_ptr, + temperatures_ptr, + output_ptr, + vocab_size, + logits_stride, + output_stride, + BLOCK_SIZE: tl.constexpr, +): + row_idx = tl.program_id(0) + temp = tl.load(temperatures_ptr + row_idx) + + logits_row = logits_ptr + row_idx * logits_stride + output_row = output_ptr + row_idx * output_stride + + # Pass 1: find global max (matches PyTorch's first reduction pass) + global_max = tl.full([], value=float("-inf"), dtype=tl.float32) + for start in range(0, vocab_size, BLOCK_SIZE): + offsets = start + tl.arange(0, BLOCK_SIZE) + mask = offsets < vocab_size + x = tl.load(logits_row + offsets, mask=mask, other=float("-inf")) + x = (x / temp).to(tl.float32) + global_max = tl.maximum(global_max, tl.max(x, axis=0)) + + # Pass 2: compute sum of exp(x - max) (matches PyTorch's second pass) + sum_exp = tl.full([], value=0.0, dtype=tl.float32) + for start in range(0, vocab_size, BLOCK_SIZE): + offsets = start + tl.arange(0, BLOCK_SIZE) + mask = offsets < vocab_size + x = tl.load(logits_row + offsets, mask=mask, other=float("-inf")) + x = (x / temp).to(tl.float32) + sum_exp += tl.sum(tl.exp(x - global_max), axis=0) + + # Pass 3: normalize (matches PyTorch's exp(x-max)/sum) + for start in range(0, vocab_size, BLOCK_SIZE): + offsets = start + tl.arange(0, BLOCK_SIZE) + mask = offsets < vocab_size + x = tl.load(logits_row + offsets, mask=mask, other=float("-inf")) + x = (x / temp).to(tl.float32) + + prob = tl.exp(x - global_max) / sum_exp + tl.store(output_row + offsets, prob, mask=mask) + + +@triton.jit +def _multi_pass_temperature_softmax_inplace_kernel( + logits_ptr, + temperatures_ptr, + vocab_size, + stride, + BLOCK_SIZE: tl.constexpr, +): + row_idx = tl.program_id(0) + temp = tl.load(temperatures_ptr + row_idx) + + row_start = logits_ptr + row_idx * stride + + # Pass 1: find global max (matches PyTorch's first reduction pass) + global_max = tl.full([], value=float("-inf"), dtype=tl.float32) + for start in range(0, vocab_size, BLOCK_SIZE): + offsets = start + tl.arange(0, BLOCK_SIZE) + mask = offsets < vocab_size + x = tl.load(row_start + offsets, mask=mask, other=float("-inf")) + x = (x / temp).to(tl.float32) + global_max = tl.maximum(global_max, tl.max(x, axis=0)) + + # Pass 2: compute sum of exp(x - max) (matches PyTorch's second pass) + sum_exp = tl.full([], value=0.0, dtype=tl.float32) + for start in range(0, vocab_size, BLOCK_SIZE): + offsets = start + tl.arange(0, BLOCK_SIZE) + mask = offsets < vocab_size + x = tl.load(row_start + offsets, mask=mask, other=float("-inf")) + x = (x / temp).to(tl.float32) + sum_exp += tl.sum(tl.exp(x - global_max), axis=0) + + # Pass 3: normalize (matches PyTorch's exp(x-max)/sum) + for start in range(0, vocab_size, BLOCK_SIZE): + offsets = start + tl.arange(0, BLOCK_SIZE) + mask = offsets < vocab_size + x = tl.load(row_start + offsets, mask=mask, other=float("-inf")) + x = (x / temp).to(tl.float32) + + prob = tl.exp(x - global_max) / sum_exp + tl.store(row_start + offsets, prob, mask=mask) + + +# --------------------------------------------------------------------------- +# Public API +# --------------------------------------------------------------------------- + +_DEFAULT_MULTI_PASS_CONFIG = {"BLOCK_SIZE": 4096, "num_warps": 16} + +# Populated by warmup from the out-of-place kernel's autotune result. +_multi_pass_inplace_config: dict | None = None + + +def _single_pass_num_warps(block_size: int) -> int: + return max(4, min(32, block_size // 256)) + + +def _get_multi_pass_inplace_config() -> dict: + """Return the launch config for the multi-pass in-place kernel.""" + if _multi_pass_inplace_config is not None: + return _multi_pass_inplace_config + return _DEFAULT_MULTI_PASS_CONFIG + + +def _dispatch_kernel( + logits: torch.Tensor, + temperatures_flat: torch.Tensor, + vocab_size: int, + batch_size: int, + output: torch.Tensor = None, +) -> None: + """Dispatch to single-pass or multi-pass kernel. output=None means in-place.""" + grid = (batch_size,) + block_size = triton.next_power_of_2(vocab_size) + inplace = output is None + + if block_size <= _MAX_SINGLE_PASS_BLOCK: + if inplace: + _single_pass_temperature_softmax_inplace_kernel[grid]( + logits, + temperatures_flat, + vocab_size, + logits.stride(0), + BLOCK_SIZE=block_size, + num_warps=_single_pass_num_warps(block_size), + ) + else: + _single_pass_temperature_softmax_kernel[grid]( + logits, + temperatures_flat, + output, + vocab_size, + logits.stride(0), + output.stride(0), + BLOCK_SIZE=block_size, + num_warps=_single_pass_num_warps(block_size), + ) + else: + if inplace: + cfg = _get_multi_pass_inplace_config() + _multi_pass_temperature_softmax_inplace_kernel[grid]( + logits, + temperatures_flat, + vocab_size, + logits.stride(0), + **cfg, + ) + else: + _multi_pass_temperature_softmax_kernel[grid]( + logits, + temperatures_flat, + output, + vocab_size, + logits.stride(0), + output.stride(0), + ) + + +def fused_temperature_softmax( + logits: torch.Tensor, + temperatures: torch.Tensor, +) -> torch.Tensor: + """Fused temperature scaling + softmax. Returns float32 probabilities.""" + batch_size, vocab_size = logits.shape + if batch_size == 0: + return torch.empty(0, vocab_size, dtype=torch.float32, device=logits.device) + + if not logits.is_contiguous(): + logits = logits.contiguous() + + output = torch.empty( + batch_size, vocab_size, dtype=torch.float32, device=logits.device + ) + temperatures_flat = temperatures.contiguous().view(-1) + _dispatch_kernel(logits, temperatures_flat, vocab_size, batch_size, output) + return output + + +def fused_temperature_softmax_inplace( + logits: torch.Tensor, + temperatures: torch.Tensor, +) -> None: + """In-place fused temperature scaling + softmax. Overwrites logits with probabilities.""" + batch_size, vocab_size = logits.shape + if batch_size == 0: + return + + if not logits.is_contiguous(): + work = logits.contiguous() + fused_temperature_softmax_inplace(work, temperatures) + logits.copy_(work) + return + + temperatures_flat = temperatures.contiguous().view(-1) + _dispatch_kernel(logits, temperatures_flat, vocab_size, batch_size) + + +def warmup_fused_temperature_softmax( + vocab_size: int, + device: torch.device | int | None = None, + logits_dtype: torch.dtype = torch.float32, +) -> None: + """Pre-compile and autotune kernels at startup so first request has no latency spike. + + For multi-pass kernels the out-of-place variant is autotuned (safe — separate + input/output buffers), and its winning config is reused for the in-place + variant so that no autotune ever runs on a live logits buffer. + + ``logits_dtype`` should match ``next_token_logits`` at inference (usually + ``model_config.dtype``) so Triton specializes the same way as in production. + """ + global _multi_pass_inplace_config + + if device is None: + device = torch.cuda.current_device() + + block_size = triton.next_power_of_2(vocab_size) + is_multi_pass = block_size > _MAX_SINGLE_PASS_BLOCK + label = "multi-pass autotune" if is_multi_pass else "single-pass JIT" + logger.info( + "Warming up fused_temperature_softmax (%s, vocab_size=%d, logits_dtype=%s) ...", + label, + vocab_size, + logits_dtype, + ) + + dummy_logits = torch.randn(1, vocab_size, dtype=logits_dtype, device=device) + dummy_temps = torch.ones(1, 1, dtype=torch.float32, device=device) + + # 1. Out-of-place kernel: autotune runs here (safe, separate buffers). + fused_temperature_softmax(dummy_logits, dummy_temps) + + # 2. Propagate best config to the in-place kernel (no autotune needed). + if is_multi_pass: + best = getattr(_multi_pass_temperature_softmax_kernel, "best_config", None) + if best is not None: + _multi_pass_inplace_config = { + "BLOCK_SIZE": best.kwargs["BLOCK_SIZE"], + "num_warps": best.num_warps, + } + if best.num_stages is not None: + _multi_pass_inplace_config["num_stages"] = best.num_stages + ns = _multi_pass_inplace_config.get("num_stages", "default") + logger.info( + "Multi-pass autotune result: BLOCK_SIZE=%d, num_warps=%d, num_stages=%s", + _multi_pass_inplace_config["BLOCK_SIZE"], + _multi_pass_inplace_config["num_warps"], + ns, + ) + else: + _multi_pass_inplace_config = None + logger.warning( + "Multi-pass fused softmax: autotune did not set best_config; " + "using default launch config for in-place kernel." + ) + + # 3. In-place kernel: JIT compile only (uses the config from step 2). + fused_temperature_softmax_inplace(dummy_logits.clone(), dummy_temps) + torch.cuda.synchronize(device) + + logger.info("fused_temperature_softmax warmup done (vocab_size=%d).", vocab_size) diff --git a/python/sglang/srt/layers/linear.py b/python/sglang/srt/layers/linear.py index 924ef64fd33b..ff959ada9a65 100644 --- a/python/sglang/srt/layers/linear.py +++ b/python/sglang/srt/layers/linear.py @@ -1257,7 +1257,7 @@ def weight_loader( output_dim, start_idx, shard_size ) - # Special case for for AQLM codebooks. + # Special case for AQLM codebooks. elif is_metadata: # metadata indicates fixed size concatenated along dim 0 shard_size = loaded_weight.shape[0] diff --git a/python/sglang/srt/layers/moe/ep_moe/layer.py b/python/sglang/srt/layers/moe/ep_moe/layer.py index ef2225eccd14..1359c8313905 100644 --- a/python/sglang/srt/layers/moe/ep_moe/layer.py +++ b/python/sglang/srt/layers/moe/ep_moe/layer.py @@ -795,24 +795,4 @@ def get_moe_impl_class(quant_config: Optional[QuantizationConfig]): if get_moe_a2a_backend().is_ascend_fuseep(): return NpuFuseEPMoE - if get_moe_runner_backend().is_flashinfer_trtllm(): - # NEW: Direct FP4 detection (bypasses EP requirements) - # Check for FP4 quantization with TRTLLM flag, regardless of EP - # FlashInferFP4MoE must be paired with ModelOptNvFp4FusedMoEMethod. - if quant_config is not None and quant_config.get_name() == "modelopt_fp4": - from sglang.srt.layers.moe.fused_moe_triton.layer import FlashInferFP4MoE - - return FlashInferFP4MoE - elif ( - quant_config is None - or quant_config.get_name() == "fp8" - or quant_config.get_name() == "mxfp8" - or quant_config.get_name() == "modelopt_fp8" - or quant_config.get_name() == "compressed_tensors" - ): - # FlashInferFusedMoE supports bf16, fp8, mxfp8 and compressed_tensors - return FusedMoE - - if get_moe_runner_backend().is_flashinfer_cutlass(): - return FusedMoE return FusedMoE diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/fused_marlin_moe.py b/python/sglang/srt/layers/moe/fused_moe_triton/fused_marlin_moe.py index 4410f07f327e..b1bb618ce5ce 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/fused_marlin_moe.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/fused_marlin_moe.py @@ -23,6 +23,13 @@ def get_scalar_type(num_bits: int, has_zp: bool): return scalar_types.uint4b8 if num_bits == 4 else scalar_types.uint8b128 +def _get_fp4_scalar_type(): + from sglang.srt.layers.quantization.utils import get_scalar_types + + _, scalar_types = get_scalar_types() + return scalar_types.float4_e2m1f + + @register_custom_op(out_shape="hidden_states") def fused_marlin_moe( hidden_states: torch.Tensor, @@ -46,6 +53,8 @@ def fused_marlin_moe( is_k_full: bool = True, inplace: bool = False, routed_scaling_factor: Optional[float] = None, + w1_global_scale: Optional[torch.Tensor] = None, + w2_global_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: """ This function computes a Mixture of Experts (MoE) layer using two sets of @@ -76,6 +85,13 @@ def fused_marlin_moe( """ from sglang.srt.layers.moe.fused_moe_triton import moe_align_block_size + # Detect FP4 Marlin mode (when global scales are provided) + _is_fp4_marlin = w1_global_scale is not None + if _is_fp4_marlin: + assert ( + w2_global_scale is not None + ), "Both w1_global_scale and w2_global_scale must be provided for FP4 Marlin mode" + assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch" assert hidden_states.shape[1] == w1.shape[1] * 16, "Hidden size mismatch w1" assert hidden_states.shape[1] == w2.shape[2] // ( @@ -85,12 +101,14 @@ def fused_marlin_moe( assert w1.is_contiguous(), "Expert weights1 must be contiguous" assert w2.is_contiguous(), "Expert weights2 must be contiguous" assert hidden_states.dtype in [torch.float16, torch.bfloat16] - assert ( - hidden_states.dtype == w1_scale.dtype - ), f"moe_wna16_marlin_gemm assumes hidden_states.dtype ({hidden_states.dtype}) == w1_scale.dtype ({w1_scale.dtype})" - assert ( - hidden_states.dtype == w2_scale.dtype - ), f"moe_wna16_marlin_gemm assumes hidden_states.dtype ({hidden_states.dtype}) == w2_scale.dtype ({w2_scale.dtype})" + # For FP4 Marlin, scales are in special float8_e4m3fn format (not input dtype) + if not _is_fp4_marlin: + assert ( + hidden_states.dtype == w1_scale.dtype + ), f"moe_wna16_marlin_gemm assumes hidden_states.dtype ({hidden_states.dtype}) == w1_scale.dtype ({w1_scale.dtype})" + assert ( + hidden_states.dtype == w2_scale.dtype + ), f"moe_wna16_marlin_gemm assumes hidden_states.dtype ({hidden_states.dtype}) == w2_scale.dtype ({w2_scale.dtype})" assert num_bits in [4, 8] M, K = hidden_states.shape @@ -121,8 +139,13 @@ def fused_marlin_moe( max_workspace_size, dtype=torch.int, device=device, requires_grad=False ) - scalar_type1 = get_scalar_type(num_bits, w1_zeros is not None) - scalar_type2 = get_scalar_type(num_bits, w2_zeros is not None) + # FP4 Marlin uses float4_e2m1f scalar type (not uint4b8/uint8b128) + if _is_fp4_marlin: + scalar_type1 = _get_fp4_scalar_type() + scalar_type2 = _get_fp4_scalar_type() + else: + scalar_type1 = get_scalar_type(num_bits, w1_zeros is not None) + scalar_type2 = get_scalar_type(num_bits, w2_zeros is not None) intermediate_cache2 = torch.empty( (M * topk_ids.shape[1], N), @@ -150,7 +173,7 @@ def fused_marlin_moe( w1, None, # b_bias_or_none w1_scale, - None, # global_scale_or_none + w1_global_scale, # None for INT4/INT8, tensor for FP4 Marlin w1_zeros, g_idx1, sort_indices1, @@ -184,7 +207,7 @@ def fused_marlin_moe( w2, None, # b_bias_or_none w2_scale, - None, # global_scale_or_none + w2_global_scale, # None for INT4/INT8, tensor for FP4 Marlin w2_zeros, g_idx2, sort_indices2, diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py index 43ad055a2716..f3fc5a544f99 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py @@ -40,7 +40,6 @@ from sglang.srt.layers.moe.token_dispatcher.flashinfer import FlashinferDispatcher from sglang.srt.layers.moe.token_dispatcher.standard import ( StandardDispatcher, - StandardDispatchOutput, ) from sglang.srt.layers.moe.topk import ( BypassedTopKOutput, @@ -66,16 +65,11 @@ cpu_has_amx_support, get_bool_env_var, is_cpu, - is_flashinfer_available, is_hip, - next_power_of_2, round_up, ) from sglang.srt.utils.custom_op import register_custom_op -if is_flashinfer_available(): - from flashinfer import fp4_quantize - _is_hip = is_hip() _is_cpu_amx_available = cpu_has_amx_support() _is_cpu = is_cpu() @@ -1146,267 +1140,6 @@ def clear_overlap_args(self) -> None: self.meta_overlap_args = None -class FlashInferFusedMoE(FusedMoE): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - def forward(self, hidden_states: torch.Tensor, topk_output: TopKOutput): - assert TopKOutputChecker.format_is_bypassed( - topk_output - ), "Only bypassed topk output is supported for flashinfer trtllm moe" - - if is_in_piecewise_cuda_graph(): - return flashinfer_bf16_moe_forward_piecewise_cuda_graph_impl( - hidden_states, - topk_output.router_logits, - topk_output.topk_config.top_k, - topk_output.topk_config.topk_group, - topk_output.topk_config.num_expert_group, - topk_output.topk_config.correction_bias, - topk_output.topk_config.renormalize, - self.layer_id, - ) - else: - return self.forward_impl(hidden_states, topk_output) - - def forward_impl(self, hidden_states: torch.Tensor, topk_output: TopKOutput): - assert ( - self.moe_runner_config.activation == "silu" - ), "Only silu is supported for flashinfer trtllm moe" - assert self.quant_method is not None - assert ( - topk_output.topk_config.renormalize - ), "Renormalize is required for flashinfer trtllm moe" - assert ( - self.num_fused_shared_experts == 0 - ), "Fused shared experts are not supported for flashinfer trtllm moe" - assert ( - self.moe_runner_config.is_gated - ), "Only gated MoEs are supported for flashinfer trtllm moe" - - router_logits = topk_output.router_logits - topk_config = topk_output.topk_config - correction_bias = topk_config.correction_bias - routed_scaling_factor = self.moe_runner_config.routed_scaling_factor - - if isinstance(self.quant_method, UnquantizedFusedMoEMethod): - # lazy import - try: - from flashinfer.fused_moe import trtllm_bf16_moe - except ImportError as e: - raise ImportError( - "Can't import trtllm_bf16_moe from flashinfer. " - "Please check flashinfer version to use bf16 with flashinfer_trtllm backend." - ) from e - - # Allocate output inside symmetric memory context - with use_symmetric_memory( - get_tp_group(), disabled=not is_allocation_symmetric() - ): - # TODO: Now trtllm_bf16_moe doesn't support inplace output, - # we can move this out when it support that. - symm_output = torch.empty( - hidden_states.shape[0], - hidden_states.shape[1], - dtype=hidden_states.dtype, - device=hidden_states.device, - ) - - # Move kernel call outside context manager to avoid graph breaks - # during torch.compile for piecewise cuda graph - moe_result = trtllm_bf16_moe( - routing_logits=router_logits, - routing_bias=correction_bias, - hidden_states=hidden_states, - gemm1_weights=self.w13_weight, - gemm2_weights=self.w2_weight, - num_experts=self.num_experts, - top_k=topk_config.top_k, - n_group=topk_config.num_expert_group, - topk_group=topk_config.topk_group, - intermediate_size=self.intermediate_size_per_partition, - local_expert_offset=self.moe_ep_rank * self.num_local_experts, - local_num_experts=self.num_local_experts, - routing_method_type=self.routing_method_type, - tune_max_num_tokens=next_power_of_2(hidden_states.shape[0]), - ) - # Copy result to symmetric memory output - symm_output.copy_(moe_result) - final_hidden_states = symm_output - - else: - - final_hidden_states = self.quant_method.apply( - layer=self, - dispatch_output=StandardDispatchOutput( - hidden_states=hidden_states, - hidden_states_scale=None, - topk_output=topk_output, - ), - ).hidden_states - - # NOTE for symmetric memory tagging: - # We do not create the context in this function. - # Instead, we create the context and tagging inside each FusedMoEMethodBase - # This can allow fine-grained tagging. - - if self.reduce_results and (self.moe_tp_size > 1 or self.moe_ep_size > 1): - final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states) - - return final_hidden_states - - -class FlashInferFP4MoE(FusedMoE): - """FP4 TRTLLM MoE implementation using FlashInfer.""" - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - # --------------------------------------------------------------------- - # Helper: quantize hidden states to FP4 each forward pass - # --------------------------------------------------------------------- - def _quantize_hidden_states_fp4(self, hidden_states: torch.Tensor): - """ - Quantize hidden states using global scale factor from quantization method. - - Global scale factor is set by ModelOptNvFp4FusedMoEMethod during weight loading. - Only block scales are computed at runtime for efficiency. - - Returns (packed_fp4_uint8, scale_float8_e4m3fn_runtime, global_scale_float32) - """ - - # flashinfer.fp4_quantize returns (packed_uint8, scale_fp8) - # Only the block scales are computed at runtime - hs_fp4_bytes, hs_sf_bytes = fp4_quantize( - hidden_states, - self.w13_input_scale_quant, - 16, # sf_vec_size - False, # use_ue8m0 - False, # is_sf_swizzled_layout - ) - - seq_len, hidden_size = hidden_states.shape - hs_fp4 = hs_fp4_bytes.reshape(seq_len, hidden_size // 2) - # TRT-LLM expects hidden state scales shaped as [seq_len, hidden_size // 16] - hs_sf = hs_sf_bytes.view(torch.float8_e4m3fn).reshape( - seq_len, hidden_size // 16 - ) - - return hs_fp4, hs_sf - - def forward(self, hidden_states: torch.Tensor, topk_output: TopKOutput): - assert TopKOutputChecker.format_is_bypassed( - topk_output - ), "Only bypassed topk output is supported for flashinfer fp4 moe" - - if is_in_piecewise_cuda_graph(): - return flashinfer_fp4_moe_forward_piecewise_cuda_graph_impl( - hidden_states, - topk_output.router_logits, - topk_output.topk_config.top_k, - topk_output.topk_config.topk_group, - topk_output.topk_config.num_expert_group, - topk_output.topk_config.correction_bias, - self.layer_id, - ) - else: - return self.forward_impl(hidden_states, topk_output) - - def forward_impl(self, hidden_states: torch.Tensor, topk_output: TopKOutput): - """Forward pass using FP4 TRTLLM kernel. - - Args: - hidden_states: Input tensor - topk_output: TopKOutput object with Bypassed format - """ - from flashinfer.fused_moe import trtllm_fp4_block_scale_moe - - assert isinstance(self.quant_method, ModelOptNvFp4FusedMoEMethod) - - assert ( - self.moe_runner_config.is_gated - ), "Only gated MoEs are supported for flashinfer fp4 moe" - - assert TopKOutputChecker.format_is_bypassed(topk_output) - - router_logits = topk_output.router_logits - topk_config = topk_output.topk_config - - hs_fp4, hs_scale_linear = self._quantize_hidden_states_fp4(hidden_states) - routing_method_type = self.routing_method_type - assert ( - routing_method_type is not None - ), "flashinfer trtllm moe nvfp4 backend has not been adapted for the current moe layer, you can set routing_method_type (See definition of RoutingMethodType please) for the moe layer explicitly for a quick adaptation." - - # DeepSeekV3 style routing requires float32 router logits, - # see this PR for details: https://github.com/flashinfer-ai/flashinfer/commit/d84e1d560da0a27961c19ca788d96c19cb9dcfb6 - if routing_method_type == RoutingMethodType.DeepSeekV3: - router_logits = router_logits.to(torch.float32) - - correction_bias = ( - None - if topk_config.correction_bias is None - else topk_config.correction_bias.to(hidden_states.dtype) - ) - - with use_symmetric_memory( - get_tp_group(), disabled=not is_allocation_symmetric() - ): - num_tokens = hs_fp4.shape[0] - hidden_size = ( - hs_fp4.shape[-1] * 2 - if hs_fp4.dtype == torch.uint8 - else hs_fp4.shape[-1] - ) - symm_output = torch.empty( - num_tokens, hidden_size, dtype=torch.bfloat16, device=hs_fp4.device - ) - result = trtllm_fp4_block_scale_moe( - routing_logits=router_logits, - routing_bias=correction_bias, - hidden_states=hs_fp4, - hidden_states_scale=hs_scale_linear.view(torch.float8_e4m3fn).reshape( - *hs_scale_linear.shape[:-1], -1 - ), - gemm1_weights=self.gemm1_weights_fp4_shuffled.data, - gemm1_weights_scale=self.gemm1_scales_fp4_shuffled.data.view( - torch.float8_e4m3fn - ), - gemm1_bias=None, - gemm1_alpha=None, - gemm1_beta=None, - gemm1_clamp_limit=None, - gemm2_weights=self.gemm2_weights_fp4_shuffled.data, - gemm2_weights_scale=self.gemm2_scales_fp4_shuffled.data.view( - torch.float8_e4m3fn - ), - gemm2_bias=None, - output1_scale_scalar=self.g1_scale_c.data, - output1_scale_gate_scalar=self.g1_alphas.data, - output2_scale_scalar=self.g2_alphas.data, - num_experts=self.num_experts, - top_k=topk_config.top_k, - n_group=topk_config.num_expert_group, - topk_group=topk_config.topk_group, - intermediate_size=self.intermediate_size_per_partition, - local_expert_offset=self.moe_ep_rank * self.num_local_experts, - local_num_experts=self.num_local_experts, - routed_scaling_factor=self.moe_runner_config.routed_scaling_factor, - # Respect the routing method configured for this layer (e.g., Renormalize for Qwen3), - # instead of always assuming DeepSeekV3. - routing_method_type=( - self.routing_method_type - if self.routing_method_type is not None - else RoutingMethodType.Default - ), - do_finalize=True, - tune_max_num_tokens=next_power_of_2(hs_fp4.shape[0]), - output=symm_output, - )[0] - - return result - - @register_custom_op(out_shape="hidden_states") def moe_forward_piecewise_cuda_graph_impl( hidden_states: torch.Tensor, @@ -1449,55 +1182,3 @@ def fused_moe_bypassed_piecewise_cuda_graph_impl( forward_context = get_forward_context() moe_layer = forward_context.moe_layers[layer_id] return moe_layer.forward_impl(hidden_states, topk_output) - - -@register_custom_op(out_shape="hidden_states") -def flashinfer_bf16_moe_forward_piecewise_cuda_graph_impl( - hidden_states: torch.Tensor, - router_logits: torch.Tensor, - top_k: int, - topk_group: Optional[int], - num_expert_group: Optional[int], - correction_bias: Optional[torch.Tensor], - renormalize: bool, - layer_id: int, -) -> torch.Tensor: - topk_output = BypassedTopKOutput( - hidden_states=hidden_states, - router_logits=router_logits, - topk_config=TopKConfig( - top_k=top_k, - topk_group=topk_group, - num_expert_group=num_expert_group, - correction_bias=correction_bias, - renormalize=renormalize, - ), - ) - forward_context = get_forward_context() - moe_layer = forward_context.moe_layers[layer_id] - return moe_layer.forward_impl(hidden_states, topk_output) - - -@register_custom_op(out_shape="hidden_states") -def flashinfer_fp4_moe_forward_piecewise_cuda_graph_impl( - hidden_states: torch.Tensor, - router_logits: torch.Tensor, - top_k: int, - topk_group: Optional[int], - num_expert_group: Optional[int], - correction_bias: Optional[torch.Tensor], - layer_id: int, -) -> torch.Tensor: - topk_output = BypassedTopKOutput( - hidden_states=hidden_states, - router_logits=router_logits, - topk_config=TopKConfig( - top_k=top_k, - topk_group=topk_group, - num_expert_group=num_expert_group, - correction_bias=correction_bias, - ), - ) - forward_context = get_forward_context() - moe_layer = forward_context.moe_layers[layer_id] - return moe_layer.forward_impl(hidden_states, topk_output) diff --git a/python/sglang/srt/layers/moe/moe_runner/flashinfer_trtllm.py b/python/sglang/srt/layers/moe/moe_runner/flashinfer_trtllm.py index 68decf875ba7..3ccdfd66fa48 100644 --- a/python/sglang/srt/layers/moe/moe_runner/flashinfer_trtllm.py +++ b/python/sglang/srt/layers/moe/moe_runner/flashinfer_trtllm.py @@ -31,7 +31,6 @@ from sglang.srt.utils.common import ( is_cuda_alike, is_flashinfer_available, - is_sm120_supported, next_power_of_2, ) @@ -41,7 +40,7 @@ StandardDispatchOutput, ) -if is_flashinfer_available() and is_sm120_supported(): +if is_flashinfer_available(): from flashinfer import fp4_quantize elif is_cuda_alike(): from sglang.jit_kernel.nvfp4 import scaled_fp4_quant as fp4_quantize @@ -362,7 +361,7 @@ def fused_experts_none_to_flashinfer_trtllm_fp8( symm_output = torch.empty( hidden_states.shape[0], hidden_states.shape[1], - dtype=torch.bfloat16, + dtype=hidden_states.dtype, device=hidden_states.device, ) @@ -442,9 +441,11 @@ def fused_experts_none_to_flashinfer_trtllm_fp8( tune_max_num_tokens=next_power_of_2(a_q.shape[0]), fp8_quantization_type=int(fp8_quantization_type), ) + # TODO: Once https://github.com/flashinfer-ai/flashinfer/issues/2703 is fixed, pass output to moe kernel and remove this copy. symm_output.copy_(output) output = symm_output else: + assert TopKOutputChecker.format_is_bypassed(topk_output) assert quant_info.w13_input_scale is not None assert quant_info.output1_scales_scalar is not None assert quant_info.output1_scales_gate_scalar is not None @@ -564,7 +565,7 @@ def fused_experts_none_to_flashinfer_trtllm_fp4( """FlashInfer TRTLLM FP4 MoE forward pass. This function handles the FP4 TRTLLM MoE path that was previously in - FlashInferFP4MoE.forward_impl and ModelOptNvFp4FusedMoEMethod.apply. + ModelOptNvFp4FusedMoEMethod.apply. """ from flashinfer.fused_moe import trtllm_fp4_block_scale_moe @@ -638,7 +639,6 @@ def fused_experts_none_to_flashinfer_trtllm_fp4( local_expert_offset=quant_info.local_expert_offset, local_num_experts=quant_info.local_num_experts, routed_scaling_factor=runner_config.routed_scaling_factor, - tile_tokens_dim=None, routing_method_type=( routing_method_type if routing_method_type is not None diff --git a/python/sglang/srt/layers/moe/moe_runner/marlin.py b/python/sglang/srt/layers/moe/moe_runner/marlin.py index 45104dd27805..429b28697d23 100644 --- a/python/sglang/srt/layers/moe/moe_runner/marlin.py +++ b/python/sglang/srt/layers/moe/moe_runner/marlin.py @@ -69,8 +69,13 @@ class MarlinMoeQuantInfo(MoeQuantInfo): w13_qzeros: Optional[torch.Tensor] = None w2_qzeros: Optional[torch.Tensor] = None - # Optional + # FP4 Marlin specific (Optional) + w13_global_scale: Optional[torch.Tensor] = None + w2_global_scale: Optional[torch.Tensor] = None + + # EP support (Optional) expert_map: Optional[torch.Tensor] = None + global_num_experts: int = -1 @register_fused_func("none", "marlin") @@ -106,6 +111,7 @@ def fused_experts_none_to_marlin( gating_output=topk_output.router_logits, topk_weights=topk_output.topk_weights, topk_ids=topk_output.topk_ids, + global_num_experts=quant_info.global_num_experts, expert_map=quant_info.expert_map, g_idx1=quant_info.w13_g_idx, g_idx2=quant_info.w2_g_idx, @@ -118,6 +124,8 @@ def fused_experts_none_to_marlin( is_k_full=quant_info.is_k_full, inplace=runner_config.inplace, routed_scaling_factor=runner_config.routed_scaling_factor, + w1_global_scale=quant_info.w13_global_scale, + w2_global_scale=quant_info.w2_global_scale, ).to(hidden_states.dtype) return StandardCombineInput( diff --git a/python/sglang/srt/layers/moe/token_dispatcher/standard.py b/python/sglang/srt/layers/moe/token_dispatcher/standard.py index b77c19f83a69..ef78ca9f3e93 100644 --- a/python/sglang/srt/layers/moe/token_dispatcher/standard.py +++ b/python/sglang/srt/layers/moe/token_dispatcher/standard.py @@ -30,7 +30,12 @@ get_moe_runner_backend, should_use_flashinfer_cutlass_moe_fp4_allgather, ) -from sglang.srt.utils.common import get_bool_env_var, is_hip, is_sm120_supported +from sglang.srt.utils.common import ( + get_bool_env_var, + get_device, + is_hip, + is_sm120_supported, +) _is_hip = is_hip() _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip @@ -149,15 +154,16 @@ def dispatch( and TopKOutputChecker.format_is_standard(topk_output) ): if self.local_expert_mapping is None: + device = get_device() self.local_expert_mapping = torch.full( - (self.num_experts,), -1, dtype=torch.int32, device="cuda" + (self.num_experts,), -1, dtype=torch.int32, device=device ) self.local_expert_mapping[ self.moe_ep_rank * self.num_local_routed_experts : (self.moe_ep_rank + 1) * self.num_local_routed_experts ] = torch.arange( - 0, self.num_local_routed_experts, dtype=torch.int32, device="cuda" + 0, self.num_local_routed_experts, dtype=torch.int32, device=device ) if self.num_local_shared_experts > 0: diff --git a/python/sglang/srt/layers/moe/utils.py b/python/sglang/srt/layers/moe/utils.py index 65da83c0ec44..0d5fa7ddbce4 100644 --- a/python/sglang/srt/layers/moe/utils.py +++ b/python/sglang/srt/layers/moe/utils.py @@ -26,6 +26,7 @@ class MoeA2ABackend(Enum): MORI = "mori" ASCEND_FUSEEP = "ascend_fuseep" FLASHINFER = "flashinfer" + CUSTOMIZED = "customized" @classmethod def _missing_(cls, value): @@ -57,6 +58,9 @@ def is_ascend_fuseep(self): def is_mori(self): return self == MoeA2ABackend.MORI + def is_customized(self): + return self == MoeA2ABackend.CUSTOMIZED + class MoeRunnerBackend(Enum): diff --git a/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py b/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py index 1072dac7b3ca..872daa191b18 100644 --- a/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py +++ b/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py @@ -16,6 +16,10 @@ CompressedTensorsLinearScheme, ) from sglang.srt.layers.quantization.fp4_utils import get_fp4_gemm_runner_backend +from sglang.srt.layers.quantization.marlin_utils_fp4 import ( + prepare_fp4_layer_for_marlin, + should_use_fp4_marlin_fallback, +) from sglang.srt.layers.quantization.modelopt_quant import ( enable_flashinfer_fp4_gemm, fp4_gemm, @@ -34,7 +38,7 @@ def __init__(self): @classmethod def get_min_capability(cls) -> int: - return 100 + return 75 # SM75+ (Turing) supports Marlin FP4 fallback; SM100 for native FP4 def create_weights( self, @@ -47,6 +51,7 @@ def create_weights( ): output_size_per_partition = sum(output_partition_sizes) layer.logical_widths = output_partition_sizes + layer.params_dtype = params_dtype layer.input_size_per_partition = input_size_per_partition layer.output_size_per_partition = output_size_per_partition @@ -91,6 +96,20 @@ def create_weights( layer.register_parameter("input_global_scale", input_global_scale) def process_weights_after_loading(self, layer) -> None: + if should_use_fp4_marlin_fallback(): + # Marlin FP4 fallback: consolidate global scale then repack weights + global_scale = layer.weight_global_scale.max().to(torch.float32) + layer.weight_global_scale = Parameter(global_scale, requires_grad=False) + prepare_fp4_layer_for_marlin( + layer, + weight_attr="weight_packed", + weight_scale_attr="weight_scale", + weight_global_scale_attr="weight_global_scale", + ) + layer.use_marlin_fallback = True + return + + layer.use_marlin_fallback = False global_input_scale = layer.input_global_scale.max().to(torch.float32) layer.input_global_scale = Parameter(global_input_scale, requires_grad=False) @@ -136,6 +155,18 @@ def apply_weights( x: torch.Tensor, bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: + if layer.use_marlin_fallback: + return torch.ops.sglang.apply_fp4_marlin_linear( + input=x, + weight=layer.weight_packed, + weight_scale=layer.weight_scale, + weight_global_scale=layer.weight_global_scale, + workspace=layer.marlin_workspace, + size_n=layer.output_size_per_partition, + size_k=layer.input_size_per_partition, + bias=bias, + ) + output_dtype = x.dtype w_n, _ = layer.weight_packed.shape output_shape = [x.shape[0], w_n] @@ -150,7 +181,10 @@ def apply_weights( w = layer.weight_packed w_blockscale = layer.weight_scale - if enable_flashinfer_fp4_gemm: + if ( + enable_flashinfer_fp4_gemm + and not get_fp4_gemm_runner_backend().is_cutlass() + ): w = layer.weight_packed.T w_blockscale = layer.weight_scale.T diff --git a/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4_moe.py b/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4_moe.py index 5898a078dbba..7824a3bcce60 100644 --- a/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4_moe.py +++ b/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4_moe.py @@ -17,6 +17,10 @@ CompressedTensorsMoEScheme, ) from sglang.srt.layers.quantization.fp8_utils import is_blackwell_supported +from sglang.srt.layers.quantization.marlin_utils_fp4 import ( + prepare_moe_fp4_layer_for_marlin, + should_use_fp4_marlin_fallback, +) from sglang.srt.layers.quantization.utils import ( prepare_static_weights_for_trtllm_fp4_moe, reorder_w1w3_to_w3w1, @@ -38,19 +42,27 @@ class CompressedTensorsW4A4Nvfp4MoE(CompressedTensorsMoEScheme): def __init__(self): - if not is_blackwell_supported(): + self.group_size = 16 + + if should_use_fp4_marlin_fallback(): + logger.warning_once( + "GPU is not Blackwell (SM100+). Using Marlin FP4 fallback kernel " + "for MoE layers. Weights remain compressed in FP4 format." + ) + self.use_marlin_fallback = True + self.use_flashinfer_trtllm = False + elif not is_blackwell_supported(): raise ValueError( "Current platform does not support NVFP4" - " quantization. Please use Blackwell and" - " above." + " quantization. Please use SM75+ (Turing or newer)." ) - self.group_size = 16 - self.use_flashinfer_trtllm = get_moe_runner_backend().is_flashinfer_trtllm() + else: + self.use_marlin_fallback = False + self.use_flashinfer_trtllm = get_moe_runner_backend().is_flashinfer_trtllm() @classmethod def get_min_capability(cls) -> int: - # Requires sm100(blackwell) architecture - return 100 + return 75 # SM75+ (Turing) supports Marlin FP4 fallback; SM100 for native FP4 def create_weights( self, @@ -64,6 +76,7 @@ def create_weights( from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported layer.params_dtype = params_dtype + layer.intermediate_size_per_partition = intermediate_size_per_partition w13_weight = torch.nn.Parameter( torch.empty( @@ -175,6 +188,21 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: ) delattr(layer, "w2_weight_packed") + if self.use_marlin_fallback: + # CompressedTensors checkpoint: global_scale is stored as the inverse. + # Actual dequant scale = 1 / stored_value. We create w*_weight_scale_2 + # with the actual scale before calling prepare_moe_fp4_layer_for_marlin(). + layer.w13_weight_scale_2 = torch.nn.Parameter( + (1.0 / layer.w13_weight_global_scale).to(layer.params_dtype), + requires_grad=False, + ) # [E, 2] + layer.w2_weight_scale_2 = torch.nn.Parameter( + (1.0 / layer.w2_weight_global_scale).to(layer.params_dtype), + requires_grad=False, + ) # [E] + prepare_moe_fp4_layer_for_marlin(layer) + return + if self.use_flashinfer_trtllm: w, s = reorder_w1w3_to_w3w1( layer.w13_weight.data, layer.w13_weight_scale.data, dim=-2 @@ -303,7 +331,10 @@ def create_moe_runner( self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig ): self.moe_runner_config = moe_runner_config - self.runner = MoeRunner(MoeRunnerBackend.TRITON, moe_runner_config) + if self.use_marlin_fallback: + self.runner = MoeRunner(MoeRunnerBackend.MARLIN, moe_runner_config) + else: + self.runner = MoeRunner(MoeRunnerBackend.TRITON, moe_runner_config) def apply_weights( self, @@ -313,6 +344,33 @@ def apply_weights( from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput + if self.use_marlin_fallback: + from sglang.srt.layers.moe.moe_runner.marlin import MarlinMoeQuantInfo + + expert_map = None + global_num_experts = -1 + if hasattr(layer, "dispatcher") and hasattr( + layer.dispatcher, "local_expert_mapping" + ): + expert_map = layer.dispatcher.local_expert_mapping + if expert_map is not None: + global_num_experts = self.moe_runner_config.num_experts + + quant_info = MarlinMoeQuantInfo( + w13_qweight=layer.w13_weight, + w2_qweight=layer.w2_weight, + w13_scales=layer.w13_weight_scale, + w2_scales=layer.w2_weight_scale, + w13_g_idx_sort_indices=None, + w2_g_idx_sort_indices=None, + weight_bits=4, + w13_global_scale=layer.w13_weight_scale_2, + w2_global_scale=layer.w2_weight_scale_2, + expert_map=expert_map, + global_num_experts=global_num_experts, + ) + return self.runner.run(dispatch_output, quant_info) + x = dispatch_output.hidden_states topk_output = dispatch_output.topk_output diff --git a/python/sglang/srt/layers/quantization/fp4_utils.py b/python/sglang/srt/layers/quantization/fp4_utils.py index 3e913e137f02..b1bd80dd09d5 100644 --- a/python/sglang/srt/layers/quantization/fp4_utils.py +++ b/python/sglang/srt/layers/quantization/fp4_utils.py @@ -4,7 +4,6 @@ from enum import Enum from typing import TYPE_CHECKING -from sglang.srt.environ import envs from sglang.srt.utils.common import is_sm120_supported if TYPE_CHECKING: @@ -17,6 +16,7 @@ class Fp4GemmRunnerBackend(Enum): """Enum for FP4 GEMM runner backend selection.""" AUTO = "auto" + CUTLASS = "cutlass" FLASHINFER_CUDNN = "flashinfer_cudnn" FLASHINFER_CUTLASS = "flashinfer_cutlass" FLASHINFER_TRTLLM = "flashinfer_trtllm" @@ -24,6 +24,9 @@ class Fp4GemmRunnerBackend(Enum): def is_auto(self) -> bool: return self == Fp4GemmRunnerBackend.AUTO + def is_cutlass(self) -> bool: + return self == Fp4GemmRunnerBackend.CUTLASS + def is_flashinfer_cudnn(self) -> bool: return self == Fp4GemmRunnerBackend.FLASHINFER_CUDNN @@ -33,6 +36,9 @@ def is_flashinfer_cutlass(self) -> bool: def is_flashinfer_trtllm(self) -> bool: return self == Fp4GemmRunnerBackend.FLASHINFER_TRTLLM + def is_flashinfer(self) -> bool: + return self.value.startswith("flashinfer_") + def get_flashinfer_backend(self) -> str: """Get the backend string to pass to FlashInfer's mm_fp4 API. @@ -56,26 +62,6 @@ def initialize_fp4_gemm_config(server_args: ServerArgs) -> None: global FP4_GEMM_RUNNER_BACKEND backend = server_args.fp4_gemm_runner_backend - - # Handle deprecated env var for backward compatibility - # TODO: Remove this in a future version - if envs.SGLANG_FLASHINFER_FP4_GEMM_BACKEND.is_set(): - env_backend = envs.SGLANG_FLASHINFER_FP4_GEMM_BACKEND.get() - if backend == "auto": - logger.warning( - "SGLANG_FLASHINFER_FP4_GEMM_BACKEND is deprecated. " - f"Please use '--fp4-gemm-backend={env_backend}' instead." - ) - if not env_backend.startswith("flashinfer_"): - env_backend = "flashinfer_" + env_backend - backend = env_backend - else: - logger.warning( - f"FP4 GEMM backend set to '{backend}' via --fp4-gemm-backend overrides " - "environment variable SGLANG_FLASHINFER_FP4_GEMM_BACKEND. " - "Using server argument value." - ) - if backend == "auto": if is_sm120_supported(): # flashinfer_cutlass produces NaN in dense MLP layers with diff --git a/python/sglang/srt/layers/quantization/fp8.py b/python/sglang/srt/layers/quantization/fp8.py index 7182e3d57ba6..9d0d15716050 100644 --- a/python/sglang/srt/layers/quantization/fp8.py +++ b/python/sglang/srt/layers/quantization/fp8.py @@ -495,20 +495,44 @@ def _process_mxfp8_linear_weight_scale(self, layer: Module) -> None: return if get_fp8_gemm_runner_backend().is_flashinfer_trtllm(): + from flashinfer import shuffle_matrix_a, shuffle_matrix_sf_a + + weight = layer.weight.data + scale_u8 = layer.weight_scale_inv.data + n, k = weight.shape + epilogue_tile_m = 128 + + copy_or_rebind_param( + layer, + "weight", + shuffle_matrix_a( + weight.contiguous().view(torch.uint8), epilogue_tile_m + ).view(torch.float8_e4m3fn), + ) + copy_or_rebind_param( + layer, + "weight_scale_inv", + shuffle_matrix_sf_a( + scale_u8.contiguous().view(torch.uint8).reshape(n, k // 32), + epilogue_tile_m, + num_elts_per_sf=32, + ) + .reshape_as(scale_u8) + .contiguous(), + ) + elif get_fp8_gemm_runner_backend().is_flashinfer_cutlass(): from flashinfer import block_scale_interleave scale_u8 = layer.weight_scale_inv.data - new_swizzled = block_scale_interleave(scale_u8.contiguous()).contiguous() + copy_or_rebind_param( + layer, + "weight_scale_inv", + block_scale_interleave(scale_u8.contiguous()).contiguous(), + ) else: # Triton path consumes canonical 2D UE8M0 scales directly. return - copy_or_rebind_param(layer, "weight_scale_inv_swizzled", new_swizzled) - layer._weight_scale_inv_swizzled_src_version = layer.weight_scale_inv._version - layer._weight_scale_inv_swizzled_src_data_ptr = ( - layer.weight_scale_inv.data_ptr() - ) - def _quantize_mxfp8_weights(self, layer: Module) -> None: weight = layer.weight.data qweight, weight_scale = mxfp8_group_quantize(weight) @@ -657,22 +681,18 @@ def apply( ) if self.use_mxfp8: - if get_fp8_gemm_runner_backend().is_flashinfer_trtllm(): - weight_scale = layer.weight_scale_inv_swizzled - else: - weight_scale = layer.weight_scale_inv if isinstance(x, tuple): return self.w8a8_mxfp8_linear( input=x[0], weight=layer.weight, - weight_scale=weight_scale, + weight_scale=layer.weight_scale_inv, input_scale=x[1], bias=bias, ) return self.w8a8_mxfp8_linear( input=x, weight=layer.weight, - weight_scale=weight_scale, + weight_scale=layer.weight_scale_inv, input_scale=None, bias=bias, ) @@ -1349,7 +1369,10 @@ def process_weights_after_loading(self, layer: Module) -> None: self.process_weights_hip_scale_padding(layer) # Align FP8 weights to FlashInfer per-tensor kernel layout if enabled - if get_moe_runner_backend().is_flashinfer_trtllm(): + if ( + get_moe_runner_backend().is_flashinfer_trtllm() + or get_moe_runner_backend().is_flashinfer_trtllm_routed() + ): from sglang.srt.layers.moe.moe_runner.flashinfer_trtllm import ( align_fp8_moe_weights_for_flashinfer_trtllm, ) @@ -1599,7 +1622,8 @@ def apply( local_num_experts=num_local_experts, intermediate_size=layer.w2_weight.shape[2], routing_method_type=int( - getattr(layer, "routing_method_type", RoutingMethodType.DeepSeekV3) + getattr(layer, "routing_method_type", None) + or RoutingMethodType.DeepSeekV3 ), block_quant=self.block_quant, use_mxfp8=getattr(self.quant_config, "use_mxfp8", False), diff --git a/python/sglang/srt/layers/quantization/fp8_kernel.py b/python/sglang/srt/layers/quantization/fp8_kernel.py index 24a1691d225d..016c836fe303 100644 --- a/python/sglang/srt/layers/quantization/fp8_kernel.py +++ b/python/sglang/srt/layers/quantization/fp8_kernel.py @@ -47,6 +47,8 @@ _is_hip = is_hip() _is_cuda = is_cuda() _is_cpu = is_cpu() +_is_sm100_supported = is_sm100_supported() +_is_sm120_supported = is_sm120_supported() _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip if _is_cuda: @@ -1299,7 +1301,7 @@ def mxfp8_block_scaled_matmul_triton( SM120: 1, SM100: 4. """ if num_stages is None: - num_stages = 1 if is_sm120_supported() else (4 if is_sm100_supported() else 1) + num_stages = 1 if _is_sm120_supported else (4 if _is_sm100_supported else 1) M, K = a.shape N, K_b = b.shape assert K == K_b diff --git a/python/sglang/srt/layers/quantization/fp8_utils.py b/python/sglang/srt/layers/quantization/fp8_utils.py index 3f26b736a6a9..49b5aeb3075b 100755 --- a/python/sglang/srt/layers/quantization/fp8_utils.py +++ b/python/sglang/srt/layers/quantization/fp8_utils.py @@ -7,7 +7,6 @@ import torch -from sglang.srt.environ import envs from sglang.srt.layers import deep_gemm_wrapper from sglang.srt.layers.quantization.fp8_kernel import sglang_per_token_group_quant_fp8 from sglang.srt.layers.quantization.mxfp4_tensor import MXFP4QuantizeUtil @@ -45,12 +44,15 @@ is_sm120_supported, offloader, ) +from sglang.srt.utils.custom_op import register_custom_op logger = logging.getLogger(__name__) _is_hip = is_hip() _is_cuda = is_cuda() _is_fp8_fnuz = is_fp8_fnuz() +_is_sm100_supported = is_sm100_supported() +_is_sm120_supported = is_sm120_supported() _is_gfx95_supported = is_gfx95_supported() _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip @@ -213,6 +215,7 @@ def _check_cutlass_block_fp8_hardware_support() -> bool: if is_blackwell_supported() and is_flashinfer_available(): + from flashinfer import SfLayout from flashinfer import mm_mxfp8 as _raw_flashinfer_mm_mxfp8 from flashinfer import mxfp8_quantize as _raw_flashinfer_mxfp8_quantize from flashinfer.gemm import gemm_fp8_nt_groupwise as _raw_gemm_fp8_nt_groupwise @@ -302,12 +305,13 @@ def flashinfer_mxfp8_quantize( input, is_sf_swizzled_layout=is_sf_swizzled_layout, alignment=alignment, + sf_swizzle_layout=SfLayout.layout_128x4, ) @register_custom_op( op_name="flashinfer_mm_mxfp8", mutates_args=[], - fake_impl=lambda q_input, weight_t, x_scale_u8, weight_scale_t, out_dtype, backend="auto": ( + fake_impl=lambda q_input, weight_t, x_scale_u8, weight_scale_t, out_dtype, use_8x4_sf_layout=False, backend="auto": ( q_input.new_empty((q_input.shape[0], weight_t.shape[1]), dtype=out_dtype) ), ) @@ -317,6 +321,7 @@ def flashinfer_mm_mxfp8( x_scale_u8: torch.Tensor, weight_scale_t: torch.Tensor, out_dtype: torch.dtype, + use_8x4_sf_layout: bool = False, backend: str = "auto", ) -> torch.Tensor: return _raw_flashinfer_mm_mxfp8( @@ -325,6 +330,7 @@ def flashinfer_mm_mxfp8( x_scale_u8, weight_scale_t, out_dtype=out_dtype, + use_8x4_sf_layout=use_8x4_sf_layout, backend=backend, ) @@ -356,11 +362,13 @@ def dispatch_w8a8_mxfp8_linear() -> Callable: """Dispatch MXFP8 linear kernel by --fp8-gemm-backend. For MXFP8, Triton remains the default path. We only route to FlashInfer - when backend is explicitly set to flashinfer_trtllm. + when backend is explicitly set to flashinfer_cutlass or flashinfer_trtllm. """ backend = get_fp8_gemm_runner_backend() if backend.is_flashinfer_trtllm(): return flashinfer_mxfp8_blockscaled_linear + elif backend.is_flashinfer_cutlass(): + return flashinfer_mxfp8_blockscaled_linear return triton_mxfp8_blockscaled_linear @@ -453,25 +461,6 @@ def initialize_fp8_gemm_config(server_args: ServerArgs) -> None: global FP8_GEMM_RUNNER_BACKEND backend = server_args.fp8_gemm_runner_backend - - # TODO(brayden): Remove env-based overrides in v0.5.7, they will be fully removed in v0.5.7. - # Only check environment variables when the server args is not set, server args should take priority. - if backend == "auto": - if envs.SGLANG_ENABLE_FLASHINFER_FP8_GEMM.get(): - backend = "flashinfer_trtllm" - elif envs.SGLANG_SUPPORT_CUTLASS_BLOCK_FP8.get(): - backend = "cutlass" - else: - if ( - envs.SGLANG_ENABLE_FLASHINFER_FP8_GEMM.get() - or envs.SGLANG_SUPPORT_CUTLASS_BLOCK_FP8.get() - ): - logger.warning( - f"FP8 GEMM backend set to '{backend}' via --fp8-gemm-backend overrides " - "environment variables SGLANG_ENABLE_FLASHINFER_FP8_GEMM and " - "SGLANG_SUPPORT_CUTLASS_BLOCK_FP8. Using server argument value." - ) - if backend == "auto" and is_sm120_supported(): # TODO(brayden): Verify if CUTLASS can be set by default once SwapAB is supported backend = "triton" @@ -875,7 +864,40 @@ def _pack_mxfp8_scales(scale_u8: torch.Tensor) -> torch.Tensor: return packed.view(1, scale_m, scale_k, 2, 256) -def triton_mxfp8_blockscaled_linear( +@register_custom_op( + op_name="triton_mxfp8_block_scaled_matmul", + mutates_args=[], + fake_impl=lambda a, a_scale, b, b_scale, output_dtype, block_m=128, block_n=256, block_k=128, num_stages=None: ( # noqa: E501 + a.new_empty((a.shape[0], b.shape[0]), dtype=output_dtype) + ), +) +def triton_mxfp8_block_scaled_matmul( + a: torch.Tensor, + a_scale: torch.Tensor, + b: torch.Tensor, + b_scale: torch.Tensor, + output_dtype: torch.dtype, + *, + block_m: int = 128, + block_n: int = 256, + block_k: int = 128, + num_stages: Optional[int] = None, +) -> torch.Tensor: + """Opaque custom op wrapper to prevent Dynamo tracing Triton grid math.""" + return mxfp8_block_scaled_matmul_triton( + a, + a_scale, + b, + b_scale, + output_dtype=output_dtype, + block_m=block_m, + block_n=block_n, + block_k=block_k, + num_stages=num_stages, + ) + + +def _raw_triton_mxfp8_blockscaled_linear( input: torch.Tensor, weight: torch.Tensor, weight_scale: torch.Tensor, @@ -883,7 +905,7 @@ def triton_mxfp8_blockscaled_linear( bias: Optional[torch.Tensor] = None, output_dtype: Optional[torch.dtype] = None, ) -> torch.Tensor: - if not (_is_cuda and (is_sm100_supported() or is_sm120_supported())): + if not (_is_cuda and (_is_sm100_supported or _is_sm120_supported)): raise RuntimeError("MXFP8 dense linear requires Blackwell GPUs (SM100/SM120).") input_2d = input.view(-1, input.shape[-1]).contiguous() @@ -935,8 +957,8 @@ def triton_mxfp8_blockscaled_linear( a_scale_packed = _pack_mxfp8_scales(x_scale_u8) b_scale_packed = _pack_mxfp8_scales(weight_scale) - num_stages = 1 if is_sm120_supported() else (4 if is_sm100_supported() else 1) - output = mxfp8_block_scaled_matmul_triton( + num_stages = 1 if _is_sm120_supported else (4 if _is_sm100_supported else 1) + output = triton_mxfp8_block_scaled_matmul( q_input, a_scale_packed, weight.contiguous(), @@ -953,6 +975,35 @@ def triton_mxfp8_blockscaled_linear( return output.to(dtype=output_dtype).view(*output_shape) +@register_custom_op( + op_name="triton_mxfp8_blockscaled_linear", + mutates_args=[], + fake_impl=lambda input, weight, weight_scale, input_scale=None, bias=None, output_dtype=None: ( + input.new_empty( + (*input.shape[:-1], weight.shape[0]), + dtype=(output_dtype if output_dtype is not None else input.dtype), + ) + ), +) +def triton_mxfp8_blockscaled_linear( + input: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + input_scale: Optional[torch.Tensor] = None, + bias: Optional[torch.Tensor] = None, + output_dtype: Optional[torch.dtype] = None, +) -> torch.Tensor: + """Opaque custom-op wrapper to prevent Dynamo guards on MXFP8 padding branches.""" + return _raw_triton_mxfp8_blockscaled_linear( + input=input, + weight=weight, + weight_scale=weight_scale, + input_scale=input_scale, + bias=bias, + output_dtype=output_dtype, + ) + + def flashinfer_mxfp8_blockscaled_linear( input: torch.Tensor, weight: torch.Tensor, @@ -980,6 +1031,7 @@ def flashinfer_mxfp8_blockscaled_linear( ) else: q_input = input_2d + x_scale_u8 = input_scale.contiguous() if output_dtype is None: if input_2d.dtype in (torch.float16, torch.bfloat16, torch.float32): @@ -989,19 +1041,34 @@ def flashinfer_mxfp8_blockscaled_linear( # Ensure transposed tensors are contiguous for FlashInfer's internal runner. weight_t = weight.contiguous().t() - weight_scale_t = ( - weight_scale.contiguous().t() - if weight_scale.ndim == 2 - else weight_scale.contiguous() - ) - output = flashinfer_mm_mxfp8( - q_input, - weight_t, - x_scale_u8, - weight_scale_t, - out_dtype=output_dtype, - backend="auto", - ) + + if get_fp8_gemm_runner_backend().is_flashinfer_trtllm(): + + weight_scale_t = weight_scale.contiguous().view(-1) + output = flashinfer_mm_mxfp8( + q_input, + weight_t, + x_scale_u8, + weight_scale_t, + out_dtype=output_dtype, + use_8x4_sf_layout=False, + backend="trtllm", + ) + elif get_fp8_gemm_runner_backend().is_flashinfer_cutlass(): + weight_scale_t = ( + weight_scale.contiguous().t() + if weight_scale.ndim == 2 + else weight_scale.contiguous() + ) + output = flashinfer_mm_mxfp8( + q_input, + weight_t, + x_scale_u8, + weight_scale_t, + out_dtype=output_dtype, + use_8x4_sf_layout=False, + backend="cutlass", + ) if bias is not None: output += bias diff --git a/python/sglang/srt/layers/quantization/marlin_utils.py b/python/sglang/srt/layers/quantization/marlin_utils.py index d2761fc8e88f..04b72ca0d68f 100644 --- a/python/sglang/srt/layers/quantization/marlin_utils.py +++ b/python/sglang/srt/layers/quantization/marlin_utils.py @@ -261,7 +261,7 @@ def marlin_make_workspace( device: torch.device, max_blocks_per_sm: int = 1 ) -> torch.Tensor: # In the new marlin kernel, we use the num of threadblocks as workspace - # size. The num of threadblocks is is sms_count * max_blocks_per_sm. + # size. The num of threadblocks is sms_count * max_blocks_per_sm. sms = torch.cuda.get_device_properties(device).multi_processor_count return torch.zeros( sms * max_blocks_per_sm, dtype=torch.int, device=device, requires_grad=False diff --git a/python/sglang/srt/layers/quantization/marlin_utils_fp4.py b/python/sglang/srt/layers/quantization/marlin_utils_fp4.py new file mode 100644 index 000000000000..5a9fb3cef84b --- /dev/null +++ b/python/sglang/srt/layers/quantization/marlin_utils_fp4.py @@ -0,0 +1,320 @@ +# SPDX-License-Identifier: Apache-2.0 +# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/utils/marlin_utils_fp4.py + +"""NVFP4 Marlin fallback: run FP4-quantized models on non-Blackwell GPUs via Marlin kernel.""" + +import logging +from typing import Optional + +import torch + +from sglang.srt.layers.quantization.marlin_utils import ( + USE_FP32_REDUCE_DEFAULT, + marlin_make_workspace, + marlin_permute_bias, + marlin_permute_scales, + should_use_atomic_add_reduce, +) +from sglang.srt.layers.quantization.utils import get_scalar_types +from sglang.srt.utils import direct_register_custom_op, get_device_capability, is_cuda + +_is_cuda = is_cuda() +if _is_cuda: + from sglang.jit_kernel.gptq_marlin import gptq_marlin_gemm + from sglang.jit_kernel.gptq_marlin_repack import gptq_marlin_repack + +ScalarType, scalar_types = get_scalar_types() + +logger = logging.getLogger(__name__) + +# NVFP4 always uses group_size=16 +FP4_MARLIN_GROUP_SIZE = 16 + + +def is_fp4_marlin_supported() -> bool: + """Check if the current GPU supports FP4 Marlin fallback (CUDA SM >= 75).""" + if not _is_cuda: + return False + if torch.version.hip is not None: + return False + major, minor = get_device_capability() + if major is None or minor is None: + return False + return (major * 10 + minor) >= 75 + + +def should_use_fp4_marlin_fallback() -> bool: + """True if non-Blackwell (or forced) AND Marlin kernel available (SM >= 75).""" + from sglang.srt.environ import envs + from sglang.srt.layers.quantization.fp8_utils import is_blackwell_supported + + force = envs.SGLANG_FORCE_NVFP4_MARLIN.get() + return (force or not is_blackwell_supported()) and is_fp4_marlin_supported() + + +def nvfp4_marlin_process_scales(marlin_scales: torch.Tensor) -> torch.Tensor: + """Convert NVFP4 scales from FP8-S1E4M3 to FP8-S0E5M3 format for Marlin. + + The int16 <<1 may wrap for large scales (e.g. 448*128=57344), but the BIT + PATTERN is preserved correctly — the kernel reads raw bytes, not int16 values. + """ + marlin_scales = marlin_scales.to(torch.half) + + if not (marlin_scales >= 0).all(): + logger.warning_once( + "NVFP4 Marlin assumes scales >= 0, but encountered negative scales. " + "Accuracy may be degraded. The scales are converted from FP8-S1E4M3 " + "to a special FP8-S0E5M3 format to speed up dequantization." + ) + + marlin_scales = marlin_scales.view(-1, 4)[:, [0, 2, 1, 3]].view( + marlin_scales.size(0), -1 + ) + + marlin_scales = (marlin_scales * (2**7)).view(torch.int16) << 1 + marlin_scales = marlin_scales.view(torch.float8_e4m3fn) + marlin_scales = marlin_scales[:, 1::2].contiguous() + + return marlin_scales + + +def nvfp4_marlin_process_global_scale(global_scale: torch.Tensor) -> torch.Tensor: + """Pre-adjust global scale with FP4/FP16/BF16 exponent bias for Marlin kernel.""" + assert global_scale.dtype in [ + torch.half, + torch.bfloat16, + ], f"global_scale dtype must be half or bfloat16, got {global_scale.dtype}" + fp4_exponent = 2 + if global_scale.dtype == torch.half: + target_exponent = 5 + elif global_scale.dtype == torch.bfloat16: + target_exponent = 8 + exponent_bias = 2 ** (target_exponent - 1) - 2 ** (fp4_exponent - 1) + return global_scale * (2.0 ** (exponent_bias - 7)) + + +def apply_fp4_marlin_linear( + input: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + weight_global_scale: Optional[torch.Tensor], + workspace: torch.Tensor, + size_n: int, + size_k: int, + bias: Optional[torch.Tensor] = None, + use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT, +) -> torch.Tensor: + """Apply FP4-quantized linear via Marlin kernel (non-Blackwell fallback).""" + reshaped_x = input.reshape(-1, input.shape[-1]) + out_shape = input.shape[:-1] + (size_n,) + + use_atomic_add = should_use_atomic_add_reduce( + m=reshaped_x.size(0), + n=size_n, + k=size_k, + device=input.device, + dtype=input.dtype, + ) + + output = gptq_marlin_gemm( + a=reshaped_x, + c=None, + b_q_weight=weight, + b_scales=weight_scale, + global_scale=weight_global_scale.reshape(-1), + b_zeros=None, + g_idx=None, + perm=None, + workspace=workspace, + b_q_type=scalar_types.float4_e2m1f, + size_m=reshaped_x.size(0), + size_n=size_n, + size_k=size_k, + use_atomic_add=use_atomic_add, + use_fp32_reduce=use_fp32_reduce, + ) + + if bias is not None: + output.add_(bias) + + return output.reshape(out_shape) + + +def fake_apply_fp4_marlin_linear( + input: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + weight_global_scale: Optional[torch.Tensor], + workspace: torch.Tensor, + size_n: int, + size_k: int, + bias: Optional[torch.Tensor] = None, + use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT, +) -> torch.Tensor: + out_shape = input.shape[:-1] + (size_n,) + return torch.empty(out_shape, dtype=input.dtype, device=input.device) + + +direct_register_custom_op( + op_name="apply_fp4_marlin_linear", + op_func=apply_fp4_marlin_linear, + mutates_args=[], + fake_impl=fake_apply_fp4_marlin_linear, +) + + +def prepare_fp4_layer_for_marlin( + layer: torch.nn.Module, + weight_attr: str = "weight", + weight_scale_attr: str = "weight_scale", + weight_global_scale_attr: str = "weight_global_scale", +) -> None: + """Repack NVFP4 linear layer weights into Marlin format in-place.""" + logger.warning_once( + "Your GPU does not have native support for FP4 computation but " + "FP4 quantization is being used. Weight-only FP4 compression will " + "be used leveraging the Marlin kernel. This may degrade " + "performance for compute-heavy workloads." + ) + + part_size_n = layer.output_size_per_partition + part_size_k = layer.input_size_per_partition + param_dtype = layer.params_dtype + + weight = getattr(layer, weight_attr) + assert weight.shape == (part_size_n, part_size_k // 2), ( + f"Expected {weight_attr} shape ({part_size_n}, {part_size_k // 2}), " + f"got {weight.shape}" + ) + + device = weight.device + + # WORKSPACE + layer.marlin_workspace = marlin_make_workspace(device) + + # WEIGHT: repack from NVFP4 native layout to Marlin tile layout + perm = torch.empty(0, dtype=torch.int, device=device) + qweight = weight.data.view(torch.int32).T.contiguous() + del weight + marlin_qweight = gptq_marlin_repack( + b_q_weight=qweight, + perm=perm, + size_k=part_size_k, + size_n=part_size_n, + num_bits=4, + ) + del qweight + setattr(layer, weight_attr, torch.nn.Parameter(marlin_qweight, requires_grad=False)) + + # WEIGHT SCALES: transpose, permute, convert to FP8-S0E5M3 + weight_scale = getattr(layer, weight_scale_attr) + weight_scale = weight_scale.data.T.contiguous().to(param_dtype) + weight_scale = marlin_permute_scales( + s=weight_scale, + size_k=part_size_k, + size_n=part_size_n, + group_size=FP4_MARLIN_GROUP_SIZE, + ) + weight_scale = nvfp4_marlin_process_scales(weight_scale) + setattr( + layer, weight_scale_attr, torch.nn.Parameter(weight_scale, requires_grad=False) + ) + + # GLOBAL SCALE: Pre-adjust exponent bias for Marlin kernel. + weight_global_scale = getattr(layer, weight_global_scale_attr) + weight_global_scale = weight_global_scale.to(param_dtype) + weight_global_scale = nvfp4_marlin_process_global_scale(weight_global_scale) + setattr( + layer, + weight_global_scale_attr, + torch.nn.Parameter(weight_global_scale, requires_grad=False), + ) + + # BIAS (if present): Permute for Marlin's fast access pattern + if hasattr(layer, "bias") and layer.bias is not None: + assert layer.bias.shape == (part_size_n,) + bias = marlin_permute_bias(layer.bias) + layer.bias = torch.nn.Parameter(bias, requires_grad=False) + + +def prepare_moe_fp4_layer_for_marlin(layer: torch.nn.Module) -> None: + """Repack NVFP4 MoE weights into Marlin format in-place (per-expert).""" + logger.warning_once( + "Your GPU does not have native support for FP4 computation but " + "FP4 quantization is being used. Weight-only FP4 compression will " + "be used leveraging the Marlin kernel for MoE layers. This may " + "degrade performance for compute-heavy workloads." + ) + + e = layer.num_local_experts + k = layer.w13_weight.shape[2] * 2 # hidden_size (packed: K//2 per uint8) + n = layer.intermediate_size_per_partition + param_dtype = layer.params_dtype + num_shards = 2 if layer.moe_runner_config.is_gated else 1 + + device = layer.w13_weight.device + perm = torch.empty(0, dtype=torch.int, device=device) + + # (size_n, size_k) for each projection + sizes = {"w13": (n * num_shards, k), "w2": (k, n)} + + # --- WEIGHT REPACKING --- + for name in ["w13_weight", "w2_weight"]: + prefix = name.split("_")[0] # "w13" or "w2" + size_n, size_k = sizes[prefix] + weight = getattr(layer, name) + + assert weight.shape == (e, size_n, size_k // 2), ( + f"Expected {name} shape ({e}, {size_n}, {size_k // 2}), " + f"got {weight.shape}" + ) + + repacked = [] + for i in range(e): + qweight = weight.data[i].view(torch.int32).T.contiguous() + repacked.append( + gptq_marlin_repack( + b_q_weight=qweight, + perm=perm, + size_k=size_k, + size_n=size_n, + num_bits=4, + ) + ) + + del weight + setattr( + layer, name, torch.nn.Parameter(torch.stack(repacked), requires_grad=False) + ) + + # --- WEIGHT SCALE PROCESSING --- + for prefix in ["w13", "w2"]: + size_n, size_k = sizes[prefix] + scales = getattr(layer, prefix + "_weight_scale").to(param_dtype) + global_scale = getattr(layer, prefix + "_weight_scale_2").to(param_dtype) + + processed = [] + for i in range(e): + s = marlin_permute_scales( + s=scales.data[i].T, + size_k=size_k, + size_n=size_n, + group_size=FP4_MARLIN_GROUP_SIZE, + ) + processed.append(nvfp4_marlin_process_scales(s)) + + del scales + setattr( + layer, + prefix + "_weight_scale", + torch.nn.Parameter(torch.stack(processed), requires_grad=False), + ) + + if global_scale.dim() > 1: + global_scale = global_scale.max(dim=-1).values + global_scale = nvfp4_marlin_process_global_scale(global_scale) + setattr( + layer, + prefix + "_weight_scale_2", + torch.nn.Parameter(global_scale, requires_grad=False), + ) diff --git a/python/sglang/srt/layers/quantization/modelopt_quant.py b/python/sglang/srt/layers/quantization/modelopt_quant.py index dc4fa71fcd20..23d8f3f7b3c8 100755 --- a/python/sglang/srt/layers/quantization/modelopt_quant.py +++ b/python/sglang/srt/layers/quantization/modelopt_quant.py @@ -40,6 +40,11 @@ is_blackwell_supported, ) from sglang.srt.layers.quantization.kv_cache import BaseKVCacheMethod +from sglang.srt.layers.quantization.marlin_utils_fp4 import ( + prepare_fp4_layer_for_marlin, + prepare_moe_fp4_layer_for_marlin, + should_use_fp4_marlin_fallback, +) from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod from sglang.srt.layers.quantization.utils import ( convert_to_channelwise, @@ -86,13 +91,19 @@ enable_flashinfer_fp4_gemm = True except ImportError: - if is_cuda(): - from sglang.jit_kernel.nvfp4 import cutlass_scaled_fp4_mm as cutlass_fp4_gemm enable_flashinfer_fp4_gemm = False reorder_rows_for_gated_act_gemm = None shuffle_matrix_a = None shuffle_matrix_sf_a = None +if is_cuda(): + try: + from sglang.jit_kernel.nvfp4 import cutlass_scaled_fp4_mm as cutlass_fp4_gemm + except ImportError: + cutlass_fp4_gemm = None +else: + cutlass_fp4_gemm = None + try: from flashinfer.fused_moe import cutlass_fused_moe as flashinfer_cutlass_fused_moe from flashinfer.fused_moe.core import ActivationType @@ -134,7 +145,15 @@ def fp4_gemm( out_features: int, ) -> torch.Tensor: fp4_backend = get_fp4_gemm_runner_backend() - if enable_flashinfer_fp4_gemm: + if fp4_backend.is_cutlass() and cutlass_fp4_gemm is not None: + # flashinfer.fp4_quantize returns scale factors as uint8 (e4m3fn bits + # stored in uint8 memory). The JIT kernel requires float8_e4m3fn dtype. + if input_sf.dtype != torch.float8_e4m3fn: + input_sf = input_sf.view(torch.float8_e4m3fn) + if weight_sf.dtype != torch.float8_e4m3fn: + weight_sf = weight_sf.view(torch.float8_e4m3fn) + return cutlass_fp4_gemm(input, weight, input_sf, weight_sf, alpha, out_dtype) + elif enable_flashinfer_fp4_gemm: # Use the remapping logic to convert SGLang backend names to FlashInfer API names backend = fp4_backend.get_flashinfer_backend() return flashinfer_fp4_gemm( @@ -1128,7 +1147,7 @@ def get_supported_act_dtypes(cls) -> List[torch.dtype]: @classmethod def get_min_capability(cls) -> int: - return 100 + return 75 # SM75+ (Turing) supports Marlin FP4 fallback; SM100 for native FP4 @staticmethod def common_group_size(cfg: dict) -> int: @@ -1259,7 +1278,7 @@ def get_quant_method(self, layer: torch.nn.Module, prefix: str): layer, prefix, Linear=ModelOptFp4LinearMethod, - Moe=ModelOptNvFp4FusedMoEMethod, # FlashInferFP4MoE needs the same quantization method but with compatible attribute handling + Moe=ModelOptNvFp4FusedMoEMethod, ) @@ -1302,6 +1321,7 @@ def create_weights( weight_loader = extra_weight_attrs.get("weight_loader") layer.logical_widths = output_partition_sizes + layer.params_dtype = params_dtype layer.input_size_per_partition = input_size_per_partition layer.output_size_per_partition = output_size_per_partition @@ -1356,6 +1376,20 @@ def create_weights( layer.register_parameter("weight_scale", weight_scale) def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + if should_use_fp4_marlin_fallback(): + # Marlin FP4 fallback: consolidate global scale then repack weights + weight_scale_2 = layer.weight_scale_2.max().to(torch.float32) + layer.weight_scale_2_marlin = Parameter(weight_scale_2, requires_grad=False) + prepare_fp4_layer_for_marlin( + layer, + weight_attr="weight", + weight_scale_attr="weight_scale", + weight_global_scale_attr="weight_scale_2_marlin", + ) + layer.use_marlin_fallback = True + return + + layer.use_marlin_fallback = False input_scale_2 = layer.input_scale.max().to(torch.float32) weight_scale_2 = layer.weight_scale_2.max().to(torch.float32) @@ -1456,6 +1490,18 @@ def apply( x: torch.Tensor, bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: + if layer.use_marlin_fallback: + return torch.ops.sglang.apply_fp4_marlin_linear( + input=x, + weight=layer.weight, + weight_scale=layer.weight_scale, + weight_global_scale=layer.weight_scale_2_marlin, + workspace=layer.marlin_workspace, + size_n=layer.output_size_per_partition, + size_k=layer.input_size_per_partition, + bias=bias, + ) + output_dtype = x.dtype x_m, _ = x.shape @@ -1478,7 +1524,10 @@ def apply( w = layer.weight w_scale_interleaved = layer.weight_scale_interleaved - if enable_flashinfer_fp4_gemm: + if ( + enable_flashinfer_fp4_gemm + and not get_fp4_gemm_runner_backend().is_cutlass() + ): w = layer.weight.T w_scale_interleaved = layer.weight_scale_interleaved.T @@ -1509,15 +1558,24 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase): def __init__(self, quant_config: ModelOptFp4Config): self.quant_config = quant_config - if not is_blackwell_supported(): + + if should_use_fp4_marlin_fallback(): + logger.warning_once( + "GPU is not Blackwell (SM100+). Using Marlin FP4 fallback kernel " + "for MoE layers. Weights remain compressed in FP4 format." + ) + self.use_marlin_fallback = True + self.enable_flashinfer_trtllm_moe = False + elif not is_blackwell_supported(): raise ValueError( "Current platform does not support NVFP4" - " quantization. Please use Blackwell and" - " above." + " quantization. Please use SM75+ (Turing or newer)." + ) + else: + self.use_marlin_fallback = False + self.enable_flashinfer_trtllm_moe = ( + get_moe_runner_backend().is_flashinfer_trtllm() ) - self.enable_flashinfer_trtllm_moe = ( - get_moe_runner_backend().is_flashinfer_trtllm() - ) self._cache_permute_indices = {} @property @@ -1671,6 +1729,9 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: Only supports pre-quantized checkpoints with FP8 weights and scales. """ + if self.use_marlin_fallback: + prepare_moe_fp4_layer_for_marlin(layer) + return # GEMM 1 scale processing if layer.moe_runner_config.is_gated: @@ -1883,6 +1944,9 @@ def create_moe_runner( self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig ): self.moe_runner_config = moe_runner_config + if self.use_marlin_fallback: + self.runner = MoeRunner(MoeRunnerBackend.MARLIN, moe_runner_config) + return if get_moe_runner_backend().is_flashinfer_trtllm(): self.runner = MoeRunner( MoeRunnerBackend.FLASHINFER_TRTLLM, moe_runner_config @@ -1905,6 +1969,34 @@ def apply( ), f"{activation=} missing from {ACT_STR_TO_TYPE_MAP.keys()=}" moe_runner_config = self.moe_runner_config + # Marlin FP4 fallback path for non-Blackwell GPUs (SM75-SM89) + if self.use_marlin_fallback: + from sglang.srt.layers.moe.moe_runner.marlin import MarlinMoeQuantInfo + + expert_map = None + global_num_experts = -1 + if hasattr(layer, "dispatcher") and hasattr( + layer.dispatcher, "local_expert_mapping" + ): + expert_map = layer.dispatcher.local_expert_mapping + if expert_map is not None: + global_num_experts = moe_runner_config.num_experts + + quant_info = MarlinMoeQuantInfo( + w13_qweight=layer.w13_weight, + w2_qweight=layer.w2_weight, + w13_scales=layer.w13_weight_scale, + w2_scales=layer.w2_weight_scale, + w13_g_idx_sort_indices=None, + w2_g_idx_sort_indices=None, + weight_bits=4, + w13_global_scale=layer.w13_weight_scale_2, + w2_global_scale=layer.w2_weight_scale_2, + expert_map=expert_map, + global_num_experts=global_num_experts, + ) + return self.runner.run(dispatch_output, quant_info) + # FlashInfer TRTLLM FP4 path - layer has shuffled weights only when # backend is flashinfer_trtllm if hasattr(layer, "gemm1_weights_fp4_shuffled"): diff --git a/python/sglang/srt/layers/quantization/modelslim/modelslim.py b/python/sglang/srt/layers/quantization/modelslim/modelslim.py index 84acecccc415..86a7bfbebfbf 100644 --- a/python/sglang/srt/layers/quantization/modelslim/modelslim.py +++ b/python/sglang/srt/layers/quantization/modelslim/modelslim.py @@ -58,13 +58,13 @@ def _rmsnorm_forward_oot( residual: Optional[torch.Tensor] = None, post_residual_addition: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: - from sgl_kernel_npu.norm.add_rmsnorm_bias import add_rmsnorm_bias - if not x.is_contiguous(): x = x.contiguous() if residual is not None: if post_residual_addition is not None: residual = residual + post_residual_addition + from sgl_kernel_npu.norm.add_rmsnorm_bias import add_rmsnorm_bias + out, residual_out = add_rmsnorm_bias( x, residual, diff --git a/python/sglang/srt/layers/quantization/unquant.py b/python/sglang/srt/layers/quantization/unquant.py index 9fcb88e541f7..94f9a1375c14 100644 --- a/python/sglang/srt/layers/quantization/unquant.py +++ b/python/sglang/srt/layers/quantization/unquant.py @@ -22,7 +22,7 @@ LinearMethodBase, QuantizeMethodBase, ) -from sglang.srt.layers.utils import MultiPlatformOp +from sglang.srt.layers.utils import MultiPlatformOp, copy_or_rebind_param from sglang.srt.utils import ( cpu_has_amx_support, get_bool_env_var, @@ -233,14 +233,12 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: # because aiter CK kernels don't support all GEMM dimensions _should_use_aiter_moe = _use_aiter and get_moe_runner_backend().is_auto() if _should_use_aiter_moe: - layer.w13_weight = torch.nn.Parameter( - shuffle_weight(layer.w13_weight.data, (16, 16)), - requires_grad=False, + copy_or_rebind_param( + layer, "w13_weight", shuffle_weight(layer.w13_weight.data, (16, 16)) ) torch.cuda.empty_cache() - layer.w2_weight = torch.nn.Parameter( - shuffle_weight(layer.w2_weight.data, (16, 16)), - requires_grad=False, + copy_or_rebind_param( + layer, "w2_weight", shuffle_weight(layer.w2_weight.data, (16, 16)) ) torch.cuda.empty_cache() @@ -317,9 +315,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: for weight_name in ["w13_weight", "w2_weight"]: weight = getattr(layer, weight_name) weight.data = weight.data.transpose(1, 2) - weight.data = npu_format_cast( - weight.data, - ) + weight.data = npu_format_cast(weight.data) return diff --git a/python/sglang/srt/layers/rocm_linear_utils.py b/python/sglang/srt/layers/rocm_linear_utils.py index 6c8a6a367e54..ae31553c95e9 100644 --- a/python/sglang/srt/layers/rocm_linear_utils.py +++ b/python/sglang/srt/layers/rocm_linear_utils.py @@ -1,10 +1,7 @@ import torch from aiter.ops.triton.fused_kv_cache import fused_qk_rope_cat_and_cache_mla from aiter.ops.triton.fused_qk_concat import fused_qk_rope_cat -from aiter.ops.triton.gemm_a16w16 import gemm_a16w16 -from aiter.ops.triton.gemm_a16w16_atomic import gemm_a16w16_atomic - -from sglang.srt.utils import BumpAllocator +from aiter.tuned_gemm import tgemm __all__ = ["fused_qk_rope_cat", "fused_qk_rope_cat_and_cache_mla"] @@ -12,26 +9,9 @@ def aiter_dsv3_router_gemm( hidden_states: torch.Tensor, weight: torch.Tensor, - gemm_output_zero_allocator: BumpAllocator = None, ): - M = hidden_states.shape[0] - N = weight.shape[0] - y = None - - if M <= 256: - # TODO (cagri): convert to bfloat16 as part of another kernel to save time - # for now it is also coupled with zero allocator. - if gemm_output_zero_allocator != None: - y = gemm_output_zero_allocator.allocate(M * N).view(M, N) - else: - y = torch.zeros((M, N), dtype=torch.float32, device=hidden_states.device) - - if y is not None: - logits = gemm_a16w16_atomic(hidden_states, weight, y=y).to(hidden_states.dtype) - else: - logits = gemm_a16w16(hidden_states, weight) - - return logits + """Use aiter tuned GEMM dispatcher (tgemm.mm) to automatically select the GEMM kernel.""" + return tgemm.mm(hidden_states, weight, otype=hidden_states.dtype) def get_dsv3_gemm_output_zero_allocator_size( diff --git a/python/sglang/srt/layers/rotary_embedding/base.py b/python/sglang/srt/layers/rotary_embedding/base.py index 1d17a1ca7f90..99a3f11ca05f 100644 --- a/python/sglang/srt/layers/rotary_embedding/base.py +++ b/python/sglang/srt/layers/rotary_embedding/base.py @@ -39,6 +39,12 @@ if _is_npu: import torch_npu + from sgl_kernel_npu.norm.fused_rope_qk_mqa import fused_rope_qk_mqa + +if _is_hip: + from sglang.srt.layers.attention.utils import ( + fused_qk_rope_reshape_and_cache, + ) class RotaryEmbedding(MultiPlatformOp): @@ -202,9 +208,14 @@ def forward_native( if offsets is not None: positions = positions + offsets + positions = positions.flatten() num_tokens = positions.shape[0] - cos_sin = self.cos_sin_cache.index_select(0, positions) + + if hasattr(self, "sin_cos_cache"): + cos_sin = self.sin_cos_cache + else: + cos_sin = self.cos_sin_cache.index_select(0, positions) cos, sin = cos_sin.chunk(2, dim=-1) query_shape = query.shape @@ -236,14 +247,35 @@ def forward_npu( assert ( fused_set_kv_buffer_arg is None ), "fused_set_kv_buffer_arg is not supported for npu implementation" - if query.dtype == torch.bfloat16 and self.cos_sin_cache.dtype == torch.float: - return self.forward_native(positions, query, key, offsets) + if ( + query.dtype == torch.bfloat16 + and self.cos_sin_cache.dtype == torch.float + or key.ndim == 3 + ): + if hasattr(self, "sin_cos_cache"): + cos_sin = self.sin_cos_cache + else: + cos_sin = self.cos_sin_cache.index_select(0, positions) + + if query.shape[0] * query.shape[1] < 65535: + return fused_rope_qk_mqa( + query, + key, + cos_sin, + self.rotary_dim, + self.is_neox_style, + ) + else: + return self.forward_native(positions, query, key, offsets) if self.is_neox_style: rotary_mode = "half" else: rotary_mode = "interleave" mrope_section = [0, 0, 0] + # The npu_mrope kernel only supports 1D or 2D tensors for query and key. + # Therefore, when their dimensions exceed 2D, we flatten query and key to 2D tensors before computation + # and reshape their original shapes afterward. query_shape = query.shape key_shape = key.shape query = query.reshape(query.shape[0], -1) @@ -296,7 +328,7 @@ def forward_cuda( query: torch.Tensor, key: torch.Tensor, offsets: Optional[torch.Tensor] = None, - fused_set_kv_buffer_arg: Optional[FusedSetKVBufferArg] = None, + fused_set_kv_buffer_arg: Optional[Union[FusedSetKVBufferArg, dict]] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: if not self.use_fallback_kernel: batch_size = positions.size(0) @@ -314,18 +346,48 @@ def forward_cuda( fused_args=fused_set_kv_buffer_arg, ) else: - assert ( - fused_set_kv_buffer_arg is None - ), "save kv cache is not supported for fallback_rotary_embedding." - self.cos_sin_cache = self.cos_sin_cache.to(query.device, dtype=query.dtype) - self.fallback_rotary_embedding( - positions, - query, - key, - self.head_size, - self.cos_sin_cache, - self.is_neox_style, - ) + + if fused_set_kv_buffer_arg is not None and _is_hip: + extra_args = fused_set_kv_buffer_arg + + k_cache_shape = fused_set_kv_buffer_arg["key_cache"].shape + qk_head_dim = k_cache_shape[-1] + tp_k_head_num = k_cache_shape[-2] + + key = key.view(-1, tp_k_head_num, qk_head_dim) + + tokens = key.shape[0] + + query = query.view(tokens, -1, qk_head_dim) + + query, key, k_cache, v_cache = fused_qk_rope_reshape_and_cache( + q=query, + k=key, + pos=positions, + cos_sin=self.cos_sin_cache, + is_neox=self.is_neox_style, + flash_layout=True, + offs=None, + q_out=query, + k_out=key, + output_zeros=False, + **extra_args, + ) + else: + assert ( + fused_set_kv_buffer_arg is None + ), "save kv cache is not supported for fallback_rotary_embedding." + self.cos_sin_cache = self.cos_sin_cache.to( + query.device, dtype=query.dtype + ) + self.fallback_rotary_embedding( + positions, + query, + key, + self.head_size, + self.cos_sin_cache, + self.is_neox_style, + ) return query, key def extra_repr(self) -> str: diff --git a/python/sglang/srt/layers/rotary_embedding/factory.py b/python/sglang/srt/layers/rotary_embedding/factory.py index e95e9543f7f6..27e28577c96e 100644 --- a/python/sglang/srt/layers/rotary_embedding/factory.py +++ b/python/sglang/srt/layers/rotary_embedding/factory.py @@ -171,6 +171,9 @@ def get_rope( dtype, mrope_section=rope_scaling["mrope_section"], mrope_interleaved=rope_scaling.get("mrope_interleaved", False), + mrope_interleaved_glm=rope_scaling.get( + "mrope_interleaved_glm", False + ), ) elif rope_scaling.get("use_fope", False): rotary_emb = FourierRotaryEmbedding( diff --git a/python/sglang/srt/layers/rotary_embedding/mrope.py b/python/sglang/srt/layers/rotary_embedding/mrope.py index 237528fd1d47..9c93ad1ffd21 100644 --- a/python/sglang/srt/layers/rotary_embedding/mrope.py +++ b/python/sglang/srt/layers/rotary_embedding/mrope.py @@ -52,12 +52,14 @@ def __init__( dtype: torch.dtype, mrope_section: Optional[List[int]] = None, mrope_interleaved: bool = False, + mrope_interleaved_glm: bool = False, ) -> None: super().__init__( head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype ) self.mrope_section = mrope_section self.mrope_interleaved = mrope_interleaved + self.mrope_interleaved_glm = mrope_interleaved_glm if self.mrope_section: expected_sum = rotary_dim // 2 actual_sum = sum(self.mrope_section) @@ -86,6 +88,37 @@ def __init__( f"Corrected mrope_section: {self.mrope_section} (sum={sum(self.mrope_section)})" ) + # MRoPE axis_map interleaving pattern depends on mrope_section sizes. + # The algorithm cycles through axes [0(T), 1(H), 2(W)] round-robin, + # skipping any axis that has exhausted its allocated pairs. + # + # For GLM-V (mrope_section=[8,12,12]): + # T(8) < H(12) = W(12), so T exhausts first at pair 24. + # Result: [0,1,2, 0,1,2, 0,1,2, 0,1,2, 0,1,2, 0,1,2, 0,1,2, 0,1,2, 1,1,2, 1,1,2, 2,2] + # After T runs out, only H and W fill the remaining slots. + # + # For Qwen3-VL (mrope_section=[24,20,20]): + # T(24) > H(20) = W(20), so H and W exhaust first near the tail. + # Result: [0,1,2, 0,1,2, ...repeated evenly..., 0,1, 0,1, 0,0] + # After H/W run out, T fills the remaining slots. + + if self.mrope_interleaved_glm: + num_pairs = rotary_dim // 2 + axis_map = torch.empty(num_pairs, dtype=torch.long) + assert sum(self.mrope_section) == num_pairs + counts = [0, 0, 0] + current_ax = 0 + + for i in range(num_pairs): + current_ax = i % 3 + while counts[current_ax] >= self.mrope_section[current_ax]: + current_ax = (current_ax + 1) % 3 + + axis_map[i] = current_ax + counts[current_ax] += 1 + self.register_buffer("axis_map", axis_map, persistent=False) + else: + self.axis_map = None if get_global_server_args().rl_on_policy_target is not None: self._forward_method = self.forward_native @@ -214,7 +247,9 @@ def forward_triton( self.head_size, self.rotary_dim, self.mrope_interleaved, + self.mrope_interleaved_glm, self.is_neox_style, + self.axis_map, ) return query, key diff --git a/python/sglang/srt/layers/rotary_embedding/triton_kernels.py b/python/sglang/srt/layers/rotary_embedding/triton_kernels.py index 9a3d21bf83bb..0a8dc2c33c7b 100644 --- a/python/sglang/srt/layers/rotary_embedding/triton_kernels.py +++ b/python/sglang/srt/layers/rotary_embedding/triton_kernels.py @@ -29,7 +29,9 @@ def _triton_mrope_forward_fused( mrope_section_h: tl.constexpr, mrope_section_w: tl.constexpr, is_interleaved: tl.constexpr, + is_interleaved_glm: tl.constexpr, is_neox_style: tl.constexpr, + axis_map_ptr, ): pid = tl.program_id(0) q_ptr = q_ptr + pid * q_stride @@ -46,9 +48,15 @@ def _triton_mrope_forward_fused( w_sin = w_cos + half_rd cos_offsets = tl.arange(0, pad_hd // 2) if is_interleaved: - h_mask = ((cos_offsets % 3) == 1) & (cos_offsets <= 3 * mrope_section_h) - w_mask = ((cos_offsets % 3) == 2) & (cos_offsets <= 3 * mrope_section_w) - t_mask = ~(h_mask | w_mask) + if is_interleaved_glm: + axes = tl.load(axis_map_ptr + cos_offsets, mask=cos_offsets < (pad_hd // 2)) + t_mask = axes == 0 + h_mask = axes == 1 + w_mask = axes == 2 + else: + h_mask = ((cos_offsets % 3) == 1) & (cos_offsets <= 3 * mrope_section_h) + w_mask = ((cos_offsets % 3) == 2) & (cos_offsets <= 3 * mrope_section_w) + t_mask = ~(h_mask | w_mask) else: t_end = mrope_section_t h_end = t_end + mrope_section_h @@ -109,7 +117,9 @@ def triton_mrope_fused( head_size: int, rotary_dim: int, mrope_interleaved: bool, + mrope_interleaved_glm: bool, is_neox_style: bool, + axis_map: torch.Tensor, ) -> None: num_tokens, n_q_dim = q.shape n_k_dim = k.shape[1] @@ -137,7 +147,9 @@ def triton_mrope_fused( mrope_section[1], mrope_section[2], mrope_interleaved, + mrope_interleaved_glm, is_neox_style, + axis_map, ) diff --git a/python/sglang/srt/layers/rotary_embedding/utils.py b/python/sglang/srt/layers/rotary_embedding/utils.py index 8d9defabe1f0..882c68241427 100644 --- a/python/sglang/srt/layers/rotary_embedding/utils.py +++ b/python/sglang/srt/layers/rotary_embedding/utils.py @@ -7,9 +7,11 @@ import torch -from sglang.srt.utils import get_compiler_backend, is_npu +from sglang.srt.utils import cpu_has_amx_support, get_compiler_backend, is_cpu, is_npu _is_npu = is_npu() +_is_cpu = is_cpu() +_is_cpu_amx_available = cpu_has_amx_support() if _is_npu: import torch_npu @@ -128,5 +130,7 @@ def apply_rotary_pos_emb_npu( if _is_npu: apply_rotary_pos_emb = apply_rotary_pos_emb_npu +elif _is_cpu and _is_cpu_amx_available: + apply_rotary_pos_emb = torch.ops.sgl_kernel.apply_rotary_pos_emb_cpu else: apply_rotary_pos_emb = apply_rotary_pos_emb_native diff --git a/python/sglang/srt/layers/sampler.py b/python/sglang/srt/layers/sampler.py index a4c7c7db037e..4196787820f4 100644 --- a/python/sglang/srt/layers/sampler.py +++ b/python/sglang/srt/layers/sampler.py @@ -18,6 +18,7 @@ from sglang.srt.server_args import get_global_server_args from sglang.srt.utils.common import crash_on_warnings, get_bool_env_var, is_cuda, is_npu +_use_fused_sampling = False if is_cuda(): from flashinfer.sampling import ( min_p_sampling_from_probs, @@ -27,6 +28,15 @@ top_k_renorm_prob, top_p_renorm_prob, ) + + from sglang.srt.layers.fused_sampling import fused_temperature_softmax_inplace + + _use_fused_sampling = True + +# Batch size threshold for fused Triton kernel vs PyTorch softmax. +# Below this threshold, PyTorch's native div+softmax is faster. +# At and above this threshold, the fused Triton kernel wins. +_FUSED_SAMPLING_BATCH_THRESHOLD = 128 if is_npu(): import torch_npu @@ -152,11 +162,20 @@ def forward( logprobs = logprobs_via_logsoftmax_kernel else: # Standard path: do softmax and sample from probs. - logits.div_(sampling_info.temperatures) - - # In-place op to save memory - logits[:] = torch.softmax(logits, dim=-1) - probs = logits + # Use fused Triton kernel for large batches where it excels; + # fall back to PyTorch for small batches where launch overhead dominates. + if ( + _use_fused_sampling + and logits.shape[0] >= _FUSED_SAMPLING_BATCH_THRESHOLD + ): + fused_temperature_softmax_inplace( + logits, sampling_info.temperatures + ) + probs = logits + else: + logits.div_(sampling_info.temperatures) + logits[:] = torch.softmax(logits, dim=-1) + probs = logits batch_next_token_ids = self._sample_from_probs( probs, sampling_info, positions, simple_sampling_case @@ -327,13 +346,15 @@ def _attach_logprobs_to_output( ( logits_output.next_token_top_logprobs_val, logits_output.next_token_top_logprobs_idx, - ) = get_top_logprobs(logprobs, top_logprobs_nums) + ) = get_top_logprobs(logprobs, top_logprobs_nums, no_copy_to_cpu=True) if any(x is not None for x in token_ids_logprobs): ( logits_output.next_token_token_ids_logprobs_val, logits_output.next_token_token_ids_logprobs_idx, - ) = get_token_ids_logprobs(logprobs, token_ids_logprobs) + ) = get_token_ids_logprobs( + logprobs, token_ids_logprobs, no_copy_to_cpu=True + ) logits_output.next_token_logprobs = logprobs[ torch.arange(len(batch_next_token_ids), device=sampling_info.device), @@ -397,7 +418,7 @@ def compute_logprobs_only( ( logits_output.next_token_top_logprobs_val, logits_output.next_token_top_logprobs_idx, - ) = get_top_logprobs(logprobs, top_logprobs_nums) + ) = get_top_logprobs(logprobs, top_logprobs_nums, no_copy_to_cpu=True) # Handle token_ids logprobs if requested if needs_token_ids_logprobs: diff --git a/python/sglang/srt/lora/layers.py b/python/sglang/srt/lora/layers.py index 21ad10447b46..9f2ad1352805 100644 --- a/python/sglang/srt/lora/layers.py +++ b/python/sglang/srt/lora/layers.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Dict, Optional, Union import torch import torch.nn.functional as F @@ -629,7 +629,6 @@ def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor return lora_output def forward(self, input_: torch.Tensor, skip_all_reduce=False): - # duplicate the logic in RowParallelLinear if self.base_layer.input_is_parallel: input_parallel = input_ else: @@ -638,8 +637,14 @@ def forward(self, input_: torch.Tensor, skip_all_reduce=False): input_, num_partitions=self.base_layer.tp_size ) input_parallel = splitted_input[tp_rank].contiguous() + + bias_ = ( + None + if (self.base_layer.tp_rank > 0 or self.base_layer.skip_bias_add) + else self.base_layer.bias + ) output_parallel = self.base_layer.quant_method.apply( - self.base_layer, input_parallel + self.base_layer, input_parallel, bias=bias_ ) should_reduce = ( @@ -668,17 +673,8 @@ def forward(self, input_: torch.Tensor, skip_all_reduce=False): else: output_ = output_parallel - if not self.base_layer.skip_bias_add: - output = ( - output_ + self.base_layer.bias - if self.base_layer.bias is not None - else output_ - ) - output_bias = None - else: - output = output_ - output_bias = self.base_layer.bias - return output, output_bias + output_bias = self.base_layer.bias if self.base_layer.skip_bias_add else None + return output_, output_bias def slice_lora_a_weights(self, A: torch.Tensor, tp_rank: int): shard_size = self.base_layer.input_size_per_partition @@ -711,11 +707,17 @@ def __init__( # initializes FusedMoE with its own moe_runner for base path super().__init__(base_layer, lora_backend) + self.experts_shared_outer_loras: bool = False + self.quant_method = base_layer.quant_method + self.tp_size = getattr(base_layer, "moe_tp_size", 1) self.tp_rank = getattr(base_layer, "moe_tp_rank", 0) self.intermediate_size_per_partition = getattr( base_layer, "intermediate_size_per_partition", None ) + self._uses_interleaved_gate_up = ( + getattr(base_layer.moe_runner_config, "gemm1_alpha", None) is not None + ) # initialize triton_lora moe runner for batches with lora enabled from sglang.srt.layers.moe.moe_runner.runner import MoeRunner @@ -782,6 +784,7 @@ def _get_lora_info(self): adapter_enabled=adapter_enabled, max_lora_rank=max_lora_rank, num_experts=self.base_layer.num_experts, + experts_shared_outer_loras=self.experts_shared_outer_loras, tp_size=self.tp_size, tp_rank=self.tp_rank, hidden_size=getattr(self.base_layer, "hidden_size", 0), @@ -839,34 +842,82 @@ def slice_lora_b_weights(self, B: torch.Tensor, tp_rank: int): return B def slice_moe_lora_a_weights( - self, A: torch.Tensor, tp_rank: int, target_module: str - ) -> torch.Tensor: + self, + A: Union[torch.Tensor, Dict[int, torch.Tensor]], + tp_rank: int, + target_module: str, + ): """Slice LoRA A weights for MoE with TP. + Accepts: + - 2D tensor [rank, hidden] (single expert) + - 3D tensor [num_experts_or_1, rank, hidden] + - dict {expert_id: 2D tensor} + Per-expert weight shapes: gate_up_proj_moe A: [rank, hidden_size] — input is full hidden_states, no slice down_proj_moe A: [rank, intermediate_size] — input is sharded intermediate """ if self.tp_size <= 1: return A - if target_module == "down_proj_moe": - shard_size = self.intermediate_size_per_partition - start = tp_rank * shard_size - end = start + shard_size - return A[:, start:end].contiguous() - return A + if target_module != "down_proj_moe": + return A + if isinstance(A, dict): + return { + eid: self._slice_moe_a(w, tp_rank, target_module) + for eid, w in A.items() + } + return self._slice_moe_a(A, tp_rank, target_module) + + def _slice_moe_a( + self, A: torch.Tensor, tp_rank: int, target_module: str + ) -> torch.Tensor: + shard_size = self.intermediate_size_per_partition + start = tp_rank * shard_size + end = start + shard_size + return A[..., start:end].contiguous() def slice_moe_lora_b_weights( - self, B: torch.Tensor, tp_rank: int, target_module: str - ) -> torch.Tensor: + self, + B: Union[torch.Tensor, Dict[int, torch.Tensor]], + tp_rank: int, + target_module: str, + ): """Slice LoRA B weights for MoE with TP. + Accepts: + - 2D tensor [output_dim, rank] (single expert) + - 3D tensor [num_experts_or_1, output_dim, rank] + - dict {expert_id: 2D tensor} + Per-expert weight shapes: gate_up_proj_moe B: [intermediate_size*2, rank] — output matches sharded base w13 down_proj_moe B: [hidden_size, rank] — output is all-reduced, no slice """ - if self.tp_size <= 1: + needs_processing = (self.tp_size > 1) or ( + target_module == "gate_up_proj_moe" and self._uses_interleaved_gate_up + ) + if not needs_processing: + return B + if target_module != "gate_up_proj_moe": return B + if isinstance(B, dict): + return { + eid: self._slice_moe_b_2d(w, tp_rank, target_module) + for eid, w in B.items() + } + if isinstance(B, torch.Tensor) and B.dim() == 3: + return torch.stack( + [ + self._slice_moe_b_2d(B[i], tp_rank, target_module) + for i in range(B.shape[0]) + ] + ) + return self._slice_moe_b_2d(B, tp_rank, target_module) + + def _slice_moe_b_2d( + self, B: torch.Tensor, tp_rank: int, target_module: str + ) -> torch.Tensor: if target_module == "gate_up_proj_moe": shard_size = self.intermediate_size_per_partition start = tp_rank * shard_size @@ -874,6 +925,8 @@ def slice_moe_lora_b_weights( full_inter = B.shape[0] // 2 gate_b = B[start:end, :] up_b = B[full_inter + start : full_inter + end, :] + if self._uses_interleaved_gate_up: + return torch.stack([gate_b, up_b], dim=1).reshape(-1, B.shape[-1]) return torch.cat([gate_b, up_b], dim=0).contiguous() return B diff --git a/python/sglang/srt/lora/lora.py b/python/sglang/srt/lora/lora.py index 8ccb674f9195..fbae373cbfce 100644 --- a/python/sglang/srt/lora/lora.py +++ b/python/sglang/srt/lora/lora.py @@ -128,8 +128,8 @@ def _process_weight(self, name: str, loaded_weight: torch.Tensor): # added/extra token emb self.added_tokens_embeddings[name] = loaded_weight.cpu() assert loaded_weight.shape[0] == self.config.lora_added_tokens_size, ( - f"LoRA adapter {self.uid} has extra_vocab_size {self.config.extra_vocab_size} specified in the config, " - f"but the loaded weight has {loaded_weight.shape[0]} extra vocab size" + f"LoRA adapter {self.uid} has lora_added_tokens_size {self.config.lora_added_tokens_size} specified in the config, " + f"but the loaded weight '{name}' has shape {loaded_weight.shape[0]} in first dimension" ) def _normalize_weights(self): @@ -137,6 +137,8 @@ def _normalize_weights(self): for layer in self.layers: weight_names = list(layer.weights.keys()) self.normalize_qkv_proj(weight_names, layer.weights) + self._rename_expert_w_to_proj(layer.weights) + weight_names = list(layer.weights.keys()) self.normalize_gate_up_proj(weight_names, layer.weights) def normalize_qkv_proj( @@ -192,6 +194,23 @@ def normalize_qkv_proj( weights[qkv_name] = weights[qkv_name].repeat(3, 1) # else: no-op as LoRA B weight is already stacked. + def _rename_expert_w_to_proj(self, weights: Dict[str, torch.Tensor]): + """Rename w1 -> gate_proj, w3 -> up_proj, w2 -> down_proj so that + normalize_gate_up_proj can stack them into gate_up_proj.""" + renames = {} + for name in list(weights.keys()): + new_name = name + if ".w1." in name: + new_name = name.replace(".w1.", ".gate_proj.") + elif ".w3." in name: + new_name = name.replace(".w3.", ".up_proj.") + elif ".w2." in name: + new_name = name.replace(".w2.", ".down_proj.") + if new_name != name: + renames[name] = new_name + for old_name, new_name in renames.items(): + weights[new_name] = weights.pop(old_name) + def normalize_gate_up_proj( self, weight_names: List[str], weights: Dict[str, torch.Tensor] ): @@ -206,8 +225,9 @@ def normalize_gate_up_proj( f"Received backend: {self.lora_backend.name}. Please verify your backend configuration " f"or consider implementing custom initialization logic for other backends." ) + cat_dim = weights[weight_name].dim() - 2 weights[gate_up_name] = torch.cat( - (weights[weight_name], weights[up_name]), 0 + (weights[weight_name], weights[up_name]), cat_dim ) weights.pop(weight_name) if up_name in weights: @@ -216,7 +236,10 @@ def normalize_gate_up_proj( # If gate_up_proj is already stacked, we normalize it following the SGL convention gate_up_name = weight_name if "lora_A" in weight_name: - weights[gate_up_name] = weights[gate_up_name].repeat(2, 1) + ndim = weights[gate_up_name].dim() + repeat_dims = [1] * ndim + repeat_dims[ndim - 2] = 2 + weights[gate_up_name] = weights[gate_up_name].repeat(*repeat_dims) # else: no-op as LoRA B weight is already stacked. def pin_weights_in_cpu(self): diff --git a/python/sglang/srt/lora/lora_config.py b/python/sglang/srt/lora/lora_config.py index 939a9331111b..917feef155cb 100644 --- a/python/sglang/srt/lora/lora_config.py +++ b/python/sglang/srt/lora/lora_config.py @@ -13,11 +13,14 @@ # ============================================================================== import json +import logging import os from typing import Dict, Optional from huggingface_hub import snapshot_download +logger = logging.getLogger(__name__) + class LoRAConfig: def __init__( @@ -25,6 +28,7 @@ def __init__( path: Optional[str] = None, config_dict: Optional[Dict] = None, added_tokens_config: Optional[Dict] = None, + base_vocab_size: Optional[int] = None, ) -> None: self.path = path @@ -38,17 +42,41 @@ def __init__( self.target_modules = self.hf_config["target_modules"] self.r = self.hf_config["r"] self.lora_alpha = self.hf_config["lora_alpha"] + + # Filter fake added tokens: tokens with ID < base_vocab_size are already + # part of the base vocabulary and should not be treated as added tokens. + # This commonly happens when added_tokens.json is copied from the base + # model's tokenizer. + if self.added_tokens_config and base_vocab_size is not None: + self.added_tokens_config = { + token: token_id + for token, token_id in self.added_tokens_config.items() + if token_id >= base_vocab_size + } + self.lora_added_tokens_size = ( len(self.added_tokens_config) if self.added_tokens_config is not None else 0 ) + if self.lora_added_tokens_size > 0: + raise ValueError( + f"LoRA adapter has {self.lora_added_tokens_size} added tokens, " + f"but added tokens are not supported yet. " + f"Added tokens: {self.added_tokens_config}" + ) + @classmethod def from_dict( cls, config_dict: Dict, added_tokens_config: Optional[Dict] = None, + base_vocab_size: Optional[int] = None, ) -> "LoRAConfig": - return cls(config_dict=config_dict, added_tokens_config=added_tokens_config) + return cls( + config_dict=config_dict, + added_tokens_config=added_tokens_config, + base_vocab_size=base_vocab_size, + ) def get_lora_config(self, dummy=False): if dummy: @@ -82,9 +110,5 @@ def get_added_tokens_config(self): with open(added_tokens_path, "r") as f: return json.load(f) except json.JSONDecodeError as e: - # Log warning but don't crash if JSON is malformed - import logging - - logger = logging.getLogger(__name__) logger.warning(f"Failed to parse added_tokens.json: {e}") return None diff --git a/python/sglang/srt/lora/lora_manager.py b/python/sglang/srt/lora/lora_manager.py index 73f6bc23544e..c704dad27c3c 100644 --- a/python/sglang/srt/lora/lora_manager.py +++ b/python/sglang/srt/lora/lora_manager.py @@ -36,6 +36,7 @@ from sglang.srt.lora.mem_pool import LoRAMemoryPool from sglang.srt.lora.utils import ( LoRAType, + auto_detect_lora_target_modules, get_normalized_target_modules, get_target_module_name, ) @@ -77,8 +78,10 @@ def __init__( server_args.enable_lora_overlap_loading ) - # Store eviction policy from server args self.eviction_policy = server_args.lora_eviction_policy + self._experts_shared_outer_override: Optional[bool] = ( + server_args.experts_shared_outer_loras + ) # LoRA backend for running sgemm kernels logger.info(f"Using {lora_backend} as backend of LoRA kernels.") @@ -133,7 +136,10 @@ def load_lora_adapter(self, lora_ref: LoRARef) -> LoRAUpdateOutput: try: # load configs - new_adapter = LoRAConfig(lora_ref.lora_path) + new_adapter = LoRAConfig( + lora_ref.lora_path, + base_vocab_size=self.base_hf_config.vocab_size, + ) self.validate_new_adapter(new_adapter, lora_ref) self.configs[lora_ref.lora_id] = new_adapter @@ -302,23 +308,33 @@ def update_lora_info(self): if isinstance(module, FusedMoEWithLoRA) and all( x in self.target_modules for x in ["gate_up_proj", "down_proj"] ): + gate_up_key = ( + "gate_up_proj_moe" + if "gate_up_proj_moe" in self.memory_pool.A_buffer + else "gate_up_proj" + ) + down_key = ( + "down_proj_moe" + if "down_proj_moe" in self.memory_pool.A_buffer + else "down_proj" + ) gate_up_a = self.memory_pool.get_tensor( - target_module="gate_up_proj_moe", + target_module=gate_up_key, layer_id=layer_id, lora_type=LoRAType.LORA_A, ) gate_up_b = self.memory_pool.get_tensor( - target_module="gate_up_proj_moe", + target_module=gate_up_key, layer_id=layer_id, lora_type=LoRAType.LORA_B, ) down_a = self.memory_pool.get_tensor( - target_module="down_proj_moe", + target_module=down_key, layer_id=layer_id, lora_type=LoRAType.LORA_A, ) down_b = self.memory_pool.get_tensor( - target_module="down_proj_moe", + target_module=down_key, layer_id=layer_id, lora_type=LoRAType.LORA_B, ) @@ -386,6 +402,16 @@ def init_state( target_modules=target_modules, ) + if self._experts_shared_outer_override is not None: + self.experts_shared_outer_loras = self._experts_shared_outer_override + else: + self.experts_shared_outer_loras = self._detect_shared_outer_loras() + if self.experts_shared_outer_loras: + logger.info( + "Shared outer LoRA mode enabled: gate_up lora_A and " + "down lora_B will be shared across experts (expert_dim=1)." + ) + self.init_lora_modules() self.init_memory_pool() self.update_lora_info() @@ -411,6 +437,26 @@ def init_lora_adapters(self, lora_paths: Optional[List[LoRARef]] = None): f"Failed to load LoRA adapter {lora_ref.lora_name}: {result.error_message}" ) + def _detect_shared_outer_loras(self) -> bool: + """Auto-detect shared outer LoRA format from loaded adapter weights. + + MoE adapters with shared outer experts store 3D tensors where + dim[0]=1 indicates weights shared across all experts, while + dim[0]=num_experts indicates per-expert weights. + Returns True if gate_up lora_A has expert_dim=1 (shared). + """ + for adapter in self.loras.values(): + for layer in adapter.layers: + for name, weight in layer.weights.items(): + if ( + "gate_up_proj" in name + and "lora_A" in name + and weight.dim() == 3 + ): + return weight.shape[0] == 1 + break + return False + def init_lora_shapes( self, max_lora_rank: Optional[int] = None, @@ -424,9 +470,6 @@ def init_lora_shapes( for lora_id, config in self.configs.items(): # Handle PEFT shorthand strings like "all-linear" or "all". - # These cannot be resolved to concrete module names without - # inspecting the base model, so we require the user to specify - # --lora-target-modules explicitly when such shorthands are used. if isinstance(config.target_modules, str): if config.target_modules in ("all-linear", "all"): if target_modules is not None: @@ -434,14 +477,20 @@ def init_lora_shapes( # per-adapter inference for this adapter. continue else: - lora_name = self.lora_refs[lora_id].lora_name - raise ValueError( - f"LoRA adapter '{lora_name}' uses " - f"target_modules='{config.target_modules}' which cannot " - "be resolved automatically. Please explicitly specify " - "--lora-target-modules during server startup. You can " - "specify 'all' to enable all supported module types." + # Resolve by scanning the base model for all + # LoRA-compatible linear modules. + adapter_target_modules = auto_detect_lora_target_modules( + self.base_model + ) + logger.info( + "LoRA adapter '%s' uses target_modules='%s'. " + "Resolved to %s by inspecting the base model.", + self.lora_refs[lora_id].lora_name, + config.target_modules, + sorted(adapter_target_modules), ) + self.target_modules.update(adapter_target_modules) + continue else: raise ValueError( f"SGLang does not recognize target_modules=" @@ -556,7 +605,11 @@ def load_lora_adapter_from_tensors( ), f"LoRA adapter with ID {lora_ref.lora_id} is already loaded. This should have been verified before request is sent to the backend." try: - new_adapter = LoRAConfig.from_dict(config_dict, added_tokens_config) + new_adapter = LoRAConfig.from_dict( + config_dict, + added_tokens_config, + base_vocab_size=self.base_hf_config.vocab_size, + ) self.validate_new_adapter(new_adapter, lora_ref) self.configs[lora_ref.lora_id] = new_adapter @@ -585,6 +638,7 @@ def init_memory_pool(self): base_model=self.base_model, eviction_policy=self.eviction_policy, lora_added_tokens_size=self.lora_added_tokens_size, + experts_shared_outer_loras=self.experts_shared_outer_loras, ) # Initializing memory pool with base model @@ -672,16 +726,17 @@ def init_lora_modules(self): # The module should be converted if it is included in target_names if module_name.split(".")[-1] in self.target_modules: layer_id = get_layer_id(module_name) + if layer_id is None: + continue self.lora_modules[layer_id][module_name] = self.set_lora_module( module_name, module ) continue - # Temporarily workaround for FusedMoE layer if isinstance(module, FusedMoE) and all( x in self.target_modules for x in ["gate_up_proj", "down_proj"] ): layer_id = get_layer_id(module_name) - self.lora_modules[layer_id][module_name] = self.set_lora_module( - module_name, module - ) + lora_module = self.set_lora_module(module_name, module) + lora_module.experts_shared_outer_loras = self.experts_shared_outer_loras + self.lora_modules[layer_id][module_name] = lora_module diff --git a/python/sglang/srt/lora/lora_moe_runners.py b/python/sglang/srt/lora/lora_moe_runners.py index 76ac964f69ac..3060e2fadd19 100644 --- a/python/sglang/srt/lora/lora_moe_runners.py +++ b/python/sglang/srt/lora/lora_moe_runners.py @@ -71,17 +71,22 @@ class LoRAInfo: """LoRA weights and dispatch info for MoE computation.""" - # LoRA weights: [num_loras, num_experts, dim1, dim2] + # LoRA weights: [num_loras, num_experts_or_1, dim1, dim2] + # When experts_shared_outer_loras=True: + # gate_up_lora_a: [num_loras, 1, max_rank, hidden_dim] (shared) + # down_lora_b: [num_loras, 1, hidden_dim, max_rank] (shared) gate_up_lora_a_weights: ( torch.Tensor - ) # [num_loras, num_experts, max_rank, hidden_dim] + ) # [num_loras, num_experts_or_1, max_rank, hidden_dim] gate_up_lora_b_weights: ( torch.Tensor ) # [num_loras, num_experts, gate_up_dim, max_rank] down_lora_a_weights: ( torch.Tensor ) # [num_loras, num_experts, max_rank, intermediate_dim] - down_lora_b_weights: torch.Tensor # [num_loras, num_experts, hidden_dim, max_rank] + down_lora_b_weights: ( + torch.Tensor + ) # [num_loras, num_experts_or_1, hidden_dim, max_rank] # Indice pointers of each segment in shape (num_segments + 1, ) seg_indptr: torch.Tensor @@ -95,6 +100,7 @@ class LoRAInfo: max_lora_rank: int # Maximum LoRA rank across all adapters num_experts: int + experts_shared_outer_loras: bool = False fully_sharded: bool = False tp_size: int = 1 @@ -469,16 +475,11 @@ def _add_lora_gate_up_delta( r = lora_info.max_lora_rank gate_up_a = lora_info.gate_up_lora_a_weights + if lora_info.experts_shared_outer_loras: + gate_up_a = gate_up_a.expand(-1, lora_info.num_experts, -1, -1) gate_up_b = lora_info.gate_up_lora_b_weights inter_size = gate_up_b.shape[2] // 2 - # Split packed gate_up weights into separate gate and up slices. - # gate_up_lora_a has shape [max_loras, num_experts, 2*r, hidden_dim] - # where the first r rows are gate_lora_a and the next r are up_lora_a. - # gate_up_lora_b has shape [max_loras, num_experts, 2*inter_size, r] - # where the first inter_size rows are gate_lora_b and the rest up_lora_b. - # Using num_slices=2 lets the kernel handle gate and up independently, - # keeping the rank dimension at r so shrink and expand both match. lora_a_stacked = [gate_up_a[:, :, :r, :], gate_up_a[:, :, r : 2 * r, :]] lora_b_stacked = [ gate_up_b[:, :, :inter_size, :], @@ -542,8 +543,12 @@ def _add_lora_down_delta( if lora_info.max_lora_rank == 0: return + down_lora_b = lora_info.down_lora_b_weights + if lora_info.experts_shared_outer_loras: + down_lora_b = down_lora_b.expand(-1, lora_info.num_experts, -1, -1) + lora_a_stacked = [lora_info.down_lora_a_weights] - lora_b_stacked = [lora_info.down_lora_b_weights] + lora_b_stacked = [down_lora_b] if lora_info.fully_sharded and lora_info.tp_size > 1: shard_size = lora_info.hidden_size // lora_info.tp_size diff --git a/python/sglang/srt/lora/mem_pool.py b/python/sglang/srt/lora/mem_pool.py index ca3310a9d289..3746692529f2 100644 --- a/python/sglang/srt/lora/mem_pool.py +++ b/python/sglang/srt/lora/mem_pool.py @@ -60,6 +60,7 @@ def __init__( base_model: torch.nn.Module, eviction_policy: str, lora_added_tokens_size: int, + experts_shared_outer_loras: bool = False, ): self.base_hf_config: AutoConfig = base_hf_config self.num_layer: int = base_hf_config.num_hidden_layers @@ -70,6 +71,7 @@ def __init__( self.lora_added_tokens_size: int = lora_added_tokens_size self.max_lora_rank: int = max_lora_rank self.target_modules: Set[str] = target_modules + self.experts_shared_outer_loras: bool = experts_shared_outer_loras # Initialize eviction policy self.eviction_policy = get_eviction_policy(eviction_policy) @@ -140,6 +142,18 @@ def is_moe_module(self, module_name: str) -> bool: """Check if module is part of MoE experts.""" return "moe" in module_name + @staticmethod + def _get_num_experts(base_model: torch.nn.Module) -> int: + cfg = base_model.config + if hasattr(cfg, "get_text_config"): + cfg = cfg.get_text_config() + return ( + getattr(cfg, "num_experts", None) + or getattr(cfg, "num_local_experts", None) + or getattr(cfg, "n_routed_experts", None) + or 1 + ) + def _get_standard_shape( self, module_name: str, @@ -178,10 +192,13 @@ def get_lora_A_shape( input_dim = divide(input_dim, self.tp_size) if self.is_moe_module(module_name): - num_experts = base_model.config.num_experts + num_experts = self._get_num_experts(base_model) + expert_dim = num_experts + if self.experts_shared_outer_loras and module_name == "gate_up_proj_moe": + expert_dim = 1 return ( self.max_loras_per_batch, - num_experts, + expert_dim, max_lora_dim * c, input_dim, ) @@ -228,8 +245,11 @@ def get_lora_B_shape( # Check if MoE module and return appropriate shape if self.is_moe_module(module_name): - num_experts = base_model.config.num_experts - return (self.max_loras_per_batch, num_experts, output_dim, max_lora_dim) + num_experts = self._get_num_experts(base_model) + expert_dim = num_experts + if self.experts_shared_outer_loras and module_name == "down_proj_moe": + expert_dim = 1 + return (self.max_loras_per_batch, expert_dim, output_dim, max_lora_dim) else: return (self.max_loras_per_batch, output_dim, max_lora_dim) @@ -264,37 +284,38 @@ def init_buffer( target_modules: Set[str], get_lora_shape_fn: Callable[[str, torch.nn.Module, int, int], Tuple[int]], ): - # Check if model has both shared experts and MoE experts + cfg = base_model.config + if hasattr(cfg, "get_text_config"): + cfg = cfg.get_text_config() has_shared_experts = ( - hasattr(base_model.config, "shared_expert_intermediate_size") - and base_model.config.shared_expert_intermediate_size > 0 - ) - has_moe = getattr(base_model.config, "num_experts", 1) > 1 + hasattr(cfg, "shared_expert_intermediate_size") + and cfg.shared_expert_intermediate_size > 0 + ) or (getattr(cfg, "n_shared_experts", 0) or 0) > 0 + has_moe = self._get_num_experts(base_model) > 1 # Shape functions automatically handle both 3D (standard) and 4D (MoE) target_modules = target_modules - set(EMBEDDING_NAMES) for module_name in target_modules: # Special handling for ambiguous target modules that can be in different contexts ambiguous_modules = {"gate_up_proj", "down_proj"} - if module_name in ambiguous_modules and has_shared_experts and has_moe: - # Allocate separate buffers for shared and MoE contexts - # Shared expert version (3D) - shared_key = module_name - buffer[shared_key] = [ - torch.empty( - get_lora_shape_fn( - module_name, base_model, self.max_lora_rank, idx - ), - dtype=self.dtype, - device=device, - ) - for idx in range(self.num_layer) - ] + if module_name in ambiguous_modules and has_moe: + # Allocate shared expert version (3D) only when model has shared experts + if has_shared_experts: + buffer[module_name] = [ + torch.zeros( + get_lora_shape_fn( + module_name, base_model, self.max_lora_rank, idx + ), + dtype=self.dtype, + device=device, + ) + for idx in range(self.num_layer) + ] # MoE expert version (4D) moe_key = f"{module_name}_moe" buffer[moe_key] = [ - torch.empty( + torch.zeros( get_lora_shape_fn( moe_key, base_model, self.max_lora_rank, idx ), @@ -306,7 +327,7 @@ def init_buffer( else: # Standard allocation for unambiguous modules buffer[module_name] = [ - torch.empty( + torch.zeros( get_lora_shape_fn( module_name, base_model, @@ -326,7 +347,7 @@ def init_embedding_buffer( ): target_modules = target_modules & set(EMBEDDING_NAMES) for module_name in target_modules: - buffer[module_name] = torch.empty( + buffer[module_name] = torch.zeros( get_lora_shape_fn( module_name, base_model, @@ -338,7 +359,7 @@ def init_embedding_buffer( ) if self.lora_added_tokens_size > 0: - self.new_embeddings_buffer["input_embeddings"] = torch.empty( + self.new_embeddings_buffer["input_embeddings"] = torch.zeros( ( self.max_loras_per_batch, self.lora_added_tokens_size, @@ -521,8 +542,8 @@ def load_lora_weight_tensor( expert_match = re.search(r"experts\.(\d+)\.", name) if expert_match: + # Per-expert MoE weight — 2D tensors, one per expert target_module = target_module + "_moe" - # MoE weight - multiple tensors per module (one per expert) if temp_A_buffer[target_module] is None: temp_A_buffer[target_module] = {} temp_B_buffer[target_module] = {} @@ -532,8 +553,15 @@ def load_lora_weight_tensor( temp_A_buffer[target_module][expert_id] = weights else: temp_B_buffer[target_module][expert_id] = weights + elif "experts" in name and weights.dim() == 3: + # Shared outer MoE weight — 3D tensor [expert_dim, rank, hidden] + target_module = target_module + "_moe" + if "lora_A" in name: + temp_A_buffer[target_module] = weights + else: + temp_B_buffer[target_module] = weights else: - # Standard weight - single tensor per module + # Standard weight — single tensor per module if "lora_A" in name: temp_A_buffer[target_module] = weights else: @@ -549,20 +577,18 @@ def load_lora_weight_tensor( if isinstance(module, FusedMoEWithLoRA): moe_target_modules = ["gate_up_proj_moe", "down_proj_moe"] for target_module in moe_target_modules: - if temp_A_buffer[target_module] is None: - continue - - for expert_id in temp_A_buffer[target_module].keys(): - temp_A_buffer[target_module][expert_id] = ( + if temp_A_buffer.get(target_module) is not None: + temp_A_buffer[target_module] = ( module.slice_moe_lora_a_weights( - temp_A_buffer[target_module][expert_id], + temp_A_buffer[target_module], self.tp_rank, target_module, ) ) - temp_B_buffer[target_module][expert_id] = ( + if temp_B_buffer.get(target_module) is not None: + temp_B_buffer[target_module] = ( module.slice_moe_lora_b_weights( - temp_B_buffer[target_module][expert_id], + temp_B_buffer[target_module], self.tp_rank, target_module, ) @@ -587,22 +613,42 @@ def load_lora_weight_tensor( temp_B_buffer[target_module], self.tp_rank ) - # Load weights into buffers (handles both 3D standard and 4D MoE) for name, weights in temp_A_buffer.items(): c = get_stacked_multiply(name) target_buffer = self.A_buffer[name][layer_id] if name in ["gate_up_proj_moe", "down_proj_moe"]: - # MoE: multiple tensors per module (one per expert) - for expert_id, expert_weight in weights.items(): - # Buffer shape: [num_loras, num_experts, max_rank, hidden_dim] - buffer_view = target_buffer[ - buffer_id, expert_id, : lora_rank * c, : - ] - load_lora_weight_tensor(buffer_view, expert_weight) + if self.experts_shared_outer_loras and name == "gate_up_proj_moe": + if isinstance(weights, torch.Tensor) and weights.dim() == 3: + buffer_view = target_buffer[ + buffer_id, 0, : lora_rank * c, : + ] + load_lora_weight_tensor(buffer_view, weights[0]) + elif isinstance(weights, dict) and len(weights) > 0: + rep = next(iter(weights.values())) + buffer_view = target_buffer[ + buffer_id, 0, : lora_rank * c, : + ] + load_lora_weight_tensor(buffer_view, rep) + else: + raise ValueError( + f"Unexpected weight format for shared outer gate_up_proj_moe lora_A: " + f"type={type(weights)}, " + f"shape={weights.shape if isinstance(weights, torch.Tensor) else 'N/A'}" + ) + elif isinstance(weights, torch.Tensor) and weights.dim() == 3: + for eid in range(weights.shape[0]): + buffer_view = target_buffer[ + buffer_id, eid, : lora_rank * c, : + ] + load_lora_weight_tensor(buffer_view, weights[eid]) + elif isinstance(weights, dict): + for expert_id, expert_weight in weights.items(): + buffer_view = target_buffer[ + buffer_id, expert_id, : lora_rank * c, : + ] + load_lora_weight_tensor(buffer_view, expert_weight) else: - # Standard: single tensor per module - c = get_stacked_multiply(name) buffer_view = target_buffer[buffer_id, : lora_rank * c, :] load_lora_weight_tensor(buffer_view, weights) @@ -610,18 +656,42 @@ def load_lora_weight_tensor( target_buffer = self.B_buffer[name][layer_id] if name in ["gate_up_proj_moe", "down_proj_moe"]: - # MoE: multiple tensors per module (one per expert) - for expert_id, expert_weight in weights.items(): - # Buffer shape: [num_loras, num_experts, intermediate_dim, max_rank] - buffer_view = target_buffer[buffer_id, expert_id, :, :lora_rank] - - weight_to_load = expert_weight - if weight_to_load is not None: - weight_to_load = weight_to_load * lora_adapter.scaling - - load_lora_weight_tensor(buffer_view, weight_to_load) + if self.experts_shared_outer_loras and name == "down_proj_moe": + if isinstance(weights, torch.Tensor) and weights.dim() == 3: + buffer_view = target_buffer[buffer_id, 0, :, :lora_rank] + w = weights[0] + if w is not None: + w = w * lora_adapter.scaling + load_lora_weight_tensor(buffer_view, w) + elif isinstance(weights, dict) and len(weights) > 0: + rep = next(iter(weights.values())) + buffer_view = target_buffer[buffer_id, 0, :, :lora_rank] + if rep is not None: + rep = rep * lora_adapter.scaling + load_lora_weight_tensor(buffer_view, rep) + else: + raise ValueError( + f"Unexpected weight format for shared outer down_proj_moe lora_B: " + f"type={type(weights)}, " + f"shape={weights.shape if isinstance(weights, torch.Tensor) else 'N/A'}" + ) + elif isinstance(weights, torch.Tensor) and weights.dim() == 3: + for eid in range(weights.shape[0]): + buffer_view = target_buffer[buffer_id, eid, :, :lora_rank] + w = weights[eid] + if w is not None: + w = w * lora_adapter.scaling + load_lora_weight_tensor(buffer_view, w) + elif isinstance(weights, dict): + for expert_id, expert_weight in weights.items(): + buffer_view = target_buffer[ + buffer_id, expert_id, :, :lora_rank + ] + w = expert_weight + if w is not None: + w = w * lora_adapter.scaling + load_lora_weight_tensor(buffer_view, w) else: - # Standard: single tensor per module buffer_view = target_buffer[buffer_id, :, :lora_rank] load_lora_weight_tensor(buffer_view, weights) diff --git a/python/sglang/srt/lora/triton_ops/sgemm_lora_b.py b/python/sglang/srt/lora/triton_ops/sgemm_lora_b.py index 357d3280548c..b796cdd0efa4 100644 --- a/python/sglang/srt/lora/triton_ops/sgemm_lora_b.py +++ b/python/sglang/srt/lora/triton_ops/sgemm_lora_b.py @@ -87,6 +87,7 @@ def _sgemm_lora_b_kernel( ) # Iterate to compute the block in output matrix + n_mask = n_offset[None, :] < N partial_sum = tl.zeros((BLOCK_S, BLOCK_N), dtype=tl.float32) for k in range(0, tl.cdiv(K, BLOCK_K)): x_tile = tl.load( @@ -96,7 +97,7 @@ def _sgemm_lora_b_kernel( ) w_tile = tl.load( w_ptrs, - mask=(k_offset[:, None] < K - k * BLOCK_K), + mask=(k_offset[:, None] < K - k * BLOCK_K) & n_mask, other=0.0, ) partial_sum += tl.dot(x_tile, w_tile) @@ -110,8 +111,8 @@ def _sgemm_lora_b_kernel( output_ptr = (output + seg_start * output_stride_0) + ( s_offset[:, None] * output_stride_0 + n_offset[None, :] * output_stride_1 ) - output_mask = s_offset[:, None] < seg_len - partial_sum += tl.load(output_ptr, mask=output_mask) + output_mask = (s_offset[:, None] < seg_len) & n_mask + partial_sum += tl.load(output_ptr, mask=output_mask, other=0.0) tl.store(output_ptr, partial_sum, mask=output_mask) diff --git a/python/sglang/srt/lora/utils.py b/python/sglang/srt/lora/utils.py index 45987d736d3c..a5d56c479502 100644 --- a/python/sglang/srt/lora/utils.py +++ b/python/sglang/srt/lora/utils.py @@ -88,9 +88,17 @@ def get_hidden_dim( elif module_name == "down_proj": return config.intermediate_size, config.hidden_size elif module_name == "gate_up_proj_moe": - return config.hidden_size, config.moe_intermediate_size * 2 + moe_inter = ( + getattr(config, "moe_intermediate_size", None) + or config.intermediate_size + ) + return config.hidden_size, moe_inter * 2 elif module_name == "down_proj_moe": - return config.moe_intermediate_size, config.hidden_size + moe_inter = ( + getattr(config, "moe_intermediate_size", None) + or config.intermediate_size + ) + return moe_inter, config.hidden_size elif module_name == "embed_tokens": # For embedding: input is vocab_size (as embedding lookup), output is hidden_size # if contain extra tokens will be added; otherwise is 0. @@ -113,13 +121,17 @@ def get_normalized_target_modules( Handles both base module names (e.g., "gate_proj") and prefixed module names (e.g., "feed_forward.gate_proj"). Also handles PEFT shorthand strings like "all-linear" or "all" by returning - {"all"} as a sentinel value (the caller should check for "all" and fall - back to the CLI --lora-target-modules to determine the concrete module set). + {"all"} as a sentinel value. Callers that need a concrete module set + should use :func:`auto_detect_lora_target_modules` to resolve the shorthand + against the loaded base model. """ - # Handle PEFT shorthand strings — these cannot be resolved to concrete - # module names without inspecting the base model, so we return {"all"} - # and let the caller fall back to the CLI --lora-target-modules. + # Handle PEFT shorthand strings — return {"all"} as sentinel. + # Callers can resolve to concrete names via auto_detect_lora_target_modules(). if isinstance(target_modules, str): + if target_modules not in ["all", "all-linear"]: + raise ValueError( + "Only 'all' or 'all-linear' can be used as the string for target module" + ) return {"all"} params_mapping = { @@ -175,6 +187,45 @@ def get_target_module_name(full_module_name: str, target_modules: Set[str]) -> s EMBEDDING_NAMES = ["embed_tokens", "lm_head"] ROW_PARALLELISM_LINEAR_LORA_NAMES = ["o_proj", "down_proj", "down_proj_moe"] +# Normalized module names that the LoRA system fully supports +# (i.e. get_hidden_dim, init_buffers, and init_lora_modules can handle them). +_KNOWN_LORA_TARGET_MODULES = frozenset( + { + "qkv_proj", + "o_proj", + "gate_up_proj", + "down_proj", + "embed_tokens", + "lm_head", + } +) + + +def auto_detect_lora_target_modules(model: "torch.nn.Module") -> set: + """Discover LoRA-compatible modules by inspecting the base model. + + Walks the model graph and returns the set of *normalized* target-module + names that (a) actually exist in the model and (b) the LoRA memory pool + can handle. This is used to resolve PEFT shorthands like ``"all-linear"`` + without requiring the user to enumerate modules on the CLI. + """ + from sglang.srt.layers.linear import LinearBase + from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE + from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead + + raw_names: set = set() + for name, module in model.named_modules(): + if isinstance(module, FusedMoE): + raw_names.add("gate_up_proj") + raw_names.add("down_proj") + elif isinstance(module, ParallelLMHead): + raw_names.add("lm_head") + elif isinstance(module, LinearBase): + raw_names.add(name.split(".")[-1]) + + normalized = get_normalized_target_modules(raw_names) + return normalized & _KNOWN_LORA_TARGET_MODULES + def get_lm_head_lora_b_shard_size(output_dim: int, shard_indices=None) -> int: """Get the LoRA B output dimension for lm_head, accounting for TP. diff --git a/python/sglang/srt/managers/async_mm_data_processor.py b/python/sglang/srt/managers/async_mm_data_processor.py deleted file mode 100644 index 85e8580cb769..000000000000 --- a/python/sglang/srt/managers/async_mm_data_processor.py +++ /dev/null @@ -1,122 +0,0 @@ -import asyncio -import logging -from concurrent.futures import ThreadPoolExecutor -from functools import partial -from typing import Any, Dict, List, Optional, Union - -logger = logging.getLogger(__name__) - - -class AsyncMMDataProcessor: - """ - Async wrapper for a multimodal processor. - - Behavior: - - If the underlying processor exposes `process_mm_data_async`, call/await it directly. - - Otherwise, fall back to running a synchronous `process_mm_data` in a thread pool. - - Optionally guard per-call concurrency via an asyncio.Semaphore. - - Optionally enforce per-call timeout via asyncio.wait_for. - """ - - def __init__( - self, - mm_processor: Any, - *, - max_concurrent_calls: Optional[int] = None, - timeout_s: Optional[float] = None, - ) -> None: - """ - Args: - mm_processor: An object exposing either - - async def process_mm_data_async(...): -> Dict[str, Any] - or - - def process_mm_data(...): -> Dict[str, Any] - max_concurrent_calls: Optional concurrency cap for per-call execution. - timeout_s: Optional timeout (seconds) for each `process()` call. - """ - self.mm_processor = mm_processor - self.timeout_s = timeout_s - - # Concurrency guard (None -> unlimited) - self.semaphore = ( - asyncio.Semaphore(max_concurrent_calls) if max_concurrent_calls else None - ) - - # Detect async path; if missing, prepare a fallback executor for sync path - self._proc_async = getattr(mm_processor, "process_mm_data_async", None) - self.is_async = asyncio.iscoroutinefunction(self._proc_async) - self.fallback_exec: Optional[ThreadPoolExecutor] = ( - ThreadPoolExecutor(max_workers=max_concurrent_calls) - if not self.is_async - else None - ) - - async def process( - self, - *, - image_data: Optional[List[Union[str, bytes]]] = None, - audio_data: Optional[List[Union[str, bytes]]] = None, - input_text_or_ids: Union[str, List[int], None] = None, - request_obj: Any, - **kwargs: Any, - ) -> Dict[str, Any]: - """ - Public entrypoint: process a single multimodal request without blocking the event loop. - """ - - async def _invoke() -> Dict[str, Any]: - if self.is_async: - # Native async implementation - return await self._proc_async( - image_data=image_data, - audio_data=audio_data, - input_text=input_text_or_ids, - request_obj=request_obj, - **kwargs, - ) - - # Synchronous fallback - sync_fn = getattr(self.mm_processor, "process_mm_data", None) - if not callable(sync_fn): - raise RuntimeError( - "mm_processor has neither 'process_mm_data_async' nor 'process_mm_data'." - ) - loop = asyncio.get_running_loop() - fn = partial( - sync_fn, - image_data=image_data, - audio_data=audio_data, - input_text=input_text_or_ids, - request_obj=request_obj, - **kwargs, - ) - return await loop.run_in_executor(self.fallback_exec, fn) - - # Apply optional concurrency guard - if self.semaphore is not None: - async with self.semaphore: - if self.timeout_s is not None: - return await asyncio.wait_for(_invoke(), timeout=self.timeout_s) - return await _invoke() - - # No concurrency guard - if self.timeout_s is not None: - return await asyncio.wait_for(_invoke(), timeout=self.timeout_s) - return await _invoke() - - def shutdown(self) -> None: - """Gracefully shutdown resources owned by this wrapper.""" - try: - if self.fallback_exec: - self.fallback_exec.shutdown(wait=False) - except Exception: - logger.exception( - "Error while shutting down fallback executor in AsyncMMDataProcessor" - ) - - def __del__(self): - # Best-effort shutdown - try: - self.shutdown() - except Exception: - pass diff --git a/python/sglang/srt/managers/detokenizer_manager.py b/python/sglang/srt/managers/detokenizer_manager.py index 17972005ffed..ce27113845c7 100644 --- a/python/sglang/srt/managers/detokenizer_manager.py +++ b/python/sglang/srt/managers/detokenizer_manager.py @@ -21,14 +21,13 @@ from typing import Dict, List, Optional, Tuple, Union import psutil -import pybase64 import setproctitle import zmq +from sglang.srt.constants import HEALTH_CHECK_RID_PREFIX from sglang.srt.environ import envs from sglang.srt.managers.io_struct import ( BatchEmbeddingOutput, - BatchMultimodalDecodeReq, BatchStrOutput, BatchTokenIDOutput, FreezeGCReq, @@ -88,16 +87,9 @@ def __init__( # Init running status self.init_running_status(server_args) - if server_args.enable_metrics: - start_cpu_monitor_thread("detokenizer") - # Init dispatcher self.init_request_dispatcher() - @staticmethod - def is_health_check_request(rid: Optional[str]) -> bool: - return isinstance(rid, str) and rid.startswith("HEALTH_CHECK") - def init_ipc_channels(self, port_args: PortArgs): context = zmq.Context(2) self.recv_from_scheduler = get_zmq_socket( @@ -120,9 +112,8 @@ def init_tokenizer(self, server_args: ServerArgs): def init_running_status(self, server_args: ServerArgs): self.decode_status = LimitedCapacityDict(capacity=DETOKENIZER_MAX_STATES) - self.is_dummy = False - self.is_tool_call_parser_gpt_oss = server_args.tool_call_parser == "gpt-oss" self.disable_tokenizer_batch_decode = server_args.disable_tokenizer_batch_decode + self.is_tool_call_parser_gpt_oss = server_args.tool_call_parser == "gpt-oss" self.soft_watchdog = Watchdog.create( debug_name="DetokenizerManager", @@ -131,12 +122,14 @@ def init_running_status(self, server_args: ServerArgs): test_stuck_time=envs.SGLANG_TEST_STUCK_DETOKENIZER.get(), ) + if server_args.enable_metrics: + start_cpu_monitor_thread("detokenizer") + def init_request_dispatcher(self): self._request_dispatcher = TypeBasedDispatcher( [ (BatchEmbeddingOutput, self.handle_batch_embedding_out), (BatchTokenIDOutput, self.handle_batch_token_id_out), - (BatchMultimodalDecodeReq, self.handle_multimodal_decode_req), (FreezeGCReq, self.handle_freeze_gc_req), ] ) @@ -190,8 +183,6 @@ def _grouped_batch_decode( ) -> List[str]: """Batch decode with grouping by (skip_special_tokens, spaces_between_special_tokens).""" - assert self.tokenizer is not None - # fast path first_skip, first_space = skip_list[0], space_list[0] if all(s == first_skip for s in skip_list) and all( @@ -236,9 +227,7 @@ def _decode_batch_token_id_output(self, recv_obj: BatchTokenIDOutput): surr_offset=0, read_offset=recv_obj.read_offsets[i], ) - if not self.is_health_check_request(rid): - # for health check requests, we do not store the decode status - self.decode_status[rid] = s + self.decode_status[rid] = s else: s = self.decode_status[rid] s.decode_ids.extend(recv_obj.decode_ids[i]) @@ -254,22 +243,16 @@ def _decode_batch_token_id_output(self, recv_obj: BatchTokenIDOutput): # Decode token ids to strings if not self.disable_tokenizer_batch_decode: - if not self.is_dummy: - # Run normal batch decode - surr_texts = self._grouped_batch_decode( - surr_ids, - recv_obj.skip_special_tokens, - recv_obj.spaces_between_special_tokens, - ) - read_texts = self._grouped_batch_decode( - read_ids, - recv_obj.skip_special_tokens, - recv_obj.spaces_between_special_tokens, - ) - else: - # If it is dummy weights, just return dummy strings to prevent potential detokenization edge cases - surr_texts = ["dog" for _ in surr_ids] - read_texts = ["cat" for _ in read_ids] + surr_texts = self._grouped_batch_decode( + surr_ids, + recv_obj.skip_special_tokens, + recv_obj.spaces_between_special_tokens, + ) + read_texts = self._grouped_batch_decode( + read_ids, + recv_obj.skip_special_tokens, + recv_obj.spaces_between_special_tokens, + ) else: # Do not use batch decode to prevent some detokenization edge cases (e.g., gpt-oss). surr_texts = [ @@ -297,25 +280,17 @@ def _decode_batch_token_id_output(self, recv_obj: BatchTokenIDOutput): output_strs = [] for i in range(bs): rid = recv_obj.rids[i] - if self.is_health_check_request(rid): - s = DecodeStatus( - decoded_text=recv_obj.decoded_texts[i], - decode_ids=recv_obj.decode_ids[i], - surr_offset=0, - read_offset=recv_obj.read_offsets[i], + try: + s = self.decode_status[rid] + except KeyError: + raise RuntimeError( + f"Decode status not found for request {rid}. " + "It may be due to the request being evicted from the decode status due to memory pressure. " + "Please increase the maximum number of requests by setting " + "the SGLANG_DETOKENIZER_MAX_STATES environment variable to a bigger value than the default value. " + f"The current value is {DETOKENIZER_MAX_STATES}. " + "For more details, see: https://github.com/sgl-project/sglang/issues/2812" ) - else: - try: - s = self.decode_status[rid] - except KeyError: - raise RuntimeError( - f"Decode status not found for request {rid}. " - "It may be due to the request being evicted from the decode status due to memory pressure. " - "Please increase the maximum number of requests by setting " - "the SGLANG_DETOKENIZER_MAX_STATES environment variable to a bigger value than the default value. " - f"The current value is {DETOKENIZER_MAX_STATES}. " - "For more details, see: https://github.com/sgl-project/sglang/issues/2812" - ) new_text = read_texts[i][len(surr_texts[i]) :] if recv_obj.finished_reasons[i] is None: # Streaming chunk: update the decode status @@ -335,6 +310,7 @@ def _decode_batch_token_id_output(self, recv_obj: BatchTokenIDOutput): recv_obj.finished_reasons[i], recv_obj.no_stop_trim[i], ) + # Incrementally send text. incremental_output = output_str[s.sent_offset :] s.sent_offset = len(output_str) @@ -342,21 +318,6 @@ def _decode_batch_token_id_output(self, recv_obj: BatchTokenIDOutput): return output_strs - def _extract_routed_experts( - self, recv_obj: BatchTokenIDOutput - ) -> list[str | None] | None: - routed_experts = None - if recv_obj.routed_experts is not None: - routed_experts = [ - ( - pybase64.b64encode(routed_experts.numpy().tobytes()).decode("utf-8") - if routed_experts is not None - else None - ) - for routed_experts in recv_obj.routed_experts - ] - return routed_experts - def handle_batch_token_id_out(self, recv_obj: BatchTokenIDOutput): # If handling idle batch, set output_strs to []. output_strs = ( @@ -364,8 +325,6 @@ def handle_batch_token_id_out(self, recv_obj: BatchTokenIDOutput): if len(recv_obj.rids) > 0 else [] ) - routed_experts = self._extract_routed_experts(recv_obj) - return BatchStrOutput( rids=recv_obj.rids, http_worker_ipcs=recv_obj.http_worker_ipcs, @@ -393,7 +352,7 @@ def handle_batch_token_id_out(self, recv_obj: BatchTokenIDOutput): output_token_ids_logprobs_idx=recv_obj.output_token_ids_logprobs_idx, output_token_entropy_val=recv_obj.output_token_entropy_val, output_hidden_states=recv_obj.output_hidden_states, - routed_experts=routed_experts, + routed_experts=recv_obj.routed_experts, customized_info=recv_obj.customized_info, placeholder_tokens_idx=None, placeholder_tokens_val=None, @@ -404,14 +363,15 @@ def handle_batch_token_id_out(self, recv_obj: BatchTokenIDOutput): time_stats=recv_obj.time_stats, ) - def handle_multimodal_decode_req(self, recv_obj: BatchMultimodalDecodeReq): - raise NotImplementedError() - def handle_freeze_gc_req(self, recv_req: FreezeGCReq): freeze_gc("Detokenizer Manager") return None +def is_health_check_request(rid: Optional[str]) -> bool: + return isinstance(rid, str) and rid.startswith(HEALTH_CHECK_RID_PREFIX) + + class LimitedCapacityDict(OrderedDict): def __init__(self, capacity: int, *args, **kwargs): super().__init__(*args, **kwargs) diff --git a/python/sglang/srt/managers/hisparse_coordinator.py b/python/sglang/srt/managers/hisparse_coordinator.py index 92ef22f404cf..89740f73682e 100644 --- a/python/sglang/srt/managers/hisparse_coordinator.py +++ b/python/sglang/srt/managers/hisparse_coordinator.py @@ -56,7 +56,7 @@ def __init__( override_kv_cache_dim=self.mem_pool_device.kv_cache_dim, ) - max_num_reqs = req_to_token_pool.size + max_num_reqs = req_to_token_pool.req_to_token.shape[0] max_context_len = req_to_token_pool.max_context_len # to have an extra page for new tokens @@ -123,7 +123,7 @@ def set_decode_producer_stream(self, stream) -> None: self.decode_producer_stream = stream def admit_request_into_staging(self, req: Req) -> None: - req.staging = True + req.hisparse_staging = True logical_indices = self.req_to_token_pool.req_to_token[ req.req_pool_idx, : len(req.fill_ids) ] @@ -161,6 +161,53 @@ def admit_request_into_staging(self, req: Req) -> None: self.ack_staging_queue.append(HiSparseAct(start_event, finish_event, req)) + def admit_request_direct(self, req: Req) -> None: + """Direct-to-host path: KV data already resides in host pool via RDMA. + + Skips staging DMA entirely. Only allocates a small device buffer + (4KB) for decode-time swap-in, then marks the request as ready. + Host indices were already written to req_to_host_pool. + + Metadata fixups after alloc_device_buffer(): + - alloc_device_buffer() sets device_buffer_tokens = [0, 1, ..., buf_size-1], + which tells the swap-in kernel that those tokens are cached in the device + buffer. In the staging path this is correct (prefill filled the buffer), + but here the buffer is empty. + """ + self.alloc_device_buffer(req) + + if req.kv_allocated_len <= self.device_buffer_size: + # Short sequences (seq_len <= device_buffer_size): the kernel fast path + # returns device_buffer_locs directly without any host loading, so we + # must preload all tokens from host pool into the device buffer + # TODO(hzh0425): Optimize this. + self._preload_to_device_buffer(req) + else: + # Long sequence: reset device_buffer_tokens to -1 so the kernel + # sees all slots as empty → every top-k lookup is a miss → host load. + self.req_device_buffer_tokens[ + :, req.req_pool_idx, : self.device_buffer_size + ] = -1 + + req.staging = False + self._skip_first_backup[req.req_pool_idx] = True + logger.debug("HiSparse: admitting request %s directly", req.rid) + + def _preload_to_device_buffer(self, req: Req) -> None: + """Preload all tokens from host pool into the device buffer.""" + n = req.kv_allocated_len + host_indices = self.req_to_host_pool[req.req_pool_idx, :n] + device_locs = self.req_to_device_buffer[req.req_pool_idx, :n] + + for layer_id in range(self.mem_pool_device.layer_num): + self.mem_pool_host.load_to_device_per_layer( + self.mem_pool_device, + host_indices, + device_locs, + layer_id, + io_backend="kernel", + ) + def alloc_device_buffer(self, req: Req) -> None: allocated_indices = self.req_to_token_pool.req_to_token[ req.req_pool_idx, : req.kv_allocated_len @@ -224,7 +271,7 @@ def collect_ready_reqs(self) -> List[Req]: _, _, req = self.ack_staging_queue.pop(0) # prepare device buffer and update req self.alloc_device_buffer(req) - req.staging = False + req.hisparse_staging = False self._skip_first_backup[req.req_pool_idx] = True finish_count -= 1 ready_reqs.append(req) @@ -494,7 +541,7 @@ def abort_staging_request(self, req: Req) -> None: """Remove a request from the staging queue and free its host resources. Must be called when aborting a request that has been admitted into staging - but has not yet completed (i.e. req.staging is True). + but has not yet completed (i.e. req.hisparse_staging is True). """ # Remove from staging queue self.ack_staging_queue = [ @@ -510,10 +557,10 @@ def abort_staging_request(self, req: Req) -> None: self.mem_pool_host.free(host_indices) self.req_to_host_pool[req.req_pool_idx, :] = -1 self._skip_first_backup[req.req_pool_idx] = False - req.staging = False + req.hisparse_staging = False def retract_req(self, req: Req) -> None: - if req.staging: + if req.hisparse_staging: self.abort_staging_request(req) else: self.request_finished(req) @@ -556,14 +603,14 @@ def swap_in_selected_pages( layer_id: int, ) -> torch.Tensor: """Swap selected top-k tokens into device memory and return their indices.""" - # The CUDA kernel expects req_pool_indices as int64 and seq_lens as int32. + # The CUDA kernel expects req_pool_indices as int64 and seq_lens as int32 or int64. if req_pool_indices.dtype != torch.int64: raise ValueError( f"req_pool_indices dtype {req_pool_indices.dtype} is not int64 as expected" ) - if seq_lens.dtype != torch.int32: + if seq_lens.dtype not in (torch.int32, torch.int64): raise ValueError( - f"seq_lens dtype {seq_lens.dtype} is not int32 as expected" + f"seq_lens dtype {seq_lens.dtype} is not int32 or int64 as expected" ) if top_k_result.dtype != torch.int32: raise ValueError( diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index 139a588aeebf..471c0168562a 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -21,6 +21,7 @@ import copy import uuid from abc import ABC +from collections import Counter from dataclasses import dataclass, field from enum import Enum from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Union @@ -58,6 +59,15 @@ def regenerate_rid(self): self.rid = uuid.uuid4().hex return self.rid + def _validate_rid_uniqueness(self): + """Validate that request IDs within a batch are unique.""" + if isinstance(self.rid, list) and len(set(self.rid)) != len(self.rid): + counts = Counter(self.rid) + duplicates = [rid for rid, count in counts.items() if count > 1] + raise ValueError( + f"Duplicate request IDs detected within the request: {duplicates}" + ) + @dataclass class BaseBatchReq(ABC): @@ -276,6 +286,8 @@ def normalize_batch_and_arguments(self): else: self._normalize_batch_inputs() + self._validate_rid_uniqueness() + def _validate_inputs(self): """Validate that the input configuration is valid.""" if ( @@ -731,6 +743,8 @@ class TokenizedGenerateReqInput(BaseReq): # Whether to return entropy return_entropy: bool = False + token_type_ids: Optional[List[int]] = None + need_wait_for_mm_inputs: bool = False num_items_assigned: Optional[Dict[Modality, List[int]]] = None @@ -853,6 +867,8 @@ def normalize_batch_and_arguments(self): self._normalize_lora_paths(self.batch_size) + self._validate_rid_uniqueness() + def _normalize_lora_paths(self, num): """Normalize LoRA paths for batch processing.""" if self.lora_path is not None: @@ -1008,38 +1024,6 @@ class BatchTokenIDOutput(BaseBatchReq, SpeculativeDecodingMetricsMixin): time_stats: Optional[List[SchedulerReqTimeStats]] = None -@dataclass -class BatchMultimodalDecodeReq(BaseBatchReq): - decoded_ids: List[int] - input_token_logprobs_val: List[float] - input_token_logprobs_idx: List[int] - output_token_logprobs_val: List[float] - output_token_logprobs_idx: List[int] - read_offsets: List[int] - skip_special_tokens: List[bool] - spaces_between_special_tokens: List[bool] - image_resolutions: List[List[int]] - resize_image_resolutions: List[List[int]] - - finished_reasons: List[BaseFinishReason] - - # Token counts - prompt_tokens: List[int] - completion_tokens: List[int] - cached_tokens: List[int] - - # The information of placeholder tokens (e.g., image token) - # idx is the index of the token in the prompt after expansion. - # val is the length of padded tokens after expansion. - placeholder_tokens_idx: List[Optional[List[int]]] - placeholder_tokens_val: List[Optional[List[int]]] - - return_bytes: List[bool] - - # The trainer step id. Used to know which step's weights are used for sampling. - token_steps: List[List[int]] = None - - @dataclass class BatchStrOutput(BaseBatchReq, SpeculativeDecodingMetricsMixin): # The finish reason @@ -1102,36 +1086,6 @@ class BatchStrOutput(BaseBatchReq, SpeculativeDecodingMetricsMixin): time_stats: Optional[List[SchedulerReqTimeStats]] = None -@dataclass -class BatchMultimodalOutput(BaseBatchReq): - # The finish reason - finished_reasons: List[dict] - decoded_ids: List[List[int]] - # The outputs - outputs: Union[List[str | bytes], List[List[Dict]]] - - # probability values for input tokens and output tokens - input_token_logprobs_val: List[List[float]] - input_token_logprobs_idx: List[List[int]] - output_token_logprobs_val: List[List[float]] - output_token_logprobs_idx: List[List[int]] - - # Token counts - prompt_tokens: List[int] - completion_tokens: List[int] - cached_tokens: List[int] - - placeholder_tokens_idx: List[Optional[List[int]]] - placeholder_tokens_val: List[Optional[List[int]]] - - return_bytes: List[bool] - # Detailed breakdown of cached tokens by source (device/host/storage) - cached_tokens_details: Optional[List[Optional[Dict[str, Any]]]] = None - - # For observability - time_stats: Optional[List[SchedulerReqTimeStats]] = None - - @dataclass class BatchEmbeddingOutput(BaseBatchReq): # The finish reason @@ -1229,21 +1183,6 @@ class DetachHiCacheStorageReqOutput(BaseReq): message: str = "" -@dataclass -class PinPrefixReqInput(BaseReq): - """Pin a prefix by token_ids to resist eviction.""" - - token_ids: List[int] = field(default_factory=list) - ttl_seconds: int = 300 # TTL in seconds, default 5 minutes - - -@dataclass -class PinPrefixReqOutput(BaseReq): - success: bool - nodes_pinned: int = 0 - message: str = "" - - @dataclass class PauseGenerationReqInput(BaseReq): """ diff --git a/python/sglang/srt/managers/mm_utils.py b/python/sglang/srt/managers/mm_utils.py index 7cc37176a40c..f80d4a9f00b9 100644 --- a/python/sglang/srt/managers/mm_utils.py +++ b/python/sglang/srt/managers/mm_utils.py @@ -327,46 +327,26 @@ def pad_input_tokens( input_ids_tensor = torch.as_tensor(input_ids) - # Check if MM splitting is enabled - if envs.SGLANG_ENABLE_MM_SPLITTING.get(): - items_by_modality = defaultdict(list) - for item in mm_inputs.mm_items: - items_by_modality[item.modality].append(item) - - token_id_map = { - Modality.IMAGE: mm_inputs.im_token_id, - Modality.MULTI_IMAGES: mm_inputs.im_token_id, - Modality.AUDIO: mm_inputs.audio_token_id, - Modality.VIDEO: mm_inputs.video_token_id, - } - - for modality, items in items_by_modality.items(): - token_id = token_id_map.get(modality) - - if not items or token_id is None: - continue + # Replace multimodal tokens using per-item offsets + items_by_modality = defaultdict(list) + for item in mm_inputs.mm_items: + items_by_modality[item.modality].append(item) + + token_id_map = { + Modality.IMAGE: mm_inputs.im_token_id, + Modality.AUDIO: mm_inputs.audio_token_id, + Modality.VIDEO: mm_inputs.video_token_id, + } - for i, item in enumerate(items): - for offset in items[i].offsets: - input_ids_tensor[offset[0] : offset[1] + 1] = item.pad_value - else: - # Create mapping of token_ids to pad_values for each modality - token_to_pad_mapping = {} - for item in mm_inputs.mm_items: - if item.is_image() and mm_inputs.im_token_id is not None: - token_to_pad_mapping[mm_inputs.im_token_id] = item.pad_value - elif item.is_audio() and mm_inputs.audio_token_id is not None: - token_to_pad_mapping[mm_inputs.audio_token_id] = item.pad_value - elif item.is_video() and mm_inputs.video_token_id is not None: - token_to_pad_mapping[mm_inputs.video_token_id] = item.pad_value - else: - raise ValueError( - f"No multimodal token id provided for {item.modality}" - ) + for modality, items in items_by_modality.items(): + token_id = token_id_map.get(modality) + + if not items or token_id is None: + continue - # Apply replacements for all tokens at once - for token_id, pad_value in token_to_pad_mapping.items(): - input_ids_tensor[input_ids_tensor == token_id] = pad_value + for i, item in enumerate(items): + for offset in items[i].offsets: + input_ids_tensor[offset[0] : offset[1] + 1] = item.pad_value ret_input_ids = input_ids_tensor.tolist() return ret_input_ids @@ -994,6 +974,8 @@ def embed_mm_inputs( multimodal_model.separate_deepstack_embeds(embedding) ) deepstack_embeddings += [deepstack_embedding] + else: + deepstack_embeddings += [None] modalities += [modality] embeddings += [embedding] masks += [mask] @@ -1341,23 +1323,29 @@ def tensor_hash(tensor_list) -> int: tensor = tensor_list if isinstance(tensor_list, list): tensor_list = flatten_nested_list(tensor_list) - tensor_list = [ + tensors = [ x.flatten() if isinstance(x, torch.Tensor) else x for x in tensor_list ] - tensor = torch.concat(tensor_list) + # GPU path: concat + triton hash (unchanged) + if any(isinstance(t, torch.Tensor) and t.is_cuda for t in tensors): + tensor = torch.concat(tensors) + return gpu_tensor_hash(tensor.cuda()) + # CPU path: hash each tensor incrementally without concat + hasher = hashlib.sha256() + for t in tensors: + t = t.detach().contiguous() + hasher.update(memoryview(t.view(torch.uint8).numpy())) + hash_bytes = hasher.digest()[:8] + return int.from_bytes(hash_bytes, byteorder="big", signed=False) + + # Single tensor if tensor.is_cuda: return gpu_tensor_hash(tensor.cuda()) tensor = tensor.detach().contiguous() - - if tensor.dtype == torch.bfloat16: - # memoryview() doesn't support PyTorch's BFloat16 dtype - tensor = tensor.float() - - assert isinstance(tensor, torch.Tensor) - tensor_cpu = tensor.cpu() - - mv = memoryview(tensor_cpu.numpy()) - return data_hash(mv.tobytes()) + hasher = hashlib.sha256() + hasher.update(memoryview(tensor.view(torch.uint8).numpy())) + hash_bytes = hasher.digest()[:8] + return int.from_bytes(hash_bytes, byteorder="big", signed=False) def hash_feature(f): @@ -1367,8 +1355,10 @@ def hash_feature(f): return data_hash(tuple(flatten_nested_list(f))) elif isinstance(f, np.ndarray): arr = np.ascontiguousarray(f) - arr_bytes = arr.tobytes() - return data_hash(arr_bytes) + hasher = hashlib.sha256() + hasher.update(memoryview(arr)) + hash_bytes = hasher.digest()[:8] + return int.from_bytes(hash_bytes, byteorder="big", signed=False) elif isinstance(f, torch.Tensor): return tensor_hash([f]) elif isinstance(f, CudaIpcTensorTransportProxy): @@ -1466,6 +1456,54 @@ def _slice_model_data( return sliced +def _try_simple_split(item, num_items, expanded_mm_items): + """Try to split a bundled item by matching feature dim-0 to offset count. + Returns True if split succeeded, False otherwise.""" + feature = item.feature if item.feature is not None else item.precomputed_embeddings + if feature is None: + return False + + if isinstance(feature, (torch.Tensor, np.ndarray)): + feature_count = feature.shape[0] + elif isinstance(feature, (list, tuple)): + feature_count = len(feature) + else: + return False + + if feature_count != num_items: + return False + + for i in range(num_items): + new_item = copy.copy(item) + if item.feature is not None: + if isinstance(item.feature, (list, tuple)): + new_item.feature = [item.feature[i]] + else: + new_item.feature = item.feature[i : i + 1] + if item.precomputed_embeddings is not None: + if isinstance(item.precomputed_embeddings, (list, tuple)): + new_item.precomputed_embeddings = [item.precomputed_embeddings[i]] + else: + new_item.precomputed_embeddings = item.precomputed_embeddings[i : i + 1] + new_item.offsets = [item.offsets[i]] + new_data = {} + for k, v in item.model_specific_data.items(): + if isinstance(v, (list, tuple)) and len(v) == num_items: + new_data[k] = [v[i]] + elif ( + isinstance(v, (torch.Tensor, np.ndarray)) + and len(v.shape) > 0 + and v.shape[0] == num_items + ): + new_data[k] = v[i : i + 1] + else: + new_data[k] = v + new_item.model_specific_data = new_data + new_item.hash = None + expanded_mm_items.append(new_item) + return True + + def get_new_expanded_mm_items(original_mm_items): expanded_mm_items = [] for item in original_mm_items: @@ -1478,7 +1516,9 @@ def get_new_expanded_mm_items(original_mm_items): image_grid_thw = item.model_specific_data.get("image_grid_thw") grid_len = _get_length(image_grid_thw) if image_grid_thw is None or grid_len != num_items: - expanded_mm_items.append(item) + # No grid info — fall back to simple split by feature dim-0 + if not _try_simple_split(item, num_items, expanded_mm_items): + expanded_mm_items.append(item) continue patches_per_item = [] @@ -1501,7 +1541,7 @@ def get_new_expanded_mm_items(original_mm_items): total_feature_len = feature_len for i in range(num_items): start, end = slice_indices[i], slice_indices[i + 1] - new_item = copy.deepcopy(item) + new_item = copy.copy(item) if item.feature is not None: new_item.feature = _slice_value(item.feature, start, end) if item.precomputed_embeddings is not None: @@ -1523,7 +1563,8 @@ def get_new_expanded_mm_items(original_mm_items): elif item.is_video(): video_grid_thw = item.model_specific_data.get("video_grid_thw") if video_grid_thw is None: - expanded_mm_items.append(item) + if not _try_simple_split(item, num_items, expanded_mm_items): + expanded_mm_items.append(item) continue # video_grid_thw shape: [num_videos, 3] where each row is [T, H, W] @@ -1592,7 +1633,7 @@ def get_new_expanded_mm_items(original_mm_items): frame_start_indices[video_idx + 1], ) - new_item = copy.deepcopy(item) + new_item = copy.copy(item) if item.feature is not None: new_item.feature = _slice_value(item.feature, start, end) if item.precomputed_embeddings is not None: @@ -1613,7 +1654,8 @@ def get_new_expanded_mm_items(original_mm_items): new_item.hash = None expanded_mm_items.append(new_item) else: - expanded_mm_items.append(item) + if not _try_simple_split(item, num_items, expanded_mm_items): + expanded_mm_items.append(item) else: expanded_mm_items.append(item) @@ -1627,45 +1669,28 @@ class ShmPointerMMData: """ def __init__(self, tensor: torch.Tensor): - self.cpu_tensor = tensor.cpu().contiguous() - self.shape = self.cpu_tensor.shape - self.dtype = self.cpu_tensor.dtype - - nbytes = self.cpu_tensor.numel() * self.cpu_tensor.element_size() - - self.shm = shared_memory.SharedMemory(create=True, size=nbytes) - + if not tensor.is_cpu: + tensor = tensor.cpu() + if not tensor.is_contiguous(): + tensor = tensor.contiguous() + self.shape = tensor.shape + self.dtype = tensor.dtype + nbytes = tensor.numel() * tensor.element_size() + shm = shared_memory.SharedMemory(create=True, size=nbytes) try: - shm_view = np.ndarray((nbytes,), dtype=np.uint8, buffer=self.shm.buf) - - shm_view[:] = self.cpu_tensor.view(torch.uint8).numpy().flatten() - finally: - self.shm.close() + dst = torch.frombuffer(shm.buf, dtype=torch.uint8) + dst.copy_(tensor.view(torch.uint8).reshape(-1)) + except BaseException: + shm.close() + shm.unlink() + raise + self.shm_name = shm.name + shm.close() + self._shm_handle = None def __getstate__(self): - if not hasattr(self, "shm") or self.shm is None: - tensor = getattr(self, "cpu_tensor", None) - if tensor is None: - tensor = getattr(self, "tensor", None) - if tensor is None: - raise RuntimeError( - "ShmPointerMMData cannot recreate shared memory without tensor" - ) - - cpu_tensor = tensor.cpu().contiguous() - self.shape = cpu_tensor.shape - self.dtype = cpu_tensor.dtype - - nbytes = cpu_tensor.numel() * cpu_tensor.element_size() - self.shm = shared_memory.SharedMemory(create=True, size=nbytes) - try: - shm_view = np.ndarray((nbytes,), dtype=np.uint8, buffer=self.shm.buf) - shm_view[:] = cpu_tensor.view(torch.uint8).numpy().flatten() - finally: - self.shm.close() - return { - "shm_name": self.shm.name, + "shm_name": self.shm_name, "shape": self.shape, "dtype": self.dtype, } @@ -1675,17 +1700,29 @@ def __setstate__(self, state): self.shape = state["shape"] self.dtype = state["dtype"] self.shm = None + self._shm_handle = shared_memory.SharedMemory(name=self.shm_name) + # Zero-copy view into shared memory (no clone, no unlink) + self.tensor = torch.frombuffer(self._shm_handle.buf, dtype=self.dtype).reshape( + self.shape + ) - shm_handle = shared_memory.SharedMemory(name=self.shm_name) - try: - self.tensor = ( - torch.frombuffer(shm_handle.buf, dtype=self.dtype) - .reshape(self.shape) - .clone() - ) - finally: - shm_handle.close() - shm_handle.unlink() + def materialize(self) -> torch.Tensor: + """Clone tensor from shm to owned memory, then release shm handle.""" + tensor = self.tensor.clone() + if self._shm_handle is not None: + self._shm_handle.close() + try: + self._shm_handle.unlink() + except FileNotFoundError: + pass # Another rank already unlinked + self._shm_handle = None + return tensor + + def __del__(self): + # Only close; never unlink. Unlinking is materialize()'s job. + if getattr(self, "_shm_handle", None) is not None: + self._shm_handle.close() + self._shm_handle = None def _get_is_default_transport(): @@ -1720,15 +1757,35 @@ def wrap_shm_features(obj): return obj +def has_shm_features(recv_reqs): + """Return True if any request in the list contains ShmPointerMMData.""" + for req in recv_reqs: + if hasattr(req, "batch"): + if has_shm_features(req.batch): + return True + elif hasattr(req, "mm_inputs") and req.mm_inputs: + for item in req.mm_inputs.get("mm_items", []): + if isinstance(item.feature, ShmPointerMMData): + return True + return False + + def unwrap_shm_features(obj): """ Restore ShmPointerMMData wrappers back into standard torch.Tensors. + Handles both single requests and batch requests. """ if _get_is_default_transport() or get_global_server_args().skip_tokenizer_init: return obj + # Handle batch requests + if hasattr(obj, "batch"): + for sub_obj in obj.batch: + unwrap_shm_features(sub_obj) + return obj + # Handle single requests if hasattr(obj, "mm_inputs") and obj.mm_inputs: mm_items = obj.mm_inputs.get("mm_items", []) for item in mm_items: if isinstance(item.feature, ShmPointerMMData): - item.feature = item.feature.tensor + item.feature = item.feature.materialize() return obj diff --git a/python/sglang/srt/managers/multi_tokenizer_mixin.py b/python/sglang/srt/managers/multi_tokenizer_mixin.py index 61f22efd80b9..4da03863068b 100644 --- a/python/sglang/srt/managers/multi_tokenizer_mixin.py +++ b/python/sglang/srt/managers/multi_tokenizer_mixin.py @@ -40,7 +40,6 @@ BaseBatchReq, BaseReq, BatchEmbeddingOutput, - BatchMultimodalOutput, BatchStrOutput, BatchTokenIDOutput, ) @@ -282,17 +281,6 @@ def _handle_output_by_index(output, i): output, "token_steps", i, check_length=False ), ) - elif isinstance(output, BatchMultimodalOutput): - new_output = BatchMultimodalOutput( - rids=[output.rids[i]], - finished_reasons=_extract_field_by_index(output, "finished_reasons", i), - outputs=_extract_field_by_index(output, "outputs", i), - prompt_tokens=_extract_field_by_index(output, "prompt_tokens", i), - completion_tokens=_extract_field_by_index(output, "completion_tokens", i), - cached_tokens=_extract_field_by_index(output, "cached_tokens", i), - placeholder_tokens_idx=None, - placeholder_tokens_val=None, - ) else: new_output = output return new_output diff --git a/python/sglang/srt/managers/multimodal_processor.py b/python/sglang/srt/managers/multimodal_processor.py index b2c9e68cb9f9..0554ed34e397 100644 --- a/python/sglang/srt/managers/multimodal_processor.py +++ b/python/sglang/srt/managers/multimodal_processor.py @@ -4,6 +4,7 @@ import logging import pkgutil +from sglang.srt.configs.model_config import ModelImpl from sglang.srt.multimodal.processors.base_processor import BaseMultimodalProcessor from sglang.srt.server_args import ServerArgs @@ -41,14 +42,41 @@ def import_processors(package_name: str, overwrite: bool = False): def get_mm_processor( - hf_config, server_args: ServerArgs, processor, transport_mode, **kwargs + hf_config, + server_args: ServerArgs, + processor, + transport_mode, + model_config=None, + **kwargs, ) -> BaseMultimodalProcessor: + model_impl = str(getattr(server_args, "model_impl", "auto")).lower() + uses_transformers_backend = model_impl == "transformers" + if model_impl == "auto" and model_config is not None: + from sglang.srt.model_loader.utils import get_resolved_model_impl + + uses_transformers_backend = ( + get_resolved_model_impl(model_config) == ModelImpl.TRANSFORMERS + ) + for model_cls, processor_cls in PROCESSOR_MAPPING.items(): - if model_cls.__name__ in hf_config.architectures: + if model_cls.__name__ not in hf_config.architectures: + continue + if not uses_transformers_backend or getattr( + processor_cls, "supports_transformers_backend", False + ): return processor_cls( hf_config, server_args, processor, transport_mode, **kwargs ) + if uses_transformers_backend: + from sglang.srt.multimodal.processors.transformers_auto import ( + TransformersAutoMultimodalProcessor, + ) + + return TransformersAutoMultimodalProcessor( + hf_config, server_args, processor, transport_mode, **kwargs + ) + raise ValueError( f"No processor registered for architecture: {hf_config.architectures}.\n" f"Registered architectures: {[model_cls.__name__ for model_cls in PROCESSOR_MAPPING.keys()]}" diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 6a019496aeb5..7cfa7d50aa00 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -198,7 +198,6 @@ def to_json(self): class Modality(Enum): IMAGE = auto() - MULTI_IMAGES = auto() VIDEO = auto() AUDIO = auto() @@ -225,9 +224,10 @@ class MultimodalInputFormat(Enum): @dataclasses.dataclass class MultimodalDataItem: """ - One MultimodalDataItem contains all inputs for one modality. - For example, if there are 3 images and 1 audio inputs, there will be 2 MultimodalDataItem. - One for images and one for audio. + One MultimodalDataItem represents a single multimodal input (one image, one video, or one audio). + For example, if there are 3 images and 1 audio, there will be 4 MultimodalDataItems. + + Each item has its own hash and pad_value, enabling per-image RadixAttention caching. We put the common fields first and the model-specific fields in model_specific_data. """ @@ -305,7 +305,7 @@ def is_audio(self): return self.modality == Modality.AUDIO def is_image(self): - return self.modality in [Modality.IMAGE, Modality.MULTI_IMAGES] + return self.modality == Modality.IMAGE def is_video(self): return self.modality == Modality.VIDEO @@ -330,11 +330,27 @@ def from_dict(obj: dict): ret.validate() return ret - def merge(self, other): - self.feature += other.feature - self.offsets += other.offsets - self.hash = hash((self.hash, other.hash)) - self.set_pad_value() + def reconstruct(self): + if not isinstance(self.feature, CudaIpcTensorTransportProxy): + return + + reconstruct_device = torch.cuda.current_device() + if isinstance(self.feature, CudaIpcTensorTransportProxy): + self.feature = self.feature.reconstruct_on_target_device(reconstruct_device) + if isinstance(self.precomputed_embeddings, CudaIpcTensorTransportProxy): + self.precomputed_embeddings = ( + self.precomputed_embeddings.reconstruct_on_target_device( + reconstruct_device + ) + ) + for extra_key in self.model_specific_data: + if isinstance( + self.model_specific_data[extra_key], CudaIpcTensorTransportProxy + ): + extra_data = self.model_specific_data[ + extra_key + ].reconstruct_on_target_device(reconstruct_device) + self.model_specific_data[extra_key] = extra_data @dataclasses.dataclass @@ -373,15 +389,9 @@ def release_features(self): @staticmethod def from_dict(obj: dict): - # Check if MM splitting is enabled - if not envs.SGLANG_ENABLE_MM_SPLITTING.get(): - mm_items = obj["mm_items"] - else: - from sglang.srt.managers.mm_utils import get_new_expanded_mm_items - - original_mm_items = obj["mm_items"] - # Now, `mm_items` contains one item per image. - mm_items = get_new_expanded_mm_items(original_mm_items) + mm_items = obj["mm_items"] + for mm_item in mm_items: + mm_item.reconstruct() ret = MultimodalInputs( mm_items=mm_items, @@ -805,7 +815,7 @@ def __init__( self.init_diffusion_llm(dllm_config) # For hisparse - self.staging = False + self.hisparse_staging = False @property def seqlen(self) -> int: @@ -1673,13 +1683,6 @@ def prepare_for_extend(self): pixel_values = getattr(mm_item, "feature", None) if isinstance(pixel_values, torch.Tensor): mm_item.feature = pixel_values.to(self.device, non_blocking=True) - elif isinstance(pixel_values, CudaIpcTensorTransportProxy): - mm_item.feature = pixel_values.reconstruct_on_target_device( - torch.cuda.current_device() - ) - # The reference by CudaIpcTensorTransportProxy was cut off, - # proactively delete to avoid slow gc. - del pixel_values if get_global_server_args().language_only: precomputed_embeddings = getattr( mm_item, "precomputed_embeddings", None @@ -2248,6 +2251,7 @@ def merge_batch(self, other: "ScheduleBatch"): self.has_stream |= other.has_stream self.has_grammar |= other.has_grammar self.return_hidden_states |= other.return_hidden_states + self.is_prefill_only = self.is_prefill_only and other.is_prefill_only if self.spec_info: self.spec_info.merge_batch(other.spec_info) @@ -2365,13 +2369,6 @@ def maybe_evict_swa(self): sliding_window_size = self.tree_cache.sliding_window_size server_args = get_global_server_args() - if ( - self.forward_mode.is_decode() - and not server_args.disable_piecewise_cuda_graph - and not self.tree_cache.is_chunk_cache() - ): - return - for idx, req in enumerate(self.reqs): if self.forward_mode.is_decode(): # We set evict_swa condition here with two reasons: diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 3e6924807ce1..58c56a516ecf 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -38,7 +38,8 @@ from torch.distributed import barrier from sglang.jit_kernel.ngram_embedding import update_token_table -from sglang.srt.configs.model_config import ModelConfig +from sglang.srt.configs.model_config import ModelConfig, ModelImpl +from sglang.srt.constants import HEALTH_CHECK_RID_PREFIX from sglang.srt.constrained.grammar_manager import GrammarManager from sglang.srt.disaggregation.decode import ( DecodePreallocQueue, @@ -119,8 +120,6 @@ LoadLoRAAdapterReqOutput, OpenSessionReqInput, PauseGenerationReqInput, - PinPrefixReqInput, - PinPrefixReqOutput, ProfileReq, ReleaseMemoryOccupationReqInput, ResumeMemoryOccupationReqInput, @@ -141,7 +140,11 @@ UpdateWeightsFromIPCReqInput, UpdateWeightsFromTensorReqInput, ) -from sglang.srt.managers.mm_utils import init_mm_embedding_cache, unwrap_shm_features +from sglang.srt.managers.mm_utils import ( + has_shm_features, + init_mm_embedding_cache, + unwrap_shm_features, +) from sglang.srt.managers.multimodal_processor import get_mm_processor, import_processors from sglang.srt.managers.overlap_utils import FutureMap from sglang.srt.managers.prefill_delayer import ( @@ -182,6 +185,7 @@ from sglang.srt.mem_cache.radix_cache import RadixCache from sglang.srt.mem_cache.session_aware_cache import SessionAwareCache from sglang.srt.model_executor.forward_batch_info import ForwardMode, PPProxyTensors +from sglang.srt.model_loader.utils import get_resolved_model_impl from sglang.srt.multiplex.multiplexing_mixin import SchedulerMultiplexMixin from sglang.srt.observability.req_time_stats import ( real_time, @@ -207,10 +211,8 @@ get_available_gpu_memory, get_bool_env_var, get_int_env_var, - get_numa_node, is_mps, kill_itself_when_parent_died, - numa_bind_to_node, point_to_point_pyobj, require_mlp_sync, set_gpu_proc_affinity, @@ -224,6 +226,7 @@ get_tokenizer_from_processor, ) from sglang.srt.utils.network import get_zmq_socket +from sglang.srt.utils.numa_utils import get_numa_node_if_available, numa_bind_to_node from sglang.srt.utils.tensor_bridge import use_mlx from sglang.srt.utils.torch_memory_saver_adapter import TorchMemorySaverAdapter from sglang.utils import TypeBasedDispatcher, get_exception_traceback @@ -697,6 +700,9 @@ def init_model_worker(self): def init_cache_with_memory_pool(self): server_args = self.server_args + uses_transformers_backend = ( + get_resolved_model_impl(self.model_config) == ModelImpl.TRANSFORMERS + ) # Hybrid memory pool self.is_hybrid_swa = self.tp_worker.is_hybrid_swa @@ -716,9 +722,21 @@ def init_cache_with_memory_pool(self): self.tp_worker.get_memory_pool() ) - # Create cache + self.disable_radix_cache = server_args.disable_radix_cache or ( + self.model_config.is_multimodal and uses_transformers_backend + ) + if self.disable_radix_cache and not server_args.disable_radix_cache: + logger.warning( + "Radix cache is disabled for multimodal models with the " + "Transformers backend to avoid multimodal prefix-cache mismatches." + ) + + effective_chunked_prefill_size = server_args.chunked_prefill_size + if self.model_config.is_multimodal and uses_transformers_backend: + effective_chunked_prefill_size = None + params = CacheInitParams( - disable=server_args.disable_radix_cache, + disable=self.disable_radix_cache, req_to_token_pool=self.req_to_token_pool, token_to_kv_pool_allocator=self.token_to_kv_pool_allocator, page_size=self.page_size, @@ -734,14 +752,11 @@ def init_cache_with_memory_pool(self): enable_mamba_extra_buffer=server_args.enable_mamba_extra_buffer(), pp_rank=self.pp_rank, pp_size=self.pp_size, - chunked_prefill_size=server_args.chunked_prefill_size, + chunked_prefill_size=effective_chunked_prefill_size, sliding_window_size=self.sliding_window_size, ) - if ( - server_args.chunked_prefill_size is not None - and server_args.disable_radix_cache - ): + if effective_chunked_prefill_size is not None and self.disable_radix_cache: if not self.is_hybrid_swa: from sglang.srt.mem_cache.chunk_cache import ChunkCache @@ -842,9 +857,22 @@ def init_running_status(self): self._engine_paused = False def init_chunked_prefill(self): - # Init chunked prefill self.chunked_prefill_size = self.server_args.chunked_prefill_size - if self.chunked_prefill_size <= 0: # -1 means disable + uses_transformers_backend = ( + get_resolved_model_impl(self.model_config) == ModelImpl.TRANSFORMERS + ) + if ( + self.chunked_prefill_size is not None + and self.chunked_prefill_size > 0 + and self.model_config.is_multimodal + and uses_transformers_backend + ): + logger.warning( + "Chunked prefill is disabled for multimodal models with the " + "Transformers backend to avoid partial multimodal chunk mismatches." + ) + self.chunked_prefill_size = None + elif self.chunked_prefill_size is not None and self.chunked_prefill_size <= 0: self.chunked_prefill_size = None self.chunked_req = None self.is_mixed_chunk = ( @@ -1181,7 +1209,6 @@ def init_request_dispatcher(self): (ClearHiCacheReqInput, self.clear_hicache_storage_wrapped), (AttachHiCacheStorageReqInput, self.attach_hicache_storage_wrapped), (DetachHiCacheStorageReqInput, self.detach_hicache_storage_wrapped), - (PinPrefixReqInput, self.pin_prefix_wrapped), (AbortReq, self.abort_request), (OpenSessionReqInput, self.open_session), (CloseSessionReqInput, self.close_session), @@ -1255,19 +1282,6 @@ def get_init_info(self) -> Dict[str, Any]: "max_req_input_len": self.max_req_input_len, } - if self.server_args.remote_instance_weight_loader_use_transfer_engine(): - ( - remote_instance_transfer_engine_session_id, - remote_instance_transfer_engine_weights_info_dict, - ) = self.get_remote_instance_transfer_engine_info() - result_dict.update( - { - "tp_rank": self.tp_rank, - "remote_instance_transfer_engine_session_id": remote_instance_transfer_engine_session_id, - "remote_instance_transfer_engine_weights_info_dict": remote_instance_transfer_engine_weights_info_dict, - } - ) - return result_dict def run_event_loop(self) -> None: @@ -1424,7 +1438,6 @@ def recv_requests( if self.recv_limit_reached(len(recv_reqs)): break recv_req = self.recv_from_tokenizer.recv_pyobj(zmq.NOBLOCK) - recv_req = unwrap_shm_features(recv_req) except zmq.ZMQError: break recv_reqs.append(recv_req) @@ -1511,6 +1524,34 @@ def recv_requests( prepare_abort(req, error_msg, status_code=status_code) self.stream_output([req], req.return_logprob) + # Unwrap shared memory features AFTER all broadcasts complete, + # so that ShmPointerMMData metadata (not full tensor data) is what + # gets serialized during broadcast_pyobj. + if recv_reqs: + # Barrier for the non-DP-attention path only: there is a single + # broadcast_pyobj on tp_cpu_group where the source rank returns + # the original objects immediately while other ranks are still in + # pickle.loads (-> __setstate__ -> shm_open). Without a barrier + # the source can call materialize() / shm_unlink before others + # open the segment. recv_reqs is consistent across all ranks + # here (same broadcast), so the guard is deadlock-free. + # + # Under DP-attention no barrier is needed: the control_reqs + # broadcast on tp_cpu_group (step 3) is a collective that forces + # every rank to complete the earlier attn_tp / attn_cp work_reqs + # deserializations (steps 1-2, which call shm_open) before any + # rank returns from step 3. POSIX guarantees shm_unlink only + # removes the name; already-open handles stay valid. + if ( + not self.server_args.enable_dp_attention + and self.tp_size > 1 + and self.model_config.is_multimodal + and has_shm_features(recv_reqs) + ): + barrier(group=self.tp_cpu_group) + for req in recv_reqs: + unwrap_shm_features(req) + return recv_reqs def _split_work_and_control_reqs(self, recv_reqs: List): @@ -1709,6 +1750,7 @@ def handle_generate_request( stream=recv_req.stream, lora_id=recv_req.lora_id, input_embeds=recv_req.input_embeds, + token_type_ids=recv_req.token_type_ids, custom_logit_processor=recv_req.custom_logit_processor, require_reasoning=recv_req.require_reasoning, return_hidden_states=recv_req.return_hidden_states, @@ -1791,10 +1833,12 @@ def handle_generate_request( SessionController.adjust_mm_offsets(recv_req, req, image_inputs) # The following steps are already fast, execute locally on each rank. - # Expand a single image token into multiple dummy tokens for receiving image embeddings - req.origin_input_ids = self.pad_input_ids_func( - req.origin_input_ids, image_inputs - ) + # Expand a single image token into multiple dummy tokens for receiving image embeddings. + # The pad function is model-specific and can be None for some backends. + if self.pad_input_ids_func: + req.origin_input_ids = self.pad_input_ids_func( + req.origin_input_ids, image_inputs + ) req.extend_image_inputs(image_inputs) self._maybe_compute_mrope_positions(req) @@ -2147,6 +2191,7 @@ def get_next_batch_to_run(self) -> Optional[ScheduleBatch]: chunked_req_to_exclude.add(self.chunked_req) self.stash_chunked_request(self.chunked_req) + # HiSparse has its own prefill-to-decode transition; skip last_batch merge. if self.enable_hisparse: ready_reqs = self.hisparse_coordinator.collect_ready_reqs() if len(ready_reqs) > 0: @@ -2156,31 +2201,35 @@ def get_next_batch_to_run(self) -> Optional[ScheduleBatch]: else: self.running_batch.merge_batch(new_batch) self.running_batch.hisparse_coordinator = self.hisparse_coordinator - else: - if self.last_batch and self.last_batch.forward_mode.is_extend(): - if self.last_batch.chunked_req is not None: - # In the context pipeline parallelism, after the last chunk, the current microbatch still track outdated chunked_req. - # We need to discard it. - chunked_req_to_exclude.add(self.last_batch.chunked_req) - - if self.dllm_config is not None and self.last_batch.reqs: - chunked_req_to_exclude.update(self.last_batch.reqs) - - # Filter batch - last_bs = self.last_batch.batch_size() - self.last_batch.filter_batch( - chunked_req_to_exclude=list(chunked_req_to_exclude) - ) - if self.last_batch.batch_size() < last_bs: - self.running_batch.batch_is_full = False - # Merge the new batch into the running batch. - if not self.last_batch.is_empty(): - if self.running_batch.is_empty(): - self.running_batch = self.last_batch - else: - # Merge running_batch with prefill batch - self.running_batch.merge_batch(self.last_batch) + if ( + not self.enable_hisparse + and self.last_batch + and self.last_batch.forward_mode.is_extend() + ): + if self.last_batch.chunked_req is not None: + # In the context pipeline parallelism, after the last chunk, the current microbatch still track outdated chunked_req. + # We need to discard it. + chunked_req_to_exclude.add(self.last_batch.chunked_req) + + if self.dllm_config is not None and self.last_batch.reqs: + chunked_req_to_exclude.update(self.last_batch.reqs) + + # Filter batch + last_bs = self.last_batch.batch_size() + self.last_batch.filter_batch( + chunked_req_to_exclude=list(chunked_req_to_exclude) + ) + if self.last_batch.batch_size() < last_bs: + self.running_batch.batch_is_full = False + + # Merge the new batch into the running batch. + if not self.last_batch.is_empty(): + if self.running_batch.is_empty(): + self.running_batch = self.last_batch + else: + # Merge running_batch with prefill batch + self.running_batch.merge_batch(self.last_batch) # For prefill-only batch, filter out finished requests since they # won't go through the decode step. This keeps running_batch accurate @@ -2410,6 +2459,13 @@ def _get_new_batch_prefill_raw( ) > 0 or (not self.running_batch.is_empty()) else: self.running_batch.batch_is_full = True + # revert matched mamba idx to avoid memory leak, if req is not added + added = len(adder.can_run_list) > 0 and req is adder.can_run_list[-1] + if not added and req.mamba_pool_idx is not None: + self.tree_cache.req_to_token_pool.mamba_pool.free( + req.mamba_pool_idx.unsqueeze(-1) + ) + req.mamba_pool_idx = None break # Update waiting queue @@ -2979,37 +3035,6 @@ def detach_hicache_storage_wrapped( return DetachHiCacheStorageReqOutput(success=False, message=msg) - def pin_prefix_wrapped(self, recv_req: PinPrefixReqInput): - if not hasattr(self.tree_cache, "pin_prefix"): - return PinPrefixReqOutput( - success=False, - nodes_pinned=0, - message="PIN requires --enable-hierarchical-cache", - ) - if getattr(self.tree_cache, "_max_pinned_tokens", 0) <= 0: - return PinPrefixReqOutput( - success=False, - nodes_pinned=0, - message="Pinning is disabled (SGLANG_HICACHE_MAX_PINNED_RATIO is 0)", - ) - nodes_pinned, reject_reason = self.tree_cache.pin_prefix( - recv_req.token_ids, recv_req.ttl_seconds - ) - if nodes_pinned == 0: - return PinPrefixReqOutput( - success=False, - nodes_pinned=0, - message=reject_reason or "No matching prefix found in cache to pin", - ) - msg = f"Pinned {nodes_pinned} nodes (ttl={recv_req.ttl_seconds}s)" - if reject_reason: - msg += f"; {reject_reason}" - return PinPrefixReqOutput( - success=True, - nodes_pinned=nodes_pinned, - message=msg, - ) - def flush_cache(self): """Flush the memory pool and cache.""" if self.is_fully_idle(): @@ -3233,6 +3258,16 @@ def _pause_engine(self) -> Tuple[List[Req], int]: def pause_generation(self, recv_req: PauseGenerationReqInput): self._engine_paused = True + if recv_req.mode == "in_place": + # In-place pause: just set the flag and return immediately. + # All scheduler state (running_batch, last_batch, chunked_req, + # result_queue) is left untouched. On resume, the normal event + # loop (get_next_batch_to_run) handles last_batch merge, + # chunked_req cleanup, and overlap result processing through + # the standard code paths. This avoids duplicating batch + # manipulation logic and the accounting bugs that come with it. + return + if self.enable_overlap and self.last_batch: # Process the results of the last batch tmp_batch, tmp_result = self.result_queue.popleft() @@ -3240,13 +3275,14 @@ def pause_generation(self, recv_req: PauseGenerationReqInput): if self.last_batch and self.last_batch.forward_mode.is_extend(): chunked_req_to_exclude = set() - if recv_req.mode == "in_place": - if self.chunked_req is not None: - chunked_req_to_exclude.add(self.chunked_req) self.last_batch.filter_batch( chunked_req_to_exclude=list(chunked_req_to_exclude) ) - self.running_batch.merge_batch(self.last_batch) + if not self.last_batch.is_empty(): + if self.running_batch.is_empty(): + self.running_batch = self.last_batch + else: + self.running_batch.merge_batch(self.last_batch) self.last_batch = None self.cur_batch = None @@ -3367,9 +3403,6 @@ def update_cache_from_scheduler( ): pass - def get_remote_instance_transfer_engine_info(self): - return self.tp_worker.get_remote_instance_transfer_engine_info() - class IdleSleeper: """ @@ -3403,7 +3436,7 @@ def maybe_sleep(self): def is_health_check_generate_req(recv_req): rid = getattr(recv_req, "rid", None) - return rid is not None and rid.startswith("HEALTH_CHECK") + return rid is not None and rid.startswith(HEALTH_CHECK_RID_PREFIX) def is_work_request(recv_req): @@ -3538,12 +3571,7 @@ def run_scheduler_process( set_gpu_proc_affinity( server_args.pp_size, server_args.tp_size, server_args.nnodes, gpu_id ) - numa_node = None - if (numa_nodes := server_args.numa_node) is not None: - numa_node = numa_nodes[gpu_id] - elif envs.SGLANG_AUTO_NUMA_BIND.get(): - numa_node = get_numa_node(gpu_id) - logger.info(f"auto get NUMA node {numa_node} for GPU {gpu_id}") + numa_node = get_numa_node_if_available(server_args, gpu_id) if numa_node is not None and not envs.SGLANG_NUMA_BIND_V2.get(): numa_bind_to_node(numa_node) diff --git a/python/sglang/srt/managers/scheduler_output_processor_mixin.py b/python/sglang/srt/managers/scheduler_output_processor_mixin.py index 864acfcc9b1b..e3b8dc87e26d 100644 --- a/python/sglang/srt/managers/scheduler_output_processor_mixin.py +++ b/python/sglang/srt/managers/scheduler_output_processor_mixin.py @@ -55,15 +55,10 @@ def _get_cached_tokens_details(self, req: Req) -> Optional[dict]: """Get detailed cache breakdown for a request, if available. Returns: - - None if HiCache is not enabled - - {"device": X, "host": Y} if HiCache enabled but L3 storage is not - - {"device": X, "host": Y, "storage": Z, "storage_backend": "..."} if L3 enabled + - None if no cached tokens at all + - {"device": X, "host": Y} without storage breakdown + - {"device": X, "host": Y, "storage": Z} with storage breakdown """ - # Only show details if HiCache is enabled - if not getattr(self, "enable_hierarchical_cache", False): - return None - - # Only show if there are any cached tokens if ( req.cached_tokens_device > 0 or req.cached_tokens_host > 0 @@ -78,6 +73,13 @@ def _get_cached_tokens_details(self, req: Req) -> Optional[dict]: details["storage"] = req.cached_tokens_storage details["storage_backend"] = self._get_storage_backend_type() return details + + if req.cached_tokens > 0: + return { + "device": req.cached_tokens, + "host": 0, + } + return None def process_batch_result_prebuilt(self: Scheduler, batch: ScheduleBatch): @@ -152,6 +154,18 @@ def process_batch_result_prefill( logits_output.input_token_logprobs = tuple( logits_output.input_token_logprobs.tolist() ) + if logits_output.next_token_top_logprobs_val: + logits_output.next_token_top_logprobs_val = [ + v.tolist() for v in logits_output.next_token_top_logprobs_val + ] + logits_output.next_token_top_logprobs_idx = [ + x.tolist() for x in logits_output.next_token_top_logprobs_idx + ] + if logits_output.next_token_token_ids_logprobs_val: + logits_output.next_token_token_ids_logprobs_val = [ + v.tolist() + for v in logits_output.next_token_token_ids_logprobs_val + ] hidden_state_offset = 0 @@ -377,7 +391,7 @@ def process_batch_result_decode( if batch.return_logprob: next_token_logprobs = logits_output.next_token_logprobs.tolist() - if batch.is_spec_v2 and logits_output.next_token_top_logprobs_val: + if logits_output.next_token_top_logprobs_val: logits_output.next_token_top_logprobs_val = [ v.tolist() for v in logits_output.next_token_top_logprobs_val ] @@ -385,7 +399,7 @@ def process_batch_result_decode( x.tolist() for x in logits_output.next_token_top_logprobs_idx ] - if batch.is_spec_v2 and logits_output.next_token_token_ids_logprobs_val: + if logits_output.next_token_token_ids_logprobs_val: logits_output.next_token_token_ids_logprobs_val = [ v.tolist() for v in logits_output.next_token_token_ids_logprobs_val @@ -941,10 +955,6 @@ def stream_output_generation( if req is skip_req: continue - # Multimodal partial stream chunks break the detokenizer, so drop aborted requests here. - if self.model_config.is_multimodal_gen and req.to_finish: - continue - if req.finished(): if req.finished_output: # With the overlap schedule, a request will try to output twice and hit this line twice @@ -963,8 +973,7 @@ def stream_output_generation( # origin stream_interval logic should_output = ( len(req.output_ids) % stream_interval == 1 - if not self.model_config.is_multimodal_gen - and stream_interval > 1 + if stream_interval > 1 else len(req.output_ids) % stream_interval == 0 ) @@ -974,8 +983,6 @@ def stream_output_generation( else: should_output = ( len(req.output_ids) % DEFAULT_FORCE_STREAM_INTERVAL == 0 - if not self.model_config.is_multimodal_gen - else False ) if should_output: @@ -991,10 +998,7 @@ def stream_output_generation( decoded_texts.append(req.decoded_text) decode_ids, read_offset = req.init_incremental_detokenize() - if self.model_config.is_multimodal_gen: - decode_ids_list.append(decode_ids) - else: - decode_ids_list.append(decode_ids[req.send_decode_id_offset :]) + decode_ids_list.append(decode_ids[req.send_decode_id_offset :]) # Exclude the tokens after stop condition output_ids_ = req.output_ids_through_stop @@ -1030,6 +1034,8 @@ def stream_output_generation( and not req.input_logprob_sent # Decode server does not send input logprobs and self.disaggregation_mode != DisaggregationMode.DECODE + # Only send when input logprobs have been computed (after prefill) + and req.input_token_logprobs_val is not None ): input_token_logprobs_val.append(req.input_token_logprobs_val) input_token_logprobs_idx.append(req.input_token_logprobs_idx) @@ -1051,39 +1057,38 @@ def stream_output_generation( input_token_ids_logprobs_idx.append([]) if req.return_logprob: + logprob_end = max(len(output_ids_), 1) output_token_logprobs_val.append( req.output_token_logprobs_val[ - send_output_token_logprobs_offset: + send_output_token_logprobs_offset:logprob_end ] ) output_token_logprobs_idx.append( req.output_token_logprobs_idx[ - send_output_token_logprobs_offset: + send_output_token_logprobs_offset:logprob_end ] ) output_top_logprobs_val.append( req.output_top_logprobs_val[ - send_output_token_logprobs_offset: + send_output_token_logprobs_offset:logprob_end ] ) output_top_logprobs_idx.append( req.output_top_logprobs_idx[ - send_output_token_logprobs_offset: + send_output_token_logprobs_offset:logprob_end ] ) output_token_ids_logprobs_val.append( req.output_token_ids_logprobs_val[ - send_output_token_logprobs_offset: + send_output_token_logprobs_offset:logprob_end ] ) output_token_ids_logprobs_idx.append( req.output_token_ids_logprobs_idx[ - send_output_token_logprobs_offset: + send_output_token_logprobs_offset:logprob_end ] ) - req.send_output_token_logprobs_offset = len( - req.output_token_logprobs_val - ) + req.send_output_token_logprobs_offset = logprob_end else: output_token_logprobs_val.append([]) output_token_logprobs_idx.append([]) @@ -1120,8 +1125,6 @@ def stream_output_generation( # Send to detokenizer if reqs or is_idle_batch: - if self.model_config.is_multimodal_gen: - return self.send_to_detokenizer.send_output( BatchTokenIDOutput( rids=rids, diff --git a/python/sglang/srt/managers/scheduler_runtime_checker_mixin.py b/python/sglang/srt/managers/scheduler_runtime_checker_mixin.py index 8d01f7792583..113073e3bd93 100644 --- a/python/sglang/srt/managers/scheduler_runtime_checker_mixin.py +++ b/python/sglang/srt/managers/scheduler_runtime_checker_mixin.py @@ -358,6 +358,8 @@ def check_tree_cache(self: Scheduler): self.tree_cache.sanity_check() def self_check_during_idle(self: Scheduler): + if self.enable_hisparse and self.hisparse_coordinator.has_ongoing_staging(): + return if self.disaggregation_mode == DisaggregationMode.PREFILL: if len(self.disagg_prefill_inflight_queue) > 0: return @@ -371,9 +373,6 @@ def self_check_during_idle(self: Scheduler): queue_size += len(self.decode_offload_manager.ongoing_offload) if queue_size: return - elif self.enable_hisparse: - if self.hisparse_coordinator.has_ongoing_staging(): - return self.check_memory() self.check_tree_cache() diff --git a/python/sglang/srt/managers/tokenizer_communicator_mixin.py b/python/sglang/srt/managers/tokenizer_communicator_mixin.py index 087337684081..544c6094014c 100644 --- a/python/sglang/srt/managers/tokenizer_communicator_mixin.py +++ b/python/sglang/srt/managers/tokenizer_communicator_mixin.py @@ -59,8 +59,6 @@ LoadLoRAAdapterReqOutput, LoRAUpdateOutput, OpenSessionReqInput, - PinPrefixReqInput, - PinPrefixReqOutput, ProfileReq, ProfileReqOutput, ProfileReqType, @@ -216,9 +214,6 @@ def init_communicators(self: TokenizerManager, server_args: ServerArgs): self.detach_hicache_storage_communicator = _Communicator( self.send_to_scheduler, server_args.dp_size ) - self.pin_prefix_communicator = _Communicator( - self.send_to_scheduler, server_args.dp_size - ) self.profile_communicator = _Communicator( self.send_to_scheduler, server_args.dp_size ) @@ -309,10 +304,6 @@ def _get_communicator_dispatcher(self: TokenizerManager): DetachHiCacheStorageReqOutput, self.detach_hicache_storage_communicator.handle_recv, ), - ( - PinPrefixReqOutput, - self.pin_prefix_communicator.handle_recv, - ), ( FlushCacheReqOutput, self.flush_cache_communicator.handle_recv, @@ -421,19 +412,6 @@ async def detach_hicache_storage( self.server_args.hicache_storage_backend_extra_config = None return out - async def pin_prefix( - self: TokenizerManager, token_ids: List[int], ttl_seconds: int = 300 - ) -> PinPrefixReqOutput: - """Pin a prefix by token_ids to resist eviction.""" - results = await self.pin_prefix_communicator( - PinPrefixReqInput(token_ids=token_ids, ttl_seconds=ttl_seconds) - ) - all_success, all_message = _Communicator.merge_results(results) - total = sum(r.nodes_pinned for r in results) - return PinPrefixReqOutput( - success=all_success, nodes_pinned=total, message=all_message - ) - async def start_profile( self: TokenizerManager, output_dir: Optional[str] = None, diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 1c1efd92592c..63714087ad04 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -32,6 +32,7 @@ from typing import Any, Awaitable, Dict, List, Optional, Tuple, Union import fastapi +import pybase64 import uvloop import zmq import zmq.asyncio @@ -43,13 +44,11 @@ from sglang.srt.environ import envs from sglang.srt.lora.lora_registry import LoRARef, LoRARegistry from sglang.srt.managers.async_dynamic_batch_tokenizer import AsyncDynamicbatchTokenizer -from sglang.srt.managers.async_mm_data_processor import AsyncMMDataProcessor from sglang.srt.managers.disagg_service import start_disagg_service from sglang.srt.managers.io_struct import ( AbortReq, ActiveRanksOutput, BatchEmbeddingOutput, - BatchMultimodalOutput, BatchStrOutput, BatchTokenIDOutput, BatchTokenizedEmbeddingReqInput, @@ -132,12 +131,15 @@ class ReqState: finished: bool event: asyncio.Event obj: Union[GenerateReqInput, EmbeddingReqInput] + + # For performance metrics time_stats: APIServerReqTimeStats last_completion_tokens: int = 1 ttft_observed: bool = False # For streaming output last_output_offset: int = 0 + last_text_offset: int = 0 # For incremental state update. # TODO(lianmin): do not initialize some lists if not needed. @@ -212,12 +214,12 @@ def __init__( # Init PD disaggregation and encoder disaggregation self.init_disaggregation() + # Subprocess liveness watchdog — set by Engine or http_server after construction + self._subprocess_watchdog = None + # Init metric collector and watchdog self.init_metric_collector_watchdog() - if self.enable_metrics: - start_cpu_monitor_thread("tokenizer") - # Init request dispatcher self.init_request_dispatcher() @@ -230,7 +232,6 @@ def init_model_config(self): self.served_model_name = server_args.served_model_name self.model_config = model_config_class.from_server_args(server_args) self.is_generation = self.model_config.is_generation - self.is_image_gen = self.model_config.is_image_gen self.context_len = self.model_config.context_len self.image_token_id = self.model_config.image_token_id self.max_req_input_len = None # Will be set later in engine.py @@ -265,12 +266,11 @@ def init_tokenizer_and_processor(self): # We create mm_processor for any skip_tokenizer_init to make sure we still encode # images even with skip_tokenizer_init=False. self.mm_processor = get_mm_processor( - self.model_config.hf_config, server_args, _processor, transport_mode - ) - self.mm_data_processor = AsyncMMDataProcessor( - self.mm_processor, - max_concurrent_calls=self.server_args.mm_max_concurrent_calls, - timeout_s=self.server_args.mm_per_request_timeout, + self.model_config.hf_config, + server_args, + _processor, + transport_mode, + model_config=self.model_config, ) if server_args.skip_tokenizer_init: @@ -338,10 +338,6 @@ def init_running_status(self): self.gracefully_exit = False self.last_receive_tstamp = real_time() - # For load balancing - self.current_load = 0 - self.current_load_lock = asyncio.Lock() - # Session self.session_futures = {} # session_id -> asyncio event @@ -440,6 +436,8 @@ def init_metric_collector_watchdog(self): collect_tokens_histogram=self.server_args.collect_tokens_histogram, ) + start_cpu_monitor_thread("tokenizer") + if self.server_args.gc_warning_threshold_secs > 0.0: configure_gc_warning(self.server_args.gc_warning_threshold_secs) self.soft_watchdog = Watchdog.create( @@ -457,7 +455,6 @@ def init_request_dispatcher(self): BatchStrOutput, BatchEmbeddingOutput, BatchTokenIDOutput, - BatchMultimodalOutput, ), self._handle_batch_output, ), @@ -488,7 +485,7 @@ async def generate_request( # Normalize the request obj.normalize_batch_and_arguments() self._set_default_priority(obj) - self._validate_rid(obj) + self._validate_rid_not_in_flight(obj) if isinstance(obj, GenerateReqInput) and obj.routed_dp_rank is not None: dp_size = self.server_args.dp_size @@ -730,10 +727,10 @@ async def _tokenize_one_request( need_wait_for_mm_inputs=obj.need_wait_for_mm_inputs, ) if mm_inputs is None: - mm_inputs: Dict = await self.mm_data_processor.process( + mm_inputs: Dict = await self.mm_processor.process_mm_data_async( image_data=obj.image_data, audio_data=obj.audio_data, - input_text_or_ids=(input_text or input_ids), + input_text=(input_text or input_ids), request_obj=obj, max_req_input_len=self.max_req_input_len, ) @@ -744,16 +741,20 @@ async def _tokenize_one_request( ): # In language_only mode with zmq_to_scheduler, if we didn't dispatch # to encoder (e.g., only one image), process locally like non-language_only mode - mm_inputs: Dict = await self.mm_data_processor.process( + mm_inputs: Dict = await self.mm_processor.process_mm_data_async( image_data=obj.image_data, audio_data=obj.audio_data, - input_text_or_ids=(input_text or input_ids), + input_text=(input_text or input_ids), request_obj=obj, max_req_input_len=self.max_req_input_len, ) if mm_inputs and "input_ids" in mm_inputs: input_ids = mm_inputs["input_ids"] + if mm_inputs and "token_type_ids" in mm_inputs: + token_type_ids = mm_inputs.pop("token_type_ids") + if not isinstance(token_type_ids, list): + token_type_ids = token_type_ids.flatten().tolist() if ( envs.SGLANG_MM_PRECOMPUTE_HASH.get() and mm_inputs @@ -770,20 +771,16 @@ async def _tokenize_one_request( obj, input_text, input_ids, input_embeds, mm_inputs, token_type_ids ) - def _validate_rid(self, obj: Union[GenerateReqInput, EmbeddingReqInput]) -> None: - """Validate the request ID (rid) uniqueness.""" - rid = obj.rid - if rid is None: + def _validate_rid_not_in_flight( + self, obj: Union[GenerateReqInput, EmbeddingReqInput] + ) -> None: + """Validate that request IDs are not already in flight.""" + if obj.rid is None: return - ids = rid if isinstance(rid, list) else [rid] - if len(ids) != len(set(ids)): - raise ValueError( - f"Duplicate request IDs detected within the request: {ids}" - ) - - for i in ids: - if i in self.rid_to_state: - raise ValueError(f"Duplicate request ID detected: {i}") + rids = obj.rid if isinstance(obj.rid, list) else [obj.rid] + conflicts = set(rids) & self.rid_to_state.keys() + if conflicts: + raise ValueError(f"Duplicate request IDs detected: {list(conflicts)}") def _validate_one_request( self, obj: Union[GenerateReqInput, EmbeddingReqInput], input_ids: List[int] @@ -982,6 +979,7 @@ def _create_tokenized_object( priority=obj.priority, extra_key=obj.extra_key, routing_key=obj.routing_key, + token_type_ids=token_type_ids, need_wait_for_mm_inputs=obj.need_wait_for_mm_inputs, num_items_assigned=obj.num_items_assigned, ) @@ -1147,90 +1145,109 @@ async def _wait_one_response( ) continue - # Drain all pending outputs atomically. For streaming, every - # chunk must be yielded to avoid dropping token deltas. For - # non-streaming only the latest cumulative output matters. - pending = state.out_list if is_stream else state.out_list[-1:] + # Drain all pending outputs atomically. + # With incremental streaming output, each chunk carries only a + # delta, so every queued chunk must be yielded to avoid dropping + # token ids. Without it, outputs are cumulative and only the + # latest chunk contains the full result, so we can safely skip + # intermediate ones. + incremental_stream = ( + is_stream and self.server_args.incremental_streaming_output + ) + out_list = state.out_list state.out_list = [] finished = state.finished state.event.clear() - for i, out in enumerate(pending): - is_last = i == len(pending) - 1 - - if finished and is_last: - # For non-streaming cases, response has not been sent yet (`response_sent_to_client_time` has not been set yet). - # Record response sent time right before we log finished results and metrics. - if not state.time_stats.response_sent_to_client_time: - state.time_stats.set_response_sent_to_client_time() - out["meta_info"][ - "response_sent_to_client_ts" - ] = state.time_stats.get_response_sent_to_client_realtime() - self.request_logger.log_finished_request( - obj, - out, - is_multimodal_gen=self.model_config.is_multimodal_gen, - request=request, + if incremental_stream and len(out_list) > 1: + if len(out_list) >= 20: + logger.warning( + "Streaming backlog: rid=%s, coalescing %d queued chunks into one. " + "This may inflate P99 ITL for affected requests.", + obj.rid, + len(out_list), ) + # Coalesce all deltas into a single chunk. Both text and + # output_ids are incremental, so we concatenate them; all + # other fields (meta_info, etc.) are taken from the last chunk. + out = dict(out_list[-1]) + if "output_ids" in out: + out["output_ids"] = [ + id for chunk in out_list for id in chunk["output_ids"] + ] + if "text" in out: + out["text"] = "".join(chunk["text"] for chunk in out_list) + else: + out = out_list[-1] - if self.request_metrics_exporter_manager.exporter_enabled(): - # Asynchronously write metrics for this request using the exporter manager. - asyncio.create_task( - self.request_metrics_exporter_manager.write_record(obj, out) - ) + if finished: + # For non-streaming cases, response has not been sent yet (`response_sent_to_client_time` has not been set yet). + # Record response sent time right before we log finished results and metrics. + if not state.time_stats.response_sent_to_client_time: + state.time_stats.set_response_sent_to_client_time() + out["meta_info"][ + "response_sent_to_client_ts" + ] = state.time_stats.get_response_sent_to_client_realtime() + self.request_logger.log_finished_request( + obj, + out, + request=request, + ) - # Check if this was an abort/error created by scheduler - if isinstance(out["meta_info"].get("finish_reason"), dict): - finish_reason = out["meta_info"]["finish_reason"] - if ( - finish_reason.get("type") == "abort" - and finish_reason.get("status_code") - == HTTPStatus.BAD_REQUEST - ): - if not is_stream: - raise ValueError(finish_reason["message"]) - else: - yield out - break - - if finish_reason.get("type") == "abort" and finish_reason.get( - "status_code" - ) in ( - HTTPStatus.SERVICE_UNAVAILABLE, - HTTPStatus.INTERNAL_SERVER_ERROR, - ): - # This is an abort request initiated by scheduler. - # Delete the key to prevent resending abort request to the scheduler and - # to ensure aborted request state is cleaned up. - if state.obj.rid in self.rid_to_state: - del self.rid_to_state[state.obj.rid] - - # Mark ongoing LoRA request as finished. - if self.server_args.enable_lora and state.obj.lora_path: - await self.lora_registry.release(state.obj.lora_id) - if not is_stream: - raise fastapi.HTTPException( - status_code=finish_reason["status_code"], - detail=finish_reason["message"], - ) - else: - yield out - break - yield out - break - - if is_stream: - # Record response sent time right before we send response. - if not state.time_stats.response_sent_to_client_time: - state.time_stats.set_response_sent_to_client_time() - out["meta_info"][ - "response_sent_to_client_ts" - ] = state.time_stats.get_response_sent_to_client_realtime() - yield out + if self.request_metrics_exporter_manager.exporter_enabled(): + # Asynchronously write metrics for this request using the exporter manager. + asyncio.create_task( + self.request_metrics_exporter_manager.write_record(obj, out) + ) - if finished: + # Check if this was an abort/error created by scheduler + if isinstance(out["meta_info"].get("finish_reason"), dict): + finish_reason = out["meta_info"]["finish_reason"] + if ( + finish_reason.get("type") == "abort" + and finish_reason.get("status_code") == HTTPStatus.BAD_REQUEST + ): + if not is_stream: + raise ValueError(finish_reason["message"]) + else: + yield out + break + + if finish_reason.get("type") == "abort" and finish_reason.get( + "status_code" + ) in ( + HTTPStatus.SERVICE_UNAVAILABLE, + HTTPStatus.INTERNAL_SERVER_ERROR, + ): + # This is an abort request initiated by scheduler. + # Delete the key to prevent resending abort request to the scheduler and + # to ensure aborted request state is cleaned up. + if state.obj.rid in self.rid_to_state: + del self.rid_to_state[state.obj.rid] + + # Mark ongoing LoRA request as finished. + if self.server_args.enable_lora and state.obj.lora_path: + await self.lora_registry.release(state.obj.lora_id) + if not is_stream: + raise fastapi.HTTPException( + status_code=finish_reason["status_code"], + detail=finish_reason["message"], + ) + else: + yield out + break + yield out break + if is_stream: + # Record response sent time right before we send response. + if not state.time_stats.response_sent_to_client_time: + state.time_stats.set_response_sent_to_client_time() + out["meta_info"][ + "response_sent_to_client_ts" + ] = state.time_stats.get_response_sent_to_client_realtime() + yield out + if not is_stream: if ( request is not None @@ -1517,7 +1534,6 @@ def _handle_batch_output( recv_obj: Union[ BatchStrOutput, BatchEmbeddingOutput, - BatchMultimodalOutput, BatchTokenIDOutput, ], ): @@ -1574,7 +1590,11 @@ def _handle_batch_output( if getattr(recv_obj, "output_hidden_states", None): meta_info["hidden_states"] = recv_obj.output_hidden_states[i] if getattr(recv_obj, "routed_experts", None): - meta_info["routed_experts"] = recv_obj.routed_experts[i] + routed_experts_tensor = recv_obj.routed_experts[i] + if routed_experts_tensor is not None: + meta_info["routed_experts"] = pybase64.b64encode( + routed_experts_tensor.numpy().tobytes() + ).decode("utf-8") if getattr(recv_obj, "customized_info", None): for k, v in recv_obj.customized_info.items(): meta_info[k] = v[i] @@ -1589,12 +1609,15 @@ def _handle_batch_output( state.output_ids.extend(recv_obj.output_ids[i]) output_token_ids = state.output_ids[state.last_output_offset :] state.last_output_offset = len(state.output_ids) + output_text = state.text[state.last_text_offset :] + state.last_text_offset = len(state.text) else: state.output_ids.extend(recv_obj.output_ids[i]) output_token_ids = state.output_ids.copy() + output_text = state.text out_dict = { - "text": state.text, + "text": output_text, "output_ids": output_token_ids, "meta_info": meta_info, } @@ -1613,8 +1636,6 @@ def _handle_batch_output( "output_ids": output_token_ids, "meta_info": meta_info, } - elif isinstance(recv_obj, BatchMultimodalOutput): - raise NotImplementedError("BatchMultimodalOut not implemented") else: assert isinstance(recv_obj, BatchEmbeddingOutput) out_dict = { @@ -1883,7 +1904,6 @@ def _calculate_spec_decoding_metrics( recv_obj: Union[ BatchStrOutput, BatchEmbeddingOutput, - BatchMultimodalOutput, BatchTokenIDOutput, ], i: int, @@ -2289,14 +2309,14 @@ def _req_stats_init( external_trace_header = None if self.server_args.enable_trace: - if request: - external_trace_header = extract_trace_headers(request.headers) - obj.external_trace_header = external_trace_header - elif obj.external_trace_header: - # When the request comes form the rust grpc server or Engine there isn't a + if obj.external_trace_header: + # When the request comes from the rust grpc server or Engine there isn't a # real request object but we still need to propagate the trace context from # the trace context that is explicitly passed in external_trace_header = obj.external_trace_header + elif request: + external_trace_header = extract_trace_headers(request.headers) + obj.external_trace_header = external_trace_header if not hasattr(obj, "is_single") or obj.is_single: time_stats = APIServerReqTimeStats(disagg_mode=self.disaggregation_mode) @@ -2393,7 +2413,6 @@ def convert_to_span_attrs( recv_obj: Union[ BatchStrOutput, BatchEmbeddingOutput, - BatchMultimodalOutput, BatchTokenIDOutput, ], i: int, @@ -2405,9 +2424,10 @@ def convert_to_span_attrs( return span_attrs # Token usage attributes - span_attrs[SpanAttributes.GEN_AI_USAGE_COMPLETION_TOKENS] = ( - recv_obj.completion_tokens[i] - ) + if not isinstance(recv_obj, BatchEmbeddingOutput): + span_attrs[SpanAttributes.GEN_AI_USAGE_COMPLETION_TOKENS] = ( + recv_obj.completion_tokens[i] + ) span_attrs[SpanAttributes.GEN_AI_USAGE_PROMPT_TOKENS] = recv_obj.prompt_tokens[ i ] @@ -2539,6 +2559,10 @@ def running_phase_sigquit_handler(self, signum=None, frame=None): logger.error( f"SIGQUIT received. {signum=}, {frame=}. It usually means one child failed." ) + # Stop subprocess watchdog before killing processes to prevent false-positive + # crash detection during normal shutdown + if self.tokenizer_manager._subprocess_watchdog is not None: + self.tokenizer_manager._subprocess_watchdog.stop() self.tokenizer_manager.dump_requests_before_crash() kill_process_tree(os.getpid()) diff --git a/python/sglang/srt/managers/tp_worker.py b/python/sglang/srt/managers/tp_worker.py index 46a87aeb8762..7f63610da8ee 100644 --- a/python/sglang/srt/managers/tp_worker.py +++ b/python/sglang/srt/managers/tp_worker.py @@ -441,12 +441,6 @@ def _forward_batch_generation_dllm( can_run_cuda_graph=can_run_cuda_graph, ) - def get_remote_instance_transfer_engine_info(self): - return ( - self.model_runner.remote_instance_transfer_engine_session_id, - self.model_runner.remote_instance_transfer_engine_weight_info, - ) - def forward_batch_generation( self, model_worker_batch: ModelWorkerBatch, diff --git a/python/sglang/srt/mem_cache/hi_mamba_radix_cache.py b/python/sglang/srt/mem_cache/hi_mamba_radix_cache.py index f96c5ac8b477..a08b2a33ca78 100644 --- a/python/sglang/srt/mem_cache/hi_mamba_radix_cache.py +++ b/python/sglang/srt/mem_cache/hi_mamba_radix_cache.py @@ -47,7 +47,6 @@ split_node_hash_value, ) from sglang.srt.observability.metrics_collector import StorageMetricsCollector -from sglang.srt.utils import bind_to_closest_numa_node_cuda if TYPE_CHECKING: from sglang.srt.mem_cache.cache_init_params import CacheInitParams @@ -104,9 +103,6 @@ def __init__(self, params: CacheInitParams, server_args: ServerArgs): "switching to page first direct layout" ) - if not server_args.disable_hicache_numa_detect: - bind_to_closest_numa_node_cuda() - self.page_size = params.page_size self.hybrid_kv_cache = params.token_to_kv_pool_allocator.get_kvcache() if not isinstance(self.hybrid_kv_cache, HybridLinearKVPool): @@ -300,7 +296,7 @@ def write_backup(self, node: TreeNode, write_back=False): extra_pools=extra_pools, ) if host_indices is not None: - node.host_value = host_indices + node.host_value = host_indices.clone() if extra_pools is not None: self.mamba_backup_commit(node, extra_pools) assert len(node.host_value) > 0 @@ -515,13 +511,22 @@ def check_hicache_events(self): self.cache_controller.storage_backend.get_stats() ) - def _protect_host_node(self, node: TreeNode): + def _protect_host_node(self, node: TreeNode, protect_mamba: bool = True): node.protect_host() self.evictable_full_host_leaves.discard(node) + if protect_mamba: + node.protect_host_mamba() + if self.mamba_host_lru_list.in_list(node): + self.mamba_host_lru_list.remove_node(node) - def _release_host_node(self, node: TreeNode): + def _release_host_node(self, node: TreeNode, release_mamba: bool = True): node.release_host() - if node.host_ref_counter == 0: + if release_mamba: + node.release_host_mamba() + if node.host_mamba_ref_counter == 0 and node.mamba_host_value is not None: + if not self.mamba_host_lru_list.in_list(node): + self.mamba_host_lru_list.insert_mru(node) + if node.host_ref_counter == 0 and node.host_mamba_ref_counter == 0: self._update_full_host_leaf_status(node) def _discard_from_leaf_sets(self, node: TreeNode): @@ -548,6 +553,7 @@ def _update_full_host_leaf_status(self, node: TreeNode): or not node.backuped or node == self.root_node or node.host_ref_counter > 0 + or node.host_mamba_ref_counter > 0 ): self.evictable_full_host_leaves.discard(node) return @@ -636,7 +642,10 @@ def _evict_host_leaf(self, node: TreeNode) -> int: assert node.mamba_value is None, f"has device mamba, {node.id=}" assert ( node.host_ref_counter == 0 - ), f"in use, {node.id=} {node.host_ref_counter=}" + ), f"host kv in use, {node.id=} {node.host_ref_counter=}" + assert ( + node.host_mamba_ref_counter == 0 + ), f"host mamba in use, {node.id=} {node.host_mamba_ref_counter=}" full_num_evicted = self.cache_controller.evict_host(node.host_value) node.host_value = None @@ -669,7 +678,11 @@ def _delete_tombstone_leaf(self, node: TreeNode) -> None: self._discard_from_leaf_sets(node) - if node.backuped and node.host_ref_counter == 0: + if ( + node.backuped + and node.host_ref_counter == 0 + and node.host_mamba_ref_counter == 0 + ): self.cache_controller.evict_host(node.host_value) node.host_value = None @@ -786,15 +799,21 @@ def evict_mamba_host(self, num_mamba_hosts: int) -> int: num_evicted = 0 while num_evicted < num_mamba_hosts and self.mamba_host_lru_list.in_list(x): x_next = self.mamba_host_lru_list.get_prev_no_lock(x) - if x.host_ref_counter > 0: - x = x_next - continue - if x in self.evictable_full_host_leaves: + # Leaf: evictable_full_host_leaves guarantees both counters == 0 + assert ( + x.host_ref_counter == 0 + ), f"evict host leaf: host_ref_counter != 0 with {x.id=} {x.host_ref_counter=}" + assert ( + x.host_mamba_ref_counter == 0 + ), f"evict host leaf: host_mamba_ref_counter != 0 with {x.id=} {x.host_mamba_ref_counter=}" self._evict_host_leaf(x) num_evicted += 1 else: - # internal host node: free host mamba only (tombstone) + # Internal host node + assert ( + x.host_mamba_ref_counter == 0 + ), f"evict host mamba internal: host_mamba_ref_counter != 0 with {x.id=} {x.host_mamba_ref_counter=}" self.mamba_host_lru_list.remove_node(x) self.mamba_pool_host.free(x.mamba_host_value) x.mamba_host_value = None @@ -834,7 +853,7 @@ def evict_mamba(self, mamba_num: int) -> int: # Leaf: evict KV + mamba atomically assert ( x.full_lock_ref == 0 - ), f"evict leaf node invalid with {x.id=} {x.full_lock_ref=}" + ), f"evict device leaf: full_lock_ref mismatch with {x.id=} {x.full_lock_ref=} {x.mamba_lock_ref=}" x_next = self.mamba_lru_list.get_prev_no_lock(x) _, mamba_evicted = self._evict_device_leaf(x) @@ -1471,9 +1490,10 @@ def _force_release_pending_storage_ops(self): logger.exception("Force release pending prefetch ops failed.") try: - for ack_id, node in list(self.ongoing_backup.items()): + for ack_id, entry in list(self.ongoing_backup.items()): try: - self._release_host_node(node) + node, mamba_host_protected = entry + self._release_host_node(node, release_mamba=mamba_host_protected) except Exception: logger.exception( "Failed to release host protection for backup op %s", ack_id @@ -1525,7 +1545,8 @@ def _drain_backup(): ack_id = operation.id entry = self.ongoing_backup.pop(ack_id, None) if entry is not None: - self._release_host_node(entry) + node, mamba_host_protected = entry + self._release_host_node(node, release_mamba=mamba_host_protected) if log_metrics and self.enable_storage_metrics: self.storage_metrics_collector.log_backuped_tokens( operation.completed_tokens @@ -1725,8 +1746,9 @@ def write_backup_storage(self, node: TreeNode): prefix_keys, extra_pools=extra_pools, ) - self.ongoing_backup[operation_id] = node - self._protect_host_node(node) + mamba_host_protected = extra_pools is not None + self.ongoing_backup[operation_id] = (node, mamba_host_protected) + self._protect_host_node(node, protect_mamba=mamba_host_protected) def prefetch_from_storage( self, @@ -1747,7 +1769,7 @@ def prefetch_from_storage( ): return - self._protect_host_node(last_host_node) + self._protect_host_node(last_host_node, protect_mamba=False) # Allocate host KV memory host_indices = self._alloc_with_evict( @@ -1756,16 +1778,21 @@ def prefetch_from_storage( self.evict_host, ) if host_indices is None: - self._release_host_node(last_host_node) + self._release_host_node(last_host_node, release_mamba=False) return # Allocate host mamba slot extra_pools = self.mamba_prefetch_alloc(new_input_tokens, last_hash) if extra_pools is None: self.cache_controller.mem_pool_host.free(host_indices) - self._release_host_node(last_host_node) + self._release_host_node(last_host_node, release_mamba=False) return + # mamba is also being loaded, protect host mamba as well + last_host_node.protect_host_mamba() + if self.mamba_host_lru_list.in_list(last_host_node): + self.mamba_host_lru_list.remove_node(last_host_node) + operation = self.cache_controller.prefetch( req_id, host_indices, @@ -2017,7 +2044,7 @@ def mamba_backup_commit( return host_indices = transfers[0].host_indices if node.mamba_host_value is None and host_indices is not None: - node.mamba_host_value = host_indices + node.mamba_host_value = host_indices.clone() self.mamba_host_lru_list.insert_mru(node) def mamba_archive_transfers(self, node: TreeNode) -> Optional[list[PoolTransfer]]: diff --git a/python/sglang/srt/mem_cache/hicache_storage.py b/python/sglang/srt/mem_cache/hicache_storage.py index 8e39fd0acc14..ba88c5f61b0a 100644 --- a/python/sglang/srt/mem_cache/hicache_storage.py +++ b/python/sglang/srt/mem_cache/hicache_storage.py @@ -307,17 +307,21 @@ def __init__( ): self.file_path = envs.SGLANG_HICACHE_FILE_BACKEND_STORAGE_DIR.get() or file_path - tp_rank, tp_size, model_name, is_mla_model = ( + tp_rank, tp_size, pp_rank, pp_size, model_name, is_mla_model = ( storage_config.tp_rank, storage_config.tp_size, + storage_config.pp_rank, + storage_config.pp_size, storage_config.model_name, storage_config.is_mla_model, ) model_name = "-".join(model_name.split("/")) if model_name else "" - if is_mla_model: - self.config_suffix = f"_{model_name}" - else: - self.config_suffix = f"_{model_name}_{tp_rank}_{tp_size}" + enable_pp = pp_size > 1 + self.config_suffix = f"_{model_name}" + if not is_mla_model: + self.config_suffix += f"_{tp_rank}_{tp_size}" + if enable_pp: + self.config_suffix += f"_{pp_size}_{pp_rank}" if not os.path.exists(self.file_path) and tp_rank == 0: os.makedirs(self.file_path) logger.info(f"Created HiCacheFile storage directory at {self.file_path}") diff --git a/python/sglang/srt/mem_cache/hiradix_cache.py b/python/sglang/srt/mem_cache/hiradix_cache.py index 2d3e579bb15f..3c1e97daab1b 100644 --- a/python/sglang/srt/mem_cache/hiradix_cache.py +++ b/python/sglang/srt/mem_cache/hiradix_cache.py @@ -8,11 +8,10 @@ import threading import time from queue import Empty -from typing import TYPE_CHECKING, Dict, List, Optional, Tuple +from typing import TYPE_CHECKING, Dict, List, Optional import torch -from sglang.srt.environ import envs from sglang.srt.managers.cache_controller import HiCacheController, PrefetchOperation from sglang.srt.mem_cache.base_prefix_cache import ( DecLockRefParams, @@ -45,7 +44,6 @@ ) from sglang.srt.mem_cache.utils import convert_to_bigram_key from sglang.srt.observability.metrics_collector import StorageMetricsCollector -from sglang.srt.utils import bind_to_closest_numa_node_cuda if TYPE_CHECKING: from sglang.srt.mem_cache.cache_init_params import CacheInitParams @@ -59,9 +57,6 @@ class HiRadixCache(RadixCache): def __init__(self, params: CacheInitParams, server_args: ServerArgs): self._enable_metrics_flag = params.enable_metrics - if not server_args.disable_hicache_numa_detect: - bind_to_closest_numa_node_cuda() - self.page_size = params.page_size self.kv_cache = params.token_to_kv_pool_allocator.get_kvcache() @@ -165,18 +160,6 @@ def __init__(self, params: CacheInitParams, server_args: ServerArgs): self.evictable_host_leaves = set() - # Pin budget: max tokens that can be pinned = ratio * host pool capacity. - pin_ratio = envs.SGLANG_HICACHE_MAX_PINNED_RATIO.get() - if pin_ratio < 0 or pin_ratio >= 1: - raise ValueError( - f"SGLANG_HICACHE_MAX_PINNED_RATIO must be in [0, 1), got {pin_ratio}" - ) - self._max_pinned_tokens = int(self.token_to_kv_pool_host.size * pin_ratio) - self.pinned_size_ = 0 - logger.info( - "Pin budget: %d tokens (ratio=%.3f)", self._max_pinned_tokens, pin_ratio - ) - super().__init__(params=params) def shutdown(self): @@ -593,7 +576,6 @@ def reset(self): # Clear per-request tracking dicts self.prefetch_loaded_tokens_by_reqid.clear() self.evictable_host_leaves.clear() - self.pinned_size_ = 0 super().reset() def get_height(self, node: TreeNode): @@ -637,7 +619,7 @@ def write_backup(self, node: TreeNode, write_back=False): node_id=node.id, ) if host_indices is not None: - node.host_value = host_indices + node.host_value = host_indices.clone() assert len(node.host_value) > 0 self.ongoing_write_through[node.id] = node if not write_back: @@ -733,79 +715,6 @@ def loading_check(self): def evictable_size(self): return self.evictable_size_ - def _is_pinned(self, node: TreeNode) -> bool: - """Check if a node has an active (non-expired) pin.""" - return node.pin_expiry > 0 and time.monotonic() <= node.pin_expiry - - def _clear_pin(self, node: TreeNode): - """Clear expired pin state and release host_ref_counter hold.""" - if node.pin_expiry > 0: - self.pinned_size_ = max(0, self.pinned_size_ - len(node.key)) - node.host_ref_counter = max(0, node.host_ref_counter - 1) - node.pin_expiry = 0.0 - node.pin_ttl = 0 - - def pin_prefix( - self, token_ids: List[int], ttl_seconds: int = 300 - ) -> Tuple[int, Optional[str]]: - """Pin nodes along a prefix path. Returns (nodes_pinned, reject_reason).""" - if self.disable or not token_ids: - return (0, None) - - key, _ = self.maybe_bigram_convert(self._to_radix_key(token_ids)) - if self.page_size != 1: - page_aligned_len = len(key) // self.page_size * self.page_size - key = key[:page_aligned_len] - if len(key) == 0: - return (0, None) - - expiry = time.monotonic() + ttl_seconds - nodes_pinned = 0 - budget_exceeded = False - node = self.root_node - child_key = self.get_child_key_fn(key) - - while len(key) > 0 and child_key in node.children: - child = node.children[child_key] - prefix_len = self.key_match_fn(child.key, key) - - # First pin on this node: check budget, then acquire hold - if child.pin_expiry == 0: - if self.pinned_size_ + len(child.key) > self._max_pinned_tokens: - budget_exceeded = True - break - child.host_ref_counter += 1 - self.pinned_size_ += len(child.key) - - # Eagerly back up to host so eviction finds pinned nodes - # already backuped and never enters the write_back drain - # path, which would leak lock_ref on in-flight - # write-through entries. No-op under write_back policy. - self._inc_hit_count(child) - - # Extend expiry and store TTL for refresh-on-hit - child.pin_expiry = max(child.pin_expiry, expiry) - child.pin_ttl = max(child.pin_ttl, ttl_seconds) - nodes_pinned += 1 - - if prefix_len < len(child.key): - break - - node = child - key = key[prefix_len:] - if len(key): - child_key = self.get_child_key_fn(key) - - logger.info( - "[PIN] pin_prefix: nodes_pinned=%d, ttl=%ds", nodes_pinned, ttl_seconds - ) - if budget_exceeded: - msg = f"Pin budget exhausted ({self.pinned_size_}/{self._max_pinned_tokens} tokens pinned)" - if nodes_pinned == 0: - return (0, msg) - return (nodes_pinned, f"prefix partially pinned; {msg}") - return (nodes_pinned, None) - def _to_radix_key(self, token_ids: List[int]) -> RadixKey: """Convert raw token_ids to a RadixKey for tree walking. @@ -884,26 +793,6 @@ def evict(self, params: EvictParams) -> EvictResult: if x.lock_ref > 0: continue - if self._is_pinned(x): - # Still active: demote to host if possible - if x.backuped: - num_evicted += self._evict_backuped(x) - continue - written = self.write_backup(x, write_back=True) - if written > 0: - num_evicted += written - write_back_nodes.append(x) - continue # backup succeeded, pin holds on host - # Host full -- drop pin so GPU can be freed - self._clear_pin(x) - logger.warning( - "[PIN] evict: can't backup node %d to host, releasing pin", - x.id, - ) - elif x.pin_expiry > 0: - # Expired pin: clear and fall through to normal eviction - self._clear_pin(x) - if not x.backuped: if self.cache_controller.write_policy == "write_back": # write to host if the node is not backuped @@ -969,11 +858,6 @@ def evict_host(self, num_tokens: int): if not x.evicted: continue - # Expire stale pins before checking host_ref_counter - if x.pin_expiry > 0 and time.monotonic() > x.pin_expiry: - self._clear_pin(x) - - # node is protected from eviction as it has ongoing prefetch, backup, or pin if x.host_ref_counter > 0: continue @@ -1352,9 +1236,6 @@ def _insert_helper_host( while len(key) > 0 and child_key in node.children.keys(): node = node.children[child_key] node.last_access_time = time.monotonic() - # Refresh pin TTL on host insert hit - if self._is_pinned(node): - node.pin_expiry = time.monotonic() + node.pin_ttl prefix_len = self.key_match_fn(node.key, key) key = key[prefix_len:] host_value = host_value[prefix_len:] @@ -1390,9 +1271,6 @@ def _match_prefix_helper(self, node: TreeNode, key: RadixKey): while len(key) > 0 and child_key in node.children.keys(): child = node.children[child_key] child.last_access_time = time.monotonic() - # Refresh pin TTL on cache hit - if self._is_pinned(child): - child.pin_expiry = time.monotonic() + child.pin_ttl prefix_len = self.key_match_fn(child.key, key) if prefix_len < len(child.key): new_node = self._split_node(child.key, child, prefix_len) @@ -1417,11 +1295,6 @@ def _split_node(self, key: RadixKey, child: TreeNode, split_len: int): new_node.children = {self.get_child_key_fn(key[split_len:]): child} new_node.parent = child.parent new_node.lock_ref = child.lock_ref - new_node.pin_expiry = child.pin_expiry - new_node.pin_ttl = child.pin_ttl - # If child is pinned, new parent inherits a host_ref_counter hold - if child.pin_expiry > 0: - new_node.host_ref_counter += 1 new_node.key = child.key[:split_len] new_node.hit_count = child.hit_count diff --git a/python/sglang/srt/mem_cache/hisparse_memory_pool.py b/python/sglang/srt/mem_cache/hisparse_memory_pool.py index 5af8d257ad6b..0f2a53917175 100644 --- a/python/sglang/srt/mem_cache/hisparse_memory_pool.py +++ b/python/sglang/srt/mem_cache/hisparse_memory_pool.py @@ -193,11 +193,38 @@ def alloc(self, need_size: int): "Page size = 1 is not supported in HiSparse allocator" ) + def alloc_logical_only( + self, + prefix_lens: torch.Tensor, + prefix_lens_cpu: torch.Tensor, + seq_lens: torch.Tensor, + seq_lens_cpu: torch.Tensor, + last_loc: torch.Tensor, + extend_num_tokens: int, + ): + """Allocate only logical indices without hisparse device indices. + + Used in the direct-to-host transfer path where KV data is written + directly to host memory by the prefill node, skipping GPU staging. + """ + return self.logical_attn_allocator.alloc_extend( + prefix_lens, + prefix_lens_cpu, + seq_lens, + seq_lens_cpu, + last_loc, + extend_num_tokens, + ) + def alloc_device_buffer(self, allocated_indices, need_size: int): assert need_size % self.page_size == 0 # clear original reference and isolate the buffer from outside addressing, allocate new buffer if needed hisparse_indices = self.full_to_hisparse_device_index_mapping[allocated_indices] self.full_to_hisparse_device_index_mapping[allocated_indices] = 0 + # Filter valid (non-zero) hisparse indices. + # In the direct-to-host path, mapping is all zeros since no hisparse + # device indices were pre-allocated. + hisparse_indices = hisparse_indices[hisparse_indices > 0] if len(hisparse_indices) >= need_size: buffer_indices = hisparse_indices[:need_size] self.free_hisparse_indices(hisparse_indices[need_size:]) diff --git a/python/sglang/srt/mem_cache/mamba_radix_cache.py b/python/sglang/srt/mem_cache/mamba_radix_cache.py index d02eb7d9b3d3..d07702cf1efd 100644 --- a/python/sglang/srt/mem_cache/mamba_radix_cache.py +++ b/python/sglang/srt/mem_cache/mamba_radix_cache.py @@ -87,6 +87,7 @@ def __init__(self, id: Optional[int] = None): self.hit_count = 0 self.host_ref_counter = 0 + self.host_mamba_ref_counter = 0 # store the host indices of KV cache self.host_value = None # store hash values of each pages @@ -122,16 +123,27 @@ def mamba_backuped(self): return self.mamba_host_value is not None def protect_host(self): - """Protect the host value from eviction.""" + """Protect the host KV value from eviction.""" self.host_ref_counter += 1 def release_host(self): - """Release the host value, allowing it to be evicted.""" + """Release the host KV value, allowing it to be evicted.""" if self.host_ref_counter > 0: self.host_ref_counter -= 1 else: raise RuntimeError("Host reference counter is already zero.") + def protect_host_mamba(self): + """Protect the host mamba value from eviction.""" + self.host_mamba_ref_counter += 1 + + def release_host_mamba(self): + """Release the host mamba value, allowing it to be evicted.""" + if self.host_mamba_ref_counter > 0: + self.host_mamba_ref_counter -= 1 + else: + raise RuntimeError("Host mamba reference counter is already zero.") + def get_last_hash_value(self) -> Optional[str]: """Returns the hash value of the last page in this node.""" if self.hash_value is None or len(self.hash_value) == 0: diff --git a/python/sglang/srt/mem_cache/memory_pool.py b/python/sglang/srt/mem_cache/memory_pool.py index 72cb64a1257c..439b35f1d3a9 100644 --- a/python/sglang/srt/mem_cache/memory_pool.py +++ b/python/sglang/srt/mem_cache/memory_pool.py @@ -45,11 +45,13 @@ quantize_k_cache, quantize_k_cache_separate, ) +from sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.mem_cache.utils import ( get_mla_kv_buffer_triton, maybe_init_custom_mem_pool, set_mla_kv_buffer_triton, + set_mla_kv_buffer_triton_fp8_quant, set_mla_kv_scale_buffer_triton, ) from sglang.srt.utils import ( @@ -75,6 +77,7 @@ _is_cpu = is_cpu() _cpu_has_amx_support = cpu_has_amx_support() _is_hip = is_hip() +_is_fp8_fnuz = is_fp8_fnuz() def get_tensor_size_bytes(t: Union[torch.Tensor, List[torch.Tensor]]): @@ -220,6 +223,7 @@ def __init__( size: int, spec_state_size: int, cache_params: BaseLinearStateParams, + mamba_layer_ids: List[int], device: str, enable_memory_saver: bool = False, speculative_num_draft_tokens: Optional[int] = None, @@ -231,7 +235,7 @@ def __init__( self.memory_saver_adapter = TorchMemorySaverAdapter.create( enable=enable_memory_saver ) - num_mamba_layers = len(cache_params.layers) + num_mamba_layers = len(mamba_layer_ids) self.size = size self.device = device @@ -463,9 +467,11 @@ def __init__( device: str, enable_memory_saver: bool, cache_params: BaseLinearStateParams, + mamba_layer_ids: List[int], enable_mamba_extra_buffer: bool, speculative_num_draft_tokens: int = None, enable_overlap_schedule: bool = True, + start_layer: Optional[int] = None, ): super().__init__( size=size, @@ -477,13 +483,13 @@ def __init__( self.mamba_ping_pong_track_buffer_size = 2 if enable_overlap_schedule else 1 self.enable_mamba_extra_buffer = enable_mamba_extra_buffer self.enable_memory_saver = enable_memory_saver - # TODO: Support PP - self.start_layer = 0 + self.start_layer = start_layer if start_layer is not None else 0 self.layer_transfer_counter = None self._init_mamba_pool( size=mamba_size, mamba_spec_state_size=mamba_spec_state_size, cache_params=cache_params, + mamba_layer_ids=mamba_layer_ids, device=device, enable_mamba_extra_buffer=enable_mamba_extra_buffer, speculative_num_draft_tokens=speculative_num_draft_tokens, @@ -494,6 +500,7 @@ def _init_mamba_pool( size: int, mamba_spec_state_size: int, cache_params: BaseLinearStateParams, + mamba_layer_ids: List[int], device: str, enable_mamba_extra_buffer: bool, speculative_num_draft_tokens: int = None, @@ -502,11 +509,12 @@ def _init_mamba_pool( size=size, spec_state_size=mamba_spec_state_size, cache_params=cache_params, + mamba_layer_ids=mamba_layer_ids, device=device, enable_memory_saver=self.enable_memory_saver, speculative_num_draft_tokens=speculative_num_draft_tokens, ) - self.mamba_map = {layer_id: i for i, layer_id in enumerate(cache_params.layers)} + self.mamba_map = {layer_id: i for i, layer_id in enumerate(mamba_layer_ids)} self.device = device self.req_index_to_mamba_index_mapping: torch.Tensor = torch.zeros( @@ -1244,13 +1252,14 @@ def __init__( use_mla: bool = False, kv_lora_rank: int = None, qk_rope_head_dim: int = None, + start_layer: Optional[int] = None, ): self.size = size self.dtype = dtype self.device = device self.full_layer_nums = len(full_attention_layer_ids) self.page_size = page_size - self.start_layer = 0 # TODO: Support PP + self.start_layer = start_layer if start_layer is not None else 0 self.layer_transfer_counter = None self.head_num = head_num self.head_dim = head_dim @@ -1576,22 +1585,36 @@ def set_mla_kv_buffer( layer_id = layer.layer_id if self.nsa_kv_cache_store_fp8: - # OPTIMIZATION: Quantize k_nope and k_rope separately to avoid concat overhead - # This also enables reuse of set_mla_kv_buffer_triton two-tensor write path - # quantize_k_cache_separate returns (nope_part, rope_part) as uint8 bytes - cache_k_nope_fp8, cache_k_rope_fp8 = quantize_k_cache_separate( - cache_k_nope, cache_k_rope - ) + if _is_hip: + # HIP FP8 path uses raw MLA KV layout (nope + rope) without per-block scales. + # Fuse BF16/FP16 -> FP8 cast with paged KV write. + fp8_dtype = ( + torch.float8_e4m3fnuz if _is_fp8_fnuz else torch.float8_e4m3fn + ) + set_mla_kv_buffer_triton_fp8_quant( + self.kv_buffer[layer_id - self.start_layer], + loc, + cache_k_nope, + cache_k_rope, + fp8_dtype, + ) + else: + # OPTIMIZATION: Quantize k_nope and k_rope separately to avoid concat overhead + # This also enables reuse of set_mla_kv_buffer_triton two-tensor write path + # quantize_k_cache_separate returns (nope_part, rope_part) as uint8 bytes + cache_k_nope_fp8, cache_k_rope_fp8 = quantize_k_cache_separate( + cache_k_nope, cache_k_rope + ) - # Reuse existing two-tensor write kernel (works with FP8 byte layout) - # cache_k_nope_fp8: (num_tokens, 1, 528) uint8 [nope_fp8(512) | scales(16)] - # cache_k_rope_fp8: (num_tokens, 1, 128) uint8 [rope_bf16_bytes(128)] - set_mla_kv_buffer_triton( - self.kv_buffer[layer_id - self.start_layer], - loc, - cache_k_nope_fp8, - cache_k_rope_fp8, - ) + # Reuse existing two-tensor write kernel (works with FP8 byte layout) + # cache_k_nope_fp8: (num_tokens, 1, 528) uint8 [nope_fp8(512) | scales(16)] + # cache_k_rope_fp8: (num_tokens, 1, 128) uint8 [rope_bf16_bytes(128)] + set_mla_kv_buffer_triton( + self.kv_buffer[layer_id - self.start_layer], + loc, + cache_k_nope_fp8, + cache_k_rope_fp8, + ) else: if cache_k_nope.dtype != self.dtype: cache_k_nope = cache_k_nope.to(self.dtype) diff --git a/python/sglang/srt/mem_cache/memory_pool_host.py b/python/sglang/srt/mem_cache/memory_pool_host.py index 10fc35239d89..1a9708c41430 100644 --- a/python/sglang/srt/mem_cache/memory_pool_host.py +++ b/python/sglang/srt/mem_cache/memory_pool_host.py @@ -21,9 +21,15 @@ from sglang.jit_kernel.hicache import ( transfer_hicache_all_layer as jit_transfer_hicache_all_layer, ) +from sglang.jit_kernel.hicache import ( + transfer_hicache_all_layer_mla as jit_transfer_hicache_all_layer_mla, +) from sglang.jit_kernel.hicache import ( transfer_hicache_one_layer as jit_transfer_hicache_one_layer, ) +from sglang.jit_kernel.hicache import ( + transfer_hicache_one_layer_mla as jit_transfer_hicache_one_layer_mla, +) from sglang.srt.mem_cache.memory_pool import ( KVCache, MambaPool, @@ -309,8 +315,16 @@ def __init__( element_size=self.element_dim * self.dtype.itemsize ) - self.k_data_refs = [self.k_buffer[i] for i in range(self.layer_num)] - self.v_data_refs = [self.v_buffer[i] for i in range(self.layer_num)] + if self.layout == "page_first": + # Transpose [page, layer, ...] -> [layer, page, ...] to get per-layer views + # This swaps strides without copying data + k_transposed = self.k_buffer.transpose(0, 1) + v_transposed = self.v_buffer.transpose(0, 1) + self.k_data_refs = [k_transposed[i] for i in range(self.layer_num)] + self.v_data_refs = [v_transposed[i] for i in range(self.layer_num)] + else: + self.k_data_refs = [self.k_buffer[i] for i in range(self.layer_num)] + self.v_data_refs = [self.v_buffer[i] for i in range(self.layer_num)] self.k_data_ptrs = torch.tensor( [x.data_ptr() for x in self.k_data_refs], dtype=torch.uint64, @@ -409,17 +423,31 @@ def load_to_device_per_layer( item_size=self.token_stride_size, ) elif self.layout == "page_first": - transfer_kv_per_layer_pf_lf( - src_k=self.k_buffer, - dst_k=device_pool.k_buffer[layer_id], - src_v=self.v_buffer, - dst_v=device_pool.v_buffer[layer_id], - src_indices=host_indices, - dst_indices=device_indices, - layer_id=layer_id, - item_size=self.token_stride_size, - src_layout_dim=self.layout_dim, - ) + if self.can_use_jit: + # Transpose [page, layer, ...] -> [layer, page, ...] then + # index by layer_id to get a per-layer view with strided layout. + # The kernel handles different src/dst strides automatically. + jit_transfer_hicache_one_layer( + k_cache_dst=device_pool.k_buffer[layer_id], + v_cache_dst=device_pool.v_buffer[layer_id], + k_cache_src=self.k_data_refs[layer_id], + v_cache_src=self.v_data_refs[layer_id], + indices_dst=device_indices, + indices_src=host_indices, + element_dim=self.element_dim, + ) + else: + transfer_kv_per_layer_pf_lf( + src_k=self.k_buffer, + dst_k=device_pool.k_buffer[layer_id], + src_v=self.v_buffer, + dst_v=device_pool.v_buffer[layer_id], + src_indices=host_indices, + dst_indices=device_indices, + layer_id=layer_id, + item_size=self.token_stride_size, + src_layout_dim=self.layout_dim, + ) elif self.layout == "page_head": transfer_kv_per_layer_ph_lf( src_k=self.k_buffer, @@ -510,17 +538,32 @@ def backup_from_device_all_layer( num_layers=self.layer_num, ) elif self.layout == "page_first": - transfer_kv_all_layer_lf_pf( - src_k_layers=device_pool.k_data_ptrs, - dst_k=self.k_buffer, - src_v_layers=device_pool.v_data_ptrs, - dst_v=self.v_buffer, - src_indices=device_indices, - dst_indices=host_indices, - item_size=self.token_stride_size, - dst_layout_dim=self.layout_dim, - num_layers=self.layer_num, - ) + if self.can_use_jit: + # Use transposed data ptrs so the kernel writes to + # [layer, page, item] view with stride layout_dim per token. + jit_transfer_hicache_all_layer( + k_ptr_dst=self.k_data_ptrs, + v_ptr_dst=self.v_data_ptrs, + indices_dst=host_indices, + k_ptr_src=device_pool.k_data_ptrs, + v_ptr_src=device_pool.v_data_ptrs, + indices_src=device_indices, + kv_cache_src_stride_bytes=self.token_stride_size, + kv_cache_dst_stride_bytes=self.layout_dim, + element_size=self.element_dim * self.dtype.itemsize, + ) + else: + transfer_kv_all_layer_lf_pf( + src_k_layers=device_pool.k_data_ptrs, + dst_k=self.k_buffer, + src_v_layers=device_pool.v_data_ptrs, + dst_v=self.v_buffer, + src_indices=device_indices, + dst_indices=host_indices, + item_size=self.token_stride_size, + dst_layout_dim=self.layout_dim, + num_layers=self.layer_num, + ) elif self.layout == "page_head": transfer_kv_all_layer_lf_ph( src_k_layers=device_pool.k_data_ptrs, @@ -766,13 +809,31 @@ def __init__( device, allocator_type, ) - self.data_refs = [self.kv_buffer[i] for i in range(self.layer_num)] + self.can_use_jit = _is_cuda and can_use_hicache_jit_kernel( + element_size=self.kv_cache_dim * self.dtype.itemsize + ) + + if self.layout == "page_first" and self.can_use_jit: + # Transpose [page, layer, ...] -> [layer, page, ...] to get per-layer views + # This swaps strides without copying data + transposed = self.kv_buffer.transpose(0, 1) + self.data_refs = [transposed[i] for i in range(self.layer_num)] + else: + self.data_refs = [self.kv_buffer[i] for i in range(self.layer_num)] self.data_ptrs = torch.tensor( [x.data_ptr() for x in self.data_refs], dtype=torch.uint64, device=self.device_pool.device, ) + def get_contiguous_buf_infos(self): + """Return (data_ptrs, data_lens, item_lens) in the same format as device pool, + for registering host memory with the disaggregation transfer engine.""" + data_ptrs = [int(self.data_ptrs[i].item()) for i in range(self.layer_num)] + data_lens = [self.kv_buffer[i].nbytes for i in range(self.layer_num)] + item_lens = [self.token_stride_size] * self.layer_num + return data_ptrs, data_lens, item_lens + def get_size_per_token(self): self.kv_lora_rank = self.device_pool.kv_lora_rank self.qk_rope_head_dim = self.device_pool.qk_rope_head_dim @@ -864,23 +925,41 @@ def load_to_device_per_layer( ): if io_backend == "kernel": if self.layout == "layer_first": - transfer_kv_per_layer_mla( - src=self.kv_buffer[layer_id], - dst=device_pool.kv_buffer[layer_id], - src_indices=host_indices, - dst_indices=device_indices, - item_size=self.token_stride_size, - ) + if self.can_use_jit: + jit_transfer_hicache_one_layer_mla( + cache_dst=device_pool.kv_buffer[layer_id], + cache_src=self.kv_buffer[layer_id], + indices_dst=device_indices, + indices_src=host_indices, + element_dim=self.kv_cache_dim, + ) + else: + transfer_kv_per_layer_mla( + src=self.kv_buffer[layer_id], + dst=device_pool.kv_buffer[layer_id], + src_indices=host_indices, + dst_indices=device_indices, + item_size=self.token_stride_size, + ) elif self.layout == "page_first": - transfer_kv_per_layer_mla_pf_lf( - src=self.kv_buffer, - dst=device_pool.kv_buffer[layer_id], - src_indices=host_indices, - dst_indices=device_indices, - layer_id=layer_id, - item_size=self.token_stride_size, - src_layout_dim=self.layout_dim, - ) + if self.can_use_jit: + jit_transfer_hicache_one_layer_mla( + cache_dst=device_pool.kv_buffer[layer_id], + cache_src=self.data_refs[layer_id], + indices_dst=device_indices, + indices_src=host_indices, + element_dim=self.kv_cache_dim, + ) + else: + transfer_kv_per_layer_mla_pf_lf( + src=self.kv_buffer, + dst=device_pool.kv_buffer[layer_id], + src_indices=host_indices, + dst_indices=device_indices, + layer_id=layer_id, + item_size=self.token_stride_size, + src_layout_dim=self.layout_dim, + ) else: raise ValueError(f"Unsupported layout: {self.layout}") elif io_backend == "direct": @@ -929,24 +1008,46 @@ def backup_from_device_all_layer( ): if io_backend == "kernel": if self.layout == "layer_first": - transfer_kv_all_layer_mla( - src_layers=device_pool.data_ptrs, - dst_layers=self.data_ptrs, - src_indices=device_indices, - dst_indices=host_indices, - item_size=self.token_stride_size, - num_layers=self.layer_num, - ) + if self.can_use_jit: + jit_transfer_hicache_all_layer_mla( + ptr_dst=self.data_ptrs, + indices_dst=host_indices, + ptr_src=device_pool.data_ptrs, + indices_src=device_indices, + cache_dst_stride_bytes=self.token_stride_size, + cache_src_stride_bytes=self.token_stride_size, + element_size=self.kv_cache_dim * self.dtype.itemsize, + ) + else: + transfer_kv_all_layer_mla( + src_layers=device_pool.data_ptrs, + dst_layers=self.data_ptrs, + src_indices=device_indices, + dst_indices=host_indices, + item_size=self.token_stride_size, + num_layers=self.layer_num, + ) elif self.layout == "page_first": - transfer_kv_all_layer_mla_lf_pf( - src_layers=device_pool.data_ptrs, - dst=self.kv_buffer, - src_indices=device_indices, - dst_indices=host_indices, - item_size=self.token_stride_size, - dst_layout_dim=self.layout_dim, - num_layers=self.layer_num, - ) + if self.can_use_jit: + jit_transfer_hicache_all_layer_mla( + ptr_dst=self.data_ptrs, + indices_dst=host_indices, + ptr_src=device_pool.data_ptrs, + indices_src=device_indices, + cache_src_stride_bytes=self.token_stride_size, + cache_dst_stride_bytes=self.layout_dim, + element_size=self.kv_cache_dim * self.dtype.itemsize, + ) + else: + transfer_kv_all_layer_mla_lf_pf( + src_layers=device_pool.data_ptrs, + dst=self.kv_buffer, + src_indices=device_indices, + dst_indices=host_indices, + item_size=self.token_stride_size, + dst_layout_dim=self.layout_dim, + num_layers=self.layer_num, + ) else: raise ValueError(f"Unsupported layout: {self.layout}") elif io_backend == "direct": diff --git a/python/sglang/srt/mem_cache/radix_cache.py b/python/sglang/srt/mem_cache/radix_cache.py index 501fc4223eee..7d1616037243 100644 --- a/python/sglang/srt/mem_cache/radix_cache.py +++ b/python/sglang/srt/mem_cache/radix_cache.py @@ -128,10 +128,6 @@ def __init__(self, id: Optional[int] = None, priority: int = 0): self.key: RadixKey = None self.value: Optional[torch.Tensor] = None self.lock_ref = 0 - self.pin_expiry: float = ( - 0.0 # absolute expiry time (time.monotonic()), 0 = not pinned - ) - self.pin_ttl: int = 0 # original TTL in seconds, for refresh-on-hit self.last_access_time = time.monotonic() self.creation_time = time.monotonic() diff --git a/python/sglang/srt/mem_cache/utils.py b/python/sglang/srt/mem_cache/utils.py index 3aec3dd89b51..ba08819c3c63 100644 --- a/python/sglang/srt/mem_cache/utils.py +++ b/python/sglang/srt/mem_cache/utils.py @@ -109,6 +109,93 @@ def set_mla_kv_buffer_triton( ) +@triton.jit +def set_mla_kv_buffer_fp8_quant_kernel( + kv_buffer_fp8_ptr, + cache_k_nope_ptr, + cache_k_rope_ptr, + loc_ptr, + buffer_stride: tl.constexpr, + nope_stride: tl.constexpr, + rope_stride: tl.constexpr, + nope_dim: tl.constexpr, + rope_dim: tl.constexpr, + BLOCK: tl.constexpr, +): + """Fuse BF16/FP16->FP8 cast with paged KV write.""" + pid_loc = tl.program_id(0) + pid_blk = tl.program_id(1) + + base = pid_blk * BLOCK + offs = base + tl.arange(0, BLOCK) + total_dim = nope_dim + rope_dim + mask = offs < total_dim + + loc = tl.load(loc_ptr + pid_loc).to(tl.int64) + dst_ptr = kv_buffer_fp8_ptr + loc * buffer_stride + offs + + if base + BLOCK <= nope_dim: + src = tl.load( + cache_k_nope_ptr + pid_loc * nope_stride + offs, + mask=mask, + other=0.0, + ) + elif base >= nope_dim: + offs_rope = offs - nope_dim + src = tl.load( + cache_k_rope_ptr + pid_loc * rope_stride + offs_rope, + mask=mask, + other=0.0, + ) + else: + is_nope = offs < nope_dim + src_nope = tl.load( + cache_k_nope_ptr + pid_loc * nope_stride + offs, + mask=mask & is_nope, + other=0.0, + ) + src_rope = tl.load( + cache_k_rope_ptr + pid_loc * rope_stride + (offs - nope_dim), + mask=mask & ~is_nope, + other=0.0, + ) + src = tl.where(is_nope, src_nope, src_rope) + + # Destination pointer is FP8-typed view; tl.store performs downcast. + tl.store(dst_ptr, src, mask=mask) + + +def set_mla_kv_buffer_triton_fp8_quant( + kv_buffer: torch.Tensor, + loc: torch.Tensor, + cache_k_nope: torch.Tensor, + cache_k_rope: torch.Tensor, + fp8_dtype: torch.dtype, +): + """Fuse BF16/FP16 MLA K quantization with paged KV write.""" + kv_buffer_fp8 = kv_buffer.view(fp8_dtype) + + nope_dim = cache_k_nope.shape[-1] + rope_dim = cache_k_rope.shape[-1] + total_dim = nope_dim + rope_dim + BLOCK = 128 + n_loc = loc.numel() + grid = (n_loc, triton.cdiv(total_dim, BLOCK)) + + set_mla_kv_buffer_fp8_quant_kernel[grid]( + kv_buffer_fp8, + cache_k_nope, + cache_k_rope, + loc, + kv_buffer_fp8.stride(0), + cache_k_nope.stride(0), + cache_k_rope.stride(0), + nope_dim, + rope_dim, + BLOCK=BLOCK, + ) + + @triton.jit def set_mla_kv_scale_buffer_kernel( kv_buffer_ptr, diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py index ea8c93963a32..066d4fedac43 100644 --- a/python/sglang/srt/model_executor/cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/cuda_graph_runner.py @@ -590,7 +590,12 @@ def __init__(self, model_runner: ModelRunner): else self.dllm_config.block_size ) - self.encoder_len_fill_value = 0 + # Non-zero encoder length ensures cross-attention kernels are captured in the graph. + self.encoder_len_fill_value = ( + getattr(model_runner.model_config.hf_config, "max_source_positions", 0) + if self.is_encoder_decoder + else 0 + ) if self.enable_torch_compile: set_torch_compile_config() diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py index bec834965397..eaecdc54bcf4 100644 --- a/python/sglang/srt/model_executor/forward_batch_info.py +++ b/python/sglang/srt/model_executor/forward_batch_info.py @@ -715,15 +715,21 @@ def _compute_spec_mrope_positions( else: # target_verify or draft_decode seq_positions = batch.spec_info.positions.view(batch_size, -1) - mrope_deltas = [ - ( - torch.tensor([0], dtype=torch.int64) - if mm_inputs[i] is None - else mm_inputs[i].mrope_position_delta.squeeze(0) + # Split text-only and mixed batches here because SpecV2 text-only batches can avoid an extra D2H. + if all(mm_input is None for mm_input in mm_inputs): + mrope_delta_tensor = torch.zeros( + (batch_size, 1), dtype=torch.int64, device=device ) - for i in range(batch_size) - ] - mrope_delta_tensor = torch.stack(mrope_deltas, dim=0).to(device=device) + else: + mrope_deltas = [ + ( + torch.zeros(1, dtype=torch.int64) + if mm_inputs[i] is None + else mm_inputs[i].mrope_position_delta.squeeze(0) + ) + for i in range(batch_size) + ] + mrope_delta_tensor = torch.stack(mrope_deltas, dim=0).to(device=device) next_input_positions = ( (seq_positions + mrope_delta_tensor).flatten().unsqueeze(0).repeat(3, 1) ) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index fc9afafac90b..af14370f8267 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -520,9 +520,11 @@ def initialize(self, pre_model_load_memory: float): and self.remote_instance_transfer_engine is not None and self.remote_instance_transfer_engine_weight_info is None ): + # Register memory and upstream the transfer engine info to the bootstrap server self.remote_instance_transfer_engine_weight_info = register_memory_region( self.model, self.remote_instance_transfer_engine ) + self._register_to_engine_info_bootstrap() # For MTP models like DeepSeek-V3 or GLM-4.5, the MTP layer(s) are used separately as draft # models for speculative decoding. In those cases, `num_nextn_predict_layers` is used to @@ -700,6 +702,52 @@ def remote_instance_init_transfer_engine(self): local_ip, self.remote_instance_transfer_engine.get_rpc_port() ).to_host_port_str() + def _register_to_engine_info_bootstrap(self): + """Register transfer engine info with the EngineInfoBootstrapServer via HTTP PUT. + + The bootstrap server runs on node_rank==0. For multi-node setups, the + host is derived from dist_init_addr. For single-node, use 127.0.0.1. + """ + import requests as http_requests + + if self.server_args.dist_init_addr: + # Multi-node: bootstrap server is on the head node (node_rank==0). + # Derive host from dist_init_addr (shared across all nodes). + bootstrap_host = ( + NetworkAddress.parse(self.server_args.dist_init_addr).resolved().host + ) + else: + bootstrap_host = "127.0.0.1" + + bootstrap_port = self.server_args.engine_info_bootstrap_port + bootstrap_na = NetworkAddress(bootstrap_host, bootstrap_port) + url = f"{bootstrap_na.to_url()}/register_transfer_engine_info" + + payload = { + "tp_rank": self.tp_rank, + "transfer_engine_info": { + "session_id": self.remote_instance_transfer_engine_session_id, + "weights_info_dict": self.remote_instance_transfer_engine_weight_info, + }, + } + + try: + resp = http_requests.put(url, json=payload, timeout=5) + if resp.status_code == 200: + logger.info( + f"Registered transfer engine info for tp_rank={self.tp_rank} " + f"with bootstrap server at {bootstrap_na}" + ) + else: + logger.error( + f"Failed to register transfer engine info for tp_rank={self.tp_rank}: " + f"{resp.status_code}, {resp.text}" + ) + except Exception as e: + logger.error( + f"Failed to register transfer engine info for tp_rank={self.tp_rank}: {e}" + ) + def _publish_modelexpress_metadata(self): """Publish TransferEngine metadata to ModelExpress server (seed mode).""" try: @@ -1108,6 +1156,10 @@ def load_model(self): self.remote_instance_transfer_engine_weight_info = ( self.loader.remote_instance_transfer_engine_weight_info ) + # Cache needs to be cleared after loading model weights (in the self.loader.load_model function). + # To avoid conflict with memory_saver_adapter.region, empty_cache operation is now moved here. + if _is_npu: + torch.npu.empty_cache() monkey_patch_vllm_parallel_state(reverse=True) # Publish metadata to ModelExpress if running as seed source @@ -1967,6 +2019,22 @@ def kernel_warmup(self): if self._should_run_flashinfer_autotune(): self._flashinfer_autotune() + self._warmup_fused_sampling() + + def _warmup_fused_sampling(self): + """Pre-compile and autotune fused sampling Triton kernels.""" + if _is_hip: + return + from sglang.srt.layers.fused_sampling import warmup_fused_temperature_softmax + + logits_warmup_dtype = ( + torch.float32 if self.server_args.enable_fp32_lm_head else self.dtype + ) + warmup_fused_temperature_softmax( + self.model_config.vocab_size, + logits_dtype=logits_warmup_dtype, + ) + def _should_run_flashinfer_autotune(self) -> bool: """Check if flashinfer autotune should be run.""" if self.server_args.disable_flashinfer_autotune: @@ -1978,6 +2046,8 @@ def _should_run_flashinfer_autotune(self) -> bool: if backend_str not in [ "flashinfer_trtllm", + # TODO: Enable for flashinfer_trtllm_routed once https://github.com/flashinfer-ai/flashinfer/issues/2749 is fixed. + # "flashinfer_trtllm_routed", "flashinfer_mxfp4", # TODO: flashinfer_cutlass will cause some flashinfer compilation errors. To be fixed. # "flashinfer_cutlass", @@ -2048,6 +2118,16 @@ def _dummy_run(self, batch_size: int, run_ctx=None): if self.server_args.enable_torch_compile: set_torch_compile_config() + should_disable_torch_compile = not getattr( + self.model, "_can_torch_compile", True + ) + if should_disable_torch_compile: + log_info_on_rank0( + logger, + "Transformers backend model reports it is not torch.compile " + "compatible (e.g. dynamic rope scaling). Disabling torch.compile.", + ) + self.server_args.enable_torch_compile = False if self.eagle_use_aux_hidden_state: self.model.set_eagle3_layers_to_capture() @@ -2068,7 +2148,11 @@ def _dummy_run(self, batch_size: int, run_ctx=None): is_encoder_decoder=self.model_config.is_encoder_decoder, require_mlp_tp_gather=require_mlp_tp_gather_, seq_len_fill_value=seq_len_fill_value, - encoder_len_fill_value=0, + encoder_len_fill_value=( + getattr(self.model_config.hf_config, "max_source_positions", 0) + if self.model_config.is_encoder_decoder + else 0 + ), num_tokens_per_bs=num_tokens_per_bs, cache_loc_dtype=torch.int64, enable_mamba_track=False, @@ -2379,6 +2463,14 @@ def init_piecewise_cuda_graphs(self): # Collect attention layers and moe layers from the model self.model.model = resolve_language_model(self.model) language_model = getattr(self.model, "language_model", self.model) + + # Some draft models (e.g. eagle3) don't have a standard 'layers' attribute + if not hasattr(language_model.model, "layers"): + logger.warning( + "Disable piecewise CUDA graph because the model does not have a 'layers' attribute" + ) + return + self.attention_layers = [] self.moe_layers = [] self.moe_fusions = [] @@ -2412,6 +2504,8 @@ def init_piecewise_cuda_graphs(self): if attn_layer is not None: self.attention_layers.append(attn_layer) + elif hasattr(layer, "mixer"): + self.attention_layers.append(None) moe_block = None moe_fusion = None diff --git a/python/sglang/srt/model_executor/model_runner_kv_cache_mixin.py b/python/sglang/srt/model_executor/model_runner_kv_cache_mixin.py index 7bc66f08d58f..a6baa4817ace 100644 --- a/python/sglang/srt/model_executor/model_runner_kv_cache_mixin.py +++ b/python/sglang/srt/model_executor/model_runner_kv_cache_mixin.py @@ -8,6 +8,7 @@ from sglang.srt.configs.model_config import get_nsa_index_head_dim, is_deepseek_nsa from sglang.srt.distributed.parallel_state import get_world_group +from sglang.srt.environ import envs from sglang.srt.layers.dp_attention import get_attention_tp_size from sglang.srt.mem_cache.allocator import ( PagedTokenToKVPoolAllocator, @@ -32,6 +33,7 @@ from sglang.srt.utils.common import ( get_available_gpu_memory, is_float4_e2m1fn_x2, + is_hip, is_npu, ) @@ -66,6 +68,7 @@ def __post_init__(self): logger = logging.getLogger(__name__) _is_npu = is_npu() +_is_hip = is_hip() class ModelRunnerKVCacheMixin: @@ -255,9 +258,17 @@ def calculate_mla_kv_cache_dim(self: ModelRunner) -> int: ): return kv_cache_dim + # On HIP with TileLang backend, keep the default MLA KV cache dimension. + # FP8 attention uses the nope(512 fp8) + rope(64 fp8) layout, without extra per-block scales. + if _is_hip and ( + self.server_args.nsa_prefill_backend == "tilelang" + or self.server_args.nsa_decode_backend == "tilelang" + ): + return kv_cache_dim + quant_block_size = NSATokenToKVPool.quant_block_size rope_storage_dtype = NSATokenToKVPool.rope_storage_dtype - # Calculate override_kv_cache_dim for FP8 storage for non-trtllm attention backends: + # Calculate override_kv_cache_dim for FP8 storage in backends that use scaled KV layout (excluding TRTLLM and HIP+TileLang). # kv_lora_rank + scale storage (kv_lora_rank // quant_block_size * 4 bytes) + rope dimension storage # Note: rope dimension is stored in original dtype (bf16), not quantized to fp8 if kv_cache_dtype == torch.float8_e4m3fn: @@ -392,7 +403,11 @@ def _init_pools(self: ModelRunner): # subscribe memory for pre-allocated requests # if max_num_reqs <= 32, we pre-allocate 2x requests - pre_alloc_size = max_num_reqs * 2 if max_num_reqs <= 32 else 0 + + pre_alloc_size = envs.SGLANG_DISAGGREGATION_NUM_PRE_ALLOCATE_REQS.get() + pre_alloc_size = ( + max_num_reqs * 2 if max_num_reqs <= 32 else pre_alloc_size + ) if config := self.mambaish_config: self.req_to_token_pool = HybridMambaDecodeReqToTokenPool( size=max_num_reqs, @@ -401,11 +416,19 @@ def _init_pools(self: ModelRunner): device=self.device, enable_memory_saver=self.server_args.enable_memory_saver, cache_params=config.mamba2_cache_params, + mamba_layer_ids=( + [ + i + for i in config.mamba2_cache_params.layers + if self.start_layer <= i < self.end_layer + ] + ), speculative_num_draft_tokens=self.server_args.speculative_num_draft_tokens, enable_mamba_extra_buffer=self.server_args.enable_mamba_extra_buffer(), pre_alloc_size=pre_alloc_size, enable_overlap_schedule=not self.server_args.disable_overlap_schedule, mamba_size=self.server_args.max_mamba_cache_size, + start_layer=self.start_layer, ) else: self.req_to_token_pool = DecodeReqToTokenPool( @@ -426,9 +449,17 @@ def _init_pools(self: ModelRunner): device=self.device, enable_memory_saver=self.server_args.enable_memory_saver, cache_params=config.mamba2_cache_params, + mamba_layer_ids=( + [ + i + for i in config.mamba2_cache_params.layers + if self.start_layer <= i < self.end_layer + ] + ), enable_mamba_extra_buffer=self.server_args.enable_mamba_extra_buffer(), speculative_num_draft_tokens=self.server_args.speculative_num_draft_tokens, enable_overlap_schedule=not self.server_args.disable_overlap_schedule, + start_layer=self.start_layer, ) else: self.req_to_token_pool = ReqToTokenPool( @@ -643,6 +674,7 @@ def _init_pools(self: ModelRunner): mamba_pool=self.req_to_token_pool.mamba_pool, enable_memory_saver=self.server_args.enable_memory_saver, use_mla=self.use_mla_backend, + start_layer=self.start_layer, **extra_args, ) else: diff --git a/python/sglang/srt/model_executor/piecewise_cuda_graph_runner.py b/python/sglang/srt/model_executor/piecewise_cuda_graph_runner.py index 18fc12785705..b1935d21c462 100644 --- a/python/sglang/srt/model_executor/piecewise_cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/piecewise_cuda_graph_runner.py @@ -682,8 +682,20 @@ def replay_prepare( next_token_logits_buffer = None + # Normalize MIXED→EXTEND so dynamo's guard (captured with EXTEND=1) doesn't fail on MIXED=3. + pcg_forward_mode = ( + ForwardMode.EXTEND + if forward_batch.forward_mode == ForwardMode.MIXED + else forward_batch.forward_mode + ) + pcg_global_forward_mode = ( + ForwardMode.EXTEND + if forward_batch.global_forward_mode == ForwardMode.MIXED + else forward_batch.global_forward_mode + ) + static_forward_batch = ForwardBatch( - forward_mode=forward_batch.forward_mode, + forward_mode=pcg_forward_mode, batch_size=bs, input_ids=input_ids, input_embeds=input_embeds, @@ -721,7 +733,7 @@ def replay_prepare( spec_info=forward_batch.spec_info, capture_hidden_mode=forward_batch.capture_hidden_mode, num_token_non_padded=forward_batch.num_token_non_padded, - global_forward_mode=forward_batch.global_forward_mode, + global_forward_mode=pcg_global_forward_mode, lora_ids=forward_batch.lora_ids, sampling_info=forward_batch.sampling_info, mm_inputs=forward_batch.mm_inputs, @@ -742,9 +754,13 @@ def replay( forward_batch: ForwardBatch, **kwargs, ) -> Union[LogitsProcessorOutput, PPProxyTensors, EmbeddingPoolerOutput]: + num_tokens = len(forward_batch.input_ids) + index = bisect.bisect_left(self.capture_num_tokens, num_tokens) + static_num_tokens = self.capture_num_tokens[index] with enable_piecewise_cuda_graph(): - # Due to the dispatch kernel for MLA model, we init the metadata with original forward_batch - self.model_runner.attn_backend.init_forward_metadata(forward_batch) + # Prepare static buffers first so set_forward_context can carry num_tokens + # into call_begin_forward (via ForwardContext.num_tokens), eliminating the + # need for a separate global and allowing pre-calculation of dummy-page count. static_forward_batch = self.replay_prepare(forward_batch, **kwargs) # Replay with set_forward_context( @@ -753,7 +769,10 @@ def replay( self.quant_config, self.moe_layers, self.moe_fusions, + num_tokens=static_num_tokens, ): + # Due to the dispatch kernel for MLA model, we init the metadata with original forward_batch + self.model_runner.attn_backend.init_forward_metadata(forward_batch) output = self.model_runner.model.forward( static_forward_batch.input_ids, static_forward_batch.positions, diff --git a/python/sglang/srt/model_loader/loader.py b/python/sglang/srt/model_loader/loader.py index 27d189d65622..b0884c68153e 100644 --- a/python/sglang/srt/model_loader/loader.py +++ b/python/sglang/srt/model_loader/loader.py @@ -479,7 +479,7 @@ def _get_weights_iterator( ) -> Generator[Tuple[str, torch.Tensor], None, None]: """Get an iterator for the model weights based on the load format.""" extra_config = self.load_config.model_loader_extra_config - use_multithread = extra_config.get("enable_multithread_load", False) + use_multithread = extra_config.get("enable_multithread_load", True) hf_folder, hf_weights_files, use_safetensors = self._prepare_weights( source.model_or_path, source.revision, source.fall_back_to_pt ) @@ -707,8 +707,6 @@ def load_weights_and_postprocess(model, weights, target_device): # parameters onto device for processing and back off after. with device_loading_context(module, target_device): quant_method.process_weights_after_loading(module) - if _is_npu: - torch.npu.empty_cache() class LayeredModelLoader(DefaultModelLoader): @@ -2865,6 +2863,245 @@ def _standard_quantization_workflow( return model.eval() +class RunaiModelStreamerLoader(BaseModelLoader): + """ + Model loader that uses Runai Model Streamer to load a model. + + Supports fast model loading from SSDs, shared filesystems and object storage (S3, GCS, Azure blob) with weight streaming. + + Configuration (via load_config.model_loader_extra_config): + - distributed (bool): Enable distributed streaming - True by default for url paths (object storage) + - concurrency (int): Number of concurrent downloads + - memory_limit (int): Memory limit for streaming buffer + + Note: Metadata files must be pre-downloaded via + ObjectStorageModel.download_and_get_path() before instantiation. + """ + + @dataclasses.dataclass + class Source: + """A source for weights.""" + + model_or_path: str + """The model ID or path.""" + + revision: Optional[str] + """The optional model revision.""" + + prefix: str = "" + """A prefix to prepend to all weights.""" + + fall_back_to_pt: bool = True + """Whether .pt weights can be used.""" + + model_config: Optional["ModelConfig"] = None + """The model configuration (for checking architecture, etc).""" + + @classmethod + def init_new(cls, model_config: ModelConfig, model): + model_weights = model_config.model_path + if hasattr(model_config, "model_weights"): + model_weights = model_config.model_weights + return cls( + model_weights, + model_config.revision, + prefix="", + fall_back_to_pt=getattr(model, "fall_back_to_pt_during_load", True), + model_config=model_config, + ) + + def __init__(self, load_config: LoadConfig): + super().__init__(load_config) + extra_config = load_config.model_loader_extra_config + allowed_keys = {"distributed", "concurrency", "memory_limit"} + unexpected_keys = set(extra_config.keys()) - allowed_keys + + if unexpected_keys: + raise ValueError( + f"Unexpected extra config keys for load format " + f"{load_config.load_format}: " + f"{unexpected_keys}" + ) + + set_runai_streamer_env(load_config) + + self._is_distributed = None + if load_config.model_loader_extra_config: + extra_config = load_config.model_loader_extra_config + + if "distributed" in extra_config and isinstance( + extra_config.get("distributed"), bool + ): + self._is_distributed = extra_config.get("distributed") + + def _prepare_weights( + self, model_name_or_path: str, revision: Optional[str] + ) -> Tuple[str, List[str]]: + """Prepare weights for the model. + + If the model is not local, it will be downloaded.""" + from sglang.srt.utils.runai_utils import is_runai_obj_uri, list_safetensors + + is_object_storage_path = is_runai_obj_uri(model_name_or_path) + if self._is_distributed is None: + self._is_distributed = is_object_storage_path + is_local = os.path.isdir(model_name_or_path) + safetensors_pattern = "*.safetensors" + index_file = SAFE_WEIGHTS_INDEX_NAME + + hf_folder = ( + model_name_or_path + if (is_local or is_object_storage_path) + else download_weights_from_hf( + model_name_or_path, + self.load_config.download_dir, + [safetensors_pattern], + revision, + ignore_patterns=self.load_config.ignore_patterns, + ) + ) + + server_args = get_global_server_args() + if server_args and server_args.model_checksum is not None: + from sglang.srt.utils.model_file_verifier import verify + + checksums_source = server_args.model_checksum or model_name_or_path + verify(model_path=hf_folder, checksums_source=checksums_source) + + hf_weights_files = list_safetensors(path=hf_folder) + + # For models like Mistral-7B-Instruct-v0.3 + # there are both sharded safetensors files and a consolidated + # safetensors file. Using both breaks. + # Here, we download the `model.safetensors.index.json` and filter + # any files not found in the index. + if not is_local and not is_object_storage_path: + download_safetensors_index_file_from_hf( + model_name_or_path, + index_file, + self.load_config.download_dir, + revision, + ) + hf_weights_files = filter_duplicate_safetensors_files( + hf_weights_files, hf_folder, index_file + ) + + if len(hf_weights_files) == 0: + raise RuntimeError( + f"Cannot find any model weights with `{model_name_or_path}`" + ) + + return hf_folder, hf_weights_files + + def _get_weights_iterator( + self, source: "Source" + ) -> Generator[Tuple[str, torch.Tensor], None, None]: + """Get an iterator for the model weights based on the load format.""" + from sglang.srt.model_loader.weight_utils import ( + runai_safetensors_weights_iterator, + ) + + hf_folder, hf_weights_files = self._prepare_weights( + source.model_or_path, source.revision + ) + + if source.model_config is not None: + hf_weights_files = maybe_add_mtp_safetensors( + hf_weights_files, + hf_folder, + "model.safetensors.index.json", + source.model_config.hf_config, + ) + + weights_iterator = runai_safetensors_weights_iterator( + hf_weights_files, self._is_distributed, self.target_device_str + ) + + if self.load_config.draft_model_idx is not None: + import re + + def filter_weights(original_weights_iterator): + pattern = r"model.mtp.layers.(\d+)." + for name, tensor in original_weights_iterator: + group = re.match(pattern, name) + if group is not None: + idx = int(group.group(1)) + if idx != self.load_config.draft_model_idx: + continue + new_name = name.replace(group.group(), "model.mtp.layers.0.") + else: + new_name = name + yield (new_name, tensor) + + weights_iterator = filter_weights(weights_iterator) + + def apply_prefix(original_weights_iterator): + yield from ( + (source.prefix + name, tensor) + for (name, tensor) in original_weights_iterator + ) + + return apply_prefix(weights_iterator) + + def _get_all_weights( + self, + model_config: ModelConfig, + model: nn.Module, + ) -> Generator[Tuple[str, torch.Tensor], None, None]: + + primary_weights = RunaiModelStreamerLoader.Source.init_new(model_config, model) + yield from self._get_weights_iterator(primary_weights) + + secondary_weights = cast( + Iterable[RunaiModelStreamerLoader.Source], + getattr(model, "secondary_weights", ()), + ) + for source in secondary_weights: + yield from self._get_weights_iterator(source) + + def download_model(self, model_config: ModelConfig) -> None: + self._prepare_weights(model_config.model_path, model_config.revision) + + def load_model( + self, + *, + model_config: ModelConfig, + device_config: DeviceConfig, + ) -> nn.Module: + + if hasattr(model_config, "modelopt_quant") and model_config.modelopt_quant: + # Load base model using shared method + raise NotImplementedError( + "Runai Model Streamer Loader does not support ModelOpt quantization yet" + ) + + assert device_config.device_type in ("cuda", "cpu"), ( + f"Runai Model Streamer only supports CUDA and CPU, " + f"got {device_config.device_type}" + ) + + if device_config.device_type == "cuda": + self.target_device_str = ( + device_config.device_type + ":" + str(device_config.gpu_id) + ) + else: + self.target_device_str = "cpu" + + target_device = torch.device(device_config.device) + with set_default_torch_dtype(model_config.dtype): + with target_device: + model = _initialize_model( + model_config, + self.load_config, + ) + + DefaultModelLoader.load_weights_and_postprocess( + model, self._get_all_weights(model_config, model), target_device + ) + + return model.eval() + + def get_model_loader( load_config: LoadConfig, model_config: Optional[ModelConfig] = None ) -> BaseModelLoader: @@ -2949,4 +3186,7 @@ def get_model_loader( except ImportError: raise ValueError("Failed to import sglang.private.private_model_loader") + if load_config.load_format == LoadFormat.RUNAI_STREAMER: + return RunaiModelStreamerLoader(load_config) + return DefaultModelLoader(load_config) diff --git a/python/sglang/srt/model_loader/remote_instance_weight_loader_utils.py b/python/sglang/srt/model_loader/remote_instance_weight_loader_utils.py index 8a945bb4c2e3..2a0aeb047ed6 100644 --- a/python/sglang/srt/model_loader/remote_instance_weight_loader_utils.py +++ b/python/sglang/srt/model_loader/remote_instance_weight_loader_utils.py @@ -106,21 +106,6 @@ def get_remote_instance_transfer_engine_info_per_rank(seed_url: str, rank: int): return None, None -def parse_remote_instance_transfer_engine_info_from_scheduler_infos(scheduler_infos): - remote_instance_transfer_engine_info = {} - for data in scheduler_infos: - if ( - "tp_rank" in data - and "remote_instance_transfer_engine_session_id" in data - and "remote_instance_transfer_engine_weights_info_dict" in data - ): - remote_instance_transfer_engine_info[data["tp_rank"]] = ( - data["remote_instance_transfer_engine_session_id"], - data["remote_instance_transfer_engine_weights_info_dict"], - ) - return remote_instance_transfer_engine_info - - def register_memory_region(model, transfer_engine): if importlib.util.find_spec("torch") is None: return register_memory_region_v1(model, transfer_engine) diff --git a/python/sglang/srt/model_loader/utils.py b/python/sglang/srt/model_loader/utils.py index 18739ed6954e..832ff1424b65 100644 --- a/python/sglang/srt/model_loader/utils.py +++ b/python/sglang/srt/model_loader/utils.py @@ -27,9 +27,87 @@ def set_default_torch_dtype(dtype: torch.dtype): torch.set_default_dtype(old_dtype) +def _is_moe_model(model_config: ModelConfig, architectures: list[str]) -> bool: + lowered_arches = [arch.lower() for arch in architectures] + if any("moe" in arch or "mixtral" in arch for arch in lowered_arches): + return True + + text_config = model_config.hf_text_config + expert_attrs = ( + "num_local_experts", + "num_experts", + "num_experts_per_tok", + "moe_intermediate_size", + "n_routed_experts", + ) + for attr in expert_attrs: + value = getattr(text_config, attr, None) + if value is None: + continue + if isinstance(value, bool): + if value: + return True + continue + if isinstance(value, (int, float)): + threshold = 0 if attr == "moe_intermediate_size" else 1 + if value > threshold: + return True + continue + if isinstance(value, (list, tuple, set, dict)): + if len(value) > 0: + return True + continue + if isinstance(value, str) and value == "": + continue + if value is not None: + return True + return False + + +def _is_sequence_classification_model(architectures: list[str]) -> bool: + return any( + "sequenceclassification" in lowered or "rewardmodel" in lowered + for lowered in (arch.lower() for arch in architectures) + ) + + +def _get_transformers_backend_arch( + model_config: ModelConfig, architectures: list[str] +) -> str: + is_pooling = not model_config.is_generation + is_multimodal = model_config.is_multimodal or ( + model_config.hf_config is not model_config.hf_text_config + ) + is_moe = _is_moe_model(model_config, architectures) + base_arch = "ForCausalLM" + if is_pooling: + base_arch = ( + "ForSequenceClassification" + if _is_sequence_classification_model(architectures) + else "EmbeddingModel" + ) + + arch = "Transformers" + if is_multimodal: + arch += "MultiModal" + if is_moe: + arch += "MoE" + return arch + base_arch + + +def _model_impl_from_architecture(architecture: str) -> ModelImpl: + if architecture.startswith("Transformers"): + return ModelImpl.TRANSFORMERS + if architecture.startswith("MindSpore"): + return ModelImpl.MINDSPORE + return ModelImpl.SGLANG + + def resolve_transformers_arch(model_config: ModelConfig, architectures: list[str]): - for i, arch in enumerate(architectures): - if arch == "TransformersForCausalLM": + backend_arch = _get_transformers_backend_arch(model_config, architectures) + + for arch in architectures: + if arch.startswith("Transformers"): continue auto_map: dict[str, str] = ( getattr(model_config.hf_config, "auto_map", None) or dict() @@ -42,15 +120,33 @@ def resolve_transformers_arch(model_config: ModelConfig, architectures: list[str # "AutoModel": "--", # "AutoModelFor": "--", # }, - auto_modules = { - name: get_class_from_dynamic_module( - module, model_config.model_path, revision=model_config.revision + auto_modules = {} + try: + auto_modules = { + name: get_class_from_dynamic_module( + module, model_config.model_path, revision=model_config.revision + ) + for name, module in sorted(auto_map.items(), key=lambda x: x[0]) + } + except Exception as e: + logger.warning( + "Failed to load dynamic modules from auto_map for '%s': %s. " + "Skipping remote model compatibility checks.", + arch, + e, ) - for name, module in sorted(auto_map.items(), key=lambda x: x[0]) - } model_module = getattr(transformers, arch, None) if model_module is None: - if "AutoModel" not in auto_map: + has_auto_model = "AutoModel" in auto_modules + if not has_auto_model and model_config.model_impl == ModelImpl.TRANSFORMERS: + logger.warning( + "Cannot resolve model class for '%s' and no auto_map.AutoModel " + "is present. Skipping compatibility gate because " + "--model-impl=transformers is explicitly requested.", + arch, + ) + continue + if not has_auto_model and "AutoModel" not in auto_map: raise ValueError( f"Cannot find model module. '{arch}' is not a registered " "model in the Transformers library (only relevant if the " @@ -58,16 +154,25 @@ def resolve_transformers_arch(model_config: ModelConfig, architectures: list[str "not present in the model config's 'auto_map' (relevant " "if the model is custom)." ) + if not has_auto_model: + raise ValueError( + f"Cannot find model module. '{arch}' is not a registered " + "model in the Transformers library and loading the custom " + f"model from auto_map failed. The remote model code may be " + f"incompatible with the installed transformers version." + ) model_module = auto_modules["AutoModel"] if model_config.model_impl == ModelImpl.TRANSFORMERS: if hasattr(model_module, "is_backend_compatible") and ( not model_module.is_backend_compatible() ): - raise ValueError( - f"The Transformers implementation of {arch} is not " - "compatible with SGLang." + logger.warning( + "The Transformers implementation of %s reports it is not " + "backend-compatible (_supports_attention_backend=False). " + "Proceeding anyway because --model-impl=transformers was " + "explicitly requested. The model may not work correctly.", + arch, ) - architectures[i] = "TransformersForCausalLM" if model_config.model_impl == ModelImpl.AUTO: if hasattr(model_module, "is_backend_compatible") and ( not model_module.is_backend_compatible() @@ -82,8 +187,7 @@ def resolve_transformers_arch(model_config: ModelConfig, architectures: list[str "performance may not be optimal.", arch, ) - architectures[i] = "TransformersForCausalLM" - return architectures + return [backend_arch] def get_model_architecture(model_config: ModelConfig) -> Tuple[Type[nn.Module], str]: @@ -114,7 +218,29 @@ def get_model_architecture(model_config: ModelConfig) -> Tuple[Type[nn.Module], architectures = ["MindSporeForCausalLM"] elif not is_native_supported or model_config.model_impl == ModelImpl.TRANSFORMERS: architectures = resolve_transformers_arch(model_config, architectures) - return ModelRegistry.resolve_model_cls(architectures) + model_cls, resolved_arch = ModelRegistry.resolve_model_cls(architectures) + setattr(model_config, "_resolved_model_arch", resolved_arch) + setattr( + model_config, + "_resolved_model_impl", + _model_impl_from_architecture(resolved_arch), + ) + return model_cls, resolved_arch + + +def get_resolved_model_impl(model_config: ModelConfig) -> ModelImpl: + resolved_model_impl = getattr(model_config, "_resolved_model_impl", None) + if resolved_model_impl is not None: + return resolved_model_impl + + resolved_arch = getattr(model_config, "_resolved_model_arch", None) + if resolved_arch is None: + _, resolved_arch = get_model_architecture(model_config) + + resolved_model_impl = _model_impl_from_architecture(resolved_arch) + setattr(model_config, "_resolved_model_arch", resolved_arch) + setattr(model_config, "_resolved_model_impl", resolved_model_impl) + return resolved_model_impl def get_architecture_class_name(model_config: ModelConfig) -> str: diff --git a/python/sglang/srt/model_loader/weight_utils.py b/python/sglang/srt/model_loader/weight_utils.py index 60e646fcf7c0..4c93176612a0 100644 --- a/python/sglang/srt/model_loader/weight_utils.py +++ b/python/sglang/srt/model_loader/weight_utils.py @@ -58,6 +58,7 @@ log_info_on_rank0, print_warning_once, ) +from sglang.srt.utils.common import is_cuda_alike from sglang.utils import is_in_ci try: @@ -186,6 +187,16 @@ def get_quant_config( if not isinstance(hf_quant_config, dict): hf_quant_config = hf_quant_config.to_dict() hf_quant_config["packed_modules_mapping"] = packed_modules_mapping + # For modelopt, route to FP4 vs FP8 config based on quant_algo + if model_config.quantization.startswith("modelopt"): + quant_algo = hf_quant_config.get("quant_algo") + if quant_algo is None: + quant_algo = hf_quant_config.get("quantization", {}).get("quant_algo") + if quant_algo is not None: + if quant_algo == "FP8" or model_config.quantization == "modelopt_fp8": + return ModelOptFp8Config.from_config(hf_quant_config) + if "FP4" in quant_algo: + return ModelOptFp4Config.from_config(hf_quant_config) return quant_cls.from_config(hf_quant_config) # In case of bitsandbytes/QLoRA, get quant config from the adapter model. @@ -1088,7 +1099,7 @@ def composed_loader(param: torch.Tensor, loaded_weight: torch.Tensor) -> None: def runai_safetensors_weights_iterator( - hf_weights_files: List[str], + hf_weights_files: List[str], is_distributed: bool = False, device: str = "cpu" ) -> Generator[Tuple[str, torch.Tensor], None, None]: """Iterate over the weights in the model safetensor files.""" from runai_model_streamer import SafetensorsStreamer @@ -1096,17 +1107,30 @@ def runai_safetensors_weights_iterator( enable_tqdm = ( not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0 ) + device = device if is_distributed and is_cuda_alike() else "cpu" with SafetensorsStreamer() as streamer: - for st_file in tqdm( + + streamer.stream_files( hf_weights_files, + device=device, + is_distributed=is_distributed, + ) + total_tensors = sum( + len(tensors_meta) + for tensors_meta in streamer.files_to_tensors_metadata.values() + ) + + tensor_iter = tqdm( + streamer.get_tensors(), + total=total_tensors, desc="Loading safetensors using Runai Model Streamer", - disable=not enable_tqdm, bar_format=BAR_FORMAT, - position=tqdm._get_free_pos(), - ): - streamer.stream_file(st_file) - yield from streamer.get_tensors() + disable=not enable_tqdm, + mininterval=2, + ) + + yield from tensor_iter def set_runai_streamer_env(load_config: LoadConfig): diff --git a/python/sglang/srt/models/deepseek_common/attention_forward_methods/forward_mla.py b/python/sglang/srt/models/deepseek_common/attention_forward_methods/forward_mla.py index fe1b3c966a71..cb83a13e9af0 100644 --- a/python/sglang/srt/models/deepseek_common/attention_forward_methods/forward_mla.py +++ b/python/sglang/srt/models/deepseek_common/attention_forward_methods/forward_mla.py @@ -292,9 +292,11 @@ def forward_absorb_prepare( q_nope_out = q_nope_out.transpose(0, 1) + skip_rope_for_nsa_tilelang_fused = self._skip_rope_for_nsa_tilelang_fused() if ( self.rotary_emb is not None and (not self._fuse_rope_for_trtllm_mla(forward_batch)) + and (not skip_rope_for_nsa_tilelang_fused) and (not _use_aiter or not _is_gfx95_supported or self.use_nsa) ): q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe) @@ -332,24 +334,69 @@ def forward_absorb_core( save_kv_cache = True if self.current_attention_backend in FORWARD_ABSORB_CORE_ATTENTION_BACKENDS: - extra_args = {} - if self._fuse_rope_for_trtllm_mla(forward_batch): - extra_args = { - "cos_sin_cache": self.rotary_emb.cos_sin_cache, - "is_neox": self.rotary_emb.is_neox_style, - "llama_4_scaling": llama_4_scaling, - } - - attn_output = self.attn_mqa( - q_nope_out, - k_nope, - k_nope, - forward_batch, - q_rope=q_pe, - k_rope=k_pe, - **extra_args, - **(dict(topk_indices=topk_indices) if topk_indices is not None else {}), - ) + if self._skip_rope_for_nsa_tilelang_fused() and self.rotary_emb is not None: + cos = self.rotary_emb.cos_cache + sin = self.rotary_emb.sin_cache + kv_cache_dtype = ( + fp8_dtype if self.kv_cache_dtype == "fp8_e4m3" else q_nope_out.dtype + ) + q_cat, _, k_pe_fused, _ = fused_qk_rope_cat_and_cache_mla( + q_nope_out, + q_pe, + k_nope, + k_pe, + forward_batch.token_to_kv_pool.get_key_buffer( + self.attn_mqa.layer_id + ), + forward_batch.out_cache_loc, + positions, + cos, + sin, + self.attn_mqa.k_scale, + self.rotary_emb.is_neox_style, + q_out_dtype=kv_cache_dtype, + ) + q_nope_fused = q_cat[..., : self.kv_lora_rank] + q_pe_fused = q_cat[..., self.kv_lora_rank :] + save_kv_cache = False + if llama_4_scaling is not None: + q_nope_fused *= llama_4_scaling + attn_output = self.attn_mqa( + q_nope_fused, + None, + None, + forward_batch, + q_rope=q_pe_fused, + k_rope=k_pe_fused, + save_kv_cache=save_kv_cache, + **( + dict(topk_indices=topk_indices) + if topk_indices is not None + else {} + ), + ) + else: + extra_args = {} + if self._fuse_rope_for_trtllm_mla(forward_batch): + extra_args = { + "cos_sin_cache": self.rotary_emb.cos_sin_cache, + "is_neox": self.rotary_emb.is_neox_style, + "llama_4_scaling": llama_4_scaling, + } + attn_output = self.attn_mqa( + q_nope_out, + k_nope, + k_nope, + forward_batch, + q_rope=q_pe, + k_rope=k_pe, + **extra_args, + **( + dict(topk_indices=topk_indices) + if topk_indices is not None + else {} + ), + ) else: if _use_aiter_gfx95: cos = self.rotary_emb.cos_cache @@ -532,3 +579,17 @@ def _fuse_rope_for_trtllm_mla( ) and forward_batch.attn_backend.data_type == torch.float8_e4m3fn ) + + def _skip_rope_for_nsa_tilelang_fused(self: DeepseekV2AttentionMLA) -> bool: + """ + Check if we should skip rope and use fused rope+cache path for TileLang NSA on gfx95. + """ + server_args = get_global_server_args() + return ( + _use_aiter_gfx95 + and self.current_attention_backend == "nsa" + and ( + server_args.nsa_decode_backend == "tilelang" + or server_args.nsa_prefill_backend == "tilelang" + ) + ) diff --git a/python/sglang/srt/models/deepseek_common/deepseek_weight_loader.py b/python/sglang/srt/models/deepseek_common/deepseek_weight_loader.py index 12ce382ed28d..b72e8290d773 100644 --- a/python/sglang/srt/models/deepseek_common/deepseek_weight_loader.py +++ b/python/sglang/srt/models/deepseek_common/deepseek_weight_loader.py @@ -50,6 +50,7 @@ _is_fp8_fnuz, _is_hip, _is_npu, + _is_xpu, _use_aiter_gfx95, awq_dequantize_func, enable_nextn_moe_bf16_cast_to_fp8, @@ -497,7 +498,7 @@ def post_load_weights( ) if ( - _is_cuda + (_is_cuda or _is_xpu) and weight_block_size[0] == 128 and weight_block_size[1] == 128 ): diff --git a/python/sglang/srt/models/deepseek_common/utils.py b/python/sglang/srt/models/deepseek_common/utils.py index 73f26c2b057e..a5579d5287c9 100644 --- a/python/sglang/srt/models/deepseek_common/utils.py +++ b/python/sglang/srt/models/deepseek_common/utils.py @@ -31,6 +31,7 @@ is_hip, is_npu, is_nvidia_cublas_version_ge_12_9, + is_xpu, ) _is_hip = is_hip() @@ -40,6 +41,7 @@ _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip _is_cpu_amx_available = cpu_has_amx_support() _is_cpu = is_cpu() +_is_xpu = is_xpu() _device_sm = get_device_sm() _is_gfx95_supported = is_gfx95_supported() _use_aiter_gfx95 = _use_aiter and _is_gfx95_supported diff --git a/python/sglang/srt/models/deepseek_nextn.py b/python/sglang/srt/models/deepseek_nextn.py index d57eb882296c..28029a0c75e9 100644 --- a/python/sglang/srt/models/deepseek_nextn.py +++ b/python/sglang/srt/models/deepseek_nextn.py @@ -15,9 +15,11 @@ """Inference-only DeepSeek NextN Speculative Decoding.""" import logging +import os from typing import Iterable, Optional, Tuple import torch +from safetensors.torch import load_file from torch import nn from transformers import PretrainedConfig @@ -99,6 +101,13 @@ def __init__( self.eh_proj = nn.Linear(2 * config.hidden_size, config.hidden_size, bias=False) + self.rot_weight = None + if _is_npu: + rot_weight_path = get_global_server_args().model_path + "/rot.safetensors" + if os.path.isfile(rot_weight_path): + self.rot_weight = load_file(rot_weight_path) + self.rot_weight = self.rot_weight["rot.weight"].npu() + self.alt_stream = ( torch.cuda.Stream() if _is_cuda or envs.SGLANG_NPU_USE_MULTI_STREAM.get() @@ -112,6 +121,7 @@ def __init__( ): layer_name = "layers." + str(config.num_hidden_layers) + self.quant_config = quant_config self.decoder = DeepseekV2DecoderLayer( config, 0, @@ -137,6 +147,9 @@ def forward( forward_batch: ForwardBatch, input_embeds: torch.Tensor = None, ) -> torch.Tensor: + if _is_npu and self.quant_config is None: + os.environ["SGLANG_DEEPEP_BF16_DISPATCH"] = "1" + os.environ["DEEP_NORMAL_MODE_USE_INT8_QUANT"] = "0" zero_allocator = BumpAllocator( buffer_size=2, dtype=torch.float32, @@ -155,7 +168,13 @@ def forward( torch.cat( ( self.enorm(hidden_states), - self.hnorm(forward_batch.spec_info.hidden_states), + self.hnorm( + forward_batch.spec_info.hidden_states + if self.rot_weight is None + else torch.matmul( + forward_batch.spec_info.hidden_states, self.rot_weight + ) + ), ), dim=-1, ) @@ -189,6 +208,9 @@ def forward( torch.cuda.current_stream(), ) + if _is_npu and self.quant_config is None: + os.environ["SGLANG_DEEPEP_BF16_DISPATCH"] = "0" + os.environ["DEEP_NORMAL_MODE_USE_INT8_QUANT"] = "1" return hidden_states diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 0d88f541a1dd..6fae78248777 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -137,6 +137,7 @@ _is_gfx95_supported, _is_hip, _is_npu, + _is_xpu, _use_aiter, _use_aiter_gfx95, ) @@ -152,9 +153,11 @@ use_intel_amx_backend, ) +if _use_aiter: + from sglang.srt.layers.rocm_linear_utils import aiter_dsv3_router_gemm + if _use_aiter_gfx95: from sglang.srt.layers.rocm_linear_utils import ( - aiter_dsv3_router_gemm, get_dsv3_gemm_output_zero_allocator_size, ) @@ -326,14 +329,8 @@ def forward( logits = dsv3_router_gemm( hidden_states, self.weight, out_dtype=torch.float32 ) - elif ( - _use_aiter_gfx95 - and hidden_states.shape[0] <= 256 - and self.weight.shape[0] <= 256 - ): - logits = aiter_dsv3_router_gemm( - hidden_states, self.weight, gemm_output_zero_allocator - ) + elif _use_aiter: + logits = aiter_dsv3_router_gemm(hidden_states, self.weight) else: logits = F.linear(hidden_states, self.weight, None) @@ -677,6 +674,7 @@ def _post_combine_hook( ) if ( not _is_cuda + and not _is_xpu and not _use_aiter or isinstance(self.experts.quant_method, KTEPWrapperMethod) ): diff --git a/python/sglang/srt/models/glm4_moe.py b/python/sglang/srt/models/glm4_moe.py index 121a31630ae3..ab40fdfcd9d5 100644 --- a/python/sglang/srt/models/glm4_moe.py +++ b/python/sglang/srt/models/glm4_moe.py @@ -34,6 +34,7 @@ from sglang.srt.distributed.device_communicators.pynccl_allocator import ( use_symmetric_memory, ) +from sglang.srt.environ import envs from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder from sglang.srt.eplb.expert_location import ModelConfigForExpertLocation from sglang.srt.eplb.expert_location_dispatch import ExpertLocationDispatchInfo @@ -91,6 +92,7 @@ is_cuda, is_hip, is_non_idle_and_non_empty, + is_npu, log_info_on_rank0, make_layers, ) @@ -102,10 +104,19 @@ _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip _is_cpu_amx_available = cpu_has_amx_support() _is_cpu = is_cpu() +_is_npu = is_npu() _device_sm = get_device_sm() logger = logging.getLogger(__name__) +if _is_npu: + from sgl_kernel_npu.norm.split_qkv_rmsnorm_rope import split_qkv_rmsnorm_rope + + from sglang.srt.hardware_backend.npu.utils import ( + process_shared_expert, + wait_share_stream, + ) + class Glm4MoeMLP(nn.Module): def __init__( @@ -278,17 +289,39 @@ def forward_prepare( if hidden_states.shape[0] == 0: return hidden_states, forward_batch, None qkv, _ = self.qkv_proj(hidden_states) - q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) - if self.use_qk_norm: - q, k = apply_qk_norm( - q=q, - k=k, - q_norm=self.q_norm, - k_norm=self.k_norm, - head_dim=self.head_dim, - alt_stream=self.alt_stream, + + if ( + not _is_npu + or forward_batch.forward_mode.is_extend_or_draft_extend_or_mixed() + ): + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + if self.use_qk_norm: + q, k = apply_qk_norm( + q=q, + k=k, + q_norm=self.q_norm, + k_norm=self.k_norm, + head_dim=self.head_dim, + alt_stream=self.alt_stream, + ) + q, k = self.rotary_emb(positions, q, k) + else: + if self.attn.layer_id == forward_batch.token_to_kv_pool.start_layer: + self.rotary_emb.get_cos_sin_with_position(positions) + q, k, v = split_qkv_rmsnorm_rope( + qkv, + self.rotary_emb.position_sin, + self.rotary_emb.position_cos, + self.q_size, + self.kv_size, + self.head_dim, + eps=self.q_norm.variance_epsilon, + q_weight=self.q_norm.weight, + k_weight=self.k_norm.weight, + q_bias=getattr(self.q_norm, "bias", None), + k_bias=getattr(self.k_norm, "bias", None), ) - q, k = self.rotary_emb(positions, q, k) + inner_state = q, k, v, forward_batch return None, forward_batch, inner_state @@ -327,9 +360,14 @@ def __init__( self.e_score_correction_bias = nn.Parameter( torch.empty((config.n_routed_experts), dtype=torch.float32) ) + # GLM requires FP32 gate projection; cache to avoid per-forward cast. + # FIXME: if gate weight is updated at runtime (e.g. expert rebalancing), _weight_fp32 must be invalidated. + self.register_buffer("_weight_fp32", None, persistent=False) def forward(self, hidden_states): - logits = F.linear(hidden_states, self.weight, None) + if self._weight_fp32 is None: + self._weight_fp32 = self.weight.data.to(torch.float32) + logits = F.linear(hidden_states.to(torch.float32), self._weight_fp32, None) return logits @@ -557,10 +595,24 @@ def forward_deepep( self, hidden_states: torch.Tensor, forward_batch: ForwardBatch ) -> torch.Tensor: shared_output = None + enable_npu_dual_stream = ( + _is_npu + and ( + forward_batch.forward_mode.is_extend() + or forward_batch.forward_mode.is_target_verify() + ) + and envs.SGLANG_NPU_USE_MULTI_STREAM.get() + ) + if hidden_states.shape[0] > 0: # router_logits: (num_tokens, n_experts) router_logits = self.gate(hidden_states) - shared_output = self._forward_shared_experts(hidden_states) + if enable_npu_dual_stream: + shared_output = process_shared_expert( + hidden_states, self._forward_shared_experts + ) + else: + shared_output = self._forward_shared_experts(hidden_states) topk_output = self.topk( hidden_states, router_logits, @@ -571,10 +623,13 @@ def forward_deepep( ) else: topk_output = self.topk.empty_topk_output(hidden_states.device) + final_hidden_states = self.experts( hidden_states=hidden_states, topk_output=topk_output, ) + if enable_npu_dual_stream: + wait_share_stream() if shared_output is not None: x = shared_output diff --git a/python/sglang/srt/models/glm4_moe_nextn.py b/python/sglang/srt/models/glm4_moe_nextn.py index 1f6e753646cb..6d2ef89724da 100644 --- a/python/sglang/srt/models/glm4_moe_nextn.py +++ b/python/sglang/srt/models/glm4_moe_nextn.py @@ -14,6 +14,7 @@ """Inference-only GLM-4.5, GLM-4.6 Speculative Decoding.""" +import contextlib import logging from typing import Iterable, Optional, Tuple @@ -22,6 +23,7 @@ from transformers import PretrainedConfig from sglang.srt.distributed import get_tensor_model_parallel_world_size +from sglang.srt.environ import temp_set_env from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder from sglang.srt.layers.dp_attention import is_dp_attention_enabled from sglang.srt.layers.layernorm import RMSNorm @@ -126,7 +128,10 @@ def __init__( nn.Module.__init__(self) self.config = config self.tp_size = get_tensor_model_parallel_world_size() - self.quant_config = quant_config + self.needs_quant_draft = ( + get_global_server_args().speculative_draft_model_quantization + ) + quant_config = quant_config if self.needs_quant_draft else None self.model = Glm4MoeModelNextN( config, quant_config, prefix=add_prefix("model", prefix) ) @@ -150,7 +155,19 @@ def forward( positions: torch.Tensor, forward_batch: ForwardBatch, ) -> torch.Tensor: - hidden_states = self.model(input_ids, positions, forward_batch) + # Support unquant speculative draft model + if self.needs_quant_draft: + cxt = contextlib.nullcontext() + else: + unquant_patch = { + "SGLANG_DEEPEP_BF16_DISPATCH": "1", + "DEEP_NORMAL_MODE_USE_INT8_QUANT": "0", + } + cxt = temp_set_env(allow_sglang=True, **unquant_patch) + + with cxt: + hidden_states = self.model(input_ids, positions, forward_batch) + return self.logits_processor( input_ids, hidden_states, self.lm_head, forward_batch ) diff --git a/python/sglang/srt/models/glm4v.py b/python/sglang/srt/models/glm4v.py index e2ff4da96d31..9bb5a92b2cdc 100644 --- a/python/sglang/srt/models/glm4v.py +++ b/python/sglang/srt/models/glm4v.py @@ -562,8 +562,6 @@ def __init__( use_data_parallel=self.use_data_parallel, ) - vision_utils.update_vit_attn_dummy_heads_config(self.config) - self.model = Glm4Model( config, quant_config=quant_config, diff --git a/python/sglang/srt/models/gpt_oss.py b/python/sglang/srt/models/gpt_oss.py index 593ef4b9f932..c2e1ac233028 100644 --- a/python/sglang/srt/models/gpt_oss.py +++ b/python/sglang/srt/models/gpt_oss.py @@ -17,6 +17,7 @@ import logging import math +import re from collections.abc import Iterable from functools import partial from typing import Any, Dict, List, Optional, Tuple, Union @@ -651,6 +652,13 @@ def forward( class GptOssForCausalLM(nn.Module): fall_back_to_pt_during_load = False + _lora_pattern_moe = re.compile( + r"^(?:model\.layers\.\d+\.(?:self_attn\.(?:qkv_proj|o_proj)|mlp\.experts)|lm_head|model\.embed_tokens)$" + ) + + def should_apply_lora(self, module_name: str) -> bool: + return bool(self._lora_pattern_moe.match(module_name)) + def __init__( self, config: GptOssConfig, diff --git a/python/sglang/srt/models/grok.py b/python/sglang/srt/models/grok.py index 8a6a5f0d5c88..408811d71460 100644 --- a/python/sglang/srt/models/grok.py +++ b/python/sglang/srt/models/grok.py @@ -61,7 +61,6 @@ from sglang.srt.model_loader.loader import DefaultModelLoader from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.utils import add_prefix, is_npu -from sglang.srt.utils.hf_transformers_utils import get_rope_config _is_npu = is_npu() @@ -478,7 +477,10 @@ def __init__( self.layer_id = layer_id self.alt_stream = alt_stream or torch.cuda.Stream() - rope_theta, _ = get_rope_config(config) + rope_theta = getattr(config, "rope_theta", None) + if rope_theta is None: + rope_params = getattr(config, "rope_parameters", None) + rope_theta = rope_params["rope_theta"] if rope_params else 10000 self.self_attn = Grok1Attention( config=config, hidden_size=self.hidden_size, diff --git a/python/sglang/srt/models/llava.py b/python/sglang/srt/models/llava.py index c76312c0a833..e07ca7418f3f 100644 --- a/python/sglang/srt/models/llava.py +++ b/python/sglang/srt/models/llava.py @@ -53,8 +53,26 @@ ) from sglang.srt.utils import add_prefix, flatten_nested_list, logger +_KNOWN_BROKEN_AUTOMODEL_CONFIG = "VoxtralRealtimeTextConfig" +_KNOWN_BROKEN_AUTOMODEL_ERROR = "Could not find VoxtralRealtimeTextModel" + class LlavaBaseForCausalLM(nn.Module): + @staticmethod + def _infer_image_aspect_ratio(mm_items): + """Determine image_aspect_ratio from processor metadata or item count.""" + # Check if processor stored the aspect_ratio it used + for item in mm_items: + ar = item.model_specific_data.get("image_aspect_ratio") + if ar is not None: + return ar + # Fallback: multi-image or video → pad, single image → anyres + image_items = [item for item in mm_items if item.is_image()] + has_video = any(item.is_video() for item in mm_items) + if len(image_items) > 1 or has_video: + return "pad" + return "anyres" + def pad_input_ids(self, input_ids: List[int], image_inputs: MultimodalInputs): image_sizes = flatten_nested_list( [item.image_sizes for item in image_inputs.mm_items] @@ -63,13 +81,8 @@ def pad_input_ids(self, input_ids: List[int], image_inputs: MultimodalInputs): pad_values = [item.pad_value for item in image_inputs.mm_items] # hardcode for spatial_unpad + anyres - if any( - item.modality == Modality.MULTI_IMAGES or item.modality == Modality.VIDEO - for item in image_inputs.mm_items - ): - image_aspect_ratio = "pad" - else: - image_aspect_ratio = "anyres" + # Use per-item aspect_ratio from processor if available, else infer + image_aspect_ratio = self._infer_image_aspect_ratio(image_inputs.mm_items) offset_list = [] image_inputs.image_pad_len = [] for image_idx, image_s in enumerate(image_sizes): @@ -165,13 +178,9 @@ def forward( # Embed text inputs input_embeds = self.language_model.model.embed_tokens(input_ids) - # Got List[List[str]] extend it to List[str] - # The length of the List should be equal to batch size - modalities_list = [] + # Compute max image offset per request to determine need_vision max_image_offset = [] for im in image_inputs: - if im: - modalities_list.extend([item.modality for item in im.mm_items]) if im and im.image_offsets: max_image_offset.append( np.max(np.array(im.image_offsets) + np.array(im.image_pad_len)) @@ -184,6 +193,18 @@ def forward( if need_vision.any(): bs = forward_batch.batch_size + + # Build per-image lists filtered by need_vision + modalities_list = [] + aspect_ratios = [] # per-image aspect ratio + for i in range(bs): + if need_vision[i] and image_inputs[i]: + items = image_inputs[i].mm_items + ar = self._infer_image_aspect_ratio(items) + for item in items: + modalities_list.append(item.modality) + aspect_ratios.append(ar) + pixel_values = flatten_nested_list( [ [item.feature for item in image_inputs[i].mm_items] @@ -191,12 +212,12 @@ def forward( if need_vision[i] ] ) + # Per-image sizes (each entry is [(w,h)] for one image) image_sizes = [ - flatten_nested_list( - [item.image_sizes for item in image_inputs[i].mm_items] - ) + item.image_sizes for i in range(bs) if need_vision[i] + for item in image_inputs[i].mm_items ] ########## Encode Image ######## @@ -225,18 +246,7 @@ def forward( new_image_features = [] height = width = self.num_patches_per_side for image_idx, image_feature in enumerate(image_features): - if modalities_list[image_idx] == Modality.IMAGE: - image_aspect_ratio = ( - self.config.image_aspect_ratio - ) # single image - elif ( - modalities_list[image_idx] == Modality.MULTI_IMAGES - or modalities_list[image_idx] == Modality.VIDEO - ): - image_aspect_ratio = "pad" # multi image - # image_aspect_ratio = ( - # "anyres" if len(image_sizes[image_idx]) == 1 else "pad" - # ) + image_aspect_ratio = aspect_ratios[image_idx] if ( image_feature.shape[0] > 1 and "anyres" in image_aspect_ratio @@ -385,6 +395,7 @@ def forward( extend_start_loc_cpu = forward_batch.extend_start_loc.cpu().numpy() extend_seq_lens = forward_batch.extend_seq_lens.cpu().numpy() prefix_lens_cpu = forward_batch.extend_prefix_lens_cpu + # Fill in the image features using flat indexing (one pt per image) pt = 0 for i in range(bs): if not need_vision[i]: @@ -393,20 +404,25 @@ def forward( start_idx = extend_start_loc_cpu[i] seq_len = extend_seq_lens[i] prefix_len = prefix_lens_cpu[i] + n_images = len(image_inputs[i].image_offsets) + + for j in range(n_images): + image_offset = image_inputs[i].image_offsets[j] - # Multiple images - for image_idx, image_offset in enumerate( - image_inputs[i].image_offsets - ): if ( - image_offset + image_inputs[i].image_pad_len[image_idx] + image_offset + image_inputs[i].image_pad_len[j] <= prefix_len ): + pt += 1 continue if image_offset >= prefix_len + seq_len: + pt += n_images - j break - tmp_image_feature = image_features[pt][image_idx] + tmp_image_feature = image_features[pt] + # Squeeze batch dim from per-image features [1, feat, hidden] + if tmp_image_feature.ndim == 3: + tmp_image_feature = tmp_image_feature[0] pad_len = tmp_image_feature.shape[0] input_offset = image_offset - prefix_len @@ -429,7 +445,7 @@ def forward( print( f"{start_idx=}, {image_offset=}, {prefix_len=}, {pad_len=}" ) - pt += 1 + pt += 1 return self.language_model( input_ids, positions, forward_batch, input_embeds=input_embeds @@ -657,7 +673,22 @@ def _config_cls_name_to_arch_name_mapping( ) -> Dict[str, str]: mapping = {} for config_cls in auto_model_type._model_mapping.keys(): - archs = auto_model_type._model_mapping.get(config_cls, None) + try: + archs = auto_model_type._model_mapping.get(config_cls, None) + except ValueError as exc: + if ( + auto_model_type is not AutoModel + or config_cls.__name__ != _KNOWN_BROKEN_AUTOMODEL_CONFIG + or _KNOWN_BROKEN_AUTOMODEL_ERROR not in str(exc) + ): + raise + logger.warning( + "Skipping broken %s mapping for config %s: %s", + auto_model_type.__name__, + config_cls.__name__, + exc, + ) + continue if archs is not None: if isinstance(archs, tuple): mapping[config_cls.__name__] = tuple( diff --git a/python/sglang/srt/models/minicpmv.py b/python/sglang/srt/models/minicpmv.py index cd4489152b8a..588c356a473c 100644 --- a/python/sglang/srt/models/minicpmv.py +++ b/python/sglang/srt/models/minicpmv.py @@ -993,7 +993,11 @@ def pad_input_ids(self, input_ids: List[int], image_inputs: MultimodalInputs): slice_end_id: int = image_inputs.slice_end_id media_token_pairs = [(im_start_id, im_end_id), (slice_start_id, slice_end_id)] - pattern = MultiModalityDataPaddingPatternTokenPairs(media_token_pairs) + # Only increment data_idx on im_start (not slice_start) so all slices + # within one image share the same pad_value for per-image caching. + pattern = MultiModalityDataPaddingPatternTokenPairs( + media_token_pairs, data_start_token_ids=[im_start_id] + ) return pattern.pad_input_tokens(input_ids, image_inputs) @@ -1155,7 +1159,11 @@ def pad_input_ids(self, input_ids: List[int], image_inputs: MultimodalInputs): slice_end_id: int = image_inputs.slice_end_id media_token_pairs = [(im_start_id, im_end_id), (slice_start_id, slice_end_id)] - pattern = MultiModalityDataPaddingPatternTokenPairs(media_token_pairs) + # Only increment data_idx on im_start (not slice_start) so all slices + # within one image share the same pad_value for per-image caching. + pattern = MultiModalityDataPaddingPatternTokenPairs( + media_token_pairs, data_start_token_ids=[im_start_id] + ) return pattern.pad_input_tokens(input_ids, image_inputs) @@ -1321,7 +1329,11 @@ def pad_input_ids(self, input_ids: List[int], image_inputs: MultimodalInputs): slice_end_id: int = image_inputs.slice_end_id media_token_pairs = [(im_start_id, im_end_id), (slice_start_id, slice_end_id)] - pattern = MultiModalityDataPaddingPatternTokenPairs(media_token_pairs) + # Only increment data_idx on im_start (not slice_start) so all slices + # within one image share the same pad_value for per-image caching. + pattern = MultiModalityDataPaddingPatternTokenPairs( + media_token_pairs, data_start_token_ids=[im_start_id] + ) return pattern.pad_input_tokens(input_ids, image_inputs) diff --git a/python/sglang/srt/models/minimax_m2.py b/python/sglang/srt/models/minimax_m2.py index 470b1e0f7ff8..e5ef2d75c7dc 100644 --- a/python/sglang/srt/models/minimax_m2.py +++ b/python/sglang/srt/models/minimax_m2.py @@ -73,6 +73,7 @@ is_non_idle_and_non_empty, make_layers, ) +from sglang.srt.utils.hf_transformers_utils import get_rope_config logger = logging.getLogger(__name__) @@ -570,7 +571,7 @@ def __init__( # RoPE settings - support partial RoPE # FIXME: minimax_m2 config use external config that not compatible with transformers v5 - self.rope_theta = config.rope_theta + self.rope_theta, self.rope_scaling = get_rope_config(config) self.max_position_embeddings = getattr(config, "max_position_embeddings", 8192) self.rotary_dim = getattr( config, "rotary_dim", self.head_dim @@ -600,13 +601,12 @@ def __init__( ) # Setup RoPE with partial rotary dimension - rope_scaling = getattr(config, "rope_scaling", None) self.rotary_emb = get_rope( self.head_dim, rotary_dim=self.rotary_dim, # Use partial rotary dimension max_position=self.max_position_embeddings, base=self.rope_theta, - rope_scaling=rope_scaling, + rope_scaling=self.rope_scaling, ) # QK Normalization layers diff --git a/python/sglang/srt/models/qwen2.py b/python/sglang/srt/models/qwen2.py index 85a510fc7854..55659586d4bb 100644 --- a/python/sglang/srt/models/qwen2.py +++ b/python/sglang/srt/models/qwen2.py @@ -269,6 +269,7 @@ def __init__( ) -> None: super().__init__() self.config = config + self.padding_idx = getattr(config, "pad_token_id", None) self.vocab_size = config.vocab_size self.pp_group = get_pp_group() diff --git a/python/sglang/srt/models/qwen3.py b/python/sglang/srt/models/qwen3.py index 4f018ea52d2f..7e557a8d5b31 100644 --- a/python/sglang/srt/models/qwen3.py +++ b/python/sglang/srt/models/qwen3.py @@ -19,6 +19,7 @@ from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.rotary_embedding import get_rope +from sglang.srt.layers.rotary_embedding.mrope import MRotaryEmbedding from sglang.srt.layers.utils import PPMissingLayer, get_layer_id from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors @@ -30,13 +31,25 @@ from sglang.srt.models.qwen2 import Qwen2Model from sglang.srt.models.utils import apply_qk_norm from sglang.srt.server_args import get_global_server_args -from sglang.srt.utils import add_prefix, is_cuda, is_npu +from sglang.srt.utils import add_prefix, get_bool_env_var, is_cuda, is_hip, is_npu Qwen3Config = None logger = logging.getLogger(__name__) _is_cuda = is_cuda() +_is_hip = is_hip() _is_npu = is_npu() +_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip + +_has_fused_qk_norm_mrope = False +if _use_aiter: + try: + from aiter import fused_qk_norm_mrope_3d_cache_pts_quant_shuffle + + _has_fused_qk_norm_mrope = True + logger.info("aiter fused_qk_norm_mrope_3d kernel available") + except ImportError: + pass if _is_npu: from sgl_kernel_npu.norm.split_qkv_rmsnorm_rope import split_qkv_rmsnorm_rope @@ -138,6 +151,19 @@ def __init__( ) self.alt_stream = alt_stream + self.use_fused_qk_norm_mrope = ( + _has_fused_qk_norm_mrope + and isinstance(self.rotary_emb, MRotaryEmbedding) + and getattr(self.rotary_emb, "mrope_section", None) is not None + ) + if self.use_fused_qk_norm_mrope: + # Scale tensors MUST stay on CPU: the C++ kernel uses .item() + # which triggers hipMemcpy D2H + sync on CUDA tensors, breaking graph capture. + # Explicit device='cpu' is required because SGLang constructs models inside + # a `with torch.device('cuda'):` context that changes the default device. + self._fused_k_scale = torch.tensor(1.0, dtype=torch.float32, device="cpu") + self._fused_v_scale = torch.tensor(1.0, dtype=torch.float32, device="cpu") + def forward_prepare_native(self, positions, hidden_states): qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) @@ -172,6 +198,68 @@ def forward_prepare_npu(self, positions, hidden_states, forward_batch): ) return q, k, v + def forward_prepare_aiter_fused_mrope( + self, positions, hidden_states, forward_batch + ): + """Fused QK-norm + 3D mRoPE + KV cache write for decode (ROCm/aiter). + + The fused HIP kernel replaces split → QK norm → mRoPE → cache write, + so KV is already in the paged cache when this returns. + Returns (q, None, None); caller must pass save_kv_cache=False to attn. + """ + qkv, _ = self.qkv_proj(hidden_states) + num_tokens = qkv.shape[0] + + qkv_3d = qkv.view(num_tokens, -1, self.head_dim) + + token_to_kv_pool = forward_batch.token_to_kv_pool + k_cache, v_cache = token_to_kv_pool.get_kv_buffer(self.attn.layer_id) + slot_mapping = forward_batch.out_cache_loc + + cos_sin = self.rotary_emb.cos_sin_cache + if cos_sin.dtype != qkv.dtype: + cos_sin = cos_sin.to(dtype=qkv.dtype) + + q_out = torch.empty( + num_tokens, + self.num_heads, + self.head_dim, + dtype=qkv.dtype, + device=qkv.device, + ) + + fused_qk_norm_mrope_3d_cache_pts_quant_shuffle( + qkv_3d, + self.q_norm.weight, + self.k_norm.weight, + cos_sin, + positions, + num_tokens, + self.num_heads, + self.num_kv_heads, + self.num_kv_heads, + self.head_dim, + self.rotary_emb.is_neox_style, + self.rotary_emb.mrope_section, + self.rotary_emb.mrope_interleaved, + self.q_norm.variance_epsilon, + q_out, + k_cache, + v_cache, + slot_mapping, + self._fused_k_scale, + self._fused_v_scale, + None, + None, + False, + False, + 0, + 0, + ) + + q = q_out.reshape(num_tokens, -1) + return q, None, None + def forward( self, positions: torch.Tensor, @@ -181,7 +269,19 @@ def forward( if get_global_server_args().rl_on_policy_target is not None: hidden_states = hidden_states.bfloat16() - if ( + save_kv_cache = True + use_aiter_fused = ( + self.use_fused_qk_norm_mrope + and forward_batch.forward_mode.is_decode() + and get_global_server_args().rl_on_policy_target is None + ) + + if use_aiter_fused: + q, k, v = self.forward_prepare_aiter_fused_mrope( + positions, hidden_states, forward_batch + ) + save_kv_cache = False + elif ( not _is_npu or forward_batch.forward_mode.is_extend_or_draft_extend_or_mixed() ): @@ -200,7 +300,7 @@ def forward( q = q.to(torch.bfloat16) k = k.to(torch.bfloat16) - attn_output = self.attn(q, k, v, forward_batch) + attn_output = self.attn(q, k, v, forward_batch, save_kv_cache=save_kv_cache) output, _ = self.o_proj(attn_output) return output diff --git a/python/sglang/srt/models/qwen3_5.py b/python/sglang/srt/models/qwen3_5.py index ebdd00e002bb..66a5a3b59732 100644 --- a/python/sglang/srt/models/qwen3_5.py +++ b/python/sglang/srt/models/qwen3_5.py @@ -67,7 +67,7 @@ from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_linear_attention import RadixLinearAttention from sglang.srt.layers.rotary_embedding import get_rope -from sglang.srt.layers.utils import PPMissingLayer +from sglang.srt.layers.utils import PPMissingLayer, get_layer_id from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors @@ -88,6 +88,7 @@ cpu_has_amx_support, is_cpu, is_cuda, + is_gfx95_supported, is_npu, make_layers, set_weight_attrs, @@ -98,6 +99,7 @@ _is_cuda = is_cuda() _is_npu = is_npu() _is_cpu = is_cpu() +_is_gfx95 = is_gfx95_supported() _is_amx_available = cpu_has_amx_support() cached_get_processor = lru_cache(get_processor) @@ -893,6 +895,14 @@ def forward( class Qwen3_5ForCausalLM(nn.Module): """Qwen3.5 Model with support for dense variant.""" + if _is_gfx95: + packed_modules_mapping = { + "qkv_proj": ["q_proj", "k_proj", "v_proj"], + "gate_up_proj": ["gate_proj", "up_proj"], + "in_proj_qkvz": ["in_proj_qkv", "in_proj_z"], + "in_proj_ba": ["in_proj_b", "in_proj_a"], + } + def __init__( self, config: Qwen3_5TextConfig, @@ -1052,6 +1062,13 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): name = name.replace(r"model.language_model.", r"model.") if ".self_attn." in name: name = name.replace(".self_attn", "") + layer_id = get_layer_id(name) + if ( + layer_id is not None + and hasattr(self, "start_layer") + and (layer_id < self.start_layer or layer_id >= self.end_layer) + ): + continue for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: @@ -1189,6 +1206,14 @@ def load_fused_expert_weights( if ".self_attn." in name: name = name.replace(".self_attn", "") + layer_id = get_layer_id(name) + if ( + layer_id is not None + and hasattr(self, "start_layer") + and (layer_id < self.start_layer or layer_id >= self.end_layer) + ): + continue + for param_name, weight_name, shard_id in stacked_params_mapping: if "experts.gate_up_proj" in name or "experts.down_proj" in name: is_fused_expert = True @@ -1299,6 +1324,10 @@ def load_fused_expert_weights( class Qwen3_5ForConditionalGeneration(Qwen3VLForConditionalGeneration): + if _is_gfx95: + packed_modules_mapping = Qwen3_5ForCausalLM.packed_modules_mapping + hf_to_sglang_mapper = None + def __init__( self, config: Qwen3_5Config, @@ -1369,6 +1398,24 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): name = name.replace(r"model.language_model.", r"model.") if ".self_attn." in name: name = name.replace(".self_attn", "") + if ( + self.config.tie_word_embeddings + and self.pp_group.is_last_rank + and "model.embed_tokens.weight" in name + ): + if "lm_head.weight" in params_dict: + lm_head_param = params_dict["lm_head.weight"] + weight_loader = getattr( + lm_head_param, "weight_loader", default_weight_loader + ) + weight_loader(lm_head_param, loaded_weight) + layer_id = get_layer_id(name) + if ( + layer_id is not None + and hasattr(self, "start_layer") + and (layer_id < self.start_layer or layer_id >= self.end_layer) + ): + continue for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: @@ -1414,6 +1461,10 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): class Qwen3_5MoeForConditionalGeneration(Qwen3VLForConditionalGeneration): """Qwen3.5 MoE Vision-Language Model.""" + if _is_gfx95: + packed_modules_mapping = Qwen3_5ForCausalLM.packed_modules_mapping + hf_to_sglang_mapper = None + def __init__( self, config: Qwen3_5MoeConfig, @@ -1523,6 +1574,25 @@ def load_fused_expert_weights( name = name.replace(r"model.language_model.", r"model.") if ".self_attn." in name: name = name.replace(".self_attn", "") + if ( + self.config.tie_word_embeddings + and self.pp_group.is_last_rank + and "model.embed_tokens.weight" in name + ): + if "lm_head.weight" in params_dict: + lm_head_param = params_dict["lm_head.weight"] + weight_loader = getattr( + lm_head_param, "weight_loader", default_weight_loader + ) + weight_loader(lm_head_param, loaded_weight) + + layer_id = get_layer_id(name) + if ( + layer_id is not None + and hasattr(self, "start_layer") + and (layer_id < self.start_layer or layer_id >= self.end_layer) + ): + continue for param_name, weight_name, shard_id in stacked_params_mapping: if name.endswith("experts.gate_up_proj") or name.endswith( diff --git a/python/sglang/srt/models/qwen3_moe.py b/python/sglang/srt/models/qwen3_moe.py index 010a73074759..912891b6a7eb 100644 --- a/python/sglang/srt/models/qwen3_moe.py +++ b/python/sglang/srt/models/qwen3_moe.py @@ -513,6 +513,7 @@ def __init__( self.compatible_with_fused_qk_norm_rope = not isinstance( self.rotary_emb, MRotaryEmbedding ) and self.head_dim in (64, 128, 256) + _yarn_factor, _, _, _ = compute_yarn_parameters(config) self.use_fused_qk_norm_rope = ( get_global_server_args().enable_fused_qk_norm_rope and self.compatible_with_fused_qk_norm_rope @@ -521,6 +522,7 @@ def __init__( self.head_dim, self.rotary_emb.is_neox_style, torch.bfloat16, + _yarn_factor != 1.0, ) ) self._used_fused_qk_norm_rope_last_call = False diff --git a/python/sglang/srt/models/qwen3_next.py b/python/sglang/srt/models/qwen3_next.py index 1b6ce088b8b0..6b92daab7cab 100644 --- a/python/sglang/srt/models/qwen3_next.py +++ b/python/sglang/srt/models/qwen3_next.py @@ -271,6 +271,16 @@ def __init__( prefix=add_prefix("in_proj_ba", prefix), ) + # Override weight_loader for packed checkpoint format. + # Must capture original_loader BEFORE overwriting. + self._override_weight_loader( + self.in_proj_qkvz, self._make_packed_weight_loader(self.in_proj_qkvz) + ) + self._override_weight_loader( + self.in_proj_ba, self._make_packed_weight_loader(self.in_proj_ba) + ) + + # Conv1d weight loader setup query_key_settings = (self.key_dim, 0, False) value_settings = (self.value_dim, 0, False) @@ -344,7 +354,74 @@ def __init__( dt_bias=self.dt_bias, ) - def fix_query_key_value_ordering(self, mixed_qkvz, mixed_ba): + @staticmethod + def _override_weight_loader(module, new_loader): + """Override weight_loader on a module's weight parameter. + + ModelWeightParameter exposes weight_loader as a read-only property + backed by _weight_loader, while plain parameters store it as a + regular attribute. This helper handles both cases.""" + param = module.weight + if hasattr(param, "_weight_loader"): + param._weight_loader = new_loader + else: + param.weight_loader = new_loader + + @staticmethod + def _make_packed_weight_loader(module): + """Create a weight_loader that does contiguous TP slicing for fused + (packed-format) checkpoint weights (shard_id=None), and delegates + to the standard MergedColumnParallelLinear loader for split checkpoint + weights (shard_id=int/tuple).""" + original_loader = module.weight.weight_loader + + def weight_loader(param, loaded_weight, loaded_shard_id=None): + if loaded_shard_id is None: + # Fused checkpoint: weight is in packed (per-head-group) + # format. Do contiguous TP slice like ColumnParallelLinear. + output_dim = getattr(param, "output_dim", None) + if output_dim is not None and module.tp_size > 1: + shard_size = param.data.shape[output_dim] + start_idx = module.tp_rank * shard_size + loaded_weight = loaded_weight.narrow( + output_dim, start_idx, shard_size + ) + assert param.data.shape == loaded_weight.shape, ( + f"Shape mismatch: param {param.data.shape} vs " + f"loaded {loaded_weight.shape}" + ) + param.data.copy_(loaded_weight) + else: + # Split checkpoint (int or tuple shard_id) → standard path + original_loader(param, loaded_weight, loaded_shard_id) + + return weight_loader + + def create_qkvz_proj( + self, + hidden_size: int, + key_dim: int, + value_dim: int, + quant_config: QuantizationConfig | None, + prefix: str, + tp_rank: Optional[int] = None, + tp_size: Optional[int] = None, + ) -> MergedColumnParallelLinear: + return MergedColumnParallelLinear( + input_size=hidden_size, + output_sizes=[key_dim, key_dim, value_dim, value_dim], + bias=False, + quant_config=quant_config, + prefix=prefix, + tp_rank=tp_rank, + tp_size=tp_size, + ) + + def fix_query_key_value_ordering( + self, + mixed_qkvz: torch.Tensor, + mixed_ba: torch.Tensor, + ): """ Derives `query`, `key` and `value` tensors from `mixed_qkvzba`. """ diff --git a/python/sglang/srt/models/qwen3_vl.py b/python/sglang/srt/models/qwen3_vl.py index 1aec50d01d27..e23719e5c3e0 100644 --- a/python/sglang/srt/models/qwen3_vl.py +++ b/python/sglang/srt/models/qwen3_vl.py @@ -1141,6 +1141,19 @@ def separate_deepstack_embeds(self, embedding): input_deepstack_embeds = embedding[:, separate_index:] return input_embeds, input_deepstack_embeds + @property + def start_layer(self) -> int: + return getattr(getattr(self, "model", None), "start_layer", 0) + + @property + def end_layer(self) -> int: + model = getattr(self, "model", None) + end_layer = getattr(model, "end_layer", None) + if end_layer is not None: + return end_layer + cfg = getattr(model, "config", None) + return int(getattr(cfg, "num_hidden_layers", 0)) + def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs): pattern = MultiModalityDataPaddingPatternMultimodalTokens() return pattern.pad_input_tokens(input_ids, mm_inputs) diff --git a/python/sglang/srt/models/qwen3_vl_moe.py b/python/sglang/srt/models/qwen3_vl_moe.py index cf1cb3879b42..3de2eefea316 100644 --- a/python/sglang/srt/models/qwen3_vl_moe.py +++ b/python/sglang/srt/models/qwen3_vl_moe.py @@ -26,6 +26,7 @@ from sglang.srt.eplb.expert_location import ModelConfigForExpertLocation from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE from sglang.srt.layers.quantization.base_config import QuantizationConfig +from sglang.srt.layers.utils import get_layer_id from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.models.qwen3_moe import Qwen3MoeDecoderLayer, Qwen3MoeModel @@ -179,9 +180,8 @@ def __init__( ): super().__init__(config, quant_config, prefix, language_model_cls) - # Only allow LoRA on attention projections within text layers for MoE. _lora_pattern_moe = re.compile( - r"^model\.layers\.(\d+)\.self_attn\.(?:qkv_proj|o_proj)$" + r"^(?:model\.layers\.(\d+)\.(?:self_attn\.(?:qkv_proj|o_proj)|mlp\.experts)|lm_head|model\.embed_tokens)$" ) def should_apply_lora(self, module_name: str) -> bool: @@ -231,6 +231,17 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): for name, loaded_weight in weights: name = name.replace(r"model.language_model.", r"model.") + layer_id = get_layer_id(name) + if ( + "visual" not in name + and layer_id is not None + and hasattr(self.model, "start_layer") + and ( + layer_id < self.model.start_layer + or layer_id >= self.model.end_layer + ) + ): + continue for param_name, weight_name, shard_id in stacked_params_mapping: if "experts.gate_up_proj" in name or "experts.down_proj" in name: diff --git a/python/sglang/srt/models/transformers.py b/python/sglang/srt/models/transformers.py index 0ea9da14a1be..d928870b9c8a 100644 --- a/python/sglang/srt/models/transformers.py +++ b/python/sglang/srt/models/transformers.py @@ -13,62 +13,299 @@ # ============================================================================== # Adapted from -# https://github.com/vllm-project/vllm/blob/a1a2aaadb9122f05667140e39cf67e5736c8b6d6/vllm/model_executor/models/transformers.py -"""Wrapper around `transformers` models""" +# https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/transformers +"""Wrapper around `transformers` models.""" +import inspect import logging import re -from typing import Iterable, Literal, Optional, Tuple, Union +from collections.abc import Iterable, Mapping +from contextlib import contextmanager +from typing import List, Literal, Optional, Tuple, Union import torch +import transformers from torch import nn from transformers import AutoModel, PretrainedConfig, PreTrainedModel +from transformers.dynamic_module_utils import get_class_from_dynamic_module from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS -from sglang.srt.distributed import divide, get_tensor_model_parallel_world_size +from sglang.srt.distributed import ( + divide, + get_moe_expert_parallel_world_size, + get_pp_group, + get_pp_indices, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_reduce, +) +from sglang.srt.eplb.expert_location import ModelConfigForExpertLocation +from sglang.srt.layers.layernorm import GemmaRMSNorm, RMSNorm from sglang.srt.layers.linear import ( ColumnParallelLinear, ReplicatedLinear, RowParallelLinear, ) from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput +from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class +from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE +from sglang.srt.layers.moe.topk import StandardTopKOutput +from sglang.srt.layers.moe.utils import filter_moe_weight_param_global_expert +from sglang.srt.layers.pooler import EmbeddingPoolerOutput, Pooler, PoolingType from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.layers.utils import PPMissingLayer from sglang.srt.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, ) -from sglang.srt.model_executor.forward_batch_info import ForwardBatch +from sglang.srt.managers.mm_utils import ( + MultiModalityDataPaddingPatternMultimodalTokens, +) +from sglang.srt.managers.schedule_batch import MultimodalDataItem, MultimodalInputs +from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors from sglang.srt.model_loader.weight_utils import default_weight_loader +from sglang.srt.models.utils import AutoWeightsLoader, WeightsMapper +from sglang.srt.server_args import get_global_server_args +from sglang.srt.utils.common import direct_register_custom_op +from sglang.srt.utils.hf_transformers_utils import get_hf_text_config + + +def can_enable_torch_compile(config: PretrainedConfig) -> bool: + """Check whether the model config is compatible with torch.compile. + + Dynamic rope scaling triggers data-dependent control flow that prevents + capturing a single computation graph, so we disable compilation for it. + """ + text_config = getattr(config, "text_config", config) + rope_scaling = getattr(text_config, "rope_scaling", None) + if isinstance(rope_scaling, dict): + rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", "")) + if rope_type == "dynamic": + return False + rope_params = getattr(text_config, "rope_parameters", None) + if isinstance(rope_params, dict): + if isinstance(next(iter(rope_params.values()), None), dict): + return not any( + rp.get("rope_type") == "dynamic" for rp in rope_params.values() + ) + if rope_params.get("rope_type") == "dynamic": + return False + return True + logger = logging.getLogger(__name__) +_TRANSFORMERS_MOE_LAYERS: dict[str, "TransformersFusedMoE"] = {} + def maybe_prefix(prefix: str, name: str) -> str: - """Add a prefix to a name if the prefix is non-empty. + return name if not prefix else f"{prefix}.{name}" - Args: - prefix: The prefix to add. If empty, no prefix will be added. - name: The name to potentially prefix. - Returns: - The string "prefix.name" if prefix was non-empty, otherwise just "name". - """ - return name if not prefix else f"{prefix}.{name}" +def log_replacement(name: str, old_module: nn.Module, new_module: nn.Module): + logger.debug("%s: %s -> %s", name, old_module, new_module) + + +def _getattr_first(obj, names, default=None): + """Return the first existing attribute from *names*, else *default*.""" + for name in names: + value = getattr(obj, name, None) + if value is not None: + return value + return default + + +def _resolve_attention_backend_model_cls(config: PretrainedConfig): + model_cls = getattr(transformers, getattr(config, "architectures", [""])[0], None) + if model_cls is not None: + return model_cls + + auto_map = getattr(config, "auto_map", {}) or {} + for key in ("AutoModel", "AutoModelForCausalLM"): + if key not in auto_map: + continue + try: + return get_class_from_dynamic_module( + auto_map[key], + getattr(config, "_name_or_path", ""), + ) + except Exception as e: + logger.warning( + "Failed to load dynamic module from auto_map[%s]: %s.", + key, + e, + ) + return None + + +def _encoder_accepts_feature_kwarg(encoder, feature_kwarg: str) -> bool: + try: + sig = inspect.signature(encoder) + except (TypeError, ValueError): + return False + + if feature_kwarg in sig.parameters: + return True + + has_var_keyword = any( + p.kind == inspect.Parameter.VAR_KEYWORD for p in sig.parameters.values() + ) + if not has_var_keyword: + return False + + required_positional_params = [ + p + for p in sig.parameters.values() + if p.kind + in (inspect.Parameter.POSITIONAL_ONLY, inspect.Parameter.POSITIONAL_OR_KEYWORD) + and p.default is inspect.Parameter.empty + ] + return len(required_positional_params) == 0 + + +@contextmanager +def _init_on_device_without_buffers(device: torch.device): + """Initialize model parameters on *device* while leaving buffers on CPU. + Adapted from ``accelerate``.""" + old_register_parameter = nn.Module.register_parameter + + def register_empty_parameter(module, name, param): + old_register_parameter(module, name, param) + if param is not None: + param_cls = type(module._parameters[name]) + kwargs = module._parameters[name].__dict__ + kwargs["requires_grad"] = param.requires_grad + module._parameters[name] = param_cls( + module._parameters[name].to(device), **kwargs + ) + + try: + nn.Module.register_parameter = register_empty_parameter + yield + finally: + nn.Module.register_parameter = old_register_parameter + + +Style = Literal["colwise", "colwise_rep", "rowwise", "rowwise_rep", "replicate"] + + +def replace_linear_class( + linear: nn.Linear, + style: Style = "replicate", + quant_config: Optional[QuantizationConfig] = None, + *, + prefix: str = "", +) -> Union[ColumnParallelLinear, RowParallelLinear, ReplicatedLinear]: + if not isinstance(style, str): + raise ValueError(f"Unsupported parallel style type {type(style)}, expected str") + + sglang_linear_cls, linear_kwargs = { + "colwise": (ColumnParallelLinear, {}), + "colwise_rep": (ColumnParallelLinear, {"gather_output": True}), + "rowwise": (RowParallelLinear, {}), + "rowwise_rep": (RowParallelLinear, {"input_is_parallel": False}), + "replicate": (ReplicatedLinear, {}), + }.get(style, (ReplicatedLinear, {})) + + class HFCompatibleLinear(sglang_linear_cls): + @property + def parent_cls(self) -> type: + return sglang_linear_cls + + def forward(self, input: torch.Tensor) -> torch.Tensor: + return super().forward(input)[0] + + return HFCompatibleLinear( + input_size=linear.in_features, + output_size=linear.out_features, + bias=linear.bias is not None, + quant_config=quant_config, + prefix=prefix, + **linear_kwargs, + ) + + +def _normalize_tp_style(style: str) -> Style: + style = style.lower().replace("-", "_") + style = { + "colwiseparallel": "colwise", + "packed_colwise": "colwise", + "local_colwise": "colwise", + "rowwiseparallel": "rowwise", + "packed_rowwise": "rowwise", + "local_rowwise": "rowwise", + "local_packed_rowwise": "rowwise", + "isolated": "replicate", + "local": "replicate", + "replicated_with_grad_allreduce": "replicate", + "moe_tp_experts": "replicate", + }.get(style, style) + if style not in {"colwise", "colwise_rep", "rowwise", "rowwise_rep", "replicate"}: + raise ValueError(f"Unsupported TP style '{style}' for Transformers backend.") + return style + + +def replace_rms_norm_class(rms_norm: nn.Module, hidden_size: int) -> nn.Module: + eps = _getattr_first(rms_norm, ("eps", "variance_epsilon"), 1e-6) + kwargs = {"hidden_size": hidden_size, "eps": eps} + weight_meta = getattr(rms_norm, "weight", None) + if weight_meta is not None: + kwargs["hidden_size"] = weight_meta.size(0) + + try: + with torch.device("cpu"): + weight_test = getattr(rms_norm.__class__(1), "weight", None) + except Exception: + weight_test = None + is_gemma = weight_test is not None and torch.all(weight_test == 0) + + if is_gemma: + base_cls = GemmaRMSNorm + norm = base_cls( + **{k: v for k, v in kwargs.items() if k in ("hidden_size", "eps")} + ) + else: + kwargs["has_weight"] = getattr(rms_norm, "with_scale", True) + if weight_meta is not None: + kwargs["weight_dtype"] = weight_meta.dtype + else: + kwargs["has_weight"] = False + base_cls = RMSNorm + norm = base_cls(**kwargs) + + # Wrap to handle 3D inputs from Transformers backbone (batch dim) + class HFCompatibleRMSNorm(norm.__class__): + def forward(self, x, *args, **kwargs): + orig_shape = x.shape + if x.ndim > 2: + x = x.reshape(-1, x.shape[-1]).contiguous() + result = super().forward(x, *args, **kwargs) + if isinstance(result, tuple): + return tuple( + ( + r.reshape(orig_shape) + if torch.is_tensor(r) and r.shape != orig_shape + else r + ) + for r in result + ) + if torch.is_tensor(result) and result.shape != orig_shape: + return result.reshape(orig_shape) + return result + + norm.__class__ = HFCompatibleRMSNorm + return norm def sglang_flash_attention_forward( - # Transformers args module: torch.nn.Module, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attention_mask: torch.Tensor, - # sglang kwargs - forward_batch: ForwardBatch, - # Transformers kwargs scaling: float = None, - attention_instances: list[RadixAttention] = None, + attention_instances: Optional[Mapping[int, RadixAttention]] = None, + forward_batch: Optional[ForwardBatch] = None, **kwargs, ): self_attn: RadixAttention = attention_instances[module.layer_idx] @@ -83,63 +320,240 @@ def sglang_flash_attention_forward( ALL_ATTENTION_FUNCTIONS["sglang"] = sglang_flash_attention_forward -class HFColumnParallelLinear(ColumnParallelLinear): +class TransformersFusedMoE(nn.Module): + """FusedMoE wrapper for the Transformers modeling backend. - def forward(self, input: torch.Tensor) -> torch.Tensor: - return super().forward(input)[0] + Wraps SGLang's native MoE implementation and exposes the + ``(hidden_states, topk_ids, topk_weights)`` signature expected by + Transformers' ``experts.forward()``. A registered custom op + (``torch.ops.sglang.transformers_moe_forward``) is used so that + ``torch.compile`` can properly graph-break around the MoE kernel. + """ + def __init__( + self, + *, + num_experts: int, + top_k: int, + hidden_size: int, + intermediate_size: int, + layer_id: int, + reduce_results: bool, + quant_config: Optional[QuantizationConfig], + prefix: str, + activation: str, + with_bias: bool, + expert_mapping: list, + ) -> None: + super().__init__() + num_redundant = get_global_server_args().ep_num_redundant_experts + experts_cls = get_moe_impl_class(quant_config) + self.experts = experts_cls( + num_experts=num_experts + num_redundant, + top_k=top_k, + layer_id=layer_id, + hidden_size=hidden_size, + intermediate_size=intermediate_size, + reduce_results=reduce_results, + quant_config=quant_config, + activation=activation, + with_bias=with_bias, + prefix=prefix, + ) + self.layer_name = prefix + self.num_experts = num_experts + self.top_k = top_k + self._expert_mapping = expert_mapping + _TRANSFORMERS_MOE_LAYERS[prefix] = self -class HFRowParallelLinear(RowParallelLinear): + @property + def tp_size(self) -> int: + return getattr(self.experts, "moe_tp_size", 1) - def forward(self, input: torch.Tensor) -> torch.Tensor: - return super().forward(input)[0] + @property + def ep_size(self) -> int: + return getattr(self.experts, "moe_ep_size", 1) + def maybe_all_reduce_tensor_model_parallel( + self, output: torch.Tensor + ) -> torch.Tensor: + if self.tp_size > 1: + return tensor_model_parallel_all_reduce(output) + return output -def replace_linear_class( - linear: nn.Linear, - style: Literal["colwise", "rowwise"], - quant_config: QuantizationConfig, -) -> Union[ColumnParallelLinear, RowParallelLinear]: - """ - Replace nn.Linear with one of vLLM's tensor parallel linear classes. - - Args: - linear (nn.Linear): `nn.Linear` to be replaced. - style (str): Tensor parallel style of the new linear, e.g. "colwise". - quant_config (QuantConfig): Quantization config for the new linear. - Returns: - Union[ColumnParallelLinear, RowParallelLinear]: The new linear. - """ + def get_expert_weights(self): + return getattr(self.experts, "get_expert_weights", lambda: None)() - if not isinstance(style, str): - raise ValueError(f"Unsupported parallel style type {type(style)}, expected str") + def get_moe_weights(self) -> list[torch.Tensor]: + num_local = getattr(self.experts, "num_local_experts", self.num_experts) + return [ + x.data + for name, x in self.experts.named_parameters() + if name not in ("correction_bias",) + and filter_moe_weight_param_global_expert(name, x, num_local) + ] - sglang_linear_cls = { - "colwise": ColumnParallelLinear, - "rowwise": RowParallelLinear, - }.get(style, ReplicatedLinear) + def forward( + self, + hidden_states: torch.Tensor, + topk_ids: torch.Tensor, + topk_weights: torch.Tensor, + **kwargs, + ) -> torch.Tensor: + topk_ids = topk_ids.to(torch.int32) + topk_weights = topk_weights.to(torch.float32) + if hidden_states.is_cuda: + return torch.ops.sglang.transformers_moe_forward( + hidden_states, + topk_ids, + topk_weights, + self.layer_name, + ) + return _transformers_moe_forward( + hidden_states, + topk_ids, + topk_weights, + self.layer_name, + ) - class HFCompatibleLinear(sglang_linear_cls): - """ - Wrapper class that removes `output_bias` from returned output. - """ + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + loaded: set[str] = set() + param_dict = dict(self.named_parameters()) + for name, loaded_weight in weights: + matched = False + for param_name, weight_name, expert_id, shard_id in self._expert_mapping: + if weight_name not in name: + continue + mapped_name = name.replace(weight_name, param_name) + param = param_dict.get(mapped_name) + if param is None: + continue + weight_loader = getattr(param, "weight_loader", default_weight_loader) + try: + weight_loader( + param, + loaded_weight, + name, + shard_id=shard_id, + expert_id=expert_id, + ) + except TypeError: + weight_loader(param, loaded_weight) + loaded.add(name) + matched = True + break + if not matched: + direct_name = name if name in param_dict else f"experts.{name}" + if direct_name in param_dict: + param = param_dict[direct_name] + weight_loader = getattr( + param, "weight_loader", default_weight_loader + ) + try: + weight_loader(param, loaded_weight) + except TypeError: + default_weight_loader(param, loaded_weight) + loaded.add(name) + else: + logger.warning( + "MoE weight '%s' in layer '%s' could not be matched to any " + "parameter and will be skipped.", + name, + self.layer_name, + ) + return loaded - @property - def parent_cls(self) -> type: - return sglang_linear_cls - def forward(self, input: torch.Tensor) -> torch.Tensor: - return super().forward(input)[0] +def _transformers_moe_forward( + hidden_states: torch.Tensor, + topk_ids: torch.Tensor, + topk_weights: torch.Tensor, + layer_name: str, +) -> torch.Tensor: + self = _TRANSFORMERS_MOE_LAYERS[layer_name] + # Record expert distribution for EPLB + from sglang.srt.eplb.expert_distribution import ( + get_global_expert_distribution_recorder, + ) - return HFCompatibleLinear( - input_size=linear.in_features, - output_size=linear.out_features, - bias=linear.bias is not None, - quant_config=quant_config, + recorder = get_global_expert_distribution_recorder() + with recorder.with_current_layer(self.experts.layer_id): + recorder.on_select_experts(topk_ids=topk_ids) + topk_output = StandardTopKOutput( + topk_weights=topk_weights, + topk_ids=topk_ids, + router_logits=topk_weights, ) + return self.experts(hidden_states.clone(), topk_output) + + +def _transformers_moe_forward_fake( + hidden_states: torch.Tensor, + topk_ids: torch.Tensor, + topk_weights: torch.Tensor, + layer_name: str, +) -> torch.Tensor: + return torch.empty_like(hidden_states) + + +direct_register_custom_op( + op_name="transformers_moe_forward", + op_func=_transformers_moe_forward, + mutates_args=["hidden_states"], + fake_impl=_transformers_moe_forward_fake, +) + +try: + from sglang.srt.compilation.compilation_config import SPLIT_OPS + + _MOE_SPLIT_OP = "sglang.transformers_moe_forward" + if _MOE_SPLIT_OP not in SPLIT_OPS: + SPLIT_OPS.append(_MOE_SPLIT_OP) +except ImportError: + pass + + +_BASE_DYNAMIC_ARG_DIMS: dict[str, int] = { + "input_ids": 0, + "positions": 0, + "input_embeds": 0, +} + +_MULTIMODAL_DYNAMIC_ARG_DIMS: dict[str, int] = { + "input_ids": 0, + "positions": -1, # last dim to support M-RoPE (Qwen2.5-VL 3Ɨseq layout) + "input_embeds": 0, +} + +class TransformersBase(nn.Module): + torch_compile_dynamic_arg_dims: dict[str, int] = _BASE_DYNAMIC_ARG_DIMS -class TransformersForCausalLM(nn.Module): + hf_to_sglang_mapper = WeightsMapper( + orig_to_new_prefix={ + "language_model.model.": "model.language_model.", + "model.transformer.": "model.", + "model.model.": "model.", + "model.lm_head.": "lm_head.", + "model.score.": "classifier.", + "model.classifier.": "classifier.", + "transformer.": "model.", + "model.": "model.", + "lm_head.": "lm_head.", + "score.": "classifier.", + "classifier.": "classifier.", + "": "model.", + } + ) + + def __init_subclass__(cls, *args, **kwargs): + super().__init_subclass__(*args, **kwargs) + mapper = WeightsMapper() + for base in cls.__mro__: + base_mapper = getattr(base, "hf_to_sglang_mapper", None) + if base_mapper is not None: + mapper = mapper | base_mapper + cls.hf_to_sglang_mapper = mapper def __init__( self, @@ -152,138 +566,1067 @@ def __init__( self.quant_config = quant_config self.config = config - self.vocab_size = config.vocab_size - self.unpadded_vocab_size = config.vocab_size - - # model is loaded under set_default_torch_dtype(model_config.dtype) - self.model: PreTrainedModel = AutoModel.from_config( - self.config, - torch_dtype=torch.get_default_dtype(), - attn_implementation="sglang", - trust_remote_code=True, - ) + self.text_config = get_hf_text_config(config) + self.weight_mapper = self.hf_to_sglang_mapper + self.pp_group = get_pp_group() - # Attention modifications (assumes 1 attention op per hidden layer) - tp_size = get_tensor_model_parallel_world_size() + # Weight loading attrs + self.skip_prefixes: list[str] = [] + self.skip_substrs: list[str] = [] + self.ignore_unexpected_prefixes: list[str] = [] + self.ignore_unexpected_suffixes: list[str] = [] + self.skip_substrs.extend([".attn.bias", ".attn.masked_bias", ".masked_bias"]) + self.ignore_unexpected_prefixes.extend(["classifier.", "score."]) + + if self.quant_config is not None: + quant_method_name = self.quant_config.get_name() + if "gptq" in quant_method_name: + self.ignore_unexpected_suffixes.append(".bias") + if "fp8" in quant_method_name: + fp8_suffix_map = {".activation_scale": ".input_scale"} + use_mxfp8 = bool(getattr(self.quant_config, "use_mxfp8", False)) + weight_block_size = getattr( + self.quant_config, "weight_block_size", None + ) + if not use_mxfp8 and weight_block_size is None: + fp8_suffix_map[".weight_scale_inv"] = ".weight_scale" + self.weight_mapper = self.weight_mapper | WeightsMapper( + orig_to_new_suffix=fp8_suffix_map + ) - # MLP modifications - self.tensor_parallel(tp_size) + # Resolve model class for _supports_attention_backend check + model_cls = _resolve_attention_backend_model_cls(config) - head_dim = ( - (config.hidden_size // config.num_attention_heads) - if not hasattr(config, "head_dim") - else config.head_dim + supports_backend = ( + getattr(model_cls, "_supports_attention_backend", True) + if model_cls + else True ) - self.attention_instances = [ - RadixAttention( - num_heads=divide(config.num_attention_heads, tp_size), - head_dim=head_dim, - # NOTE: We use Llama scale as default, if it's set by - # Transformers, it's updated in sglang_flash_attention_forward - scaling=head_dim**-0.5, - num_kv_heads=divide(config.num_key_value_heads, tp_size), - layer_id=i, - quant_config=self.quant_config, - prefix=f"{i}.attn", + + # Initialize on meta device to avoid premature GPU allocation + self.text_config._attn_implementation = "sglang" + if supports_backend: + with _init_on_device_without_buffers(torch.device("meta")): + self.model: PreTrainedModel = AutoModel.from_config( + self.config, + torch_dtype=torch.get_default_dtype(), + trust_remote_code=True, + ) + else: + raise ValueError( + f"Model {model_cls} does not support custom attention backends " + "(_supports_attention_backend=False). The Transformers backend " + "requires custom attention support." ) - for i in range(config.num_hidden_layers) - ] - # Model modifications + self.vocab_size = getattr( + self.text_config, + "vocab_size", + self.model.get_input_embeddings().num_embeddings, + ) + self.unpadded_vocab_size = self.vocab_size + + # Embedding scale (e.g. Whisper) + input_embeddings = self.model.get_input_embeddings() + self.embed_scale = getattr(input_embeddings, "embed_scale", None) + + self.start_layer = 0 + self.end_layer = getattr(self.text_config, "num_hidden_layers", 0) + + # Pipeline parallel + self.pipeline_parallel() + # Module replacement (Linear → TP, RMSNorm → fused, MoE overridden by MoEMixin) + tp_size = get_tensor_model_parallel_world_size() + self.recursive_replace() + # Attention instances + self.attention_instances = self._create_attention_instances(tp_size) + # Vocab embeddings self.replace_vocab_embed_class(self.model) - # ForCausalLM modifications - self.lm_head = ParallelLMHead( - config.vocab_size, - config.hidden_size, - quant_config=self.quant_config, - prefix=maybe_prefix(prefix, "lm_head"), - ) - if config.tie_word_embeddings: - self.lm_head.weight = self.model.get_input_embeddings().weight + # Initialize remaining meta-device parameters to real device tensors + self._init_parameters(self.model) - self.logits_processor = LogitsProcessor(config) + self.lm_head: Optional[ParallelLMHead] = None + self.logits_processor: Optional[LogitsProcessor] = None + self.pooler: Optional[Pooler] = None + + self._compile_compatible = can_enable_torch_compile(config) + + @property + def _can_torch_compile(self) -> bool: + """Whether this model instance is safe to wrap with torch.compile.""" + return self._compile_compatible + + def _init_parameters(self, module: nn.Module): + """Materialize any parameters still on the meta device.""" + for name, param in module.named_parameters(recurse=False): + if param.device == torch.device("meta"): + new_param = nn.Parameter( + torch.empty_like( + param.data, + device="cuda", + ) + ) + setattr(module, name, new_param) + for child in module.children(): + self._init_parameters(child) def log_replacement(self, name: str, old_module: nn.Module, new_module: nn.Module): logger.debug("%s: %s -> %s", name, old_module, new_module) - def tensor_parallel(self, tp_size: int): - """ - Apply the model's tensor parallelization plan. - Currently only supports linear layers. - """ - tp_plan = getattr(self.model.config, "base_model_tp_plan", None) or {} + # -- TP plan handling --------------------------------------------------- + def _get_model_tp_plan(self) -> Mapping[str, str]: + plan = ( + getattr(self.model, "tp_plan", None) + or getattr(self.model, "_tp_plan", None) + or getattr(self.model.config, "base_model_tp_plan", None) + or getattr(self.text_config, "base_model_tp_plan", None) + ) + if plan: + return plan + + plan = self._infer_tp_plan_from_children() + return plan if plan else {} + + _LANGUAGE_MODEL_CHILD_NAMES = frozenset( + {"language_model", "text_model", "model", "lm"} + ) + + def _infer_tp_plan_from_children(self) -> dict[str, str]: + plan: dict[str, str] = {} + for child_name, child_module in self.model.named_children(): + child_plan = getattr(child_module, "_tp_plan", None) + if child_plan: + plan.update({f"{child_name}.{k}": v for k, v in child_plan.items()}) + continue + + child_config = getattr(child_module, "config", None) + if child_config is not None: + child_tp = getattr(child_config, "base_model_tp_plan", None) + if child_tp: + plan.update({f"{child_name}.{k}": v for k, v in child_tp.items()}) + continue + + if child_name not in self._LANGUAGE_MODEL_CHILD_NAMES: + continue + if child_config is None: + continue + model_type = getattr(child_config, "model_type", "") + base_type = ( + model_type.replace("_vl_text", "") + .replace("_vl", "") + .replace("_text", "") + ) + if base_type and base_type != model_type: + try: + from transformers import AutoConfig + + base_cfg = AutoConfig.for_model(base_type) + base_tp = getattr(base_cfg, "base_model_tp_plan", None) + if base_tp: + plan.update( + {f"{child_name}.{k}": v for k, v in base_tp.items()} + ) + except Exception as e: + logger.debug( + "Could not infer TP plan from base model type '%s': %s", + base_type, + e, + ) + return plan + + def _normalize_tp_plan(self, tp_plan: Mapping[str, str]) -> dict[str, Style]: + normalized = {} + for pattern, style in tp_plan.items(): + if pattern.startswith("^model\\."): + pattern = "^" + pattern[len("^model\\.") :] + elif pattern.startswith("model\\."): + pattern = pattern[len("model\\.") :] + elif pattern.startswith("model."): + pattern = pattern[len("model.") :] + normalized[pattern] = _normalize_tp_style(style) + return normalized + + # -- Recursive module replacement (Linear + RMSNorm) -------------------- + def recursive_replace(self): + tp_size = get_tensor_model_parallel_world_size() + tp_plan = self._normalize_tp_plan(self._get_model_tp_plan()) if not tp_plan and tp_size > 1: raise ValueError( f"{type(self.model)} does not support tensor parallel yet!" ) - def _tensor_parallel(module: nn.Module, prefix: str = ""): + # Prefix patterns to match from `self.model` + prefixed_plan = {maybe_prefix("model", k): v for k, v in tp_plan.items()} + + def _recursive_replace(module: nn.Module, prefix: str): for child_name, child_module in module.named_children(): qual_name = maybe_prefix(prefix, child_name) - for pattern, style in tp_plan.items(): - if re.match(pattern, qual_name) and isinstance( - child_module, nn.Linear - ): - new_module = replace_linear_class( - child_module, style, self.quant_config - ) - setattr(module, child_name, new_module) - self.log_replacement(qual_name, child_module, new_module) + new_module = child_module + + if isinstance(child_module, nn.Linear): + pattern = next( + (p for p in prefixed_plan if re.match(p, qual_name)), + None, + ) + style = prefixed_plan.get(pattern, "replicate") + new_module = replace_linear_class( + child_module, + style, + self.quant_config, + prefix=qual_name, + ) + elif child_module.__class__.__name__.endswith("RMSNorm"): + new_module = replace_rms_norm_class( + child_module, + self.text_config.hidden_size, + ) else: - _tensor_parallel(child_module, prefix=qual_name) + _recursive_replace(child_module, prefix=qual_name) + + if new_module is not child_module: + setattr(module, child_name, new_module) + log_replacement(qual_name, child_module, new_module) + + _recursive_replace(self.model, prefix="model") + + # -- Pipeline parallel -------------------------------------------------- + def _get_model_pp_plan(self) -> Mapping[str, object]: + return ( + getattr(self.model, "_pp_plan", None) + or getattr(self.model, "pp_plan", None) + or getattr(self.model.config, "base_model_pp_plan", None) + or getattr(self.text_config, "base_model_pp_plan", None) + or {} + ) + + def _register_missing_prefix(self, prefix: str): + if not prefix.endswith("."): + prefix += "." + if prefix not in self.skip_prefixes: + self.skip_prefixes.append(prefix) + + @staticmethod + def _make_pp_missing_layer(original: nn.Module) -> PPMissingLayer: + """Create a PPMissingLayer that preserves plain attributes from + *original* so that the HF forward loop can still access per-layer + metadata (e.g. ``attention_type`` on Qwen2 decoder layers).""" + replacement = PPMissingLayer() + for key, value in original.__dict__.items(): + if key.startswith("_"): + continue + if isinstance(value, (nn.Module, nn.Parameter, torch.Tensor)): + continue + setattr(replacement, key, value) + return replacement + + def _get_submodule_or_none(self, name: str) -> Optional[nn.Module]: + try: + return self.model.get_submodule(name) + except AttributeError: + return None + + def _set_submodule(self, name: str, module: nn.Module): + if "." in name: + parent_name, child_name = name.rsplit(".", 1) + parent_module = self.model.get_submodule(parent_name) + else: + parent_module = self.model + child_name = name + setattr(parent_module, child_name, module) + + def pipeline_parallel(self): + if self.pp_group.world_size <= 1: + return + + pp_plan = self._get_model_pp_plan() + if not pp_plan: + raise ValueError( + f"{type(self.model)} does not support pipeline parallel yet!" + ) + + pp_keys = [re.sub(r"^model\.", "", name) for name in pp_plan.keys()] + module_list_idx = None + module_list_name = None + for idx, name in enumerate(pp_keys): + if isinstance(self._get_submodule_or_none(name), nn.ModuleList): + if module_list_idx is not None: + raise ValueError( + "Pipeline parallel with multiple ModuleList blocks is not supported." + ) + module_list_idx = idx + module_list_name = name + + if module_list_idx is None or module_list_name is None: + raise ValueError(f"Could not find ModuleList in {type(self.model)}.") + + keep_prefix_modules = self.pp_group.is_first_rank or ( + getattr(self.text_config, "tie_word_embeddings", False) + and self.pp_group.is_last_rank + ) + for name in pp_keys[:module_list_idx]: + if keep_prefix_modules: + continue + self._set_submodule(name, PPMissingLayer()) + self._register_missing_prefix(maybe_prefix("model", name)) + + layers = self.model.get_submodule(module_list_name) + self.start_layer, self.end_layer = get_pp_indices( + len(layers), + self.pp_group.rank_in_group, + self.pp_group.world_size, + ) + for idx in range(len(layers)): + if self.start_layer <= idx < self.end_layer: + continue + layers[idx] = self._make_pp_missing_layer(layers[idx]) + self._register_missing_prefix( + maybe_prefix("model", f"{module_list_name}.{idx}") + ) + + for name in pp_keys[module_list_idx + 1 :]: + if self.pp_group.is_last_rank: + continue + self._set_submodule(name, PPMissingLayer()) + self._register_missing_prefix(maybe_prefix("model", name)) + + # -- Attention instances ------------------------------------------------ + def _create_attention_instances(self, tp_size: int) -> dict[int, RadixAttention]: + num_heads = self.text_config.num_attention_heads + num_kv_heads = getattr(self.text_config, "num_key_value_heads", num_heads) + hidden_size = self.text_config.hidden_size + head_dim = getattr(self.text_config, "head_dim", hidden_size // num_heads) + + layer_types = getattr(self.text_config, "layer_types", None) or getattr( + self.config, "layer_types", None + ) + global_sliding_window = getattr( + self.text_config, "sliding_window", None + ) or getattr(self.config, "sliding_window", None) + + # Detect encoder-only models (non-causal attention everywhere) + is_encoder_only = any( + not getattr(m, "is_causal", True) + for m in self.model.modules() + if hasattr(m, "is_causal") + ) + if is_encoder_only and self.config != self.text_config: + is_encoder_only = False + if is_encoder_only: + logger.info( + "Detected encoder-only model (non-causal attention). " + "Using RadixAttention with is_cross_attention=True." + ) - _tensor_parallel(self.model) + instances = {} + for idx in range(self.start_layer, self.end_layer): + # Per-layer sliding window (e.g. Gemma2, Cohere) + per_layer_sliding_window = -1 + if ( + layer_types is not None + and idx < len(layer_types) + and layer_types[idx] == "sliding_attention" + and global_sliding_window is not None + ): + per_layer_sliding_window = global_sliding_window + instances[idx] = RadixAttention( + num_heads=divide(num_heads, tp_size), + head_dim=head_dim, + scaling=head_dim**-0.5, + num_kv_heads=divide(num_kv_heads, tp_size), + layer_id=idx, + quant_config=self.quant_config, + sliding_window_size=per_layer_sliding_window, + is_cross_attention=is_encoder_only, + prefix=f"{idx}.attn", + ) + return instances + + # -- Vocab embedding replacement ---------------------------------------- def replace_vocab_embed_class(self, module: nn.Module): - # Use native set input embeddings + old_module = self.model.get_input_embeddings() + if old_module is None or isinstance(old_module, PPMissingLayer): + return + embedding_dim = getattr(old_module, "embedding_dim", None) + if embedding_dim is None: + embedding_dim = _getattr_first( + self.text_config, + ("embedding_size", "hidden_size"), + None, + ) + assert embedding_dim is not None new_module = VocabParallelEmbedding( self.vocab_size, - self.config.hidden_size, - org_num_embeddings=self.config.vocab_size, + embedding_dim, + org_num_embeddings=self.vocab_size, quant_config=None, ) - self.log_replacement( - "input embedding", self.model.get_input_embeddings(), new_module - ) + + old_embed_scale = getattr(old_module, "embed_scale", None) + if old_embed_scale is not None: + base_cls = new_module.__class__ + + class ScaledEmbedding(base_cls): + def forward(self, input_): + return base_cls.forward(self, input_) * self.embed_scale + + new_module.__class__ = ScaledEmbedding + new_module.embed_scale = old_embed_scale + self.embed_scale = None + + self.log_replacement("input embedding", old_module, new_module) self.model.set_input_embeddings(new_module) + # -- Forward ------------------------------------------------------------ + def _format_position_ids(self, positions: torch.Tensor) -> torch.Tensor: + if positions.ndim == 2 and positions.shape[0] == 3: + return positions[:, None, ...] + if positions.ndim == 1: + return positions[None, ...] + return positions + + def _run_hf_backbone( + self, + input_ids: Optional[torch.Tensor], + input_embeds: Optional[torch.Tensor], + positions: torch.Tensor, + forward_batch: ForwardBatch, + **kwargs, + ) -> torch.Tensor: + hf_input_ids = None if input_ids is None else input_ids[None, ...] + hf_input_embeds = None + if input_embeds is not None: + hf_input_embeds = input_embeds[None, ...] + hf_input_ids = None + + # Scale embeddings if needed + if ( + self.embed_scale is not None + and hf_input_ids is not None + and hf_input_embeds is None + ): + hf_input_embeds = ( + self.model.get_input_embeddings()(hf_input_ids) * self.embed_scale + ) + hf_input_ids = None + + return self.model( + input_ids=hf_input_ids, + inputs_embeds=hf_input_embeds, + use_cache=False, + position_ids=self._format_position_ids(positions), + return_dict=False, + forward_batch=forward_batch, + attention_instances=self.attention_instances, + **kwargs, + )[0][0, ...] + + def _forward_hidden_states( + self, + input_ids: Optional[torch.Tensor], + positions: torch.Tensor, + forward_batch: ForwardBatch, + input_embeds: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + return self._run_hf_backbone( + input_ids=input_ids, + input_embeds=input_embeds, + positions=positions, + forward_batch=forward_batch, + ) + @torch.no_grad() def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, forward_batch: ForwardBatch, + pp_proxy_tensors: Optional[PPProxyTensors] = None, input_embeds: torch.Tensor = None, get_embedding: bool = False, - ) -> LogitsProcessorOutput: - assert get_embedding is False, "embedding is not supported yet" - aux_hidden_states = None - hidden_states = self.model( - input_ids[None, ...], - use_cache=False, - position_ids=positions[None, ...], + ) -> Union[LogitsProcessorOutput, EmbeddingPoolerOutput, PPProxyTensors]: + runtime_input_ids: Optional[torch.Tensor] = input_ids + runtime_input_embeds = input_embeds + if not self.pp_group.is_first_rank: + assert pp_proxy_tensors is not None + runtime_input_ids = None + runtime_input_embeds = pp_proxy_tensors["hidden_states"] + + hidden_states = self._forward_hidden_states( + input_ids=runtime_input_ids, + positions=positions, forward_batch=forward_batch, - attention_instances=self.attention_instances, - return_dict=False, - )[0][ - 0, ... - ] # we remove batch dimension for now + input_embeds=runtime_input_embeds, + ) + if not self.pp_group.is_last_rank: + return PPProxyTensors( + {"hidden_states": hidden_states, "residual": hidden_states} + ) + + if get_embedding: + assert ( + self.pooler is not None + ), "pooling is not enabled for this model class" + return self.pooler(hidden_states, forward_batch) + + assert self.logits_processor is not None and self.lm_head is not None return self.logits_processor( - input_ids, hidden_states, self.lm_head, forward_batch, aux_hidden_states + input_ids, hidden_states, self.lm_head, forward_batch, None ) - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): - params_dict = dict(self.named_parameters()) - for name, loaded_weight in weights: - if name not in params_dict: - name = f"{self.model.base_model_prefix}.{name}" - if name in params_dict: - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", default_weight_loader) - weight_loader(param, loaded_weight) + # -- Weight loading ----------------------------------------------------- + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader( + self, + skip_prefixes=self.skip_prefixes, + skip_substrs=self.skip_substrs, + ignore_unexpected_prefixes=self.ignore_unexpected_prefixes, + ignore_unexpected_suffixes=self.ignore_unexpected_suffixes, + ) + return loader.load_weights(weights, mapper=self.weight_mapper) + + +class CausalMixin: + + def __init__(self, *args, prefix: str = "", **kwargs): + super().__init__(*args, prefix=prefix, **kwargs) + + tie_word_embeddings = getattr(self.text_config, "tie_word_embeddings", False) + if tie_word_embeddings: + self.skip_prefixes.append("lm_head.") + + if not self.pp_group.is_last_rank: + self._register_missing_prefix("lm_head") + return + + self.lm_head = ParallelLMHead( + self.vocab_size, + self.text_config.hidden_size, + quant_config=self.quant_config, + prefix=maybe_prefix(prefix, "lm_head"), + ) + if tie_word_embeddings: + self.lm_head.weight = self.model.get_input_embeddings().weight + + logit_scale = getattr(self.text_config, "logit_scale", 1.0) + self.logits_processor = LogitsProcessor( + self.text_config, logit_scale=logit_scale + ) + + +class EmbeddingMixin: + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.ignore_unexpected_prefixes.append("lm_head.") + if not self.pp_group.is_last_rank: + return + pooling_name = str(getattr(self.config, "pooling_type", "LAST")).upper() + pooling_type = PoolingType.CLS if pooling_name == "CLS" else PoolingType.LAST + normalize = bool(getattr(self.config, "normalize", True)) + self.pooler = Pooler(pooling_type=pooling_type, normalize=normalize) + + +class MoEMixin: + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + @classmethod + def get_model_config_for_expert_location( + cls, config + ) -> Optional[ModelConfigForExpertLocation]: + text_config = getattr(config, "text_config", config) + num_experts = _getattr_first( + text_config, + ("num_local_experts", "num_experts", "n_routed_experts"), + ) + if num_experts is None: + return None + num_groups = getattr(text_config, "n_group", None) + return ModelConfigForExpertLocation( + num_layers=text_config.num_hidden_layers, + num_logical_experts=num_experts, + num_groups=num_groups, + ) + + @property + def routed_experts_weights_of_layer(self) -> dict[int, list[torch.Tensor]]: + return { + fused.experts.layer_id: fused.get_moe_weights() for fused in self.moe_layers + } + + def _get_expert_mapping(self, num_experts: int) -> List[Tuple[str, str, int, str]]: + ckpt_names = [ + ("gate_proj", "down_proj", "up_proj"), + ("w1", "w2", "w3"), + ("linear", "linear_1", "linear_v"), + ] + mapping: list = [] + for gate, down, up in ckpt_names: + mapping.extend( + FusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name=gate, + ckpt_down_proj_name=down, + ckpt_up_proj_name=up, + num_experts=num_experts, + ) + ) + # AutoWeightsLoader dispatches to TransformersFusedMoE (which IS the + # ``experts`` module) so the incoming weight names have the "experts." + # prefix already stripped. Remove it from weight_name in the mapping. + mapping = [ + (pn, wn.removeprefix("experts."), eid, sid) for pn, wn, eid, sid in mapping + ] + return mapping + + def recursive_replace(self): + """Replace experts modules with TransformersFusedMoE, then call + super().recursive_replace() for Linear/RMSNorm replacement.""" + text_config = self.text_config + + num_experts = _getattr_first( + text_config, + ("num_local_experts", "num_experts", "n_routed_experts"), + ) + assert num_experts is not None, "Cannot determine num_experts from config." + + top_k = _getattr_first(text_config, ("num_experts_per_tok", "top_k")) + assert top_k is not None, "Cannot determine top_k from config." + + hidden_size = text_config.hidden_size + intermediate_size = _getattr_first( + text_config, + ("moe_intermediate_size", "intermediate_size"), + ) + assert intermediate_size is not None, "Cannot determine intermediate_size." + + num_shared_experts = _getattr_first( + text_config, + ("n_shared_experts", "moe_num_shared_experts"), + 0, + ) + reduce_results = num_shared_experts == 0 + + renormalize = getattr(text_config, "norm_topk_prob", top_k > 1) + + # Activation function + activation = "silu" + wrapped_arch = self.config.architectures[0].lower() + if "gptoss" in wrapped_arch: + activation = "swigluoai" + elif "grok1" in wrapped_arch: + activation = "gelu" + + # Expert mapping for AutoWeightsLoader + expert_mapping = self._get_expert_mapping(num_experts) + + # EPLB / EP tracking + num_redundant = get_global_server_args().ep_num_redundant_experts + ep_size = get_moe_expert_parallel_world_size() + + self.mlp_moe_layers: list[nn.Module] = [] + self.moe_layers: list[TransformersFusedMoE] = [] + self.num_moe_layers = 0 + self.num_logical_experts = num_experts + self.num_physical_experts = num_experts + num_redundant + self.num_local_physical_experts = self.num_physical_experts // max(ep_size, 1) + self.num_shared_experts = num_shared_experts + self.num_redundant_experts = num_redundant + + def _add_all_reduce(mlp: nn.Module): + class MLPWithAllReduce(mlp.__class__): + def forward(self, *args, **kwargs): + output = super().forward(*args, **kwargs) + return self.experts.maybe_all_reduce_tensor_model_parallel(output) + + mlp.__class__ = MLPWithAllReduce + + def _recursive_replace(module: nn.Module, prefix: str): + for child_name, child_module in module.named_children(): + qual_name = maybe_prefix(prefix, child_name) + + is_modulelist = isinstance(child_module, nn.ModuleList) + params = list(child_module.parameters()) + is_3d = len(params) > 0 and all(p.ndim == 3 for p in params) + + if child_name == "experts" and (is_modulelist or is_3d): + mlp = module + experts = child_module + + has_bias = any("bias" in n for n, _ in experts.named_parameters()) + + nonlocal reduce_results + if reduce_results: + if any("shared_expert" in n for n, _ in mlp.named_parameters()): + reduce_results = False + self.num_shared_experts = 1 + + layer_id = self.num_moe_layers + + fused_experts = TransformersFusedMoE( + num_experts=num_experts, + top_k=top_k, + hidden_size=hidden_size, + intermediate_size=intermediate_size, + layer_id=layer_id, + reduce_results=reduce_results, + quant_config=self.quant_config, + prefix=qual_name, + activation=activation, + with_bias=has_bias, + expert_mapping=expert_mapping, + ) + mlp.experts = fused_experts + log_replacement(qual_name, experts, fused_experts) + + self.mlp_moe_layers.append(mlp) + self.moe_layers.append(fused_experts) + self.num_moe_layers += 1 + + if not reduce_results and ( + fused_experts.tp_size > 1 or fused_experts.ep_size > 1 + ): + _add_all_reduce(mlp) + else: + _recursive_replace(child_module, prefix=qual_name) + + _recursive_replace(self.model, prefix="model") + super().recursive_replace() + + +class MultiModalMixin: + torch_compile_dynamic_arg_dims: dict[str, int] = _MULTIMODAL_DYNAMIC_ARG_DIMS + + # Older VL checkpoints (e.g. Qwen2.5-VL) store text weights as + # "model.layers.*" but transformers >=5.0 nests the text model under + # "model.language_model.*". Map explicitly so these load correctly. + hf_to_sglang_mapper = WeightsMapper( + orig_to_new_prefix={ + "language_model.model.": "model.language_model.", + "text_model.model.": "model.text_model.", + "text_model.lm_head.": "lm_head.", + "language_model.lm_head.": "lm_head.", + "vision_tower.": "model.vision_tower.", + "vision_model.": "model.vision_model.", + "vision_embed_tokens.": "model.vision_embed_tokens.", + "image_newline.": "model.image_newline.", + "vqmodel.": "model.vqmodel.", + "multi_modal_projector.": "model.multi_modal_projector.", + "visual.": "model.visual.", + "model.layers.": "model.language_model.layers.", + "model.embed_tokens.": "model.language_model.embed_tokens.", + "model.norm.": "model.language_model.norm.", + "model.rotary_emb.": "model.language_model.rotary_emb.", + } + ) + + _mm_feature_kwarg = { + "image": "pixel_values", + "video": "pixel_values_videos", + "audio": "input_features", + } + _mm_encoder_candidates = { + "image": ("get_image_features", "get_image_feature"), + "video": ("get_video_features", "get_video_feature"), + "audio": ("get_audio_features", "get_audio_feature"), + } + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._mm_padding_pattern = MultiModalityDataPaddingPatternMultimodalTokens() + + def _uses_mrope_positions(self) -> bool: + rope_scaling = getattr(self.text_config, "rope_scaling", None) + if isinstance(rope_scaling, Mapping) and "mrope_section" in rope_scaling: + return True + rope_type = str(getattr(self.text_config, "rope_type", "")).lower() + return "mrope" in rope_type + + def pad_input_ids(self, input_ids: list[int], mm_inputs: MultimodalInputs): + return input_ids + + def _get_modality_encoder(self, modality_name: str): + for name in self._mm_encoder_candidates[modality_name]: + fn = getattr(self.model, name, None) + if fn is not None: + return fn + raise AttributeError(f"No encoder method found for modality '{modality_name}'") + + def _get_modality_dtype_device( + self, modality_name: str + ) -> tuple[Optional[torch.dtype], Optional[torch.device]]: + module_candidates = { + "image": ("vision_tower", "vision_model"), + "video": ("video_tower", "vision_tower", "vision_model"), + "audio": ("audio_tower", "audio_model", "audio_encoder"), + } + modules = [] + for name in module_candidates.get(modality_name, ()): + module = getattr(self.model, name, None) + if module is not None: + modules.append(module) + modules.append(self.model) + + for module in modules: + for param in module.parameters(): + if torch.is_floating_point(param): + return param.dtype, param.device + for buf in module.buffers(): + if torch.is_floating_point(buf): + return buf.dtype, buf.device + return None, None + + def _cast_mm_value(self, value, dtype, device): + if torch.is_tensor(value): + if value.is_floating_point() and dtype is not None: + return value.to(dtype=dtype, device=device) + return value + if isinstance(value, dict): + return {k: self._cast_mm_value(v, dtype, device) for k, v in value.items()} + if isinstance(value, list): + return [self._cast_mm_value(v, dtype, device) for v in value] + if isinstance(value, tuple): + return tuple(self._cast_mm_value(v, dtype, device) for v in value) + return value + + def _to_tensor_output(self, output) -> torch.Tensor: + if hasattr(output, "pooler_output") and output.pooler_output is not None: + output = output.pooler_output + if isinstance(output, tuple): + output = output[0] + if isinstance(output, (list, tuple)): + if len(output) == 0: + raise ValueError("Empty multimodal encoder output.") + if all(torch.is_tensor(x) for x in output): + output = torch.cat( + [x.reshape(-1, x.shape[-1]) if x.ndim > 2 else x for x in output], + dim=0, + ) + else: + output = output[0] + elif hasattr(output, "last_hidden_state"): + output = output.last_hidden_state + elif isinstance(output, dict): + if output.get("pooler_output", None) is not None: + output = output["pooler_output"] + else: + output = next(v for v in output.values() if torch.is_tensor(v)) + if isinstance(output, (list, tuple)): + if len(output) == 0: + raise ValueError("Empty multimodal encoder output.") + if all(torch.is_tensor(x) for x in output): + output = torch.cat( + [ + x.reshape(-1, x.shape[-1]) if x.ndim > 2 else x + for x in output + ], + dim=0, + ) + else: + output = output[0] + + if output.ndim > 2: + output = output.reshape(-1, output.shape[-1]) + return output + + def _encode_modality_items( + self, modality_name: str, items: list[MultimodalDataItem] + ) -> torch.Tensor: + encoder = self._get_modality_encoder(modality_name) + feature_kwarg = self._mm_feature_kwarg[modality_name] + target_dtype, target_device = self._get_modality_dtype_device(modality_name) + outputs = [] + for item in items: + kwargs = self._cast_mm_value( + dict(item.model_specific_data), + dtype=target_dtype, + device=target_device, + ) + feature = self._cast_mm_value( + item.feature, + dtype=target_dtype, + device=target_device, + ) + if _encoder_accepts_feature_kwarg(encoder, feature_kwarg): + kwargs[feature_kwarg] = feature + result = encoder(**kwargs) + else: + result = encoder(feature, **kwargs) + outputs.append(self._to_tensor_output(result)) + return torch.cat(outputs, dim=0) + + def get_image_feature(self, items: list[MultimodalDataItem]) -> torch.Tensor: + return self._encode_modality_items("image", items) + + def get_video_feature(self, items: list[MultimodalDataItem]) -> torch.Tensor: + return self._encode_modality_items("video", items) + + def get_audio_feature(self, items: list[MultimodalDataItem]) -> torch.Tensor: + return self._encode_modality_items("audio", items) + + def _collect_mm_kwargs(self, forward_batch: ForwardBatch) -> dict: + """Collect multimodal tensors from the forward batch and return them + as kwargs suitable for the HF model's forward method.""" + kwargs = {} + + if getattr(forward_batch, "token_type_ids", None) is not None: + tti = forward_batch.token_type_ids + if tti.ndim == 1: + tti = tti.unsqueeze(0) + token_type_key = ( + "mm_token_type_ids" + if "mm_token_type_ids" + in inspect.signature(self.model.forward).parameters + else "token_type_ids" + ) + kwargs[token_type_key] = tti + + if ( + not forward_batch.forward_mode.is_decode() + and forward_batch.contains_mm_inputs() + ): + mm_inputs = forward_batch.mm_inputs + target_device = next(self.model.parameters()).device + + for batch_idx in range(len(mm_inputs or [])): + mm_input = mm_inputs[batch_idx] + if mm_input is None: + continue + for item in mm_input.mm_items or []: + for key, value in (item.model_specific_data or {}).items(): + if isinstance(value, torch.Tensor): + value = value.to(device=target_device) + if key not in kwargs: + kwargs[key] = value + elif isinstance(value, torch.Tensor) and isinstance( + kwargs[key], torch.Tensor + ): + kwargs[key] = torch.cat([kwargs[key], value], dim=0) + if item.feature is not None: + feature_key = self._mm_feature_kwarg.get( + item.modality.name.lower(), "pixel_values" + ) + feature = item.feature + if isinstance(feature, torch.Tensor): + feature = feature.to(device=target_device) + if feature_key not in kwargs: + kwargs[feature_key] = feature + elif isinstance(feature, torch.Tensor) and isinstance( + kwargs[feature_key], torch.Tensor + ): + kwargs[feature_key] = torch.cat( + [kwargs[feature_key], feature], dim=0 + ) + + return kwargs + + def _forward_hidden_states( + self, + input_ids: Optional[torch.Tensor], + positions: torch.Tensor, + forward_batch: ForwardBatch, + input_embeds: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if input_embeds is not None: + return super()._forward_hidden_states( + input_ids=input_ids, + positions=positions, + forward_batch=forward_batch, + input_embeds=input_embeds, + ) + + if ( + self._uses_mrope_positions() + and getattr(forward_batch, "mrope_positions", None) is not None + ): + positions = forward_batch.mrope_positions + + mm_kwargs = self._collect_mm_kwargs(forward_batch) + + return self._run_hf_backbone( + input_ids=input_ids, + input_embeds=None, + positions=positions, + forward_batch=forward_batch, + **mm_kwargs, + ) + + +class TransformersForCausalLM(CausalMixin, TransformersBase): + pass + + +class TransformersMoEForCausalLM(MoEMixin, CausalMixin, TransformersBase): + pass + + +class TransformersMultiModalForCausalLM(MultiModalMixin, CausalMixin, TransformersBase): + pass + + +class TransformersMultiModalMoEForCausalLM( + MultiModalMixin, MoEMixin, CausalMixin, TransformersBase +): + pass + + +class TransformersEmbeddingModel(EmbeddingMixin, TransformersBase): + pass + + +class TransformersMoEEmbeddingModel(MoEMixin, EmbeddingMixin, TransformersBase): + pass + + +class TransformersMultiModalEmbeddingModel( + MultiModalMixin, EmbeddingMixin, TransformersBase +): + pass + + +class TransformersMultiModalMoEEmbeddingModel( + MultiModalMixin, MoEMixin, EmbeddingMixin, TransformersBase +): + pass + + +class TransformersForSequenceClassification(EmbeddingMixin, TransformersBase): + pass + + +class TransformersMoEForSequenceClassification( + MoEMixin, EmbeddingMixin, TransformersBase +): + pass + + +class TransformersMultiModalForSequenceClassification( + MultiModalMixin, EmbeddingMixin, TransformersBase +): + pass + + +class TransformersMultiModalMoEForSequenceClassification( + MultiModalMixin, MoEMixin, EmbeddingMixin, TransformersBase +): + pass -EntryClass = [TransformersForCausalLM] +EntryClass = [ + TransformersForCausalLM, + TransformersMoEForCausalLM, + TransformersMultiModalForCausalLM, + TransformersMultiModalMoEForCausalLM, + TransformersEmbeddingModel, + TransformersMoEEmbeddingModel, + TransformersMultiModalEmbeddingModel, + TransformersMultiModalMoEEmbeddingModel, + TransformersForSequenceClassification, + TransformersMoEForSequenceClassification, + TransformersMultiModalForSequenceClassification, + TransformersMultiModalMoEForSequenceClassification, +] diff --git a/python/sglang/srt/models/utils.py b/python/sglang/srt/models/utils.py index 0d93cdd82e3a..13a0e70f0e88 100644 --- a/python/sglang/srt/models/utils.py +++ b/python/sglang/srt/models/utils.py @@ -13,6 +13,7 @@ # ============================================================================== from __future__ import annotations +import itertools from collections.abc import Iterable, Mapping from dataclasses import dataclass, field from functools import lru_cache @@ -28,13 +29,15 @@ from sglang.srt.mem_cache.swa_memory_pool import SWAKVPool from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode from sglang.srt.model_executor.forward_batch_info import ForwardBatch -from sglang.srt.utils import get_current_device_stream_fast, is_cuda +from sglang.srt.model_loader.weight_utils import default_weight_loader +from sglang.srt.utils import get_current_device_stream_fast, is_cuda, is_hip from sglang.srt.utils.custom_op import register_custom_op if TYPE_CHECKING: from sglang.srt.layers.layernorm import RMSNorm _is_cuda = is_cuda() +_is_hip = is_hip() WeightsMapping = Mapping[str, Optional[str]] """If a key maps to a value of `None`, the corresponding weight is ignored.""" @@ -48,6 +51,13 @@ class WeightsMapper: orig_to_new_prefix: WeightsMapping = field(default_factory=dict) orig_to_new_suffix: WeightsMapping = field(default_factory=dict) + def __or__(self, other: "WeightsMapper") -> "WeightsMapper": + return WeightsMapper( + orig_to_new_substr={**self.orig_to_new_substr, **other.orig_to_new_substr}, + orig_to_new_prefix={**self.orig_to_new_prefix, **other.orig_to_new_prefix}, + orig_to_new_suffix={**self.orig_to_new_suffix, **other.orig_to_new_suffix}, + ) + def _map_name(self, key: str) -> Optional[str]: for substr, new_key in sorted( self.orig_to_new_substr.items(), key=lambda i: len(i[0]), reverse=True @@ -105,6 +115,161 @@ def apply_dict(self, values: dict[str, Any]) -> dict[str, Any]: } +class AutoWeightsLoader: + ROTARY_EMBEDS_UNUSED_WEIGHTS = [ + "rotary_pos_emb.inv_freq", + "rotary_emb.inv_freq", + "rotary_emb.cos_cached", + "rotary_emb.sin_cached", + ] + + def __init__( + self, + module: torch.nn.Module, + *, + skip_prefixes: list[str] | None = None, + skip_substrs: list[str] | None = None, + ignore_unexpected_prefixes: list[str] | None = None, + ignore_unexpected_suffixes: list[str] | None = None, + ) -> None: + self.module = module + self.skip_prefixes = list(skip_prefixes or []) + self.skip_substrs = [ + *(skip_substrs or []), + *self.ROTARY_EMBEDS_UNUSED_WEIGHTS, + ] + self.ignore_unexpected_prefixes = list(ignore_unexpected_prefixes or []) + self.ignore_unexpected_suffixes = list(ignore_unexpected_suffixes or []) + + def _groupby_prefix( + self, + weights: Iterable[tuple[str, torch.Tensor]], + ) -> Iterable[tuple[str, Iterable[tuple[str, torch.Tensor]]]]: + weights_by_parts = ( + (weight_name.split(".", 1), weight_data) + for weight_name, weight_data in weights + ) + for prefix, group in itertools.groupby(weights_by_parts, key=lambda x: x[0][0]): + yield prefix, ( + ("" if len(parts) == 1 else parts[1], weight_data) + for parts, weight_data in group + ) + + @staticmethod + def _get_qualname(prefix: str, rest: str) -> str: + if prefix == "": + return rest + if rest == "": + return prefix + return f"{prefix}.{rest}" + + def _can_skip(self, qualname: str) -> bool: + return any(qualname.startswith(p) for p in self.skip_prefixes) or any( + sub in qualname for sub in self.skip_substrs + ) + + def _can_ignore_unexpected(self, qualname: str) -> bool: + return any( + qualname.startswith(p) for p in self.ignore_unexpected_prefixes + ) or any(qualname.endswith(s) for s in self.ignore_unexpected_suffixes) + + def _load_param( + self, + base_prefix: str, + param: torch.nn.Parameter, + weights: Iterable[tuple[str, torch.Tensor]], + ) -> Iterable[str]: + for weight_name, weight_data in weights: + weight_qualname = self._get_qualname(base_prefix, weight_name) + if self._can_skip(weight_qualname): + continue + if weight_name != "": + if self._can_ignore_unexpected(weight_qualname): + continue + raise ValueError( + f"Attempted to load nested weight {weight_qualname!r} " + f"into parameter {base_prefix!r}" + ) + + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, weight_data) + yield weight_qualname + + def _load_module( + self, + base_prefix: str, + module: torch.nn.Module, + weights: Iterable[tuple[str, torch.Tensor]], + ) -> Iterable[str]: + if module.__class__.__name__ == "PPMissingLayer": + return + + if module is not self.module: + module_load_weights = getattr(module, "load_weights", None) + if callable(module_load_weights): + loaded = module_load_weights(weights) + if loaded is not None: + yield from ( + self._get_qualname(base_prefix, loaded_name) + for loaded_name in loaded + ) + return + + child_modules = dict(module.named_children()) + child_params = dict(module.named_parameters(recurse=False)) + child_buffers = dict(module.named_buffers(recurse=False)) + for child_prefix, child_weights in self._groupby_prefix(weights): + prefix = self._get_qualname(base_prefix, child_prefix) + if child_prefix in child_modules: + if self._can_skip(prefix + "."): + continue + yield from self._load_module( + prefix, + child_modules[child_prefix], + child_weights, + ) + continue + + if child_prefix in child_params: + if self._can_skip(prefix): + continue + yield from self._load_param( + prefix, child_params[child_prefix], child_weights + ) + continue + + if child_prefix in child_buffers: + if self._can_skip(prefix): + continue + yield from self._load_param( + prefix, child_buffers[child_prefix], child_weights + ) + continue + + if self._can_skip(prefix) or self._can_skip(prefix + "."): + continue + if self._can_ignore_unexpected(prefix) or self._can_ignore_unexpected( + prefix + "." + ): + continue + raise ValueError( + f"No module or parameter named {prefix!r} in {self.module._get_name()}." + ) + + def load_weights( + self, + weights: Iterable[tuple[str, torch.Tensor]], + *, + mapper: WeightsMapper | None = None, + ) -> set[str]: + if mapper is not None: + weights = mapper.apply(weights) + weights = ( + (name, weight) for name, weight in weights if not self._can_skip(name) + ) + return set(self._load_module("", self.module, weights)) + + def enable_fused_set_kv_buffer(forward_batch: ForwardBatch): """Enable fused set_kv_buffer only on CUDA with bfloat16 KV cache.""" return ( @@ -113,7 +278,7 @@ def enable_fused_set_kv_buffer(forward_batch: ForwardBatch): and forward_batch.token_to_kv_pool.dtype == torch.bfloat16 and not isinstance(forward_batch.token_to_kv_pool, SWAKVPool) and not is_prefill_context_parallel_enabled() - ) + ) or (_is_hip and not is_prefill_context_parallel_enabled()) def create_fused_set_kv_buffer_arg( @@ -128,13 +293,35 @@ def create_fused_set_kv_buffer_arg( k_buffer = token_to_kv_pool.get_key_buffer(layer_id) v_buffer = token_to_kv_pool.get_value_buffer(layer_id) - assert layer.k_scale is None and layer.v_scale is None, "scale not supported" - return FusedSetKVBufferArg( - value=value, - k_buffer=k_buffer.view(k_buffer.shape[0], -1), - v_buffer=v_buffer.view(v_buffer.shape[0], -1), - cache_loc=forward_batch.out_cache_loc, - ) + + if not _is_hip: + assert layer.k_scale is None and layer.v_scale is None, "scale not supported" + return FusedSetKVBufferArg( + value=value, + k_buffer=k_buffer.view(k_buffer.shape[0], -1), + v_buffer=v_buffer.view(v_buffer.shape[0], -1), + cache_loc=forward_batch.out_cache_loc, + ) + else: + page_size = token_to_kv_pool.page_size + slot_mapping_swa = ( + token_to_kv_pool.full_to_swa_index_mapping.long() + if layer.sliding_window_size > 0 + else None + ) + return { + "v": value.view(-1, layer.tp_v_head_num, layer.v_head_dim), + "k_scale": layer.k_scale, + "v_scale": layer.v_scale, + "key_cache": k_buffer.view( + -1, page_size, layer.tp_k_head_num, layer.qk_head_dim + ), + "value_cache": v_buffer.view( + -1, page_size, layer.tp_v_head_num, layer.v_head_dim + ), + "slot_mapping": forward_batch.out_cache_loc, + "swa_slot_mapping": slot_mapping_swa, + } def permute_inv(perm: torch.Tensor) -> torch.Tensor: diff --git a/python/sglang/srt/models/whisper.py b/python/sglang/srt/models/whisper.py index d69fb666d2d8..d9190a2f12ed 100644 --- a/python/sglang/srt/models/whisper.py +++ b/python/sglang/srt/models/whisper.py @@ -94,70 +94,16 @@ def forward( """Input shape: Batch x Time x Channel""" if self.is_cross_attention: + # Cross-attention: KV cached during prefill, read from pool during decode. q, _ = self.q_proj(hidden_states) + q = q * self.scaling if cross_hidden_states is not None: kv, _ = self.kv_proj(cross_hidden_states) k, v = kv.split([self.kv_size, self.kv_size], dim=-1) else: - k = torch.zeros_like(q) - v = torch.zeros_like(q) - - q = q * self.scaling - num_heads = self.attn.tp_q_head_num - head_dim = self.attn.head_dim - - q = q.view(-1, num_heads, head_dim) - k = k.view(-1, num_heads, head_dim) - v = v.view(-1, num_heads, head_dim) - - q_len = q.shape[0] - kv_len = k.shape[0] - - q = q.transpose(0, 1) - k = k.transpose(0, 1) - v = v.transpose(0, 1) - - attn_weights = torch.bmm(q, k.transpose(1, 2)) - - # Apply block-diagonal mask for batched cross-attention - batch_size = forward_batch.batch_size if forward_batch else 1 - if batch_size > 1 and kv_len > 0: - encoder_len_per_request = kv_len // batch_size - if encoder_len_per_request * batch_size == kv_len: - is_decode = forward_batch.forward_mode.is_decode() - if is_decode: - mask = torch.zeros( - (q_len, kv_len), device=q.device, dtype=torch.bool - ) - for i in range(batch_size): - enc_start = i * encoder_len_per_request - enc_end = (i + 1) * encoder_len_per_request - mask[i, enc_start:enc_end] = True - attn_weights = attn_weights.masked_fill( - ~mask.unsqueeze(0), float("-inf") - ) - else: - seq_lens = forward_batch.seq_lens - if seq_lens is not None and len(seq_lens) == batch_size: - seq_lens_list = seq_lens.tolist() - mask = torch.zeros( - (q_len, kv_len), device=q.device, dtype=torch.bool - ) - q_start = 0 - for i, dec_len in enumerate(seq_lens_list): - enc_start = i * encoder_len_per_request - enc_end = (i + 1) * encoder_len_per_request - q_end = q_start + dec_len - mask[q_start:q_end, enc_start:enc_end] = True - q_start = q_end - attn_weights = attn_weights.masked_fill( - ~mask.unsqueeze(0), float("-inf") - ) - - attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1) - attn_output = torch.bmm(attn_weights, v) - attn_output = attn_output.transpose(0, 1) - attn_output = attn_output.reshape(q_len, num_heads * head_dim) + k = None + v = None + attn_output = self.attn(q, k, v, forward_batch) else: qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.chunk(chunks=3, dim=-1) @@ -394,6 +340,7 @@ def forward( position_ids=None, ): inputs_embeds = self.embed_tokens(input_ids) + position_ids = position_ids.clamp(max=self.max_target_positions - 1) positions = self.embed_positions(position_ids) hidden_states = inputs_embeds + positions.to(inputs_embeds.device) @@ -420,7 +367,6 @@ def __init__( ) self.logits_processor = LogitsProcessor(config) self.config = config - self._encoder_cache = {} def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ @@ -468,8 +414,14 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) - def pad_input_ids(self, input_ids: List[int], _mm_inputs: MultimodalInputs): - return input_ids + def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs): + # Prepend dummy encoder tokens so that prepare_encoder_info_extend + # correctly allocates encoder KV cache locations in the KV pool. + # These dummy tokens are stripped before the model forward receives input_ids. + encoder_len = self.config.max_source_positions + mm_inputs.num_image_tokens = encoder_len + pad_ids = [0] * encoder_len + return pad_ids + input_ids def forward( self, @@ -479,29 +431,22 @@ def forward( **kwargs: Any, ) -> LogitsProcessorOutput: dtype = self.encoder.conv1.weight.dtype - is_decode = forward_batch.forward_mode.is_decode() - - if is_decode: - encoder_outputs = None - if forward_batch.req_pool_indices is not None: - req_indices = forward_batch.req_pool_indices.tolist() - encoder_list = [] - for req_idx in req_indices: - if req_idx in self._encoder_cache: - encoder_list.append(self._encoder_cache[req_idx]) - if encoder_list: - encoder_outputs = torch.cat(encoder_list, dim=0) - else: - encoder_list = [] + + # Run encoder for requests that haven't cached encoder output yet. + # During decode or when encoder is already cached, encoder_hidden_states + # is None and cross-attention reads KV from the pool via RadixAttention. + encoder_hidden_states = None + if not forward_batch.forward_mode.is_decode(): mm_inputs_list = forward_batch.mm_inputs if forward_batch.mm_inputs else [] - req_indices = ( - forward_batch.req_pool_indices.tolist() - if forward_batch.req_pool_indices is not None - else [] + encoder_cached_list = ( + forward_batch.encoder_cached if forward_batch.encoder_cached else [] ) - for req_idx, mm_input in zip(req_indices, mm_inputs_list): - if mm_input is None or not mm_input.mm_items: + encoder_list = [] + for i, (mm_input, cached) in enumerate( + zip(mm_inputs_list, encoder_cached_list) + ): + if cached or mm_input is None or not mm_input.mm_items: continue features = mm_input.mm_items[0].feature @@ -513,21 +458,17 @@ def forward( features.device, non_blocking=True ) - req_encoder_outputs = self.encoder( + req_encoder_output = self.encoder( features.to(dtype), encoder_position_ids, forward_batch ) - req_encoder_outputs = req_encoder_outputs.squeeze(0) - - self._encoder_cache[req_idx] = req_encoder_outputs - encoder_list.append(req_encoder_outputs) + req_encoder_output = req_encoder_output.squeeze(0) + encoder_list.append(req_encoder_output) if encoder_list: - encoder_outputs = torch.cat(encoder_list, dim=0) - else: - encoder_outputs = None + encoder_hidden_states = torch.cat(encoder_list, dim=0) decoder_outputs = self.decoder( - input_ids, encoder_outputs, forward_batch, positions + input_ids, encoder_hidden_states, forward_batch, positions ) logits = self.logits_processor( diff --git a/python/sglang/srt/multimodal/processors/base_processor.py b/python/sglang/srt/multimodal/processors/base_processor.py index 9ce169024570..773ce8620d97 100644 --- a/python/sglang/srt/multimodal/processors/base_processor.py +++ b/python/sglang/srt/multimodal/processors/base_processor.py @@ -40,6 +40,7 @@ _is_xpu = is_xpu() SGL_USE_CUDA_IPC = envs.SGLANG_USE_CUDA_IPC_TRANSPORT.get() +_IPC_POOL_HANDLE_CACHE = envs.SGLANG_USE_IPC_POOL_HANDLE_CACHE.get() @dataclasses.dataclass @@ -136,7 +137,6 @@ def get_modality_of_token(self, token: str) -> Optional[Modality]: def get_token_id_by_modality(self, modality: Modality) -> Optional[int]: return { Modality.IMAGE: self.image_token_id, - Modality.MULTI_IMAGES: self.image_token_id, Modality.VIDEO: self.video_token_id, Modality.AUDIO: self.audio_token_id, }.get(modality) @@ -173,6 +173,7 @@ def get_combined_regex(self) -> re.Pattern: class BaseMultimodalProcessor(ABC): models = [] + gpu_image_decode = True # Enable GPU decoding by default def __init__( self, hf_config, server_args, _processor, transport_mode, *args, **kwargs @@ -357,7 +358,7 @@ def get_mm_data(self, prompt, embeddings, **kwargs): mm_items.append( MultimodalDataItem( modality=modality, - offsets=offset, + offsets=[offset], precomputed_embeddings=embedding_slice, ) ) @@ -470,8 +471,9 @@ def get_estimated_frames_list(self, image_data): return estimated_frames_list - @staticmethod + @classmethod def _load_single_item( + cls, data, modality: Modality, frame_count_limit=None, @@ -483,7 +485,8 @@ def _load_single_item( If data is processor_output or precomputed embedding, return directly. - Static method that can be pickled for multiprocessing""" + Class method that can be pickled for multiprocessing + """ if isinstance(data, dict): data_format = data.get("format") if data_format in ( @@ -495,8 +498,13 @@ def _load_single_item( return data try: if modality == Modality.IMAGE: - img, _ = load_image(data) - if discard_alpha_channel and img.mode != "RGB": + img, _ = load_image(data, cls.gpu_image_decode) + if ( + discard_alpha_channel + and not isinstance(img, torch.Tensor) + and img.mode != "RGB" + ): + # Needed only when `img` is a PIL image img = img.convert("RGB") return img elif modality == Modality.VIDEO: @@ -537,7 +545,7 @@ def _submit_mm_data_loading_tasks_simple( type(data), ) future = self.io_executor.submit( - BaseMultimodalProcessor._load_single_item, + self.__class__._load_single_item, data, modality, None, # frame_count_limit: no consider for fast path @@ -597,7 +605,7 @@ def submit_data_loading_tasks( futures.append( self.io_executor.submit( - BaseMultimodalProcessor._load_single_item, + self.__class__._load_single_item, data, modality, frame_count_limit, @@ -989,7 +997,8 @@ def collect_mm_items_from_processor_output( self, data_dict: dict, modality: Modality = None ) -> List[MultimodalDataItem]: """ - Create mm_items directly from processor output, with one item for each modality + Create mm_items from processor output. Initially creates one item per modality; + these are later split into per-image/video items by get_new_expanded_mm_items. Note that the data_dict can be passed via offline engine api """ @@ -1132,6 +1141,11 @@ def process_and_combine_mm_data( mm_token_id=mm_token_id, ) + # Split bundled items into per-image/video items for better cache granularity + from sglang.srt.managers.mm_utils import get_new_expanded_mm_items + + all_collected_items = get_new_expanded_mm_items(all_collected_items) + """ solution for cuda-ipc memory-leak: 1. memory-pool: each time get a slice from memory-pool and use it as transport-data (with async lock guard) @@ -1144,7 +1158,7 @@ def process_and_combine_mm_data( # post-process for item in all_collected_items: if isinstance(item.feature, torch.Tensor) and item.feature.is_cuda: - sync_flag, available_slice = ( + sync_flag, available_slice, byte_offset = ( self.cudaipc_mmfeature_pool.return_a_slice_tensor_with_flag( item.feature ) @@ -1157,6 +1171,13 @@ def process_and_combine_mm_data( data=available_slice, info_data=item.feature, sync_buffer_meta=sync_flag, + pool_ipc_handle=( + self.cudaipc_mmfeature_pool._pool_ipc_handle + if _IPC_POOL_HANDLE_CACHE + else None + ), + pool_byte_offset=byte_offset, + pool_device_index=self.cudaipc_mmfeature_pool._pool_device_index, ) elif not self.server_args.keep_mm_feature_on_device: item.feature = item.feature.cpu() @@ -1165,7 +1186,7 @@ def process_and_combine_mm_data( and item.precomputed_embeddings.is_cuda ): - sync_flag, available_slice = ( + sync_flag, available_slice, byte_offset = ( self.cudaipc_mmfeature_pool.return_a_slice_tensor_with_flag( item.precomputed_embeddings ) @@ -1179,6 +1200,13 @@ def process_and_combine_mm_data( data=available_slice, info_data=item.precomputed_embeddings, sync_buffer_meta=sync_flag, + pool_ipc_handle=( + self.cudaipc_mmfeature_pool._pool_ipc_handle + if _IPC_POOL_HANDLE_CACHE + else None + ), + pool_byte_offset=byte_offset, + pool_device_index=self.cudaipc_mmfeature_pool._pool_device_index, ) elif not self.server_args.keep_mm_feature_on_device: item.precomputed_embeddings = item.precomputed_embeddings.cpu() diff --git a/python/sglang/srt/multimodal/processors/internvl.py b/python/sglang/srt/multimodal/processors/internvl.py index 955198730e92..c95d495adc0a 100644 --- a/python/sglang/srt/multimodal/processors/internvl.py +++ b/python/sglang/srt/multimodal/processors/internvl.py @@ -27,6 +27,7 @@ class InternVLProcessor(BaseMultimodalProcessor): models = [InternVLChatModel, InternS1ForConditionalGeneration] + gpu_image_decode = False # InternVL HF processor does not support tensor inputs IMAGENET_MEAN = [0.485, 0.456, 0.406] IMAGENET_STD = [0.229, 0.224, 0.225] @@ -587,11 +588,21 @@ async def process_qwen_mm_data_async( items = [] if image_tensor is not None: - items.append( - MultimodalDataItem( - feature=image_tensor, modality=Modality.IMAGE, offsets=image_offsets - ) + # Split per-image for better cache granularity + assert len(num_patches_list) == len(image_offsets), ( + f"InternVL: num_patches_list ({len(num_patches_list)}) != " + f"image_offsets ({len(image_offsets)})" ) + cumulative = 0 + for i, num_patches in enumerate(num_patches_list): + items.append( + MultimodalDataItem( + feature=image_tensor[cumulative : cumulative + num_patches], + modality=Modality.IMAGE, + offsets=[image_offsets[i]], + ) + ) + cumulative += num_patches if video_tensor is not None: items.append( MultimodalDataItem( @@ -701,11 +712,21 @@ async def process_internlm2_mm_data_async( items = [] if pixel_values is not None: - items.append( - MultimodalDataItem( - feature=pixel_values, modality=Modality.IMAGE, offsets=image_offsets - ) + # Split per-image for better cache granularity + assert len(num_patches_list) == len(image_offsets), ( + f"InternVL: num_patches_list ({len(num_patches_list)}) != " + f"image_offsets ({len(image_offsets)})" ) + cumulative = 0 + for i, num_patches in enumerate(num_patches_list): + items.append( + MultimodalDataItem( + feature=pixel_values[cumulative : cumulative + num_patches], + modality=Modality.IMAGE, + offsets=[image_offsets[i]], + ) + ) + cumulative += num_patches return { "input_ids": input_ids, diff --git a/python/sglang/srt/multimodal/processors/kimi_k25.py b/python/sglang/srt/multimodal/processors/kimi_k25.py index d8bb9ceb3a8b..cef3e6933499 100644 --- a/python/sglang/srt/multimodal/processors/kimi_k25.py +++ b/python/sglang/srt/multimodal/processors/kimi_k25.py @@ -16,6 +16,7 @@ # Compatible with KimiVLForConditionalGeneration class KimiK2_5VLImageProcessor(SGLangBaseProcessor): models = [KimiK25ForConditionalGeneration] + gpu_image_decode = False # KimiK2.5VL HF processor does not support tensor inputs def __init__(self, hf_config, server_args, _processor, *args, **kwargs): super().__init__(hf_config, server_args, _processor, *args, **kwargs) diff --git a/python/sglang/srt/multimodal/processors/kimi_vl.py b/python/sglang/srt/multimodal/processors/kimi_vl.py index cd7cfe2fd3ae..b466f1b40994 100644 --- a/python/sglang/srt/multimodal/processors/kimi_vl.py +++ b/python/sglang/srt/multimodal/processors/kimi_vl.py @@ -13,6 +13,7 @@ # Compatible with KimiVLForConditionalGeneration class KimiVLImageProcessor(SGLangBaseProcessor): models = [KimiVLForConditionalGeneration] + gpu_image_decode = False # KimiVL HF processor does not support tensor inputs def __init__(self, hf_config, server_args, _processor, *args, **kwargs): super().__init__(hf_config, server_args, _processor, *args, **kwargs) diff --git a/python/sglang/srt/multimodal/processors/llava.py b/python/sglang/srt/multimodal/processors/llava.py index 8111f03afbad..8729e85470b8 100644 --- a/python/sglang/srt/multimodal/processors/llava.py +++ b/python/sglang/srt/multimodal/processors/llava.py @@ -1,4 +1,5 @@ import asyncio +import os from typing import Dict, List, Optional, Union import numpy as np @@ -33,6 +34,7 @@ class LlavaImageProcessor(BaseMultimodalProcessor): LlavaQwenForCausalLM, LlavaMistralForCausalLM, ] + gpu_image_decode = False # Llava processes loaded image as PIL image explicitly def __init__(self, hf_config, server_args, _processor, *args, **kwargs): super().__init__(hf_config, server_args, _processor, *args, **kwargs) @@ -49,7 +51,7 @@ def _process_single_image_task( try: url = image_data.url if isinstance(image_data, ImageData) else image_data - image, image_size = load_image(url) + image, image_size = load_image(url, False) if image_size is not None: # It is a video with multiple images image_hash = hash(url) @@ -95,7 +97,7 @@ async def _process_single_image( ): if self.cpu_executor is not None: loop = asyncio.get_running_loop() - return await loop.run_in_executor( + fut = loop.run_in_executor( self.cpu_executor, LlavaImageProcessor._process_single_image_task, image_data, @@ -103,6 +105,8 @@ async def _process_single_image( grid_pinpoints, self._processor, ) + timeout = int(os.environ.get("REQUEST_TIMEOUT", "10")) + return await asyncio.wait_for(fut, timeout=timeout) else: return self._process_single_image_task( image_data, @@ -183,34 +187,39 @@ async def process_mm_data_async( pixel_values.append(pixel_v) data_hashes.append(image_h) image_sizes.append(image_s) - - if isinstance(pixel_values[0], np.ndarray): - pixel_values = np.stack(pixel_values, axis=0) else: # A single image pixel_values, image_hash, image_size = await self._process_single_image( image_data[0], aspect_ratio, grid_pinpoints ) + pixel_values = [pixel_values] image_sizes = [image_size] else: raise ValueError(f"Invalid image data: {image_data}") modality = Modality.IMAGE if isinstance(request_obj.modalities, list): - if request_obj.modalities[0] == "multi-images": - modality = Modality.MULTI_IMAGES - elif request_obj.modalities[0] == "video": + if request_obj.modalities[0] == "video": modality = Modality.VIDEO - return { - "mm_items": [ + # Create one item per image for better cache granularity + mm_items = [] + for pixel_v, image_s in zip(pixel_values, image_sizes): + # Ensure ndim=4 so the model forward takes the correct encode branch + if isinstance(pixel_v, np.ndarray) and pixel_v.ndim == 3: + pixel_v = np.expand_dims(pixel_v, 0) + mm_items.append( MultimodalDataItem( - feature=pixel_values, + feature=pixel_v, model_specific_data={ - "image_sizes": image_sizes, + "image_sizes": [image_s], + "image_aspect_ratio": aspect_ratio, }, modality=modality, ) - ], + ) + + return { + "mm_items": mm_items, } diff --git a/python/sglang/srt/multimodal/processors/minicpm.py b/python/sglang/srt/multimodal/processors/minicpm.py index 2a375c9dabb4..613079e04acf 100644 --- a/python/sglang/srt/multimodal/processors/minicpm.py +++ b/python/sglang/srt/multimodal/processors/minicpm.py @@ -19,6 +19,7 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor): models = [MiniCPMV, MiniCPMO] support_dynamic_frame_expansion = True + gpu_image_decode = False # MiniCPM HF processor does not support tensor inputs def __init__(self, hf_config, server_args, _processor, *args, **kwargs): super().__init__(hf_config, server_args, _processor, *args, **kwargs) @@ -222,6 +223,8 @@ async def process_mm_data_async( f"{len(pixel_values)} vs. {len(tgt_sizes)}" ) + # Track slices per image (like vLLM's num_slices) + slices_per_image: List[int] = [] pixel_values_flat: List[torch.Tensor] = [] tgt_sizes_flat: List[torch.Tensor] = [] for pixel_b, tgt_b in zip(pixel_values, tgt_sizes): @@ -230,6 +233,7 @@ async def process_mm_data_async( raise ValueError( "Inconsistent N lengths, found: " f"{len(pixel_b)} vs {len(tgt_b)}" ) + slices_per_image.append(len(pixel_b)) for pixel_n, tgt_n in zip(pixel_b, tgt_b): pixel_values_flat += [pixel_n] tgt_sizes_flat += [tgt_n] @@ -249,14 +253,23 @@ async def process_mm_data_async( image_offsets.extend(slice_offsets) image_offsets = sorted(image_offsets) + # Create one item per image, each with its own slices and offsets if len(pixel_values) != 0: - item = MultimodalDataItem( - feature=pixel_values, - offsets=image_offsets, - model_specific_data={"tgt_size": tgt_sizes_flat}, - modality=Modality.IMAGE, - ) - items += [item] + pv_idx = 0 + offset_idx = 0 + for num_slices in slices_per_image: + items.append( + MultimodalDataItem( + feature=pixel_values[pv_idx : pv_idx + num_slices], + offsets=image_offsets[offset_idx : offset_idx + num_slices], + model_specific_data={ + "tgt_size": tgt_sizes_flat[pv_idx : pv_idx + num_slices] + }, + modality=Modality.IMAGE, + ) + ) + pv_idx += num_slices + offset_idx += num_slices if ( "audio_features" in res diff --git a/python/sglang/srt/multimodal/processors/nano_nemotron_vl.py b/python/sglang/srt/multimodal/processors/nano_nemotron_vl.py index 83d72441f861..98986090f979 100644 --- a/python/sglang/srt/multimodal/processors/nano_nemotron_vl.py +++ b/python/sglang/srt/multimodal/processors/nano_nemotron_vl.py @@ -35,6 +35,9 @@ class NanoNemotronVLImageProcessor(BaseMultimodalProcessor): models = [NemotronH_Nano_VL_V2] + gpu_image_decode = ( + False # NanoNemotronVL processes loaded image as PIL image explicitly + ) def __init__(self, hf_config, server_args, _image_processor, *args, **kwargs): super().__init__(hf_config, server_args, _image_processor, *args, **kwargs) diff --git a/python/sglang/srt/multimodal/processors/pixtral.py b/python/sglang/srt/multimodal/processors/pixtral.py index 47b1513e8fd6..ed40fc01785f 100644 --- a/python/sglang/srt/multimodal/processors/pixtral.py +++ b/python/sglang/srt/multimodal/processors/pixtral.py @@ -19,6 +19,7 @@ class PixtralProcessor(BaseMultimodalProcessor): models = [PixtralVisionModel, PixtralForConditionalGeneration] + gpu_image_decode = False # Pixtral processes loaded image as PIL image explicitly PAD_TOKEN = "" DEFAULT_IMAGE_TOKEN = "[IMG]" diff --git a/python/sglang/srt/multimodal/processors/qwen_audio.py b/python/sglang/srt/multimodal/processors/qwen_audio.py index 817b880502ca..90c2ffd456f4 100644 --- a/python/sglang/srt/multimodal/processors/qwen_audio.py +++ b/python/sglang/srt/multimodal/processors/qwen_audio.py @@ -61,7 +61,7 @@ def get_mm_data(self, prompt, embeddings, **kwargs): mm_items.append( MultimodalDataItem( modality=modality, - offsets=offset, + offsets=[offset], precomputed_embeddings=embedding_slice, ) ) diff --git a/python/sglang/srt/multimodal/processors/qwen_vl.py b/python/sglang/srt/multimodal/processors/qwen_vl.py index 2e1159dcd958..95c7cd21a38d 100644 --- a/python/sglang/srt/multimodal/processors/qwen_vl.py +++ b/python/sglang/srt/multimodal/processors/qwen_vl.py @@ -234,6 +234,7 @@ async def preprocess_video( # Compatible with Qwen-VL & Qwen-Omni Series class QwenVLImageProcessor(SGLangBaseProcessor): + supports_transformers_backend = True models = [ Qwen2VLForConditionalGeneration, Qwen2_5_VLForConditionalGeneration, @@ -468,7 +469,7 @@ def get_mm_data(self, prompt, embeddings, **kwargs): mm_items.append( MultimodalDataItem( modality=modality, - offsets=offset, + offsets=[offset], precomputed_embeddings=embedding_slice, ) ) diff --git a/python/sglang/srt/multimodal/processors/transformers_auto.py b/python/sglang/srt/multimodal/processors/transformers_auto.py new file mode 100644 index 000000000000..b99f06616fe1 --- /dev/null +++ b/python/sglang/srt/multimodal/processors/transformers_auto.py @@ -0,0 +1,215 @@ +from typing import Optional + +import torch + +from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem +from sglang.srt.multimodal.processors.base_processor import ( + BaseMultimodalProcessor, + MultimodalSpecialTokens, +) +from sglang.srt.utils import load_image + + +def _first_attr(obj, names: tuple[str, ...], default=None): + for name in names: + value = getattr(obj, name, None) + if value is not None: + return value + return default + + +def _uses_mrope(hf_config) -> bool: + text_config = getattr(hf_config, "text_config", hf_config) + rope_scaling = getattr(text_config, "rope_scaling", None) or {} + if isinstance(rope_scaling, dict) and "mrope_section" in rope_scaling: + return True + rope_type = str(getattr(text_config, "rope_type", "")).lower() + return "mrope" in rope_type + + +class TransformersAutoMultimodalProcessor(BaseMultimodalProcessor): + """Generic multimodal processor for the Transformers backend. + + Unlike model-specific processors that rely on regex-based token matching + in the raw prompt, this processor applies the HF processor directly to + the prompt text + raw media. This handles models like Gemma3 where the + chat template uses a marker (````) that the HF processor + internally expands into placeholder tokens. + """ + + models = [] + + def __init__(self, hf_config, server_args, _processor, *args, **kwargs): + super().__init__(hf_config, server_args, _processor, *args, **kwargs) + self.mm_tokens = MultimodalSpecialTokens( + image_token=getattr(_processor, "image_token", None), + video_token=getattr(_processor, "video_token", None), + audio_token=getattr(_processor, "audio_token", None), + image_token_id=_first_attr( + hf_config, + ("image_token_id", "image_token_index", "im_token_id"), + ), + video_token_id=_first_attr( + hf_config, + ("video_token_id",), + ), + audio_token_id=_first_attr( + hf_config, + ("audio_token_id",), + ), + ).build(_processor) + + self._is_mrope = _uses_mrope(hf_config) + if self._is_mrope: + vision_config = getattr(hf_config, "vision_config", None) + self._spatial_merge_size = getattr(vision_config, "spatial_merge_size", 2) + self._tokens_per_second = getattr(vision_config, "tokens_per_second", None) + self._vision_start_token_id = _first_attr( + hf_config, ("vision_start_token_id",) + ) + self._model_type = getattr(hf_config, "model_type", "") + + def _compute_mrope_positions( + self, + input_ids: list[int], + image_grid_thw: Optional[torch.Tensor] = None, + video_grid_thw: Optional[torch.Tensor] = None, + ): + from sglang.srt.layers.rotary_embedding import MRotaryEmbedding + + input_ids_tensor = torch.tensor(input_ids, dtype=torch.long).unsqueeze(0) + mrope_positions, mrope_position_delta = MRotaryEmbedding.get_rope_index( + spatial_merge_size=self._spatial_merge_size, + image_token_id=self.mm_tokens.image_token_id, + video_token_id=self.mm_tokens.video_token_id or -1, + vision_start_token_id=self._vision_start_token_id, + model_type=self._model_type, + input_ids=input_ids_tensor, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + tokens_per_second=self._tokens_per_second, + ) + return mrope_positions.squeeze(1), mrope_position_delta + + def _load_images(self, image_data) -> list: + """Download / decode images from URLs, file paths, or base64.""" + if not image_data: + return [] + images = [] + for data in image_data: + img, _ = load_image(data) + if img.mode != "RGB": + img = img.convert("RGB") + images.append(img) + return images + + def _apply_hf_processor(self, text: str, images=None, videos=None): + """Run the HF processor on text + media and return the full output. + + This is the key method that makes the generic processor work for + models with non-trivial token expansion (Gemma3, PaliGemma, etc.). + The HF processor handles chat-template expansion, image token + insertion, and tokenization in one shot. + """ + kwargs = {} + if images: + kwargs["images"] = images + if videos: + kwargs["videos"] = videos + return self._processor(text=text, return_tensors="pt", **kwargs) + + def _build_mm_items( + self, processor_output: dict, input_ids: torch.Tensor + ) -> list[MultimodalDataItem]: + """Extract MultimodalDataItem objects from the HF processor output.""" + items = self.collect_mm_items_from_processor_output(processor_output) + + modality_to_token_id = { + Modality.IMAGE: self.mm_tokens.image_token_id, + Modality.MULTI_IMAGES: self.mm_tokens.image_token_id, + Modality.VIDEO: self.mm_tokens.video_token_id, + Modality.AUDIO: self.mm_tokens.audio_token_id, + } + + for item in items: + token_id = modality_to_token_id.get(item.modality) + if token_id is not None: + item.offsets = self.get_mm_items_offset(input_ids, token_id) + + return items + + async def process_mm_data_async( + self, + image_data, + audio_data, + input_text, + request_obj, + **kwargs, + ): + video_data = getattr(request_obj, "video_data", None) + if video_data is not None and not isinstance(video_data, list): + video_data = [video_data] + + # Load raw media + images = self._load_images(image_data) + # TODO: video / audio loading when needed + + # Apply HF processor — handles token expansion internally + processor_output = self._apply_hf_processor( + text=input_text, + images=images or None, + videos=video_data or None, + ) + + input_ids = processor_output["input_ids"].flatten() + + # Build mm_items from processor output + mm_items = self._build_mm_items(processor_output, input_ids) + + ret = { + "input_ids": input_ids.tolist(), + "mm_items": mm_items, + } + + # Propagate token_type_ids for models that need it (Gemma3, PaliGemma) + token_type_key = ( + "mm_token_type_ids" + if "mm_token_type_ids" in processor_output + else "token_type_ids" + ) + if token_type_key in processor_output: + ret["token_type_ids"] = processor_output[token_type_key].flatten().tolist() + + if self.mm_tokens.image_token_id is not None: + ret["im_token_id"] = self.mm_tokens.image_token_id + if self.mm_tokens.video_token_id is not None: + ret["video_token_id"] = self.mm_tokens.video_token_id + if self.mm_tokens.audio_token_id is not None: + ret["audio_token_id"] = self.mm_tokens.audio_token_id + + image_start_id = _first_attr( + self.hf_config, + ("image_start_token_id", "vision_start_token_id", "im_start_id"), + ) + image_end_id = _first_attr( + self.hf_config, + ("image_end_token_id", "vision_end_token_id", "im_end_id"), + ) + if image_start_id is not None: + ret["im_start_id"] = image_start_id + if image_end_id is not None: + ret["im_end_id"] = image_end_id + + # M-RoPE positions (Qwen2.5-VL, Qwen3-VL) + if self._is_mrope: + image_grid_thw = processor_output.get("image_grid_thw") + video_grid_thw = processor_output.get("video_grid_thw") + mrope_positions, mrope_position_delta = self._compute_mrope_positions( + ret["input_ids"], + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + ) + ret["mrope_positions"] = mrope_positions + ret["mrope_position_delta"] = mrope_position_delta + + return ret diff --git a/python/sglang/srt/multimodal/processors/whisper.py b/python/sglang/srt/multimodal/processors/whisper.py index 2737b2862eac..c09aa885426e 100644 --- a/python/sglang/srt/multimodal/processors/whisper.py +++ b/python/sglang/srt/multimodal/processors/whisper.py @@ -115,10 +115,9 @@ def __init__(self, hf_config, server_args, _processor, *args, **kwargs): # Cache tokenizer for language token lookup self._tokenizer = getattr(self._processor, "tokenizer", None) - def _extract_language_from_request(self, request_obj) -> Optional[str]: + def _pop_sampling_param(self, request_obj, key: str): sampling_params = getattr(request_obj, "sampling_params", None) or {} - language = sampling_params.pop("language", None) - return normalize_language_to_code(language) + return sampling_params.pop(key, None) def _get_language_token_id(self, language: Optional[str]) -> int: # Default to English if not specified @@ -148,27 +147,35 @@ async def process_mm_data_async( # For Whisper, ALWAYS use the proper transcription token sequence # and IGNORE any text prompt - Whisper is a pure speech-to-text model # The decoder_start_token_id and forced_decoder_ids from generation config - # set up: <|startoftranscript|> <|lang|> <|task|> [<|notimestamps|>] + # set up: <|startoftranscript|> <|lang|> <|task|> [<|notimestamps|> or <|0.00|>] - # Extract language from request and get token ID - language = self._extract_language_from_request(request_obj) + language = normalize_language_to_code( + self._pop_sampling_param(request_obj, "language") + ) language_token_id = self._get_language_token_id(language) + timestamp_granularities = self._pop_sampling_param( + request_obj, "timestamp_granularities" + ) # Build decoder input tokens - # <|startoftranscript|> + <|lang|> + <|transcribe|> + <|notimestamps|> decoder_start_token_id = getattr( self.hf_config, "decoder_start_token_id", 50258 ) transcribe_token_id = self._tokenizer.convert_tokens_to_ids("<|transcribe|>") - notimestamps_token_id = self._tokenizer.convert_tokens_to_ids( - "<|notimestamps|>" - ) + + # Use <|0.00|> to enable timestamp generation, or <|notimestamps|> to disable + if timestamp_granularities: + timestamp_token_id = self._tokenizer.convert_tokens_to_ids("<|0.00|>") + else: + timestamp_token_id = self._tokenizer.convert_tokens_to_ids( + "<|notimestamps|>" + ) input_ids = [ decoder_start_token_id, language_token_id, transcribe_token_id, - notimestamps_token_id, + timestamp_token_id, ] # Whisper expects input features padded to max_length (3000 frames = 30 seconds) diff --git a/python/sglang/srt/multimodal/vit_cuda_graph_runner.py b/python/sglang/srt/multimodal/vit_cuda_graph_runner.py index cfdf62915a55..8819cfdaba86 100644 --- a/python/sglang/srt/multimodal/vit_cuda_graph_runner.py +++ b/python/sglang/srt/multimodal/vit_cuda_graph_runner.py @@ -17,11 +17,13 @@ from __future__ import annotations import inspect +from contextlib import nullcontext from typing import Dict, Hashable, List, Optional, Tuple import torch import torch.nn as nn +from sglang.srt.distributed.parallel_state import get_tp_group from sglang.srt.layers.attention.vision import VisionAttention from sglang.srt.server_args import get_global_server_args @@ -139,7 +141,11 @@ def _create_graph( override_backend = get_global_server_args().mm_attention_backend - with torch.cuda.graph(graph): + tp_group = get_tp_group() + ca_comm = tp_group.ca_comm + capture_ctx = ca_comm.capture() if ca_comm is not None else nullcontext() + + with capture_ctx, torch.cuda.graph(graph): y = None deepstack_outs: List[torch.Tensor] = [] deepstack_capture_idx = 0 diff --git a/python/sglang/srt/observability/metrics_collector.py b/python/sglang/srt/observability/metrics_collector.py index 16da400d5a39..490d755ede0f 100644 --- a/python/sglang/srt/observability/metrics_collector.py +++ b/python/sglang/srt/observability/metrics_collector.py @@ -702,6 +702,30 @@ def __init__( ), labelnames=list(labels.keys()) + ["category"], ) + self.estimated_flops_per_gpu_total = Counter( + name="sglang:estimated_flops_per_gpu_total", + documentation=( + "Estimated number of floating point operations per GPU " + "(for Model FLOPs Utilization calculations)." + ), + labelnames=labels.keys(), + ) + self.estimated_read_bytes_per_gpu_total = Counter( + name="sglang:estimated_read_bytes_per_gpu_total", + documentation=( + "Estimated number of bytes read from memory per GPU " + "(for Model FLOPs Utilization calculations)." + ), + labelnames=labels.keys(), + ) + self.estimated_write_bytes_per_gpu_total = Counter( + name="sglang:estimated_write_bytes_per_gpu_total", + documentation=( + "Estimated number of bytes written to memory per GPU " + "(for Model FLOPs Utilization calculations)." + ), + labelnames=labels.keys(), + ) self.dp_cooperation_realtime_tokens_total = Counter( name="sglang:dp_cooperation_realtime_tokens_total", @@ -928,6 +952,25 @@ def increment_gpu_execution_seconds( **dp_cooperation_info.to_labels(), ).inc(t) + def increment_estimated_perf( + self, + num_flops_per_gpu: float = 0.0, + num_read_bytes_per_gpu: float = 0.0, + num_write_bytes_per_gpu: float = 0.0, + ) -> None: + if num_flops_per_gpu > 0: + self.estimated_flops_per_gpu_total.labels(**self.labels).inc( + num_flops_per_gpu + ) + if num_read_bytes_per_gpu > 0: + self.estimated_read_bytes_per_gpu_total.labels(**self.labels).inc( + num_read_bytes_per_gpu + ) + if num_write_bytes_per_gpu > 0: + self.estimated_write_bytes_per_gpu_total.labels(**self.labels).inc( + num_write_bytes_per_gpu + ) + def log_stats(self, stats: SchedulerStats) -> None: self._log_gauge_queue_count(self.num_running_reqs, stats.num_running_reqs) self._log_gauge(self.num_used_tokens, stats.num_used_tokens) diff --git a/python/sglang/srt/observability/request_metrics_exporter.py b/python/sglang/srt/observability/request_metrics_exporter.py index ca4db1a9b6ff..14ece7498a6a 100644 --- a/python/sglang/srt/observability/request_metrics_exporter.py +++ b/python/sglang/srt/observability/request_metrics_exporter.py @@ -7,6 +7,7 @@ from datetime import datetime from typing import List, Optional, Union +from sglang.srt.constants import HEALTH_CHECK_RID_PREFIX from sglang.srt.managers.io_struct import EmbeddingReqInput, GenerateReqInput from sglang.srt.server_args import ServerArgs @@ -128,7 +129,7 @@ async def write_record( self, obj: Union[GenerateReqInput, EmbeddingReqInput], out_dict: dict ): # Do not log health check requests, since they don't represent real user requests. - if isinstance(obj.rid, str) and "HEALTH_CHECK" in obj.rid: + if isinstance(obj.rid, str) and HEALTH_CHECK_RID_PREFIX in obj.rid: return try: diff --git a/python/sglang/srt/observability/scheduler_metrics_mixin.py b/python/sglang/srt/observability/scheduler_metrics_mixin.py index 052c23645693..ff5695ce2e8a 100644 --- a/python/sglang/srt/observability/scheduler_metrics_mixin.py +++ b/python/sglang/srt/observability/scheduler_metrics_mixin.py @@ -5,7 +5,7 @@ import time from collections import defaultdict from contextlib import contextmanager -from typing import TYPE_CHECKING, List, Optional, Union +from typing import TYPE_CHECKING, List, Optional, Tuple, Union from sglang.srt.disaggregation.kv_events import EventPublisherFactory, KVEventBatch from sglang.srt.disaggregation.utils import DisaggregationMode @@ -114,6 +114,7 @@ def init_metrics( self.stats = SchedulerStats() # Metrics + self.enable_mfu_metrics = False self.enable_metrics = self.server_args.enable_metrics self.is_stats_logging_rank = self.attn_tp_rank == 0 self.current_scheduler_metrics_enabled = self.enable_metrics and ( @@ -148,6 +149,12 @@ def init_metrics( enable_hierarchical_cache=self.enable_hierarchical_cache, server_args=self.server_args, ) + self.enable_mfu_metrics = bool(self.server_args.enable_mfu_metrics) + if self.enable_mfu_metrics: + self._init_estimated_perf_constants() + self._mfu_log_flops = 0.0 + self._mfu_log_read_bytes = 0.0 + self._mfu_log_write_bytes = 0.0 if ENABLE_METRICS_DEVICE_TIMER: self.forward_pass_device_timer = DeviceTimer( @@ -175,6 +182,139 @@ def update_spec_metrics(self: Scheduler, bs: int, num_accepted_tokens: int): self.spec_num_forward_ct += bs self.num_generated_tokens += num_accepted_tokens + def _init_estimated_perf_constants(self: Scheduler) -> None: + model_config = self.model_config + hf_text_config = model_config.hf_text_config + + hidden_size = float(model_config.hidden_size) + num_layers = float(getattr(model_config, "num_attention_layers", 0)) + head_dim = float(getattr(model_config, "head_dim", 0)) + num_attn_heads = float(model_config.get_num_attention_heads(self.tp_size)) + num_kv_heads = float(model_config.get_num_kv_heads(self.tp_size)) + intermediate_size = getattr(hf_text_config, "intermediate_size", None) + if intermediate_size is None: + intermediate_size = getattr(hf_text_config, "ffn_hidden_size", 0) + intermediate_size = float(intermediate_size) + + dtype_num_bytes = getattr(model_config.dtype, "itemsize", None) + if dtype_num_bytes is None: + dtype_num_bytes = 2 + # Keep this estimator lightweight and consistent with current server dtype. + # KV cache quantization-aware bytes can be added in a follow-up. + act_bytes = float(dtype_num_bytes) + w_bytes = float(dtype_num_bytes) + cache_bytes = float(dtype_num_bytes) + + # Linear-layer FLOPs per token on one GPU. + attn_linear_flops = ( + 2.0 * hidden_size * head_dim * (num_attn_heads + 2.0 * num_kv_heads) + + 2.0 * hidden_size * head_dim * num_attn_heads + ) + mlp_flops = ( + 6.0 * hidden_size * intermediate_size if intermediate_size > 0 else 0.0 + ) + self._linear_flops_per_token = max( + 0.0, (attn_linear_flops + mlp_flops) * num_layers + ) + + # Attention dot-product FLOPs coefficient to multiply token-context product. + # attn_qk + attn_av = 4 * q * TC * d * L + self._attn_dot_flops_coeff = 4.0 * num_attn_heads * head_dim * num_layers + + # KV cache bytes (write one K and one V vector per generated token). + self._kv_cache_bytes_per_token = ( + 2.0 * num_layers * num_kv_heads * head_dim * cache_bytes + ) + + # Weight read bytes per token. + self._weight_read_bytes_per_token = ( + hidden_size + * head_dim + * (num_attn_heads + 2.0 * num_kv_heads) + * w_bytes + * num_layers + + hidden_size * head_dim * num_attn_heads * w_bytes * num_layers + + ( + 3.0 * hidden_size * intermediate_size * w_bytes * num_layers + if intermediate_size > 0 + else 0.0 + ) + ) + + # Activation movement bytes per token (coarse approximation). + self._qkv_act_bytes_per_token = ( + hidden_size * act_bytes * num_layers + + (num_attn_heads + 2.0 * num_kv_heads) * head_dim * act_bytes * num_layers + + head_dim * num_attn_heads * act_bytes * num_layers + + hidden_size * act_bytes * num_layers + ) + self._ffn_act_bytes_per_token = ( + 3.0 * intermediate_size * act_bytes * num_layers + if intermediate_size > 0 + else 0.0 + ) + + # Prefill reads Q/K/V activations from on-device memory. + self._prefill_attn_act_read_per_token = ( + (num_attn_heads + 2.0 * num_kv_heads) * head_dim * act_bytes * num_layers + ) + + # Decode reads Q from activation memory; K/V reads are from KV cache. + self._decode_q_read_bytes_per_token = ( + num_attn_heads * head_dim * act_bytes * num_layers + ) + + def _estimate_prefill_perf( + self: Scheduler, num_tokens: int + ) -> Tuple[float, float, float]: + tokens = max(0, int(num_tokens)) + if tokens == 0: + return 0.0, 0.0, 0.0 + + # Causal prefill token-context product. + context_product = tokens * (tokens + 1) / 2.0 + flops = ( + tokens * self._linear_flops_per_token + + self._attn_dot_flops_coeff * context_product + ) + + read_bytes = ( + tokens * self._weight_read_bytes_per_token + + tokens * self._qkv_act_bytes_per_token + + tokens * self._prefill_attn_act_read_per_token + ) + write_bytes = ( + tokens * self._kv_cache_bytes_per_token + + tokens * self._qkv_act_bytes_per_token + + tokens * self._ffn_act_bytes_per_token + ) + return flops, read_bytes, write_bytes + + def _estimate_decode_perf( + self: Scheduler, batch: ScheduleBatch, num_tokens: int + ) -> Tuple[float, float, float]: + tokens = max(0, int(num_tokens)) + if tokens == 0: + return 0.0, 0.0, 0.0 + + total_context = float(batch.seq_lens_cpu.sum().item()) + flops = ( + tokens * self._linear_flops_per_token + + self._attn_dot_flops_coeff * total_context + ) + read_bytes = ( + tokens * self._weight_read_bytes_per_token + + tokens * self._qkv_act_bytes_per_token + + tokens * self._decode_q_read_bytes_per_token + + total_context * self._kv_cache_bytes_per_token + ) + write_bytes = ( + tokens * self._kv_cache_bytes_per_token + + tokens * self._qkv_act_bytes_per_token + + tokens * self._ffn_act_bytes_per_token + ) + return flops, read_bytes, write_bytes + def reset_metrics(self: Scheduler): self.forward_ct_decode = 0 self.num_generated_tokens = 0 @@ -275,6 +415,11 @@ def report_prefill_stats( msg += f"{graph_backend[self.device]}: {can_run_cuda_graph}, " msg += f"input throughput (token/s): {self.last_input_throughput:.2f}" + if self.enable_mfu_metrics and gap_latency > 0: + flops, _, _ = self._estimate_prefill_perf(prefill_stats.log_input_tokens) + tflops_per_s = flops / gap_latency / 1e12 + msg += f", est. prefill TFLOPS/s (per GPU): {tflops_per_s:.2f}" + if self.is_stats_logging_rank: logger.info(msg) @@ -287,6 +432,15 @@ def report_prefill_stats( prefill_cache_tokens=prefill_stats.log_hit_tokens, dp_cooperation_info=dp_cooperation_info, ) + if self.enable_mfu_metrics: + flops, read_bytes, write_bytes = self._estimate_prefill_perf( + prefill_stats.log_input_tokens + ) + self.metrics_collector.increment_estimated_perf( + num_flops_per_gpu=flops, + num_read_bytes_per_gpu=read_bytes, + num_write_bytes_per_gpu=write_bytes, + ) # Basics total_tokens = prefill_stats.log_input_tokens + prefill_stats.log_hit_tokens @@ -354,11 +508,24 @@ def report_decode_stats( # Every-iteration work: realtime token counting + status logger if self.current_scheduler_metrics_enabled: + decode_tokens = batch.batch_size() + num_accepted_tokens self.metrics_collector.increment_realtime_tokens( # TODO unify this w/ the bumping logic in `Scheduler.num_generated_tokens` accumulator - decode_tokens=batch.batch_size() + num_accepted_tokens, + decode_tokens=decode_tokens, dp_cooperation_info=batch.dp_cooperation_info, ) + if self.enable_mfu_metrics: + flops, read_bytes, write_bytes = self._estimate_decode_perf( + batch, decode_tokens + ) + self.metrics_collector.increment_estimated_perf( + num_flops_per_gpu=flops, + num_read_bytes_per_gpu=read_bytes, + num_write_bytes_per_gpu=write_bytes, + ) + self._mfu_log_flops += flops + self._mfu_log_read_bytes += read_bytes + self._mfu_log_write_bytes += write_bytes if x := self.scheduler_status_logger: x.maybe_dump(batch, self.waiting_queue) @@ -490,6 +657,22 @@ def report_decode_stats( f"#queue-req: {len(self.waiting_queue)}" ) + if self.enable_mfu_metrics and gap_latency > 0: + flops_per_s = self._mfu_log_flops / gap_latency + read_bytes_per_s = self._mfu_log_read_bytes / gap_latency + write_bytes_per_s = self._mfu_log_write_bytes / gap_latency + tflops_per_s = flops_per_s / 1e12 + read_gb_per_s = read_bytes_per_s / 1e9 + write_gb_per_s = write_bytes_per_s / 1e9 + msg += ( + f", est. decode TFLOPS/s (per GPU): {tflops_per_s:.2f}, " + f"est. read BW (GB/s per GPU): {read_gb_per_s:.2f}, " + f"est. write BW (GB/s per GPU): {write_gb_per_s:.2f}" + ) + self._mfu_log_flops = 0.0 + self._mfu_log_read_bytes = 0.0 + self._mfu_log_write_bytes = 0.0 + if self.is_stats_logging_rank: logger.info(msg) if self.current_scheduler_metrics_enabled: diff --git a/python/sglang/srt/parser/reasoning_parser.py b/python/sglang/srt/parser/reasoning_parser.py index a6867c9f8b54..c3dbb3116464 100644 --- a/python/sglang/srt/parser/reasoning_parser.py +++ b/python/sglang/srt/parser/reasoning_parser.py @@ -495,6 +495,7 @@ class ReasoningParser: "gpt-oss": GptOssDetector, "kimi": KimiDetector, "kimi_k2": KimiK2Detector, + "mimo": Qwen3Detector, "qwen3": Qwen3Detector, "qwen3-thinking": Qwen3Detector, "minimax": Qwen3Detector, diff --git a/python/sglang/srt/ray/engine.py b/python/sglang/srt/ray/engine.py index 94c26436edde..f36bdd4607b1 100644 --- a/python/sglang/srt/ray/engine.py +++ b/python/sglang/srt/ray/engine.py @@ -90,8 +90,13 @@ def _launch_scheduler_processes( server_args: ServerArgs, port_args: PortArgs, run_scheduler_process_func: Callable, - ) -> SchedulerInitResult: - """Launch schedulers as Ray actors.""" + ) -> tuple[SchedulerInitResult, None]: + """Launch schedulers as Ray actors. + + Returns: + Tuple of (RaySchedulerInitResult, None). + scheduler_procs is None since Ray uses actors instead of mp.Process. + """ if server_args.dp_size > 1: raise NotImplementedError( "Ray support for dp_size > 1 is not yet implemented. " @@ -183,8 +188,11 @@ def wait_for_completion(): except Exception as e: logger.error(f"Ray scheduler actor terminated with error: {e}") - return RaySchedulerInitResult( - scheduler_infos=scheduler_infos, - wait_for_completion=wait_for_completion, - scheduler_actors=scheduler_actors, + return ( + RaySchedulerInitResult( + scheduler_infos=scheduler_infos, + wait_for_completion=wait_for_completion, + scheduler_actors=scheduler_actors, + ), + None, ) diff --git a/python/sglang/srt/ray/http_server.py b/python/sglang/srt/ray/http_server.py index c0580838e195..c2acda83e248 100644 --- a/python/sglang/srt/ray/http_server.py +++ b/python/sglang/srt/ray/http_server.py @@ -44,13 +44,17 @@ def launch_server( if execute_warmup_func is None: execute_warmup_func = _execute_server_warmup - tokenizer_manager, template_manager, port_args, scheduler_init_result = ( - RayEngine._launch_subprocesses( - server_args, - init_tokenizer_manager_func=init_tokenizer_manager_func, - run_scheduler_process_func=run_scheduler_process_func, - run_detokenizer_process_func=run_detokenizer_process_func, - ) + ( + tokenizer_manager, + template_manager, + port_args, + scheduler_init_result, + subprocess_watchdog, + ) = RayEngine._launch_subprocesses( + server_args, + init_tokenizer_manager_func=init_tokenizer_manager_func, + run_scheduler_process_func=run_scheduler_process_func, + run_detokenizer_process_func=run_detokenizer_process_func, ) _setup_and_run_http_server( @@ -59,6 +63,7 @@ def launch_server( template_manager, port_args, scheduler_init_result.scheduler_infos, + subprocess_watchdog, execute_warmup_func=execute_warmup_func, launch_callback=launch_callback, ) diff --git a/python/sglang/srt/sampling/penaltylib/__init__.py b/python/sglang/srt/sampling/penaltylib/__init__.py index 26a780517ce7..9ba6d73ac68f 100644 --- a/python/sglang/srt/sampling/penaltylib/__init__.py +++ b/python/sglang/srt/sampling/penaltylib/__init__.py @@ -2,10 +2,12 @@ from sglang.srt.sampling.penaltylib.min_new_tokens import BatchedMinNewTokensPenalizer from sglang.srt.sampling.penaltylib.orchestrator import BatchedPenalizerOrchestrator from sglang.srt.sampling.penaltylib.presence_penalty import BatchedPresencePenalizer +from sglang.srt.sampling.penaltylib.repetition_penalty import BatchedRepetitionPenalizer __all__ = [ "BatchedFrequencyPenalizer", "BatchedMinNewTokensPenalizer", "BatchedPresencePenalizer", "BatchedPenalizerOrchestrator", + "BatchedRepetitionPenalizer", ] diff --git a/python/sglang/srt/sampling/penaltylib/orchestrator.py b/python/sglang/srt/sampling/penaltylib/orchestrator.py index 7ef123f554f9..650c719f37ca 100644 --- a/python/sglang/srt/sampling/penaltylib/orchestrator.py +++ b/python/sglang/srt/sampling/penaltylib/orchestrator.py @@ -52,19 +52,56 @@ def cumulate_output_tokens(self, output_ids: torch.Tensor): for penalizer in self.penalizers.values(): penalizer.cumulate_output_tokens(output_ids=output_ids) - def apply(self, logits: torch.Tensor) -> torch.Tensor: + def apply(self, logits: torch.Tensor, repeat: Optional[int] = None): """ - Apply the penalizers to the logits. - Note that it may apply the penalizers in-place. + Apply all penalizers to the logits in-place. Args: - logits (torch.Tensor): The logits to apply the penalizers to. - - Returns: - torch.Tensor: The logits after applying the penalizers. + logits: The logits tensor to apply penalties to. + repeat: If set (speculative decoding), per-request penalties are + expanded via repeat_interleave to match the draft token layout. + Additive penalties are captured into a zeros tensor, expanded, + then added; scaling penalties are accumulated, expanded, then + applied directly. """ + if repeat is None: + for penalizer in self.penalizers.values(): + penalizer.apply(logits) + else: + # Additive: capture into zeros, expand, add + bs = logits.shape[0] // repeat + additive = torch.zeros( + (bs, logits.shape[1]), dtype=torch.float32, device=logits.device + ) + self.accumulate_additive_penalties(additive) + logits.add_(torch.repeat_interleave(additive, repeat, dim=0)) + # Scaling: accumulate, expand, apply + accumulated = self.accumulate_scaling_penalties() + if accumulated is not None: + from sglang.srt.sampling.penaltylib.repetition_penalty import ( + apply_scaling_penalties, + ) + + expanded = torch.repeat_interleave(accumulated, repeat, dim=0) + apply_scaling_penalties(logits, expanded) + + def accumulate_additive_penalties(self, logits: torch.Tensor): + """Apply only additive (non-multiplicative) penalizers.""" for penalizer in self.penalizers.values(): - penalizer.apply(logits) + if not penalizer.is_multiplicative: + penalizer.apply(logits) + + def accumulate_scaling_penalties(self) -> Optional[torch.Tensor]: + """Accumulate all multiplicative penalty tensors into one, or None if none active.""" + result = None + for penalizer in self.penalizers.values(): + if not penalizer._is_prepared or not penalizer.is_multiplicative: + continue + if result is None: + result = penalizer.get_scaling_penalties().clone() + else: + result *= penalizer.get_scaling_penalties() + return result def filter(self, keep_indices: torch.Tensor): """ @@ -132,6 +169,8 @@ class _BatchedPenalizer(abc.ABC): An abstract class for a batched penalizer. """ + is_multiplicative: bool = False + def __init__(self, orchestrator: BatchedPenalizerOrchestrator): self._orchestrator_ref: weakref.ReferenceType[BatchedPenalizerOrchestrator] = ( weakref.ref(orchestrator) @@ -227,6 +266,13 @@ def _apply(self, logits: torch.Tensor) -> torch.Tensor: """ pass + def get_scaling_penalties(self) -> torch.Tensor: + """ + Return the accumulated scaling penalty tensor for multiplicative penalizers. + Only meaningful when is_multiplicative is True. Subclasses should override. + """ + raise NotImplementedError + @abc.abstractmethod def _filter(self, keep_indices: torch.Tensor): """ diff --git a/python/sglang/srt/sampling/penaltylib/repetition_penalty.py b/python/sglang/srt/sampling/penaltylib/repetition_penalty.py new file mode 100644 index 000000000000..fd03fb2b5c89 --- /dev/null +++ b/python/sglang/srt/sampling/penaltylib/repetition_penalty.py @@ -0,0 +1,78 @@ +import torch + +from sglang.srt.sampling.penaltylib.orchestrator import _BatchedPenalizer +from sglang.srt.utils import get_compiler_backend + + +@torch.compile(dynamic=True, backend=get_compiler_backend()) +def apply_scaling_penalties(logits, scaling_penalties): + logits[:] = torch.where( + logits < 0, + logits * scaling_penalties, + logits / scaling_penalties, + ) + + +class BatchedRepetitionPenalizer(_BatchedPenalizer): + """ + Repetition penalizer penalizes tokens based on their presence in the generated output. + """ + + is_multiplicative: bool = True + + def _is_required(self) -> bool: + return any( + req.sampling_params.repetition_penalty != 1.0 + for req in self.orchestrator.reqs() + ) + + def _prepare(self): + self.cumulated_repetition_penalties = torch.ones( + (len(self.orchestrator.reqs()), self.orchestrator.vocab_size), + dtype=torch.float32, + device=self.orchestrator.device, + ) + self.repetition_penalties = ( + torch.tensor( + data=[ + req.sampling_params.repetition_penalty + for req in self.orchestrator.reqs() + ], + dtype=torch.float32, + device=self.orchestrator.device, + ) + ).unsqueeze_(1) + + def _cumulate_output_tokens(self, output_ids: torch.Tensor): + self.cumulated_repetition_penalties.scatter_( + dim=1, + index=output_ids.unsqueeze(1), + src=self.repetition_penalties, + ) + + def _apply(self, logits: torch.Tensor) -> torch.Tensor: + apply_scaling_penalties(logits, self.cumulated_repetition_penalties) + return logits + + def get_scaling_penalties(self) -> torch.Tensor: + return self.cumulated_repetition_penalties + + def _filter(self, keep_indices: torch.Tensor): + self.repetition_penalties = self.repetition_penalties[keep_indices] + self.cumulated_repetition_penalties = self.cumulated_repetition_penalties[ + keep_indices + ] + + def _merge(self, their: "BatchedRepetitionPenalizer"): + self.repetition_penalties = torch.cat( + [self.repetition_penalties, their.repetition_penalties], dim=0 + ) + self.cumulated_repetition_penalties = torch.cat( + [self.cumulated_repetition_penalties, their.cumulated_repetition_penalties], + dim=0, + ) + + def _teardown(self) -> None: + for name in ("repetition_penalties", "cumulated_repetition_penalties"): + if hasattr(self, name): + delattr(self, name) diff --git a/python/sglang/srt/sampling/sampling_batch_info.py b/python/sglang/srt/sampling/sampling_batch_info.py index 914dde0f6ac1..885936b0ec95 100644 --- a/python/sglang/srt/sampling/sampling_batch_info.py +++ b/python/sglang/srt/sampling/sampling_batch_info.py @@ -8,6 +8,7 @@ import sglang.srt.sampling.penaltylib as penaltylib from sglang.srt.sampling.custom_logit_processor import CustomLogitProcessor +from sglang.srt.sampling.penaltylib.repetition_penalty import apply_scaling_penalties from sglang.srt.sampling.sampling_params import TOP_K_ALL from sglang.srt.server_args import get_global_server_args @@ -46,7 +47,10 @@ class SamplingBatchInfo: # Penalizer penalizer_orchestrator: Optional[penaltylib.BatchedPenalizerOrchestrator] = None - acc_linear_penalties: torch.Tensor = None # Used in the overlap mode + acc_additive_penalties: Optional[torch.Tensor] = None # Used in the overlap mode + acc_scaling_penalties: Optional[torch.Tensor] = ( + None # Used in the overlap mode for repetition penalty + ) # Whether any request has custom logit processor has_custom_logit_processor: bool = False @@ -159,6 +163,7 @@ def from_schedule_batch(cls, batch: ScheduleBatch, vocab_size: int): penaltylib.BatchedFrequencyPenalizer, penaltylib.BatchedMinNewTokensPenalizer, penaltylib.BatchedPresencePenalizer, + penaltylib.BatchedRepetitionPenalizer, }, ) @@ -229,19 +234,29 @@ def update_regex_vocab_mask(self): def update_penalties(self): if self.penalizer_orchestrator.is_required: - self.acc_linear_penalties = torch.zeros( + self.acc_additive_penalties = torch.zeros( (len(self.temperatures), self.vocab_size), dtype=torch.float32, device=self.temperatures.device, ) - self.penalizer_orchestrator.apply(self.acc_linear_penalties) + self.penalizer_orchestrator.accumulate_additive_penalties( + self.acc_additive_penalties + ) + self.acc_scaling_penalties = ( + self.penalizer_orchestrator.accumulate_scaling_penalties() + ) else: - self.acc_linear_penalties = None + self.acc_additive_penalties = None + self.acc_scaling_penalties = None def apply_logits_bias(self, logits: torch.Tensor): - if self.acc_linear_penalties is not None: + if self.acc_additive_penalties is not None: + # Used in the overlap mode + logits.add_(self.acc_additive_penalties) + + if self.acc_scaling_penalties is not None: # Used in the overlap mode - logits.add_(self.acc_linear_penalties) + apply_scaling_penalties(logits, self.acc_scaling_penalties) if self.penalizer_orchestrator and self.penalizer_orchestrator.is_required: # Used in the non-overlap mode diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index e7db9bf9f579..6c5e559ce313 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -42,6 +42,7 @@ get_device_name, get_device_sm, get_int_env_var, + get_nvidia_driver_version, get_quantization_config, human_readable_int, is_blackwell_supported, @@ -67,6 +68,7 @@ ) from sglang.srt.utils.hf_transformers_utils import check_gguf_file from sglang.srt.utils.network import NetworkAddress, get_free_port, wait_port_available +from sglang.srt.utils.runai_utils import ObjectStorageModel, is_runai_obj_uri from sglang.utils import is_in_ci logger = logging.getLogger(__name__) @@ -90,6 +92,7 @@ "remote_instance", "fastsafetensors", "private", + "runai_streamer", ] QUANTIZATION_CHOICES = [ @@ -213,6 +216,7 @@ FP4_GEMM_RUNNER_BACKEND_CHOICES = [ "auto", + "cutlass", "flashinfer_cudnn", "flashinfer_cutlass", "flashinfer_trtllm", @@ -285,7 +289,7 @@ class ServerArgs: The arguments of the server. NOTE: When you add new arguments, please make sure the order - in this class definition the same as the order in the the function + in this class definition the same as the order in the function `ServerArgs.add_cli_args`. Please follow the existing style to group the new arguments into related groups or create new groups. """ @@ -398,6 +402,7 @@ class ServerArgs: crash_dump_folder: Optional[str] = None show_time_cost: bool = False enable_metrics: bool = False + enable_mfu_metrics: bool = False enable_metrics_for_all_schedulers: bool = False tokenizer_metrics_custom_labels_header: str = "x-custom-labels" tokenizer_metrics_allowed_custom_labels: Optional[List[str]] = None @@ -463,6 +468,7 @@ class ServerArgs: lora_eviction_policy: str = "lru" lora_backend: str = "csgmv" max_lora_chunk_size: Optional[int] = 16 + experts_shared_outer_loras: Optional[bool] = None # Kernel backend attention_backend: Optional[str] = None @@ -500,8 +506,6 @@ class ServerArgs: speculative_draft_model_quantization: Optional[str] = None # Speculative decoding (ngram) - speculative_ngram_min_match_window_size: int = 1 - speculative_ngram_max_match_window_size: int = 12 speculative_ngram_min_bfs_breadth: int = 1 speculative_ngram_max_bfs_breadth: int = 10 speculative_ngram_match_type: Literal["BFS", "PROB"] = "BFS" @@ -517,6 +521,7 @@ class ServerArgs: moe_runner_backend: str = "auto" flashinfer_mxfp4_moe_precision: Literal["default", "bf16"] = "default" enable_flashinfer_allreduce_fusion: bool = False + enforce_disable_flashinfer_allreduce_fusion: bool = False enable_aiter_allreduce_fusion: bool = False deepep_mode: Literal["auto", "normal", "low_latency"] = "auto" ep_num_redundant_experts: int = 0 @@ -555,7 +560,6 @@ class ServerArgs: hicache_write_policy: str = "write_through" hicache_io_backend: str = "kernel" hicache_mem_layout: str = "layer_first" - disable_hicache_numa_detect: bool = False hicache_storage_backend: Optional[str] = None hicache_storage_prefetch_policy: str = "best_effort" hicache_storage_backend_extra_config: Optional[str] = None @@ -662,6 +666,7 @@ class ServerArgs: enable_deterministic_inference: bool = False rl_on_policy_target: Optional[str] = None enable_attn_tp_input_scattered: bool = False + gc_threshold: Optional[List[int]] = None # Context parallelism used in the long sequence prefill phase of DeepSeek v3.2 enable_nsa_prefill_context_parallel: bool = False nsa_prefill_cp_mode: str = "round-robin-split" @@ -713,6 +718,7 @@ class ServerArgs: "transfer_engine", "nccl", "modelexpress" ] = "nccl" remote_instance_weight_loader_start_seed_via_transfer_engine: bool = False + engine_info_bootstrap_port: int = 6789 modelexpress_config: Optional[str] = None # For PD-Multiplexing @@ -721,8 +727,6 @@ class ServerArgs: sm_group_num: int = 8 # For Multi-Modal - mm_max_concurrent_calls: int = 32 - mm_per_request_timeout: float = 10.0 enable_broadcast_mm_inputs_process: bool = False enable_prefix_mm_cache: bool = False mm_enable_dp_encoder: bool = False @@ -742,6 +746,8 @@ def __post_init__(self): Orchestrates the handling of various server arguments, ensuring proper configuration and validation. """ + self._maybe_download_model_for_runai() + # Normalize load balancing defaults early (before dummy-model short-circuit). self._handle_load_balance_method() @@ -843,6 +849,17 @@ def __post_init__(self): # Handle any other necessary validations. self._handle_other_validations() + def _maybe_download_model_for_runai(self): + if is_runai_obj_uri(self.model_path): + ObjectStorageModel.download_and_get_path(self.model_path) + + if ( + self.tokenizer_path is not None + and is_runai_obj_uri(self.tokenizer_path) + and self.tokenizer_path != self.model_path + ): + ObjectStorageModel.download_and_get_path(self.tokenizer_path) + def _handle_load_balance_method(self): if self.disaggregation_mode not in ("null", "prefill", "decode"): raise ValueError( @@ -1353,6 +1370,9 @@ def _generate_cuda_graph_batch_sizes(self): capture_bs = [bs for bs in capture_bs if bs <= self.cuda_graph_max_bs] + if self.cuda_graph_max_bs not in capture_bs: + capture_bs.append(self.cuda_graph_max_bs) + return capture_bs def _generate_cpu_graph_batch_sizes(self): @@ -1409,15 +1429,10 @@ def _set_default_nsa_kv_cache_dtype(self, major: int, quantization: str) -> str: ) if self.kv_cache_dtype == "auto": - # TODO: Temporarily set default dtype on B200 as bfloat16 to avoid performance regression. - # TODO: Remove this after the performance regression is fixed. (Ref: https://github.com/sgl-project/sglang/issues/21291) - if quantization == "modelopt_fp4" and major >= 10 and self.dp_size > 1: + if major >= 10: self.kv_cache_dtype = "fp8_e4m3" else: self.kv_cache_dtype = "bfloat16" - # self.kv_cache_dtype = ( - # "fp8_e4m3" if (major >= 10 and self.dp_size > 1) else "bfloat16" - # ) logger.warning( f"Setting KV cache dtype to {self.kv_cache_dtype} for DeepSeek DSA on SM{major} device." ) @@ -1432,16 +1447,25 @@ def _set_default_nsa_backends(self, kv_cache_dtype: str, major: int) -> str: user_set_prefill = self.nsa_prefill_backend is not None user_set_decode = self.nsa_decode_backend is not None + # HiSparse requires flashmla_sparse for both prefill and decode + if self.enable_hisparse: + if not user_set_prefill: + self.nsa_prefill_backend = "flashmla_sparse" + if not user_set_decode: + self.nsa_decode_backend = "flashmla_sparse" + logger.warning( + f"HiSparse enabled: using flashmla_sparse NSA backends " + f"(prefill={self.nsa_prefill_backend}, decode={self.nsa_decode_backend})." + ) + return + if not user_set_prefill and not user_set_decode and is_hip(): self.nsa_prefill_backend = "tilelang" self.nsa_decode_backend = "tilelang" elif kv_cache_dtype == "fp8_e4m3": - if self.dp_size == 1 and major >= 10: + if major >= 10: self.nsa_prefill_backend = "trtllm" self.nsa_decode_backend = "trtllm" - logger.warning( - "Flashmla is not supported on Blackwell device without DP attention. Set NSA prefill/decode backends to trtllm, which runs fast but loses a little accuracy." - ) else: # flashmla_auto dispatches to flashmla_sparse/flashmla_kv based on hardware and heuristics if not user_set_prefill: @@ -1510,14 +1534,6 @@ def _handle_model_specific_adjustments(self): logger.warning( f"Set dense attention kv len threshold to model index_topk={envs.SGLANG_NSA_PREFILL_DENSE_ATTN_KV_LEN_THRESHOLD.get()} for DeepSeek with DSA." ) - if self.nsa_prefill_backend == "trtllm": - # We temporarily set the threshold to 128k to avoid IMA error. Should be removed after supporting flashmla prefill impl with trtllm decode impl. - envs.SGLANG_NSA_PREFILL_DENSE_ATTN_KV_LEN_THRESHOLD.set( - 128 * 1024 - ) - logger.warning( - "TRTLLM sparse MLA kernel requires MHA as prefill impl, the threshold for dense attention is overridden. This will be fixed in the future." - ) if self.is_attention_backend_not_set(): self.attention_backend = "nsa" logger.info("Use nsa attention backend for DeepSeek with DSA.") @@ -1711,13 +1727,6 @@ def _handle_model_specific_adjustments(self): quant_method = get_quantization_config(hf_config) is_mxfp4_quant_format = quant_method == "mxfp4" - if is_blackwell_supported(): - # workaround for https://github.com/flashinfer-ai/flashinfer/issues/2006 - if not self.enable_dp_attention and self.nnodes == 1: - self.enable_flashinfer_allreduce_fusion = True - logger.info( - "Enable FlashInfer AllReduce Fusion on sm100 for GptOssForCausalLM" - ) if not self.enable_dp_attention and self.nnodes == 1 and is_hip(): # TODO (Hubert): Put this back later # self.enable_aiter_allreduce_fusion = True @@ -1762,10 +1771,21 @@ def _handle_model_specific_adjustments(self): and is_triton_kernels_available() and self.quantization is None ): - self.moe_runner_backend = "triton_kernel" - logger.warning( - "Detected GPT-OSS model, enabling triton_kernels MOE kernel." - ) + # The triton_kernels package segfaults on Blackwell (B200) + # with NVIDIA driver >= 595. Fall back to triton backend. + if is_blackwell_supported() and get_nvidia_driver_version() >= ( + 595, + ): + self.moe_runner_backend = "triton" + logger.warning( + "Detected GPT-OSS model on Blackwell with driver >= 595, " + "using triton MOE kernel to avoid triton_kernels SIGSEGV." + ) + else: + self.moe_runner_backend = "triton_kernel" + logger.warning( + "Detected GPT-OSS model, enabling triton_kernels MOE kernel." + ) if self.moe_runner_backend == "triton_kernel": assert ( @@ -2081,6 +2101,7 @@ def _handle_model_specific_adjustments(self): "Qwen3_5ForConditionalGeneration", ] and (is_sm90_supported() or is_sm100_supported()) + and self.tp_size > 1 and not self.enable_dp_attention and self.attn_cp_size <= 1 and self.nnodes == 1 @@ -2088,6 +2109,17 @@ def _handle_model_specific_adjustments(self): and self.moe_a2a_backend == "none" ): self.enable_flashinfer_allreduce_fusion = True + logger.info( + f"Auto-enabling FlashInfer AllReduce Fusion on SM90/SM10X for {model_arch}" + ) + + # Apply enforce_disable_flashinfer_allreduce_fusion after all model-specific adjustments + if self.enforce_disable_flashinfer_allreduce_fusion: + self.enable_flashinfer_allreduce_fusion = False + logger.info( + "FlashInfer allreduce fusion is forcibly disabled " + "via --enforce-disable-flashinfer-allreduce-fusion." + ) def _handle_mamba_radix_cache( self, @@ -2197,6 +2229,12 @@ def _get_default_attn_backend(self, use_mla_backend: bool, model_config): 2.2 We will use Flashinfer backend on blackwell. 2.3 Otherwise, we will use triton backend. """ + # Whisper requires flashinfer for cross-attention CUDA graph support + if "WhisperForConditionalGeneration" in ( + model_config.hf_config.architectures or [] + ): + return "flashinfer" + if not use_mla_backend: # MHA architecture if is_hopper_with_cuda_12_3() and is_no_spec_infer_or_topk_one(self): @@ -2272,12 +2310,16 @@ def _handle_attention_backend_compatibility(self): self.speculative_algorithm is None ), "Speculative decoding is currently not supported with Flex Attention backend" - # Encoder-decoder models (e.g., Whisper) - if model_config.is_encoder_decoder: - logger.warning( - "Cuda graph is disabled for encoder-decoder models (e.g., Whisper)" - ) - self.disable_cuda_graph = True + # Whisper's encoder token padding conflicts with prefix caching. + # Only disable for Whisper; other encoder-decoder models (e.g., mllama) use radix cache. + if ( + model_config.is_encoder_decoder + and not self.disable_radix_cache + and "WhisperForConditionalGeneration" + in (model_config.hf_config.architectures or []) + ): + logger.info("Radix cache is disabled for Whisper") + self.disable_radix_cache = True # Major NVIDIA platforms backends if ( @@ -3064,8 +3106,10 @@ def _handle_speculative_decoding(self): self.enable_mixed_chunk = False self.speculative_eagle_topk = self.speculative_ngram_max_bfs_breadth if self.speculative_num_draft_tokens is None: - self.speculative_num_draft_tokens = ( - self.speculative_ngram_max_match_window_size + self.speculative_num_draft_tokens = 12 + logger.warning( + "speculative_num_draft_tokens is set to 12 by default for ngram speculative decoding. " + "You can override this by explicitly setting --speculative-num-draft-tokens." ) logger.warning( "The overlap scheduler and mixed chunked prefill are disabled because of " @@ -3101,7 +3145,9 @@ def _handle_load_format(self): "Detected Mistral native format checkpoint, setting load_format='mistral'" ) - if is_remote_url(self.model_path): + if is_runai_obj_uri(self.model_path): + self.load_format = "runai_streamer" + elif is_remote_url(self.model_path): self.load_format = "remote" if self.custom_weight_loader is None: @@ -3154,22 +3200,42 @@ def _handle_load_format(self): def _is_mistral_native_format(self) -> bool: """Detect if the model uses Mistral native format (params.json + consolidated weights). - Models like Mistral-7B-Instruct-v0.3 have BOTH params.json (native) and - config.json (HF standard). When both exist, prefer the HF format to avoid - parameter name mismatches between consolidated.safetensors (native names - like layers.0.attention.wk.weight) and HuggingFace model classes (names - like model.layers.0.self_attn.k_proj.weight). + When both params.json and config.json exist, default to HF format to + avoid weight-name mismatches (e.g. Mistral-7B-Instruct-v0.3). + + Exception: models routed through ``_load_mistral_large_3_for_causal_LM`` + (mistral-large-3, mistral-small-4, leanstral) build their config from + params.json and expect native weight names, so native format is required + even when config.json is also present. """ + # Keep in sync with the name checks in + # hf_transformers_utils.py::get_config / get_tokenizer. + _MISTRAL_NATIVE_CONFIG_PATTERNS = ( + "mistral-large-3", + "mistral-small-4", + "leanstral", + ) + + def _check_format(has_params: bool, has_hf_config: bool) -> bool: + if has_params and not has_hf_config: + return True + if has_params and has_hf_config: + model_lower = str(self.model_path).lower() + if any(name in model_lower for name in _MISTRAL_NATIVE_CONFIG_PATTERNS): + return True + return False + if os.path.isdir(self.model_path): has_params = os.path.exists(os.path.join(self.model_path, "params.json")) has_hf_config = os.path.exists(os.path.join(self.model_path, "config.json")) - return has_params and not has_hf_config + return _check_format(has_params, has_hf_config) + # For hub models, check remote files try: from huggingface_hub import HfApi files = {s.rfilename for s in HfApi().model_info(self.model_path).siblings} - return "params.json" in files and "config.json" not in files + return _check_format("params.json" in files, "config.json" in files) except Exception: return False @@ -3189,6 +3255,17 @@ def _handle_pd_disaggregation(self): "Cuda graph is disabled for prefill server when piecewise cuda graph is not enabled." ) + if self.disaggregation_mode in ("prefill", "decode"): + if ( + envs.SGLANG_DISAGG_STAGING_BUFFER.get() + and self.disaggregation_transfer_backend != "mooncake" + ): + raise ValueError( + f"SGLANG_DISAGG_STAGING_BUFFER requires " + f"disaggregation_transfer_backend='mooncake', " + f"got '{self.disaggregation_transfer_backend}'." + ) + def _handle_encoder_disaggregation(self): if self.enable_prefix_mm_cache and not self.encoder_only: raise ValueError( @@ -4218,6 +4295,11 @@ def add_cli_args(parser: argparse.ArgumentParser): action="store_true", help="Enable log prometheus metrics.", ) + parser.add_argument( + "--enable-mfu-metrics", + action="store_true", + help="Enable estimated MFU-related prometheus metrics.", + ) parser.add_argument( "--enable-metrics-for-all-schedulers", action="store_true", @@ -4557,6 +4639,14 @@ def add_cli_args(parser: argparse.ArgumentParser): choices=[16, 32, 64, 128], help="Maximum chunk size for the ChunkedSGMV LoRA backend. Only used when --lora-backend is 'csgmv'. Choosing a larger value might improve performance.", ) + parser.add_argument( + "--experts-shared-outer-loras", + default=ServerArgs.experts_shared_outer_loras, + action="store_true", + help="Force shared outer LoRA mode for MoE models. " + "When set, w1/w3 lora_A and w2 lora_B are shared across experts " + "(expert_dim=1). By default this is auto-detected from adapter weights.", + ) # Kernel backend parser.add_argument( @@ -4637,9 +4727,7 @@ def add_cli_args(parser: argparse.ArgumentParser): "'flashinfer_deepgemm' (Hopper SM90 only; uses swapAB optimization for small M dimensions in decoding), " "'cutlass' (optimal for Hopper/Blackwell GPUs and high-throughput), " "'triton' (fallback, widely compatible), " - "'aiter' (ROCm only). " - "NOTE: This replaces the deprecated environment variables " - "SGLANG_ENABLE_FLASHINFER_FP8_GEMM and SGLANG_SUPPORT_CUTLASS_BLOCK_FP8.", + "'aiter' (ROCm only). ", ) parser.add_argument( "--fp4-gemm-backend", @@ -4649,11 +4737,10 @@ def add_cli_args(parser: argparse.ArgumentParser): dest="fp4_gemm_runner_backend", help="Choose the runner backend for NVFP4 GEMM operations. " "Options: 'auto' (default; selects flashinfer_cudnn on SM120, flashinfer_cutlass otherwise), " - "'flashinfer_cutlass' (CUTLASS backend), " + "'cutlass' (SGLang CUTLASS kernel), " + "'flashinfer_cutlass' (FlashInfer CUTLASS backend), " "'flashinfer_cudnn' (FlashInfer cuDNN backend, optimal on CUDA 13+ with cuDNN 9.15+), " - "'flashinfer_trtllm' (FlashInfer TensorRT-LLM backend, requires different weight preparation with shuffling). " - "NOTE: This replaces the deprecated environment variable " - "SGLANG_FLASHINFER_FP4_GEMM_BACKEND.", + "'flashinfer_trtllm' (FlashInfer TensorRT-LLM backend, requires different weight preparation with shuffling). ", ) parser.add_argument( "--disable-flashinfer-autotune", @@ -4764,18 +4851,6 @@ def add_cli_args(parser: argparse.ArgumentParser): ) # Speculative decoding (ngram) - parser.add_argument( - "--speculative-ngram-min-match-window-size", - type=int, - default=ServerArgs.speculative_ngram_min_match_window_size, - help="The minimum window size for pattern matching in ngram speculative decoding.", - ) - parser.add_argument( - "--speculative-ngram-max-match-window-size", - type=int, - default=ServerArgs.speculative_ngram_max_match_window_size, - help="The maximum window size for pattern matching in ngram speculative decoding.", - ) parser.add_argument( "--speculative-ngram-min-bfs-breadth", type=int, @@ -4850,6 +4925,11 @@ def add_cli_args(parser: argparse.ArgumentParser): action="store_true", help="Enable FlashInfer allreduce fusion with Residual RMSNorm.", ) + parser.add_argument( + "--enforce-disable-flashinfer-allreduce-fusion", + action="store_true", + help="Enforce disable FlashInfer allreduce fusion.", + ) parser.add_argument( "--enable-aiter-allreduce-fusion", action="store_true", @@ -5073,11 +5153,6 @@ def add_cli_args(parser: argparse.ArgumentParser): default=ServerArgs.hicache_mem_layout, help="The layout of host memory pool for hierarchical cache.", ) - parser.add_argument( - "--disable-hicache-numa-detect", - action="store_true", - help="Disable binding the process to the NUMA node closest to the active CUDA device when hierarchical cache is enabled.", - ) parser.add_argument( "--hicache-storage-backend", type=str, @@ -5546,7 +5621,7 @@ def add_cli_args(parser: argparse.ArgumentParser): "--numa-node", type=int, nargs="+", - help="Sets the numa node for the subprocesses. i-th element corresponds to i-th subprocess.", + help="Sets the numa node for the subprocesses. i-th element corresponds to i-th subprocess. If unset, will be automatically detected on NUMA systems.", ) parser.add_argument( "--enable-deterministic-inference", @@ -5605,6 +5680,12 @@ def add_cli_args(parser: argparse.ArgumentParser): action="store_true", help="Enable fused moe triton and sum all reduce.", ) + parser.add_argument( + "--gc-threshold", + type=int, + nargs="+", + help="Set the garbage collection thresholds (the collection frequency). Accepts 1 to 3 integers.", + ) # Dynamic batch tokenizer parser.add_argument( @@ -5773,6 +5854,13 @@ def add_cli_args(parser: argparse.ArgumentParser): action="store_true", help="Start seed server via transfer engine backend for remote instance weight loader.", ) + parser.add_argument( + "--engine-info-bootstrap-port", + type=int, + default=ServerArgs.engine_info_bootstrap_port, + help="Port for the engine info bootstrap server. Default is 6789. " + "Must be set explicitly when running multiple instances on the same node.", + ) parser.add_argument( "--modelexpress-config", type=str, @@ -5807,18 +5895,6 @@ def add_cli_args(parser: argparse.ArgumentParser): ) # For Multi-Modal - parser.add_argument( - "--mm-max-concurrent-calls", - type=int, - default=ServerArgs.mm_max_concurrent_calls, - help="The max concurrent calls for async mm data processing.", - ) - parser.add_argument( - "--mm-per-request-timeout", - type=int, - default=ServerArgs.mm_per_request_timeout, - help="The timeout for each multi-modal request in seconds.", - ) parser.add_argument( "--enable-broadcast-mm-inputs-process", action="store_true", @@ -5892,7 +5968,7 @@ def from_cli_args(cls, args: argparse.Namespace): attrs = [attr.name for attr in dataclasses.fields(cls)] return cls(**{attr: getattr(args, attr) for attr in attrs}) - def url(self): + def url(self, port: Optional[int] = None): scheme = "https" if self.ssl_certfile else "http" # When binding to all interfaces, use loopback for internal requests. host = self.host @@ -5900,7 +5976,13 @@ def url(self): host = "127.0.0.1" elif host == "::": host = "::1" - return NetworkAddress(host, self.port).to_url(scheme) + return NetworkAddress(host, port if port is not None else self.port).to_url( + scheme + ) + + @property + def engine_info_bootstrap_url(self): + return self.url(port=self.engine_info_bootstrap_port) def ssl_verify(self): """Return the value for the requests library's ``verify=`` parameter. @@ -5998,11 +6080,12 @@ def check_server_args(self): }, "moe_dense_tp_size only support 1 and None currently" # Check served model name to not have colon as it is reserved for LoRA adapter syntax - assert ":" not in self.served_model_name, ( - "served_model_name cannot contain a colon (':') character. " - "The colon is reserved for the 'model:adapter' syntax used in LoRA adapter specification. " - f"Invalid value: '{self.served_model_name}'" - ) + if not is_runai_obj_uri(self.served_model_name): + assert ":" not in self.served_model_name, ( + "served_model_name cannot contain a colon (':') character. " + "The colon is reserved for the 'model:adapter' syntax used in LoRA adapter specification. " + f"Invalid value: '{self.served_model_name}'" + ) # Check LoRA self.check_lora_server_args() @@ -6092,6 +6175,17 @@ def check_server_args(self): assert ( self.disable_radix_cache ), "Hierarchical sparse attention currently requires --disable-radix-cache." + for attr, label in [ + ("nsa_prefill_backend", "prefill"), + ("nsa_decode_backend", "decode"), + ]: + backend = getattr(self, attr) + if backend is not None and backend != "flashmla_sparse": + raise ValueError( + f"HiSparse requires flashmla_sparse NSA {label} backend, " + f"but got --nsa-{label}-backend={backend}. " + f"Please use --nsa-{label}-backend=flashmla_sparse or omit it." + ) assert ( self.schedule_conservativeness >= 0 @@ -6121,6 +6215,12 @@ def check_server_args(self): "When enabling two batch overlap, moe_a2a_backend cannot be 'none'." ) + if self.gc_threshold: + if not (1 <= len(self.gc_threshold) <= 3): + raise ValueError( + "When setting gc_threshold, it must contain 1 to 3 integers." + ) + def check_lora_server_args(self): assert self.max_loras_per_batch > 0, "max_loras_per_batch must be positive" diff --git a/python/sglang/srt/speculative/cpp_ngram/ngram_corpus.py b/python/sglang/srt/speculative/cpp_ngram/ngram_corpus.py index e44a3da6b2ec..0eb6bd71cbf7 100644 --- a/python/sglang/srt/speculative/cpp_ngram/ngram_corpus.py +++ b/python/sglang/srt/speculative/cpp_ngram/ngram_corpus.py @@ -1,64 +1,48 @@ # -*- coding: utf-8 -*- import logging -import os from typing import List, Tuple import numpy as np -from torch.utils.cpp_extension import load -logger = logging.getLogger(__name__) +from sglang.jit_kernel.ngram_corpus import get_ngram_corpus_cls -_abs_path = os.path.dirname(os.path.abspath(__file__)) -ngram_corpus_cpp = load( - name="ngram_corpus_cpp", - sources=[ - f"{_abs_path}/ngram_corpus_binding.cpp", - f"{_abs_path}/ngram.cpp", - f"{_abs_path}/trie.cpp", - f"{_abs_path}/result.cpp", - ], - extra_cflags=["-O3", "-std=c++20"], -) +logger = logging.getLogger(__name__) class NgramCorpus: def __init__( self, max_trie_depth=18, - min_match_window_size=1, - max_match_window_size=10, min_bfs_breadth=1, max_bfs_breadth=8, draft_token_num=8, match_type="BFS", capacity=1000000, - ): - param = ngram_corpus_cpp.Param() - param.max_trie_depth = max_trie_depth - param.min_match_window_size = min_match_window_size - param.max_match_window_size = max_match_window_size - param.min_bfs_breadth = min_bfs_breadth - param.max_bfs_breadth = max_bfs_breadth - param.draft_token_num = draft_token_num - param.match_type = match_type - self._ngram = ngram_corpus_cpp.Ngram(capacity, param) - + ) -> None: + cls = get_ngram_corpus_cls() + self._obj = cls( + capacity=capacity, + max_trie_depth=max_trie_depth, + min_bfs_breadth=min_bfs_breadth, + max_bfs_breadth=max_bfs_breadth, + draft_token_num=draft_token_num, + match_type=match_type, + ) self.default_mask = np.ones((1, 1), dtype=np.int64) self.draft_token_num = draft_token_num def batch_put(self, batch_tokens: List[List[int]]): - self._ngram.asyncInsert(batch_tokens) + self._obj.insert(batch_tokens) def synchronize(self): - self._ngram.synchronize() + self._obj.synchronize() # type: ignore def reset(self): - self._ngram.reset() + self._obj.reset() # type: ignore def batch_get(self, batch_tokens: List[List[int]]) -> Tuple[np.ndarray, np.ndarray]: - result = self._ngram.batchMatch(batch_tokens) - return np.array(result.token), np.array(result.mask) + return self._obj.match(batch_tokens) def leaf_paths_from_mask( self, tokens: List[int], tree_mask: List[List[int]] diff --git a/python/sglang/srt/speculative/cpp_ngram/ngram_corpus_binding.cpp b/python/sglang/srt/speculative/cpp_ngram/ngram_corpus_binding.cpp deleted file mode 100644 index 8da395440293..000000000000 --- a/python/sglang/srt/speculative/cpp_ngram/ngram_corpus_binding.cpp +++ /dev/null @@ -1,43 +0,0 @@ -#include -#include - -#include "ngram.h" - -PYBIND11_MODULE(ngram_corpus_cpp, m) { - using namespace ngram; - namespace py = pybind11; - m.doc() = ""; - - py::class_(m, "Ngram") - .def(py::init(), py::arg("capacity"), py::arg("param")) - .def("asyncInsert", &Ngram::asyncInsert, "") - .def("batchMatch", &Ngram::batchMatch, "") - .def("reset", &Ngram::reset, "") - .def("synchronize", &Ngram::synchronize, ""); - - py::class_(m, "Param") - .def(py::init<>()) - .def_readwrite("enable", &Param::enable) - .def_readwrite("enable_router_mode", &Param::enable_router_mode) - .def_readwrite("min_bfs_breadth", &Param::min_bfs_breadth) - .def_readwrite("max_bfs_breadth", &Param::max_bfs_breadth) - .def_readwrite("min_match_window_size", &Param::min_match_window_size) - .def_readwrite("max_match_window_size", &Param::max_match_window_size) - .def_readwrite("max_trie_depth", &Param::max_trie_depth) - .def_readwrite("draft_token_num", &Param::draft_token_num) - .def_readwrite("match_type", &Param::match_type) - .def_readwrite("batch_min_match_window_size", &Param::batch_min_match_window_size) - .def_readwrite("batch_draft_token_num", &Param::batch_draft_token_num) - .def("get_draft_token_num", &Param::get_draft_token_num, "") - .def("get_min_match_window_size", &Param::get_min_match_window_size, "") - .def("parse", &Param::parse, "") - .def("resetBatchMinMatchWindowSize", &Param::resetBatchMinMatchWindowSize, "") - .def("resetBatchReturnTokenNum", &Param::resetBatchReturnTokenNum, "") - .def("detail", &Param::detail, ""); - - py::class_(m, "Result") - .def(py::init<>()) - .def_readwrite("token", &Result::token) - .def_readwrite("mask", &Result::mask) - .def("truncate", &Result::truncate); -} diff --git a/python/sglang/srt/speculative/eagle_info.py b/python/sglang/srt/speculative/eagle_info.py index 572a76f6a140..dbb91f555ecf 100644 --- a/python/sglang/srt/speculative/eagle_info.py +++ b/python/sglang/srt/speculative/eagle_info.py @@ -284,15 +284,15 @@ def verify( or sampling_info.logit_bias is not None ): # This is a relaxed version of penalties for speculative decoding. - linear_penalty = torch.zeros( - (bs, logits_output.next_token_logits.shape[1]), - dtype=torch.float32, - device=batch.device, - ) - sampling_info.apply_logits_bias(linear_penalty) - logits_output.next_token_logits.add_( - torch.repeat_interleave(linear_penalty, self.draft_token_num, dim=0) + sampling_info.penalizer_orchestrator.apply( + logits_output.next_token_logits, repeat=self.draft_token_num ) + if sampling_info.logit_bias is not None: + logits_output.next_token_logits.add_( + torch.repeat_interleave( + sampling_info.logit_bias, self.draft_token_num, dim=0 + ) + ) # Apply grammar mask if vocab_mask is not None: diff --git a/python/sglang/srt/speculative/eagle_info_v2.py b/python/sglang/srt/speculative/eagle_info_v2.py index 1224fbd33e36..1f348f9b1e08 100644 --- a/python/sglang/srt/speculative/eagle_info_v2.py +++ b/python/sglang/srt/speculative/eagle_info_v2.py @@ -234,14 +234,12 @@ def prepare_for_v2_verify( # Set mamba_track_indices for mamba prefix-cache state tracking if get_global_server_args().enable_mamba_extra_buffer(): - batch.mamba_track_indices = torch.tensor( + batch.mamba_track_indices = torch.stack( [ req.mamba_ping_pong_track_buffer[req.mamba_next_track_idx] for req in batch.reqs - ], - dtype=torch.int64, - device=device, - ) + ] + ).to(torch.int64) batch.mamba_track_mask = None batch.mamba_track_seqlens = None diff --git a/python/sglang/srt/speculative/eagle_worker_v2.py b/python/sglang/srt/speculative/eagle_worker_v2.py index 248e7015c8ae..0ed93e198881 100644 --- a/python/sglang/srt/speculative/eagle_worker_v2.py +++ b/python/sglang/srt/speculative/eagle_worker_v2.py @@ -12,9 +12,9 @@ from sglang.srt.hardware_backend.npu.graph_runner.eagle_draft_npu_graph_runner import ( EAGLEDraftNpuGraphRunner, ) -from sglang.srt.layers.attention.triton_backend import TritonMultiStepDraftBackend +from sglang.srt.layers.attention.triton_backend import TritonAttnBackend from sglang.srt.layers.attention.trtllm_mla_backend import ( - TRTLLMMLAMultiStepDraftBackend, + TRTLLMMLABackend, ) from sglang.srt.layers.dp_attention import get_attention_tp_group from sglang.srt.layers.logits_processor import LogitsProcessorOutput @@ -294,8 +294,8 @@ def init_cuda_graphs(self): ) supports_cuda_draft_extend_graph = _is_cuda and ( - isinstance(self.draft_attn_backend, TritonMultiStepDraftBackend) - or isinstance(self.draft_attn_backend, TRTLLMMLAMultiStepDraftBackend) + isinstance(self.draft_extend_attn_backend, TritonAttnBackend) + or isinstance(self.draft_extend_attn_backend, TRTLLMMLABackend) ) # Capture extend # TODO: support draft extend cuda graph for more attention backends diff --git a/python/sglang/srt/speculative/ngram_info.py b/python/sglang/srt/speculative/ngram_info.py index c18cf79658d7..7aafe9870769 100644 --- a/python/sglang/srt/speculative/ngram_info.py +++ b/python/sglang/srt/speculative/ngram_info.py @@ -398,17 +398,20 @@ def verify( ) # Apply penalty - if sampling_info.penalizer_orchestrator.is_required: + if ( + sampling_info.penalizer_orchestrator.is_required + or sampling_info.logit_bias is not None + ): # This is a relaxed version of penalties for speculative decoding. - linear_penalty = torch.zeros( - (bs, logits_output.next_token_logits.shape[1]), - dtype=torch.float32, - device=self.device, - ) - sampling_info.apply_logits_bias(linear_penalty) - logits_output.next_token_logits.add_( - torch.repeat_interleave(linear_penalty, self.draft_token_num, dim=0) + sampling_info.penalizer_orchestrator.apply( + logits_output.next_token_logits, repeat=self.draft_token_num ) + if sampling_info.logit_bias is not None: + logits_output.next_token_logits.add_( + torch.repeat_interleave( + sampling_info.logit_bias, self.draft_token_num, dim=0 + ) + ) # Apply grammar mask if vocab_mask is not None: diff --git a/python/sglang/srt/speculative/ngram_worker.py b/python/sglang/srt/speculative/ngram_worker.py index 04a38cefbb83..8c108915c939 100644 --- a/python/sglang/srt/speculative/ngram_worker.py +++ b/python/sglang/srt/speculative/ngram_worker.py @@ -41,9 +41,6 @@ def __init__( self.page_size = server_args.page_size self.draft_token_num: int = server_args.speculative_num_draft_tokens self.max_trie_depth: int = server_args.speculative_ngram_max_trie_depth - self.max_match_window_size: int = ( - server_args.speculative_ngram_max_match_window_size - ) self.max_batch_size = target_worker.max_running_requests self.device = f"cuda:{gpu_id}" if gpu_id >= 0 else "cuda" @@ -51,8 +48,6 @@ def __init__( self._init_preallocated_tensors() self.ngram_corpus = NgramCorpus( - min_match_window_size=server_args.speculative_ngram_min_match_window_size, - max_match_window_size=server_args.speculative_ngram_max_match_window_size, min_bfs_breadth=server_args.speculative_ngram_min_bfs_breadth, max_bfs_breadth=server_args.speculative_ngram_max_bfs_breadth, match_type=server_args.speculative_ngram_match_type, @@ -131,7 +126,7 @@ def _prepare_draft_tokens( batch_tokens = [] for req in batch.reqs: check_token = self._efficient_concat_last_n( - req.origin_input_ids, req.output_ids, self.max_match_window_size + req.origin_input_ids, req.output_ids, self.max_trie_depth ) batch_tokens.append(check_token) req_drafts, mask = self.ngram_corpus.batch_get(batch_tokens) diff --git a/python/sglang/srt/utils/common.py b/python/sglang/srt/utils/common.py index a325822ed8fe..7d87210369ad 100644 --- a/python/sglang/srt/utils/common.py +++ b/python/sglang/srt/utils/common.py @@ -20,6 +20,7 @@ import builtins import ctypes import functools +import gc import importlib import inspect import io @@ -86,6 +87,7 @@ from torch import nn from torch.library import Library from torch.utils._contextlib import _DecoratorContextManager +from torchvision.io import decode_jpeg from typing_extensions import Literal from sglang.srt.environ import envs @@ -125,7 +127,7 @@ def is_hip() -> bool: @lru_cache(maxsize=1) def is_cuda(): - return torch.cuda.is_available() and torch.version.cuda + return torch.cuda.is_available() and torch.version.cuda is not None @lru_cache(maxsize=1) @@ -363,7 +365,7 @@ def get_int_env_var(name: str, default: int = 0) -> int: def support_triton(backend: str) -> bool: - return backend not in ["torch_native", "intel_amx"] + return backend not in ["torch_native", "intel_amx", "ascend"] _ENABLE_TORCH_INFERENCE_MODE = get_bool_env_var( @@ -763,64 +765,109 @@ class ImageData: max_dynamic_patch: Optional[int] = None +image_extension_names = (".png", ".jpg", ".jpeg", ".webp", ".gif") + + +def is_jpeg_with_cuda(image_bytes: bytes = b"", gpu_image_decode: bool = True) -> bool: + """ + Check three conditions: + 1. whether CUDA is available. + 2. whether input is recognized as JPEG. + 3. whether GPU image decode is enabled (some models such as CPM forcibly disable this). + """ + if not is_cuda() or not gpu_image_decode: + return False + if image_bytes != b"": + return image_bytes.startswith(b"\xff\xd8") and image_bytes.endswith(b"\xff\xd9") + return False + + +def _load_image( + image_bytes: bytes = b"", + image_file: str = "", + gpu_image_decode: bool = True, +) -> Union[torch.Tensor, Image.Image]: + """ + Try to decode JPEG with nvJPEG on GPU and return a torch device tensor, + otherwise fallback to decode with PIL on CPU and return a PIL Image. + Keep the fallback path since nvJPEG may fail on some JPEG images that are not strictly compliant with the standard, while PIL is more tolerant. + """ + if image_file != "": + image_bytes = get_image_bytes(image_file) + if is_jpeg_with_cuda(image_bytes, gpu_image_decode): + try: + encoded_image = torch.frombuffer(image_bytes, dtype=torch.uint8) + image_tensor = decode_jpeg(encoded_image, device="cuda") + return image_tensor + except Exception as e: + logger.warning( + f"Failed to decode JPEG on GPU, falling back to CPU. Error: {e}" + ) + return Image.open(BytesIO(image_bytes)) + + def load_image( image_file: Union[Image.Image, str, ImageData, bytes], -) -> tuple[Image.Image, tuple[int, int]]: + gpu_image_decode: bool = True, +) -> tuple[Union[torch.Tensor, Image.Image], Optional[tuple[int, int]]]: + """ + Load image from multiple input formats, including: + ImageData, PIL Image, bytes, URL, file path, or base64 string. + """ if isinstance(image_file, ImageData): image_file = image_file.url - image = image_size = None + image = None + image_size: Optional[tuple[int, int]] = None if isinstance(image_file, Image.Image): image = image_file image_size = (image.width, image.height) elif isinstance(image_file, bytes): - image = Image.open(BytesIO(image_file)) - elif image_file.startswith("http://") or image_file.startswith("https://"): - timeout = int(os.getenv("REQUEST_TIMEOUT", "3")) - response = requests.get(image_file, stream=True, timeout=timeout) - try: - response.raise_for_status() - image = Image.open(response.raw) - image.load() # Force loading to avoid issues after closing the stream - finally: - response.close() - elif image_file.startswith("file://"): - image_file = unquote(urlparse(image_file).path) - image = Image.open(image_file) - elif image_file.lower().endswith(("png", "jpg", "jpeg", "webp", "gif")): - image = Image.open(image_file) - elif image_file.startswith("data:"): - image_file = image_file.split(",")[1] - image = Image.open(BytesIO(pybase64.b64decode(image_file, validate=True))) - elif isinstance(image_file, str): - image = Image.open(BytesIO(pybase64.b64decode(image_file, validate=True))) + image = _load_image(image_bytes=image_file, gpu_image_decode=gpu_image_decode) + elif isinstance(image_file, str) and image_file.startswith(("http://", "https://")): + image = _load_image(image_file=image_file, gpu_image_decode=gpu_image_decode) + elif isinstance(image_file, str) and image_file.startswith("file://"): + image = _load_image( + image_file=unquote(urlparse(image_file).path), + gpu_image_decode=gpu_image_decode, + ) + elif isinstance(image_file, str) and image_file.lower().endswith( + image_extension_names + ): + image = _load_image(image_file=image_file, gpu_image_decode=gpu_image_decode) + elif isinstance(image_file, str) and image_file.startswith("data:"): + image = _load_image(image_file=image_file, gpu_image_decode=gpu_image_decode) + elif isinstance( + image_file, str + ): # Other formats, try to decode as base64 by default + image = _load_image(image_file=image_file, gpu_image_decode=gpu_image_decode) else: raise ValueError(f"Invalid image: {image_file}") - return image, image_size -def get_image_bytes(image_file: Union[str, bytes]): +def get_image_bytes(image_file: Union[str, bytes]) -> bytes: + """Normalize various image inputs into raw bytes.""" if isinstance(image_file, bytes): return image_file - elif image_file.startswith("http://") or image_file.startswith("https://"): + if image_file.startswith(("http://", "https://")): timeout = int(os.getenv("REQUEST_TIMEOUT", "3")) response = requests.get(image_file, timeout=timeout) - return response.content - elif image_file.startswith("file://"): - image_file = unquote(urlparse(image_file).path) - with open(image_file, "rb") as f: - return f.read() - elif image_file.lower().endswith(("png", "jpg", "jpeg", "webp", "gif")): + try: + response.raise_for_status() + result = response.content + finally: + response.close() + return result + if image_file.startswith(("file://", "/")): with open(image_file, "rb") as f: return f.read() - elif image_file.startswith("data:"): - image_file = image_file.split(",")[1] + if isinstance(image_file, str) and image_file.startswith("data:"): + _, encoded = image_file.split(",", 1) + return pybase64.b64decode(encoded, validate=True) + if isinstance(image_file, str): return pybase64.b64decode(image_file, validate=True) - elif isinstance(image_file, str): - return pybase64.b64decode(image_file, validate=True) - else: - raise NotImplementedError(f"Invalid image: {image_file}") + raise NotImplementedError(f"Invalid image: {image_file}") def _normalize_video_input( @@ -976,7 +1023,7 @@ def check_pkg_version_at_least(pkg: str, min_version: str) -> bool: Args: pkg: Package name (distribution name, e.g., "flashinfer-python") - min_version: Minimum version required (e.g., "0.6.6") + min_version: Minimum version required (e.g., "0.6.7") Returns: True if package is installed and version >= min_version, False otherwise @@ -2379,6 +2426,8 @@ def launch_dummy_health_check_server(host, port, enable_metrics): import uvicorn from fastapi import FastAPI, Response + from sglang.srt.utils.network import NetworkAddress + app = FastAPI() @app.get("/ping") @@ -2421,14 +2470,16 @@ def run_server(): logger.error(f"Dummy health check server failed to start: {e}") raise finally: - logger.info(f"Dummy health check server stopped at {host}:{port}") + logger.info( + f"Dummy health check server stopped at {NetworkAddress(host, port).to_host_port_str()}" + ) thread = threading.Thread( target=run_server, daemon=True, name="health-check-server" ) thread.start() logger.info( - f"Dummy health check server started in background thread at {host}:{port}" + f"Dummy health check server started in background thread at {NetworkAddress(host, port).to_host_port_str()}" ) @@ -2945,8 +2996,6 @@ def gc_callback(phase, info): def freeze_gc(context: str): - import gc - g0_before, g1_before, g2_before = gc_object_counts() gc.freeze() g0_after, g1_after, g2_after = gc_object_counts() @@ -2961,8 +3010,6 @@ def freeze_gc(context: str): def configure_gc_logger(): logger.info("Enable GC Logger") - import gc - gc_start_time = {} def gc_callback(phase, info): @@ -3352,6 +3399,41 @@ def is_triton_kernels_available() -> bool: return importlib.util.find_spec("triton_kernels") is not None +@lru_cache(maxsize=1) +def get_nvidia_driver_version() -> tuple: + """Return the NVIDIA driver version as a tuple of ints, e.g. (595, 58, 3). + Returns (0,) on failure.""" + version_str = get_nvidia_driver_version_str() + if version_str is None: + return (0,) + try: + return tuple(int(x) for x in version_str.split(".")) + except ValueError: + return (0,) + + +@lru_cache(maxsize=1) +def get_nvidia_driver_version_str() -> str: + """Return the NVIDIA driver version string, e.g. '595.58.03'. + Returns None on failure.""" + try: + result = subprocess.run( + [ + "nvidia-smi", + "--query-gpu=driver_version", + "--format=csv,noheader,nounits", + ], + capture_output=True, + text=True, + check=True, + timeout=10, + ) + version_str = result.stdout.strip().split("\n")[0].strip() + return version_str if version_str else None + except (subprocess.CalledProcessError, FileNotFoundError, ValueError): + return None + + def check_cuda_result(raw_output): import cuda.bindings.runtime as cuda_rt @@ -3408,30 +3490,6 @@ def get_device_sm_nvidia_smi(): return (0, 0) # Default/fallback value -def get_libnuma(): - libnuma = None - - for libnuma_so in ["libnuma.so", "libnuma.so.1"]: - try: - libnuma = ctypes.CDLL(libnuma_so) - except OSError as e: - logger.error(f"{e}") - libnuma = None - if libnuma is not None: - break - return libnuma - - -def numa_bind_to_node(node: int): - libnuma = get_libnuma() - - if libnuma is None or libnuma.numa_available() < 0: - logger.error("numa not available on this system, skip bind action") - else: - libnuma.numa_run_on_node(ctypes.c_int(node)) - libnuma.numa_set_preferred(ctypes.c_int(node)) - - def json_list_type(value): try: return orjson.loads(value) @@ -3729,140 +3787,3 @@ def get_or_create_event_loop(): loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) return loop - - -def get_numa_node_count() -> int: - """ - Get the number of NUMA nodes available on the system. - Must be called after is_numa_available() is True. - Returns: - int: The number of NUMA nodes. - """ - libnuma = get_libnuma() - return libnuma.numa_max_node() + 1 - - -def is_numa_available() -> bool: - try: - libnuma = get_libnuma() - return libnuma.numa_available() >= 0 - except Exception: - return False - - -def get_system_nvgpu_count() -> int: - """ - Get the total number of GPUs in the system (not affected by CUDA_VISIBLE_DEVICES). - - Returns: - int: The total number of physical GPUs. - """ - result = subprocess.run( - ["nvidia-smi", "--list-gpus"], - capture_output=True, - text=True, - check=True, - ) - gpu_lines = [ - line - for line in result.stdout.strip().split("\n") - if line.strip().startswith("GPU") - ] - return len(gpu_lines) - - -@lru_cache(maxsize=1) -def get_device_numa_node_cuda(gpu_id: int = 0) -> int: - """ - Retrieve the NUMA node ID of the CPU socket closest to the gpu_id. - - First tries to query nvidia-smi topology. If it returns a single NUMA ID, uses that directly. - If it returns multiple NUMA IDs (comma/dash separated), falls back to distributing GPUs - evenly across NUMA nodes based on GPU ID intervals. - - For example, with 8 GPUs and 2 NUMA nodes: GPUs 0-3 -> node 0, GPUs 4-7 -> node 1. - - Returns: - int: The NUMA node ID (e.g., 0, 1). - - Raises: - RuntimeError: If device information cannot be retrieved. - """ - - physical_device_id = get_physical_device_id(gpu_id) - - # Query NUMA topology from nvidia-smi - result = subprocess.run( - ["nvidia-smi", "topo", "-C", "-i", str(physical_device_id)], - capture_output=True, - text=True, - check=True, - ) - - output_line = result.stdout.strip() - prefix = "NUMA IDs of closest CPU:" - - if output_line.startswith(prefix): - numa_id_str = output_line[len(prefix) :].strip() - if numa_id_str.isdigit(): - return int(numa_id_str) - - # Fall back: distribute GPUs evenly across NUMA nodes - numa_count = get_numa_node_count() - gpu_count = get_system_nvgpu_count() - - if gpu_count >= numa_count: - gpus_per_numa = gpu_count // numa_count # >= 1 - numa_node = physical_device_id // gpus_per_numa # 0 ~ numa_count - 1 - else: - logger.warning( - f"GPU count {gpu_count} is less than NUMA count {numa_count}. Using first NUMA node." - ) - numa_node = 0 - - return numa_node - - -def get_numa_node(gpu_id): - numa_node = None - try: - device = get_device() - if device == "cuda": - numa_node = get_device_numa_node_cuda(gpu_id) - else: - logger.info(f"Now only supports NVIDIA devices") - except Exception as e: - logger.error(f"Error: {e}") - - return numa_node - - -@lru_cache(maxsize=1) -def get_current_device_numa_node_cuda() -> int: - """ - Retrieve the NUMA node ID of the CPU socket closest to the currently active CUDA device. - """ - - logical_device_id = torch.cuda.current_device() - numa_node = get_device_numa_node_cuda(logical_device_id) - - return numa_node - - -def nvgpu_available() -> bool: - if not torch.cuda.is_available(): - return False - if torch.version.cuda is None: - return False - return True - - -def bind_to_closest_numa_node_cuda(): - """ - Bind the current process to the NUMA node closest to the active CUDA device. - - Uses `numa` library calls via ctypes to set the CPU affinity of the process. - """ - if is_numa_available() and nvgpu_available(): - node_id = get_current_device_numa_node_cuda() - numa_bind_to_node(node_id) diff --git a/python/sglang/srt/utils/cuda_ipc_transport_utils.py b/python/sglang/srt/utils/cuda_ipc_transport_utils.py index 6d76242aef32..9c4abf014c5e 100644 --- a/python/sglang/srt/utils/cuda_ipc_transport_utils.py +++ b/python/sglang/srt/utils/cuda_ipc_transport_utils.py @@ -3,7 +3,7 @@ import threading import time from multiprocessing import shared_memory -from typing import Tuple +from typing import Any, Tuple import numpy as np import torch @@ -22,6 +22,49 @@ SHM_LOCK_FILE = "/tmp/shm_wr_lock.lock" +# Cache for pool-level IPC handles on the consumer side. +# Key: the pool CUDA IPC handle tuple. Value: opened UntypedStorage. +_pool_storage_cache: dict = {} +_pool_cache_lock = threading.Lock() + + +def _normalize_pool_cache_key(pool_handle, pool_device_index: int) -> tuple[Any, ...]: + normalized_handle = ( + pool_handle if isinstance(pool_handle, tuple) else tuple(pool_handle) + ) + return (pool_device_index, normalized_handle) + + +def _open_pooled_storage_uncached(pool_handle): + return torch.UntypedStorage._new_shared_cuda(*pool_handle) + + +def _pool_handle_cache_get_or_open(cache_key, pool_handle): + storage = _pool_storage_cache.get(cache_key) + if storage is None: + with _pool_cache_lock: + storage = _pool_storage_cache.get(cache_key) + if storage is None: + storage = _open_pooled_storage_uncached(pool_handle) + _pool_storage_cache[cache_key] = storage + return storage + + +def _pool_handle_cache_set(cache_key, storage): + with _pool_cache_lock: + _pool_storage_cache[cache_key] = storage + + +def _pool_handle_cache_invalidate(cache_key): + with _pool_cache_lock: + _pool_storage_cache.pop(cache_key, None) + + +def _pool_handle_cache_clear(): + with _pool_cache_lock: + _pool_storage_cache.clear() + + class ShmSyncBuffer: def __init__(self, byte_size: int = 4): self.buffer = shared_memory.SharedMemory(create=True, size=byte_size) @@ -80,6 +123,9 @@ def __init__(self, memory_size, recycle_interval): self.memory_pool = torch.empty( memory_size, dtype=torch.int8, device="cuda" ).contiguous() + storage = self.memory_pool.untyped_storage() + self._pool_ipc_handle = storage._share_cuda_() + self._pool_device_index = self.memory_pool.device.index self.sync_flag_list = [] @@ -181,8 +227,9 @@ def return_a_slice_tensor_with_flag(self, src_tensor: torch.Tensor): return ( available_chunk.sync_flag.meta_data, self.memory_pool[available_chunk.start : available_chunk.end], + available_chunk.start, ) - return None, None + return None, None, None def recycle_chunks(self): @@ -229,6 +276,9 @@ def __init__( data: torch.Tensor, info_data: torch.Tensor, sync_buffer_meta, + pool_ipc_handle=None, + pool_byte_offset: int = 0, + pool_device_index: int = 0, ): if (not isinstance(data, torch.Tensor)) or ( @@ -238,7 +288,24 @@ def __init__( f"Input 'data' must be a torch.Tensor, but got {type(data)}" ) - self.proxy_state = self.get_proxy_state(data, info_data) + if pool_ipc_handle is not None: + self.proxy_state = { + "ipc_extra": { + "pool_handle": pool_ipc_handle, + "pool_byte_offset": pool_byte_offset, + "pool_device_index": pool_device_index, + "shape": data.shape, + "dtype": data.dtype, + "stride": data.stride(), + "storage_offset": 0, + "nbytes": data.numel() * data.element_size(), + "recons_shape": info_data.shape, + "recons_dtype": info_data.dtype, + }, + "tensor_data": None, + } + else: + self.proxy_state = self.get_proxy_state(data, info_data) self.reconstruct_tensor = None self.sync_data_meta = sync_buffer_meta self.sync_buffer = None @@ -283,6 +350,62 @@ def get_proxy_state(self, data, info_data): return state + def _reconstruct_from_ipc_extra(self, ipc_extra, *, use_cache: bool): + shape = ipc_extra["shape"] + dtype = ipc_extra["dtype"] + stride = ipc_extra["stride"] + target_device = torch.device(f"cuda:{ipc_extra['pool_device_index']}") + cache_key = _normalize_pool_cache_key( + ipc_extra["pool_handle"], ipc_extra["pool_device_index"] + ) + + with torch.cuda.device(target_device): + if use_cache: + storage = _pool_handle_cache_get_or_open( + cache_key, ipc_extra["pool_handle"] + ) + storage_to_cache = None + else: + storage = _open_pooled_storage_uncached(ipc_extra["pool_handle"]) + storage_to_cache = storage + slice_storage = storage[ + ipc_extra["pool_byte_offset"] : ipc_extra["pool_byte_offset"] + + ipc_extra["nbytes"] + ] + slice_tensor = torch.empty(0, dtype=dtype, device=target_device).set_( + slice_storage, + storage_offset=ipc_extra["storage_offset"], + size=shape, + stride=stride, + ) + + return slice_tensor, target_device, cache_key, storage_to_cache + + def _copy_slice_tensor_to_target( + self, + slice_tensor: torch.Tensor, + rebuild_device: torch.device, + recons_shape, + recons_dtype, + ): + with torch.cuda.device(rebuild_device): + reconstructed_tensor = torch.empty( + recons_shape, dtype=recons_dtype, device=rebuild_device + ).contiguous() + reconstructed_tensor.view(torch.int8).view(-1).copy_(slice_tensor) + + open(SHM_LOCK_FILE, "a").close() + # write the shm_sync_buffer with a file lock + with open(SHM_LOCK_FILE, "w+") as f: + fcntl.flock(f, fcntl.LOCK_EX) + sync_flag = self.get_sync_flag + sync_flag += 1 + fcntl.flock(f, fcntl.LOCK_UN) + + self.close_shm() + + return reconstructed_tensor + def reconstruct_on_target_device(self, rebuild_device_idx): rebuild_device = torch.device(f"cuda:{rebuild_device_idx}") if ( @@ -293,52 +416,58 @@ def reconstruct_on_target_device(self, rebuild_device_idx): if self.proxy_state["ipc_extra"]: ipc_extra = self.proxy_state["ipc_extra"] - ( - handle, - shape, - dtype, - stride, - source_device_index, - s_offset, - recons_shape, - recons_dtype, - ) = ( - ipc_extra["handle"], - ipc_extra["shape"], - ipc_extra["dtype"], - ipc_extra["stride"], - ipc_extra["device_index"], - ipc_extra["storage_offset"], - ipc_extra["recons_shape"], - ipc_extra["recons_dtype"], + recons_shape = ipc_extra["recons_shape"] + recons_dtype = ipc_extra["recons_dtype"] + + if "pool_handle" in ipc_extra: + try: + ( + slice_tensor, + _target_device, + cache_key, + storage_to_cache, + ) = self._reconstruct_from_ipc_extra(ipc_extra, use_cache=True) + except Exception as e: + cache_key = _normalize_pool_cache_key( + ipc_extra["pool_handle"], ipc_extra["pool_device_index"] + ) + logger.info( + "Failed to deserialize from cached pooled CUDA IPC handle (%s). " + "Invalidating cache entry and retrying uncached.", + e, + ) + _pool_handle_cache_invalidate(cache_key) + ( + slice_tensor, + _target_device, + _cache_key, + storage_to_cache, + ) = self._reconstruct_from_ipc_extra(ipc_extra, use_cache=False) + if storage_to_cache is not None: + _pool_handle_cache_set(cache_key, storage_to_cache) + else: + # Non-pooled path: open handle directly (original behavior) + try: + storage = torch.UntypedStorage._new_shared_cuda( + *ipc_extra["handle"] + ) + target_device = torch.device(f"cuda:{ipc_extra['device_index']}") + with torch.cuda.device(target_device): + slice_tensor = torch.empty( + 0, dtype=ipc_extra["dtype"], device=target_device + ).set_( + storage, + storage_offset=ipc_extra["storage_offset"], + size=ipc_extra["shape"], + stride=ipc_extra["stride"], + ) + except Exception as e: + logger.info("Failed to deserialize from CUDA IPC handle (%s).", e) + raise + + reconstructed_tensor = self._copy_slice_tensor_to_target( + slice_tensor, rebuild_device, recons_shape, recons_dtype ) - - try: - target_device = torch.device(f"cuda:{source_device_index}") - with torch.cuda.device(target_device): - storage = torch.UntypedStorage._new_shared_cuda(*handle) - slice_tensor = torch.empty( - 0, dtype=dtype, device=target_device - ).set_(storage, storage_offset=s_offset, size=shape, stride=stride) - - reconstructed_tensor = torch.empty( - recons_shape, dtype=recons_dtype, device=rebuild_device - ).contiguous() - reconstructed_tensor.view(torch.int8).view(-1).copy_(slice_tensor) - - open(SHM_LOCK_FILE, "a").close() - # write the shm_sync_buffer with a file lock - with open(SHM_LOCK_FILE, "w+") as f: - fcntl.flock(f, fcntl.LOCK_EX) - sync_flag = self.get_sync_flag - sync_flag += 1 - fcntl.flock(f, fcntl.LOCK_UN) - - self.close_shm() - - except Exception as e: - logger.info(f"Error: Failed to deserialize from CUDA IPC handle ({e}).") - raise e elif isinstance(self.proxy_state["tensor_data"], torch.Tensor): reconstructed_tensor = self.proxy_state["tensor_data"].to( rebuild_device, non_blocking=True diff --git a/python/sglang/srt/utils/hf_transformers_utils.py b/python/sglang/srt/utils/hf_transformers_utils.py index 318b2e18649f..b1b1631bb9ee 100644 --- a/python/sglang/srt/utils/hf_transformers_utils.py +++ b/python/sglang/srt/utils/hf_transformers_utils.py @@ -27,6 +27,7 @@ from huggingface_hub import snapshot_download from sglang.srt.utils import get_bool_env_var +from sglang.srt.utils.runai_utils import ObjectStorageModel, is_runai_obj_uri # Compatibility shim: flash-attn-4 registers a bare ``flash_attn`` namespace # that makes ``is_flash_attn_2_available()`` return True, but lacks the v2 API @@ -488,6 +489,9 @@ def get_config( kwargs["gguf_file"] = model model = Path(model).parent + if is_runai_obj_uri(model): + model = ObjectStorageModel.get_path(model) + if is_remote_url(model): # BaseConnector implements __del__() to clean up the local dir. # Since config files need to exist all the time, so we DO NOT use @@ -712,6 +716,57 @@ def filter(self, record: logging.LogRecord) -> bool: return "Calling super().encode with" not in record.getMessage() +_is_base_mistral_patched = False + +# transformers version where _patch_mistral_regex calls model_info() on every tokenizer load +_TRANSFORMERS_PATCHED_VERSION = "5.3.0" + + +def _patch_is_base_mistral_in_ci(): + """Patch transformers' _patch_mistral_regex to avoid HF API calls in CI. + + transformers defines is_base_mistral as a local function inside + _patch_mistral_regex, so it cannot be patched via module attribute. + Instead we replace the entire _patch_mistral_regex classmethod with a + version that simply returns the tokenizer unchanged. + + In CI this prevents exhausting the 3000 req/5min HF API rate limit. + """ + global _is_base_mistral_patched + if _is_base_mistral_patched: + return + + from sglang.srt.environ import envs + + if not envs.SGLANG_IS_IN_CI.get(): + return + + import transformers + + if transformers.__version__ != _TRANSFORMERS_PATCHED_VERSION: + logger.warning( + "transformers version changed to %s (expected %s), " + "_patch_mistral_regex patch skipped — may need update if 429 errors recur", + transformers.__version__, + _TRANSFORMERS_PATCHED_VERSION, + ) + _is_base_mistral_patched = True # don't warn repeatedly + return + + from transformers import PreTrainedTokenizerFast + + if hasattr(PreTrainedTokenizerFast, "_patch_mistral_regex"): + + @classmethod + def _noop_patch_mistral_regex(cls, tokenizer, *args, **kwargs): + return tokenizer + + PreTrainedTokenizerFast._patch_mistral_regex = _noop_patch_mistral_regex + logger.info("CI: patched _patch_mistral_regex to skip HF API calls") + + _is_base_mistral_patched = True + + def get_tokenizer( tokenizer_name: str, *args, @@ -747,6 +802,9 @@ def get_tokenizer( kwargs["gguf_file"] = tokenizer_name tokenizer_name = Path(tokenizer_name).parent + if is_runai_obj_uri(tokenizer_name): + tokenizer_name = ObjectStorageModel.get_path(tokenizer_name) + if is_remote_url(tokenizer_name): # BaseConnector implements __del__() to clean up the local dir. # Since config files need to exist all the time, so we DO NOT use @@ -755,6 +813,8 @@ def get_tokenizer( client.pull_files(ignore_pattern=["*.pt", "*.safetensors", "*.bin"]) tokenizer_name = client.get_local_dir() + _patch_is_base_mistral_in_ci() + try: tokenizer = AutoTokenizer.from_pretrained( tokenizer_name, diff --git a/python/sglang/srt/utils/numa_utils.py b/python/sglang/srt/utils/numa_utils.py index 2c934af0d567..cabba3bc499c 100644 --- a/python/sglang/srt/utils/numa_utils.py +++ b/python/sglang/srt/utils/numa_utils.py @@ -1,26 +1,30 @@ +import ctypes +import glob import logging +import math import multiprocessing import os import random +import shutil import time from contextlib import contextmanager from pathlib import Path +from typing import Optional + +import psutil from sglang.srt.environ import envs from sglang.srt.server_args import ServerArgs -from sglang.srt.utils import get_numa_node +from sglang.srt.utils import is_cuda + +_is_cuda = is_cuda() logger = logging.getLogger(__name__) @contextmanager def configure_subprocess(server_args: ServerArgs, gpu_id: int): - numa_node = None - if (numa_nodes := server_args.numa_node) is not None: - numa_node = numa_nodes[gpu_id] - elif envs.SGLANG_AUTO_NUMA_BIND.get(): - numa_node = get_numa_node(gpu_id) - + numa_node = get_numa_node_if_available(server_args, gpu_id) if numa_node is not None and envs.SGLANG_NUMA_BIND_V2.get(): numactl_args = f"--cpunodebind={numa_node} --membind={numa_node}" executable, debug_str = _create_numactl_executable(numactl_args=numactl_args) @@ -58,3 +62,157 @@ def _mp_set_executable(executable: str, debug_str: str): ), f"{multiprocessing.spawn.get_executable()=}" multiprocessing.spawn.set_executable(old_executable) logger.info(f"mp.set_executable revert to {old_executable}") + + +def get_numa_node_if_available(server_args: ServerArgs, gpu_id: int) -> Optional[int]: + """ + Returns the NUMA node for the given GPU id. If it is not set in the server_args, it will try to query the NUMA node for the GPU. + If the NUMA node is not available, has already been configured externally, or the user lacks permission to set NUMA affinity, it will return None. + + Args: + server_args: The server arguments. + gpu_id: The GPU id. + + Returns: + The NUMA node for the given GPU id or None if it is not available. + """ + if server_args.numa_node is not None: + return server_args.numa_node[gpu_id] + if _is_numa_available(): + queried_numa_node = _query_numa_node_for_gpu(gpu_id) + if len(queried_numa_node) == 0: + return None + if len(queried_numa_node) > 1: + # get_numa_node_for_gpu could return multiple nodes, we use the first one for now. + # I don't think there any hardware configs that would have more than one. + logger.warning( + f"Multiple NUMA nodes found for GPU {gpu_id}: {queried_numa_node}. Using the first one." + ) + return queried_numa_node[0] + return None + + +def get_libnuma(): + libnuma = None + + for libnuma_so in ["libnuma.so", "libnuma.so.1"]: + try: + libnuma = ctypes.CDLL(libnuma_so) + except OSError as e: + logger.debug(f"{e}") + libnuma = None + if libnuma is not None: + break + return libnuma + + +def numa_bind_to_node(node: int): + libnuma = get_libnuma() + + if libnuma is None or libnuma.numa_available() < 0: + logger.warning("numa not available on this system, skip bind action") + else: + libnuma.numa_run_on_node(ctypes.c_int(node)) + libnuma.numa_set_preferred(ctypes.c_int(node)) + + +def _can_set_mempolicy() -> bool: + """Check if the process has permission to use NUMA memory policy syscalls.""" + try: + libnuma = get_libnuma() + if libnuma is None or libnuma.numa_available() < 0: + return False + mode = ctypes.c_int() + ret = libnuma.get_mempolicy( + ctypes.byref(mode), None, ctypes.c_ulong(0), None, ctypes.c_ulong(0) + ) + return ret == 0 + except Exception: + return False + + +def _is_numa_available() -> bool: + """ + Check if NUMA is available and not already configured externally. + """ + if not _is_cuda: + return False + + # Check if this is a numa system. + if not os.path.isdir("/sys/devices/system/node/node1"): + return False + + # Check if affinity is already constrained + pid = os.getpid() + process = psutil.Process(pid) + cpu_affinity = process.cpu_affinity() + all_cpus = list(range(psutil.cpu_count())) + constrained_affinity = cpu_affinity != all_cpus + if constrained_affinity: + logger.warning( + "NUMA affinity is already constrained for process, skipping NUMA node configuration for GPU. Remove your constraints to allow automatic configuration." + ) + return False + + if not shutil.which("numactl") and envs.SGLANG_NUMA_BIND_V2.get(): + logger.debug( + "numactl command not found, skipping NUMA node configuration for GPU. Install numactl (e.g., apt-get install numactl) to enable automatic NUMA binding." + ) + return False + + if not _can_set_mempolicy(): + logger.warning( + "User lacks permission to set NUMA affinity, skipping NUMA node configuration for GPU. If using docker, try adding --cap-add SYS_NICE to your docker run command." + ) + return False + + return True + + +def _query_numa_node_for_gpu(device_id: int): + """ + Get the NUMA node affinity list for a GPU device. + + Args: + device_id: GPU device index. + Returns: + List of NUMA node IDs that have affinity with the device. + """ + try: + import pynvml + except ModuleNotFoundError: + logger.warning("pynvml not installed, skipping NUMA node configuration for GPU") + return [] + + try: + pynvml.nvmlInit() + + handle = pynvml.nvmlDeviceGetHandleByIndex(device_id) + numa_node_count = len(glob.glob("/sys/devices/system/node/node[0-9]*")) + + c_ulong_bits = ctypes.sizeof(ctypes.c_ulong) * 8 + node_set_size = max(1, math.ceil(numa_node_count / c_ulong_bits)) + node_set = pynvml.nvmlDeviceGetMemoryAffinity( + handle, + node_set_size, + pynvml.NVML_AFFINITY_SCOPE_NODE, + ) + + # Decode the bitmask into a list of NUMA node IDs + numa_nodes = [] + for node_id in range(numa_node_count): + mask_array_index = node_id // c_ulong_bits + mask_bit_index = node_id % c_ulong_bits + if node_set[mask_array_index] & (1 << mask_bit_index): + numa_nodes.append(node_id) + return numa_nodes + except pynvml.NVMLError as e: + logger.warning( + f"NVML error querying memory affinity for GPU {device_id}: {e}, skipping NUMA node configuration for GPU" + ) + return [] + finally: + try: + pynvml.nvmlShutdown() + except Exception: + pass # Ignore shutdown errors diff --git a/python/sglang/srt/utils/request_logger.py b/python/sglang/srt/utils/request_logger.py index e20a19a91625..4ce4585ad072 100644 --- a/python/sglang/srt/utils/request_logger.py +++ b/python/sglang/srt/utils/request_logger.py @@ -162,7 +162,6 @@ def log_finished_request( self, obj: Union["GenerateReqInput", "EmbeddingReqInput"], out: Any, - is_multimodal_gen: bool = False, request: Optional["fastapi.Request"] = None, ) -> None: if not self.log_requests: @@ -181,20 +180,15 @@ def log_finished_request( } if headers: log_data["headers"] = headers - if not is_multimodal_gen: - log_data["out"] = _transform_data_for_logging( - out, max_length, out_skip_names - ) + log_data["out"] = _transform_data_for_logging( + out, max_length, out_skip_names + ) log_json(self.targets, "request.finished", log_data) else: obj_str = _dataclass_to_string_truncated( obj, max_length, skip_names=skip_names ) - out_str = ( - "" - if is_multimodal_gen - else f", out={_dataclass_to_string_truncated(out, max_length, skip_names=out_skip_names)}" - ) + out_str = f", out={_dataclass_to_string_truncated(out, max_length, skip_names=out_skip_names)}" headers_str = f", headers={headers}" if headers else "" self._log(f"Finish: obj={obj_str}{headers_str}{out_str}") diff --git a/python/sglang/srt/utils/runai_utils.py b/python/sglang/srt/utils/runai_utils.py new file mode 100644 index 000000000000..0424a6371bde --- /dev/null +++ b/python/sglang/srt/utils/runai_utils.py @@ -0,0 +1,134 @@ +# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/model_loader/runai_utils.py + +import hashlib +import logging +import os +from pathlib import Path + +logger = logging.getLogger(__name__) + +SUPPORTED_SCHEMES = ["s3://", "gs://", "az://"] + +# Design Pattern: Single Metadata Download Before Process Launch + +# 1. Engine entrypoint (engine.py) or server arguments post init (server_args.py): +# - Downloads config/tokenizer metadata ONCE before launching subprocesses +# - This happens in the main process, avoiding multi-process coordination +# +# 2. ModelConfig/HF Utils (model_config.py, hf_transformers_utils.py): +# - Use ObjectStorageModel.get_path() to retrieve the cached local path +# - NO re-download - just path resolution +# +# 3. RunaiModelStreamerLoader (loader.py): +# - Calls list_safetensors() which operates directly on the object storage URI +# - Streams weights lazily during model loading + +# This avoids file locks, race conditions, and duplicate downloads + + +def get_cache_dir() -> str: + # Expand user path (~) to ensure absolute paths for locking + path = os.getenv("SGLANG_CACHE_DIR", "~/.cache/sglang/") + return os.path.expanduser(path) + + +def list_safetensors(path: str = "") -> list[str]: + """ + List full file names from object path and filter by allow pattern. + + Args: + path: The object storage path to list from. + + Returns: + list[str]: List of full object storage paths allowed by the pattern + """ + from runai_model_streamer import list_safetensors as runai_list_safetensors + + return runai_list_safetensors(path) + + +def is_runai_obj_uri(model_or_path: str | Path) -> bool: + # Cast to str to handle pathlib.Path inputs which lack string methods (like .lower) + return str(model_or_path).lower().startswith(tuple(SUPPORTED_SCHEMES)) + + +class ObjectStorageModel: + """ + Model loader that uses Runai Model Streamer to load a model. + + Supports object storage (S3, GCS) with lazy weight streaming. + + Configuration (via load_config.model_loader_extra_config): + - distributed (bool): Enable distributed streaming + - concurrency (int): Number of concurrent downloads + - memory_limit (int): Memory limit for streaming buffer + + Note: Metadata files must be pre-downloaded via + ObjectStorageModel.download_and_get_path() before instantiation. + + Attributes: + dir: The temporary created directory. + """ + + def __init__(self, url: str) -> None: + self.dir = ObjectStorageModel.get_path(url) + + from runai_model_streamer import ObjectStorageModel as RunaiObjectStorageModel + + self._runai_obj = RunaiObjectStorageModel(model_path=url, dst=self.dir) + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + return self._runai_obj.__exit__(exc_type, exc_val, exc_tb) + + def pull_files( + self, + allow_pattern: list[str] | None = None, + ignore_pattern: list[str] | None = None, + ) -> None: + """Pull files from object storage into the local cache directory. + + Args: + allow_pattern: File patterns to include (e.g. ["*.json"]). + ignore_pattern: File patterns to exclude. + """ + self._runai_obj.pull_files(allow_pattern, ignore_pattern) + + @classmethod + def download_and_get_path(cls, model_path: str) -> str: + """ + Downloads the model metadata (excluding heavy weights) and returns + the local directory path. Safe for concurrent usage by multiple processes + """ + with cls(url=model_path) as downloader: + downloader.pull_files( + ignore_pattern=[ + "*.pt", + "*.safetensors", + "*.bin", + "*.tensors", + "*.pth", + ], + ) + cache_dir = downloader.dir + logger.info(f"Runai Model : {cache_dir}, metadata ready.") + return cache_dir + + @classmethod + def get_path(cls, model_path: str) -> str: + """ + Returns the local directory path. + """ + model_hash = hashlib.sha256(str(model_path).encode()).hexdigest()[:16] + base_dir = get_cache_dir() + + # Ensure base cache dir exists + os.makedirs(os.path.join(base_dir, "model_streamer"), exist_ok=True) + + return os.path.join( + base_dir, + "model_streamer", + model_hash, + ) diff --git a/python/sglang/srt/utils/watchdog.py b/python/sglang/srt/utils/watchdog.py index 651243d5d122..092ff887f1cd 100644 --- a/python/sglang/srt/utils/watchdog.py +++ b/python/sglang/srt/utils/watchdog.py @@ -1,12 +1,14 @@ from __future__ import annotations import logging +import os import signal import sys import threading import time from contextlib import contextmanager -from typing import Callable, Optional +from multiprocessing import Process +from typing import Callable, List, Optional import psutil @@ -159,3 +161,63 @@ def _watchdog_once(self): # Wait for some time so that the parent process can print the error. time.sleep(5) self.parent_process.send_signal(signal.SIGQUIT) + + +class SubprocessWatchdog: + """Monitors subprocess liveness and triggers SIGQUIT when a crash is detected. + + When a subprocess crashes (e.g., NCCL timeout causing C++ std::terminate()), + Python exception handlers never run, leaving the main process as a zombie + service. This watchdog polls subprocess liveness in a daemon thread and + sends SIGQUIT to trigger proper cleanup. + + See: https://github.com/sgl-project/sglang/issues/18421 + """ + + def __init__( + self, + processes: List[Process], + process_names: Optional[List[str]] = None, + interval: float = 1.0, + ): + self._processes = processes + self._names = process_names or [f"process_{i}" for i in range(len(processes))] + self._interval = interval + self._stop_event = threading.Event() + self._thread: Optional[threading.Thread] = None + + def start(self) -> None: + if self._thread is not None or not self._processes: + return + self._thread = threading.Thread( + target=self._monitor_loop, daemon=True, name="subprocess-watchdog" + ) + self._thread.start() + + def stop(self) -> None: + self._stop_event.set() + if self._thread is not None: + self._thread.join(timeout=self._interval * 2) + self._thread = None + + def _monitor_loop(self) -> None: + try: + while not self._stop_event.wait(self._interval): + if self._check_processes(): + return + except Exception as e: + logger.error(f"SubprocessWatchdog thread crashed: {e}", exc_info=True) + + def _check_processes(self) -> bool: + for proc, name in zip(self._processes, self._names): + if proc.is_alive() or proc.exitcode == 0: + continue + + logger.error( + f"Subprocess {name} (pid={proc.pid}) crashed " + f"with exit code {proc.exitcode}. " + f"Triggering SIGQUIT for cleanup..." + ) + os.kill(os.getpid(), signal.SIGQUIT) + return True + return False diff --git a/python/sglang/test/accuracy_test_runner.py b/python/sglang/test/accuracy_test_runner.py index 0cf007220701..cc780622c197 100644 --- a/python/sglang/test/accuracy_test_runner.py +++ b/python/sglang/test/accuracy_test_runner.py @@ -150,22 +150,127 @@ def _run_simple_eval( kill_process_tree(process.pid) -def _run_few_shot_eval( +# Cached uv venv for NeMo Skills (persists across variants within a process). +_nemo_venv_dir: Optional[str] = None +_nemo_data_prepared: set = set() + + +def _get_nemo_venv() -> Tuple[str, dict]: + """Get or create a uv venv with nemo_skills installed. + + Returns (venv_python_path, env_dict) reusable across calls. + """ + import os + import subprocess + import tempfile + + global _nemo_venv_dir + + if _nemo_venv_dir is not None: + venv_python = f"{_nemo_venv_dir}/venv/bin/python" + env = { + **dict(os.environ), + "NEMO_SKILLS_DISABLE_UNCOMMITTED_CHANGES_CHECK": "1", + "OPENAI_API_KEY": "dummy", + "VIRTUAL_ENV": f"{_nemo_venv_dir}/venv", + "PATH": f"{_nemo_venv_dir}/venv/bin:" + os.environ.get("PATH", ""), + } + return venv_python, env + + _nemo_venv_dir = tempfile.mkdtemp(prefix="nemo_skills_") + print(f"Creating NeMo Skills venv in {_nemo_venv_dir}...") + + # Create venv + result = subprocess.run( + ["uv", "venv", f"{_nemo_venv_dir}/venv", "--python", "3.12"], + capture_output=True, + text=True, + ) + if result.returncode != 0: + subprocess.run( + ["uv", "venv", f"{_nemo_venv_dir}/venv"], + capture_output=True, + text=True, + ) + + # Install nemo_skills + print("Installing nemo_skills...") + pip_result = subprocess.run( + [ + "uv", + "pip", + "install", + "--python", + f"{_nemo_venv_dir}/venv/bin/python", + "git+https://github.com/NVIDIA/NeMo-Skills.git", + ], + capture_output=True, + text=True, + timeout=300, + ) + if pip_result.returncode != 0: + raise RuntimeError(f"Failed to install nemo_skills: {pip_result.stderr[-500:]}") + + print("NeMo Skills installed successfully") + return _get_nemo_venv() + + +def _ensure_nemo_data_prepared( + venv_python: str, env: dict, dataset: str +) -> Tuple[bool, Optional[str]]: + """Prepare NeMo Skills dataset data if not already done. + + Uses the venv python so data lands inside the venv's nemo_skills package. + """ + import subprocess + + if dataset in _nemo_data_prepared: + return True, None + + print(f"Preparing {dataset} data (this may take a few minutes for VLM datasets)...") + result = subprocess.run( + [venv_python, "-m", "nemo_skills.dataset.prepare", dataset], + text=True, + timeout=600, + env=env, + ) + if result.returncode != 0: + return False, f"Failed to prepare {dataset} data (exit {result.returncode})" + + _nemo_data_prepared.add(dataset) + return True, None + + +def _run_nemo_skills_eval( model: ModelLaunchSettings, base_url: str, - num_questions: Optional[int] = None, - num_shots: int = 8, - max_tokens: int = 512, + dataset: str, + max_tokens: Optional[int] = None, + repeat: Optional[int] = None, + temperature: Optional[float] = None, + top_p: Optional[float] = None, ) -> Tuple[bool, Optional[str], Optional[dict]]: - """Run evaluation using few_shot backend (few_shot_gsm8k.py). + """Run evaluation using NeMo Skills (ns eval) for benchmarks like mmmu-pro. + + Uses an isolated uv venv (shared across variants) so nemo_skills dependencies + don't interfere with the system python / sglang server. Returns: Tuple of (success, error_message, metrics_dict) """ - from sglang.test.few_shot_gsm8k import run_eval as run_few_shot_eval + import subprocess + import tempfile process = None try: + # Get or create the shared venv (once per process) + venv_python, env = _get_nemo_venv() + + # Prepare dataset (once per process, cached) + ok, err = _ensure_nemo_data_prepared(venv_python, env, dataset) + if not ok: + return False, err, None + process = popen_launch_server( model.model_path, base_url, @@ -174,27 +279,154 @@ def _run_few_shot_eval( env=model.env, ) - args = SimpleNamespace( - num_shots=num_shots, - data_path=None, - num_questions=num_questions or 200, - max_new_tokens=max_tokens, - parallel=128, - host="http://127.0.0.1", - port=int(base_url.split(":")[-1]), - ) - - metrics = run_few_shot_eval(args) + port = int(base_url.split(":")[-1]) + server_address = f"http://127.0.0.1:{port}/v1" + repeat_val = repeat or 1 + max_tokens_val = max_tokens or 32768 + benchmark_spec = f"{dataset}:{repeat_val}" + + # Build ns eval command using venv python + # Note: nemo_skills.pipeline.eval requires the "eval" subcommand + output_dir = tempfile.mkdtemp(prefix="ns_eval_output_") + cmd = [ + venv_python, + "-m", + "nemo_skills.pipeline.eval", + "eval", + f"--benchmarks={benchmark_spec}", + "--server_type=sglang", + f"--model={model.model_path}", + f"--server_address={server_address}", + f"--output_dir={output_dir}", + f"++inference.tokens_to_generate={max_tokens_val}", + ] - # Normalize metrics format (few_shot returns "accuracy", simple_eval returns "score") - if "accuracy" in metrics and "score" not in metrics: - metrics["score"] = metrics["accuracy"] + if temperature is not None: + cmd.append(f"++inference.temperature={temperature}") + if top_p is not None: + cmd.append(f"++inference.top_p={top_p}") + + # Add VLM-specific config + if dataset in ("mmmu-pro", "mmmu_pro"): + cmd.append("++prompt_config=vlm/mmmu-pro") + cmd.append("++max_concurrent_requests=512") + cmd.append("++max_samples=500") + + print(f"Running: {' '.join(cmd)}") + eval_result = subprocess.run( + cmd, + capture_output=True, + text=True, + timeout=7200, + env=env, + ) - return True, None, metrics + print(eval_result.stdout[-2000:] if eval_result.stdout else "(no stdout)") + if eval_result.stderr: + print(eval_result.stderr[-1000:]) + + if eval_result.returncode != 0: + return ( + False, + f"ns eval failed (exit {eval_result.returncode}): {eval_result.stderr[-500:]}", + None, + ) + + # Parse results + summarize_result = subprocess.run( + [ + venv_python, + "-m", + "nemo_skills.pipeline.summarize_results", + f"{output_dir}/eval-results", + ], + capture_output=True, + text=True, + timeout=60, + env=env, + ) + output = summarize_result.stdout + "\n" + eval_result.stdout + print(f"Summary: {summarize_result.stdout[:1000]}") + + # Parse accuracy from output (format varies, look for common patterns) + import re + + score = None + for line in output.split("\n"): + match = re.search(r"(?:accuracy|score)[:\s]+([0-9.]+)", line, re.IGNORECASE) + if match: + score = float(match.group(1)) + + if score is None: + # Try to find it in eval-results directory + import glob + import json + + for result_file in glob.glob( + f"{output_dir}/eval-results/**/*.json", recursive=True + ): + try: + with open(result_file) as f: + data = json.load(f) + if isinstance(data, dict): + score = ( + data.get("accuracy") + or data.get("score") + or data.get("mean_score") + ) + if score is not None: + break + except (json.JSONDecodeError, KeyError): + continue + + if score is None: + # Last resort: compute accuracy directly from JSONL output + import glob + import json + + for jsonl_file in sorted( + glob.glob(f"{output_dir}/eval-results/**/*.jsonl*", recursive=True) + ): + correct = 0 + total = 0 + try: + with open(jsonl_file) as f: + for line in f: + line = line.strip() + if not line: + continue + entry = json.loads(line) + expected = entry.get("expected_answer", "") + generation = entry.get("generation", "") + # Extract "Answer: X" from the end of generation + answer_match = re.search( + r"Answer:\s*([A-J])", generation, re.IGNORECASE + ) + if answer_match: + predicted = answer_match.group(1).upper() + if predicted == expected.upper(): + correct += 1 + total += 1 + except (json.JSONDecodeError, KeyError, OSError): + continue + if total > 0: + score = correct / total + print( + f"Computed accuracy from {jsonl_file}: " + f"{correct}/{total} = {score:.4f}" + ) + break + + if score is None: + return False, "Could not parse accuracy from ns eval output", None + + return True, None, {"score": score} + + except subprocess.TimeoutExpired: + return False, "NeMo Skills eval timed out", None except Exception as e: - return False, f"Few-shot evaluation exception: {str(e)}", None - + return False, f"NeMo Skills eval exception: {str(e)}", None finally: if process: kill_process_tree(process.pid) @@ -224,18 +456,17 @@ def run_accuracy_test( print(f"{'='*60}\n") # Run evaluation based on dataset type - # Use few_shot_eval for gsm8k by default for backward compatibility. - # Use simple_eval when any extended params are set that few_shot_eval doesn't support. - has_extended_params = any( - getattr(params, field) is not None - for field in ("thinking_mode", "temperature", "top_p", "top_k", "repeat") - ) - if params.dataset == "gsm8k" and not has_extended_params: - success, error, metrics = _run_few_shot_eval( + # - NeMo Skills: mmmu-pro (and other VLM evals needing ns eval) + # - simple_eval: everything else (gsm8k, gpqa, mmlu, mmmu, etc.) + if params.dataset in ("mmmu-pro", "mmmu_pro"): + success, error, metrics = _run_nemo_skills_eval( model=model, base_url=base_url, - num_questions=params.num_examples, - max_tokens=params.max_tokens or 512, + dataset="mmmu-pro", + max_tokens=params.max_tokens, + repeat=params.repeat or 1, + temperature=params.temperature, + top_p=params.top_p, ) else: success, error, metrics = _run_simple_eval( diff --git a/python/sglang/test/ascend/test_ascend_utils.py b/python/sglang/test/ascend/test_ascend_utils.py index 43a078287793..681299a30da1 100644 --- a/python/sglang/test/ascend/test_ascend_utils.py +++ b/python/sglang/test/ascend/test_ascend_utils.py @@ -117,9 +117,18 @@ QWEN3_235B_A22B_W8A8_WEIGHTS_PATH = os.path.join( MODEL_WEIGHTS_DIR, "vllm-ascend/Qwen3-235B-A22B-W8A8" ) +QWEN3_30B_A3B_GPTQ_2507_INT4_WEIGHTS_PATH = os.path.join( + MODEL_WEIGHTS_DIR, "Qwen/Qwen3-30B-A3B-GPTQ-Int4" +) +QWEN3_30B_A3B_INSTRUCT_2507_INT4_AUTOROUND_WEIGHTS_PATH = os.path.join( + MODEL_WEIGHTS_DIR, "Intel/Qwen3-30B-A3B-Instruct-2507-int4-AutoRound" +) QWEN3_30B_A3B_INSTRUCT_2507_WEIGHTS_PATH = os.path.join( MODEL_WEIGHTS_DIR, "Qwen/Qwen3-30B-A3B-Instruct-2507" ) +QWEN3_8B_INT4_AUTOROUND_WEIGHTS_PATH = os.path.join( + MODEL_WEIGHTS_DIR, "Intel/Qwen3-8B-int4-AutoRound" +) QWEN3_8B_WEIGHTS_PATH = os.path.join(MODEL_WEIGHTS_DIR, "Qwen/Qwen3-8B") QWEN3_8B_EAGLE3_WEIGHTS_PATH = os.path.join(MODEL_WEIGHTS_DIR, "Qwen/Qwen3-8B_eagle3") QWEN3_32B_WEIGHTS_PATH = os.path.join(MODEL_WEIGHTS_DIR, "Qwen/Qwen3-32B") @@ -133,15 +142,6 @@ QWEN3_32B_W8A8_MINDIE_WEIGHTS_PATH = os.path.join( MODEL_WEIGHTS_DIR, "aleoyang/Qwen3-32B-w8a8-MindIE" ) -QWEN3_235B_A22B_W8A8_WEIGHTS_PATH = os.path.join( - MODEL_WEIGHTS_DIR, "vllm-ascend/Qwen3-235B-A22B-W8A8" -) -QWEN3_CODER_480B_A35B_INSTRUCT_W8A8_QUAROT_WEIGHTS_PATH = os.path.join( - MODEL_WEIGHTS_DIR, "Qwen3-Coder-480B-A35B-Instruct-w8a8-QuaRot" -) -QWEN3_NEXT_80B_A3B_INSTRUCT_WEIGHTS_PATH = os.path.join( - MODEL_WEIGHTS_DIR, "Qwen/Qwen3-Next-80B-A3B-Instruct" -) QWQ_32B_W8A8_WEIGHTS_PATH = os.path.join(MODEL_WEIGHTS_DIR, "vllm-ascend/QWQ-32B-W8A8") SMOLLM_1_7B_WEIGHTS_PATH = os.path.join(MODEL_WEIGHTS_DIR, "HuggingFaceTB/SmolLM-1.7B") STABLELM_2_1_6B_WEIGHTS_PATH = os.path.join( diff --git a/python/sglang/test/bench_one_batch_server_internal.py b/python/sglang/test/bench_one_batch_server_internal.py index 4585340da084..39e7ea4376ad 100644 --- a/python/sglang/test/bench_one_batch_server_internal.py +++ b/python/sglang/test/bench_one_batch_server_internal.py @@ -609,7 +609,7 @@ def run_one_case( last_gen_throughput = -1 acc_length = -1 else: - response = requests.get(url + "/get_server_info", timeout=DEFAULT_TIMEOUT) + response = requests.get(url + "/server_info", timeout=DEFAULT_TIMEOUT) response.raise_for_status() server_info = response.json() internal_state = server_info.get("internal_states", [{}]) @@ -793,7 +793,7 @@ def run_benchmark_internal( skip_max_running_requests_threshold = float("inf") else: model_name = None - response = requests.get(base_url + "/get_server_info", timeout=DEFAULT_TIMEOUT) + response = requests.get(base_url + "/server_info", timeout=DEFAULT_TIMEOUT) response.raise_for_status() server_info = response.json() if "tokenizer_path" in server_info: diff --git a/python/sglang/test/few_shot_gsm8k.py b/python/sglang/test/few_shot_gsm8k.py index 5d3992bf6bb4..e3631419b18d 100644 --- a/python/sglang/test/few_shot_gsm8k.py +++ b/python/sglang/test/few_shot_gsm8k.py @@ -1,6 +1,11 @@ """ Run few-shot GSM-8K evaluation. +.. deprecated:: + This module is deprecated. Use ``sglang.test.run_eval`` with + ``eval_name="gsm8k"`` instead, which routes through the unified + Chat API evaluation framework with dump_metric support. + Usage: python3 -m sglang.test.few_shot_gsm8k --num-questions 200 """ @@ -9,6 +14,7 @@ import ast import re import time +import warnings import numpy as np @@ -50,6 +56,12 @@ def get_answer_value(answer_str): def run_eval(args): + warnings.warn( + "sglang.test.few_shot_gsm8k is deprecated. " + "Use sglang.test.run_eval with eval_name='gsm8k' instead.", + DeprecationWarning, + stacklevel=2, + ) # Select backend set_default_backend(RuntimeEndpoint(normalize_base_url(args.host, args.port))) diff --git a/python/sglang/test/few_shot_gsm8k_engine.py b/python/sglang/test/few_shot_gsm8k_engine.py index 13a30be1c9ee..d06e15d35f78 100644 --- a/python/sglang/test/few_shot_gsm8k_engine.py +++ b/python/sglang/test/few_shot_gsm8k_engine.py @@ -1,8 +1,16 @@ +""" +.. deprecated:: + This module is deprecated. Use ``sglang.test.run_eval`` with + ``eval_name="gsm8k"`` instead, which routes through the unified + Chat API evaluation framework with dump_metric support. +""" + import argparse import ast import asyncio import re import time +import warnings from typing import Optional import numpy as np @@ -49,6 +57,12 @@ async def concurrent_generate(engine, prompts, sampling_param): def run_eval(args): + warnings.warn( + "sglang.test.few_shot_gsm8k_engine is deprecated. " + "Use sglang.test.run_eval with eval_name='gsm8k' instead.", + DeprecationWarning, + stacklevel=2, + ) # Select backend engine = sgl.Engine(model_path=args.model_path, log_level="error") diff --git a/python/sglang/test/kits/cache_hit_kit.py b/python/sglang/test/kits/cache_hit_kit.py index 81895eff07c7..5e1c9172c29e 100644 --- a/python/sglang/test/kits/cache_hit_kit.py +++ b/python/sglang/test/kits/cache_hit_kit.py @@ -221,7 +221,7 @@ async def _send_one(payload): def _get_page_size(base_url: str) -> int: """Query server for page_size used by radix cache.""" try: - resp = requests.get(f"{base_url}/get_server_info", timeout=10) + resp = requests.get(f"{base_url}/server_info", timeout=10) resp.raise_for_status() info = resp.json() return info.get("page_size", 1) diff --git a/python/sglang/test/kits/eval_accuracy_kit.py b/python/sglang/test/kits/eval_accuracy_kit.py index cab5d48601eb..25bf58151c9f 100644 --- a/python/sglang/test/kits/eval_accuracy_kit.py +++ b/python/sglang/test/kits/eval_accuracy_kit.py @@ -3,7 +3,6 @@ import requests -from sglang.test.few_shot_gsm8k import run_eval as run_eval_gsm8k from sglang.test.run_eval import run_eval from sglang.test.test_utils import is_in_amd_ci, is_in_ci, write_github_step_summary @@ -19,17 +18,20 @@ def _check_accept_length(test_case, base_url, threshold): class GSM8KMixin: - """Mixin for few-shot GSM8K evaluation. + """Mixin for GSM8K evaluation via OpenAI Chat API. Required attributes on the test class: base_url: str gsm8k_accuracy_thres: float + + Optional attributes: + model: str (if not set, auto-detected from server) """ gsm8k_accuracy_thres: float = _THRESHOLD_NOT_SET gsm8k_accept_length_thres: Optional[float] = None gsm8k_num_questions: int = 200 - gsm8k_parallel: int = 128 + gsm8k_num_threads: int = 128 def test_gsm8k(self): assert ( @@ -39,17 +41,21 @@ def test_gsm8k(self): requests.get(self.base_url + "/flush_cache") args = SimpleNamespace( - num_shots=5, - data_path=None, - num_questions=self.gsm8k_num_questions, - max_new_tokens=512, - parallel=self.gsm8k_parallel, - host="http://127.0.0.1", - port=int(self.base_url.split(":")[-1]), + base_url=self.base_url, + model=self.model, + eval_name="gsm8k", + api="completion", + max_tokens=512, + num_examples=self.gsm8k_num_questions, + num_threads=self.gsm8k_num_threads, ) - metrics = run_eval_gsm8k(args) + metrics = run_eval(args) print(f"{metrics=}") - self.assertGreaterEqual(metrics["accuracy"], self.gsm8k_accuracy_thres) + + if is_in_ci(): + write_github_step_summary(f"### test_gsm8k\n{metrics['score']=:.4f}\n") + + self.assertGreaterEqual(metrics["score"], self.gsm8k_accuracy_thres) if self.gsm8k_accept_length_thres is not None: _check_accept_length(self, self.base_url, self.gsm8k_accept_length_thres) diff --git a/python/sglang/test/kl_test_utils.py b/python/sglang/test/kl_test_utils.py index 116f0ad7ee40..b3c90caaec97 100644 --- a/python/sglang/test/kl_test_utils.py +++ b/python/sglang/test/kl_test_utils.py @@ -208,7 +208,7 @@ def test_input_output_logprobs_match_helper( def test_input_output_logprobs_match_prefill_cache_hit_helper( base_url, ACC_THRESHOLDS, model_name, max_samples=None, max_new_tokens=8192 ): - server_info = requests.get(base_url + "/get_server_info").json() + server_info = requests.get(base_url + "/server_info").json() if server_info["disable_radix_cache"]: print("Radix cache is disabled, skipping test") return @@ -261,7 +261,7 @@ def test_input_output_logprobs_match_prefill_cache_hit_helper( def test_input_output_logprobs_match_decode_cache_hit_helper( base_url, ACC_THRESHOLDS, model_name, max_samples=None, max_new_tokens=8192 ): - server_info = requests.get(base_url + "/get_server_info").json() + server_info = requests.get(base_url + "/server_info").json() if server_info["disable_radix_cache"]: print("Radix cache is disabled, skipping test") return diff --git a/python/sglang/test/lora_utils.py b/python/sglang/test/lora_utils.py index 6a9b05190002..9de8d1d6e300 100644 --- a/python/sglang/test/lora_utils.py +++ b/python/sglang/test/lora_utils.py @@ -379,6 +379,7 @@ def run_lora_test_one_by_one( disable_radix_cache: bool = False, mem_fraction_static: float = 0.88, test_tag: str = "", + attention_backend: Optional[str] = None, ): """ Input a batch of prompts, and run lora tests one by one with several generate requests @@ -428,6 +429,7 @@ def run_lora_test_one_by_one( disable_cuda_graph=disable_cuda_graph, disable_radix_cache=disable_radix_cache, mem_fraction_static=mem_fraction_static, + attention_backend=attention_backend, ) as srt_runner: srt_outputs = srt_runner.forward( prompts, max_new_tokens=max_new_tokens, lora_paths=adaptor_names @@ -439,6 +441,7 @@ def run_lora_test_one_by_one( model_type="generation", tp_size=model_case.tp_size, mem_fraction_static=mem_fraction_static, + attention_backend=attention_backend, ) as srt_runner: srt_no_lora_outputs = srt_runner.forward(prompts, max_new_tokens=max_new_tokens) @@ -765,8 +768,6 @@ def run_lora_multiple_batch_on_model_cases( else { "speculative_algorithm": "NGRAM", "speculative_num_draft_tokens": 5, - "speculative_ngram_min_match_window_size": 2, - "speculative_ngram_max_match_window_size": 15, } ) srt_runner = SRTRunner( diff --git a/python/sglang/test/nightly_utils.py b/python/sglang/test/nightly_utils.py index ac69fabb7010..2a9d01f2e8ef 100644 --- a/python/sglang/test/nightly_utils.py +++ b/python/sglang/test/nightly_utils.py @@ -324,7 +324,7 @@ def _get_spec_accept_length(self) -> Optional[float]: The average speculative decoding accept length, or None if not available. """ try: - response = requests.get(f"{self.base_url}/get_server_info", timeout=10) + response = requests.get(f"{self.base_url}/server_info", timeout=10) if response.status_code == 200: server_info = response.json() internal_states = server_info.get("internal_states", []) diff --git a/python/sglang/test/run_combined_tests.py b/python/sglang/test/run_combined_tests.py index c1ac82657d47..fa419b3ff542 100644 --- a/python/sglang/test/run_combined_tests.py +++ b/python/sglang/test/run_combined_tests.py @@ -104,6 +104,7 @@ def run_combined_tests( model_result = { "model": model.model_path, + "variant": model.variant, "perf_result": None, "accuracy_result": None, "tool_call_result": None, @@ -243,8 +244,9 @@ def run_combined_tests( failed_test_str = ", ".join(failed_tests) if failed_tests else "unknown" error_str = "; ".join(str(e) for e in r["errors"]) + variant_str = f" [{r['variant']}]" if r.get("variant") else "" failure_lines.append( - f" Model {i + 1} ({r['model']}): {failed_test_str} - {error_str}" + f" Model {i + 1} ({r['model']}{variant_str}): {failed_test_str} - {error_str}" ) failure_summary = "\n".join(failure_lines) diff --git a/python/sglang/test/run_eval.py b/python/sglang/test/run_eval.py index 11ba1dc5cf00..d872966e7497 100644 --- a/python/sglang/test/run_eval.py +++ b/python/sglang/test/run_eval.py @@ -10,6 +10,7 @@ from sglang.test.simple_eval_common import ( ChatCompletionSampler, + CompletionSampler, Eval, make_report, set_ulimit, @@ -19,10 +20,10 @@ def get_thinking_kwargs(args): thinking_mode = getattr(args, "thinking_mode", None) if thinking_mode in THINKING_MODE_CHOICES: - if thinking_mode == "deepseek-v3": + if thinking_mode in ["deepseek-v3", "kimi-k2"]: thinking_param = "thinking" else: - # Qwen3 + # All models other than dpsk v3/kimi_k2 thinking_param = "enable_thinking" return {thinking_param: True} return {} @@ -60,16 +61,29 @@ def run_eval_once(args, base_url: str, eval_obj: Eval) -> dict: if value is not None: extra_body[param_name] = value - sampler = ChatCompletionSampler( - model=args.model, + common_kwargs = dict( + model=getattr(args, "model", None), max_tokens=getattr(args, "max_tokens", 2048), top_p=getattr(args, "top_p", 1.0), base_url=base_url, temperature=getattr(args, "temperature", 0.0), - reasoning_effort=getattr(args, "reasoning_effort", None), - extra_body=extra_body if extra_body else None, ) + api_mode = getattr(args, "api", "chat") + if api_mode == "completion": + # Default stop tokens for completion API (matches few_shot_gsm8k behavior) + stop = getattr(args, "stop", ["Question", "Assistant:", "<|separator|>"]) + sampler = CompletionSampler( + **common_kwargs, + stop=stop, + ) + else: + sampler = ChatCompletionSampler( + **common_kwargs, + reasoning_effort=getattr(args, "reasoning_effort", None), + extra_body=extra_body if extra_body else None, + ) + # Run eval tic = time.perf_counter() result = eval_obj(sampler) @@ -134,7 +148,7 @@ def run_eval(args): categories = args.categories.split(",") if args.categories else None eval_obj = LongBenchV2Eval( - model=args.model, + model=getattr(args, "model", None), data_source=data_source, num_examples=args.num_examples, num_threads=args.num_threads, @@ -170,9 +184,16 @@ def run_eval(args): if getattr(args, "repeat", 1) == 1: result, latency, sampler = run_eval_once(args, base_url, eval_obj) metrics = result.metrics | {"score": result.score} + metrics["latency"] = latency print(f"Total latency: {latency:.3f} s") print(f"Score: {metrics['score']:.3f}") + # Compute output throughput from accumulated completion tokens + total_completion_tokens = sum(sampler._completion_tokens) + if total_completion_tokens > 0 and latency > 0: + metrics["output_throughput"] = total_completion_tokens / latency + print(f"Output throughput: {metrics['output_throughput']:.3f} token/s") + # Report metrics to unified collection framework dump_metric( f"{args.eval_name}_score", @@ -195,19 +216,31 @@ def run_eval(args): ] scores_repeat = [] + latencies = [] + total_completion_tokens = 0 for f in futures: result, latency, sampler = f.result() scores_repeat.append(result.score) + latencies.append(latency) + total_completion_tokens += sum(sampler._completion_tokens) mean_score = sum(scores_repeat) / len(scores_repeat) + mean_latency = sum(latencies) / len(latencies) + total_latency = sum(latencies) scores_repeat = [f"{s:.3f}" for s in scores_repeat] print("=" * 20) print(f"Repeat: {args.repeat}, mean: {mean_score:.3f}") print(f"Scores: {scores_repeat}") + print(f"Mean latency: {mean_latency:.3f} s") print("=" * 20) metrics = result.metrics | {"scores": scores_repeat} metrics = metrics | {"mean_score": mean_score} + metrics["latency"] = mean_latency + + if total_completion_tokens > 0 and total_latency > 0: + metrics["output_throughput"] = total_completion_tokens / total_latency + print(f"Output throughput: {metrics['output_throughput']:.3f} token/s") # Report metrics to unified collection framework dump_metric( @@ -239,7 +272,7 @@ def run_eval(args): return metrics -THINKING_MODE_CHOICES = ["deepseek-v3", "qwen3"] +THINKING_MODE_CHOICES = ["deepseek-v3", "qwen-3", "glm-45", "kimi-k2"] if __name__ == "__main__": parser = argparse.ArgumentParser() @@ -266,6 +299,13 @@ def run_eval(args): "--repeat", type=int, default=1, help="repeat the evaluation n times" ) parser.add_argument("--eval-name", type=str, default="mmlu") + parser.add_argument( + "--api", + type=str, + default="chat", + choices=["chat", "completion"], + help="API mode: 'chat' for /v1/chat/completions, 'completion' for /v1/completions", + ) parser.add_argument("--num-examples", type=int) parser.add_argument("--num-threads", type=int, default=512) parser.add_argument("--max-tokens", type=int, default=2048) diff --git a/python/sglang/test/runners.py b/python/sglang/test/runners.py index fa737338d7f1..adbfcaf41d72 100644 --- a/python/sglang/test/runners.py +++ b/python/sglang/test/runners.py @@ -15,6 +15,7 @@ import json import multiprocessing as mp import os +import queue as queue_mod from dataclasses import dataclass from typing import Any, List, Optional, Tuple, Union @@ -411,7 +412,16 @@ def forward( self.in_queue.put( (prompts, image_data, max_new_tokens, lora_paths, token_ids_logprob) ) - return self.out_queue.get() + while True: + try: + return self.out_queue.get(timeout=5) + except queue_mod.Empty: + if not self.model_proc.is_alive() and self.out_queue.empty(): + exitcode = self.model_proc.exitcode + raise RuntimeError( + f"HFRunner subprocess died with exit code {exitcode} " + f"before producing output" + ) def terminate(self): self.model_proc.terminate() @@ -564,8 +574,6 @@ def __init__( speculative_num_steps: Optional[int] = None, speculative_eagle_topk: Optional[int] = None, speculative_num_draft_tokens: Optional[int] = None, - speculative_ngram_min_match_window_size: Optional[int] = None, - speculative_ngram_max_match_window_size: Optional[int] = None, disable_overlap_schedule: bool = False, disable_custom_all_reduce: bool = False, torchao_config: Optional[str] = None, @@ -596,12 +604,7 @@ def __init__( spec_kwargs["speculative_num_draft_tokens"] = speculative_num_draft_tokens elif speculative_algorithm == "NGRAM": spec_kwargs["speculative_algorithm"] = speculative_algorithm - spec_kwargs["speculative_ngram_min_match_window_size"] = ( - speculative_ngram_min_match_window_size - ) - spec_kwargs["speculative_ngram_max_match_window_size"] = ( - speculative_ngram_max_match_window_size - ) + spec_kwargs["speculative_num_draft_tokens"] = speculative_num_draft_tokens self.engine = Engine( model_path=model_path, diff --git a/python/sglang/test/server_fixtures/disaggregation_fixture.py b/python/sglang/test/server_fixtures/disaggregation_fixture.py index 53baed4d5a63..eda4004e9994 100644 --- a/python/sglang/test/server_fixtures/disaggregation_fixture.py +++ b/python/sglang/test/server_fixtures/disaggregation_fixture.py @@ -32,6 +32,7 @@ def setUpClass(cls): cls.prefill_url = f"http://{cls.base_host}:{cls.prefill_port}" cls.decode_url = f"http://{cls.base_host}:{cls.decode_port}" cls.lb_url = f"http://{cls.base_host}:{cls.lb_port}" + cls.base_url = cls.lb_url print( f"{cls.base_host=} {cls.lb_port=} {cls.prefill_port=} {cls.decode_port=} {cls.bootstrap_port=}" ) diff --git a/python/sglang/test/simple_eval_common.py b/python/sglang/test/simple_eval_common.py index 6e9733eb7221..b9e4057fa74c 100644 --- a/python/sglang/test/simple_eval_common.py +++ b/python/sglang/test/simple_eval_common.py @@ -109,6 +109,7 @@ def __init__( self.reasoning_effort = reasoning_effort self.extra_body = extra_body self.image_format = "url" + self._completion_tokens: list[int] = [] print( f"ChatCompletionSampler initialized with {self.system_message=} {self.temperature=} {self.max_tokens=} {self.reasoning_effort=} {self.extra_body=}" ) @@ -151,6 +152,8 @@ def __call__(self, message_list: MessageList) -> str: reasoning_effort=self.reasoning_effort, extra_body=self.extra_body, ) + if response.usage and response.usage.completion_tokens is not None: + self._completion_tokens.append(response.usage.completion_tokens) return response.choices[0].message.content or "" # NOTE: BadRequestError is triggered once for MMMU, please uncomment if you are rerunning MMMU except openai.BadRequestError as e: @@ -169,6 +172,75 @@ def __call__(self, message_list: MessageList) -> str: return "" +class CompletionSampler(SamplerBase): + """ + Sample from OpenAI's completion API (non-chat). + Sends raw text prompts without chat template wrapping. + """ + + def __init__( + self, + base_url: str = None, + model: Optional[str] = None, + temperature: float = 0.0, + top_p: float = 1.0, + max_tokens: int = 2048, + stop: Optional[List[str]] = None, + ): + self.client = OpenAI(base_url=base_url, http_client=LargerHttpxClient()) + + if model is None: + model = self.client.models.list().data[0].id + + self.model = model + self.temperature = temperature + self.top_p = top_p + self.max_tokens = max_tokens + self.stop = stop + self._completion_tokens: list[int] = [] + print( + f"CompletionSampler initialized with {self.model=} {self.temperature=} {self.max_tokens=} {self.stop=}" + ) + + def _pack_message(self, role: str, content: Any): + return {"role": str(role), "content": content} + + def __call__(self, message_list: MessageList) -> str: + # Extract raw text from message list (eval objects pack prompt as a single user message) + prompt = "\n".join( + msg["content"] + for msg in message_list + if isinstance(msg.get("content"), str) + ) + trial = 0 + while trial < 6: + try: + response = self.client.completions.create( + model=self.model, + prompt=prompt, + temperature=self.temperature, + top_p=self.top_p, + max_tokens=self.max_tokens, + stop=self.stop, + ) + if response.usage and response.usage.completion_tokens is not None: + self._completion_tokens.append(response.usage.completion_tokens) + return response.choices[0].text or "" + except openai.BadRequestError as e: + print("Bad Request Error", e) + return "" + except Exception as e: + exception_backoff = 2**trial + print( + f"Rate limit exception so wait and retry {trial} after {exception_backoff} sec", + e, + ) + time.sleep(exception_backoff) + trial += 1 + print(f"All retry attempts exhausted for request. Returning empty response.") + return "" + + QUERY_TEMPLATE_MULTICHOICE = """ Answer the following multiple choice question. The last line of your response should be of the following format: 'Answer: $LETTER' (without quotes) where LETTER is one of ABCD. Think step by step before answering. @@ -450,7 +522,7 @@ def make_report_from_example_htmls(htmls: List[str]): def download_dataset(path, url): print(f"Downloading dataset {path} from {url}") try: - response = requests.get(url, stream=True) + response = requests.get(url, stream=True, timeout=30) response.raise_for_status() total_size = int(response.headers.get("content-length", 0)) diff --git a/python/sglang/test/simple_eval_gpqa.py b/python/sglang/test/simple_eval_gpqa.py index b39366ef5df8..3ad37a604432 100644 --- a/python/sglang/test/simple_eval_gpqa.py +++ b/python/sglang/test/simple_eval_gpqa.py @@ -32,7 +32,10 @@ def __init__( num_threads: int, n_repeats: int = 1, ): - df = pandas.read_csv(filename) + if "://" in filename: + df = pandas.read_csv(filename, storage_options={"timeout": 30}) + else: + df = pandas.read_csv(filename) examples = [row.to_dict() for _, row in df.iterrows()] rng = random.Random(0) if num_examples: diff --git a/python/sglang/test/simple_eval_math.py b/python/sglang/test/simple_eval_math.py index 37d4b120b930..6cb5658bbdbb 100644 --- a/python/sglang/test/simple_eval_math.py +++ b/python/sglang/test/simple_eval_math.py @@ -40,7 +40,10 @@ def __init__( num_examples: Optional[int], num_threads: int, ): - df = pandas.read_csv(filename) + if "://" in filename: + df = pandas.read_csv(filename, storage_options={"timeout": 30}) + else: + df = pandas.read_csv(filename) examples = [row.to_dict() for _, row in df.iterrows()] if num_examples: examples = random.Random(0).sample(examples, num_examples) diff --git a/python/sglang/test/simple_eval_mgsm.py b/python/sglang/test/simple_eval_mgsm.py index 0b0b72a20f72..03098b95be50 100644 --- a/python/sglang/test/simple_eval_mgsm.py +++ b/python/sglang/test/simple_eval_mgsm.py @@ -115,7 +115,7 @@ def score_mgsm(target: str, prediction: str) -> bool: def get_lang_examples(lang: str) -> list[dict[str, str]]: fpath = LANG_TO_FPATH[lang] examples = [] - with urllib.request.urlopen(fpath) as f: + with urllib.request.urlopen(fpath, timeout=30) as f: for line in f.read().decode("utf-8").splitlines(): inputs, targets = line.strip().split("\t") if "." in targets: diff --git a/python/sglang/test/simple_eval_mmlu.py b/python/sglang/test/simple_eval_mmlu.py index a68dbb935a21..281da9e802f0 100644 --- a/python/sglang/test/simple_eval_mmlu.py +++ b/python/sglang/test/simple_eval_mmlu.py @@ -86,7 +86,10 @@ class MMLUEval(Eval): def __init__(self, filename: str, num_examples: Optional[int], num_threads: int): - df = pandas.read_csv(filename) + if "://" in filename: + df = pandas.read_csv(filename, storage_options={"timeout": 30}) + else: + df = pandas.read_csv(filename) examples = [row.to_dict() for _, row in df.iterrows()] if num_examples: examples = random.Random(0).sample(examples, num_examples) diff --git a/python/sglang/test/simple_eval_mmmu_vlm.py b/python/sglang/test/simple_eval_mmmu_vlm.py index f647340ea4be..e05885e9d739 100644 --- a/python/sglang/test/simple_eval_mmmu_vlm.py +++ b/python/sglang/test/simple_eval_mmmu_vlm.py @@ -148,12 +148,20 @@ def _key(idx): options = None # Build final textual prompt; include choices if MC - prompt_text = f"Question: {question}\n\n" + prompt_text = f"{question}\n" if options: letters = [chr(ord("A") + i) for i in range(len(options))] for letter, opt in zip(letters, options): - prompt_text += f"{letter}) {opt}\n" - prompt_text += "\nAnswer: " + prompt_text += f"{letter}. {opt}\n" + prompt_text += ( + "\nAnswer the following multiple-choice question. " + "The last line of your response should be of the " + "following format: 'Answer: $LETTER' (without quotes) " + "where LETTER is one of the options. " + "Think step by step before answering." + ) + else: + prompt_text += "\nAnswer: " samples.append( { @@ -330,6 +338,14 @@ def _parse_multi_choice_response( response: str, all_choices: List[str], index2ans: dict ) -> str: # loosely adapted from benchmark mmmu eval + + # First, look for explicit "Answer: X" pattern (last occurrence) + answer_matches = re.findall(r"[Aa]nswer\s*:\s*\*?\*?\s*\(?([A-Z])\)?", response) + if answer_matches: + candidate = answer_matches[-1] + if candidate in all_choices: + return candidate + for char in [",", ".", "!", "?", ";", ":", "'"]: response = response.strip(char) response = " " + response + " " diff --git a/python/sglang/test/test_mm_utils.py b/python/sglang/test/test_mm_utils.py new file mode 100644 index 000000000000..bc8fc63de4d7 --- /dev/null +++ b/python/sglang/test/test_mm_utils.py @@ -0,0 +1,50 @@ +import unittest +from unittest.mock import Mock, patch + +import torch + +from sglang.srt.managers import mm_utils, schedule_batch +from sglang.srt.managers.schedule_batch import ( + Modality, + MultimodalDataItem, + MultimodalInputs, +) + + +def _make_proxy_with_reconstruct_result(tensor: torch.Tensor): + proxy = mm_utils.CudaIpcTensorTransportProxy.__new__( + mm_utils.CudaIpcTensorTransportProxy + ) + proxy.reconstruct_on_target_device = Mock(return_value=tensor) + return proxy + + +class TestMultimodalInputsFromDict(unittest.TestCase): + def test_materialize_proxy(self): + feature_tensor = torch.tensor([[7.0], [8.0]], dtype=torch.float32) + proxy_feature = _make_proxy_with_reconstruct_result(feature_tensor) + mm_item = MultimodalDataItem( + modality=Modality.IMAGE, + offsets=[(0, 1), (1, 2)], + feature=proxy_feature, + model_specific_data={"image_grid_thw": [[1, 1, 1], [1, 1, 1]]}, + ) + + with patch.object( + schedule_batch.torch.cuda, "is_available", return_value=True + ), patch.object( + schedule_batch.torch.cuda, "current_device", return_value=0 + ), patch.object( + schedule_batch.envs.SGLANG_MM_BUFFER_SIZE_MB, "get", return_value=0 + ): + mm_inputs = MultimodalInputs.from_dict({"mm_items": [mm_item]}) + + # Splitting happens at the processor layer, not in from_dict. + # from_dict just reconstructs and passes through. + self.assertEqual(len(mm_inputs.mm_items), 1) + self.assertTrue(torch.equal(mm_inputs.mm_items[0].feature, feature_tensor)) + proxy_feature.reconstruct_on_target_device.assert_called_once_with(0) + + +if __name__ == "__main__": + unittest.main(verbosity=2) diff --git a/python/sglang/test/test_utils.py b/python/sglang/test/test_utils.py index 44f4e010ace6..a360358c7acd 100644 --- a/python/sglang/test/test_utils.py +++ b/python/sglang/test/test_utils.py @@ -2056,6 +2056,44 @@ def _distributed_worker(rank, world_size, backend, port, func, result_queue, kwa dist.destroy_process_group() +def maybe_stub_sgl_kernel(): + """Stub sgl_kernel if it cannot be imported (e.g. no GPU). + + Must be called before any import that transitively depends on sgl_kernel. + On machines with a working sgl_kernel this is a no-op. + """ + try: + import sgl_kernel # noqa: F401 + + return + except (ImportError, OSError): + pass + + import importlib.abc + import importlib.machinery + + class _SglKernelLoader(importlib.abc.Loader): + def create_module(self, spec): + return None + + def exec_module(self, module): + from unittest.mock import MagicMock + + module.__getattr__ = lambda name: MagicMock() + + class _SglKernelFinder(importlib.abc.MetaPathFinder): + def find_spec(self, fullname, path, target=None): + if fullname == "sgl_kernel" or fullname.startswith("sgl_kernel."): + return importlib.machinery.ModuleSpec( + fullname, + _SglKernelLoader(), + is_package=True, + ) + return None + + sys.meta_path.insert(0, _SglKernelFinder()) + + class CustomTestCase(unittest.TestCase): def __init_subclass__(cls, **kwargs): @@ -2068,9 +2106,11 @@ def __init_subclass__(cls, **kwargs): if getattr(setup, "_safe_setup_wrapped", False): return - def safe_setUpClass(klass, _orig=setup): + orig_func = setup.__func__ + + def safe_setUpClass(klass): try: - _orig.__func__(klass) + orig_func(klass) except Exception: # Best-effort cleanup; suppress teardown errors so the # original setUpClass exception propagates clearly. diff --git a/python/sglang/test/vlm_utils.py b/python/sglang/test/vlm_utils.py index 24ced63eb541..a1897be0cad9 100644 --- a/python/sglang/test/vlm_utils.py +++ b/python/sglang/test/vlm_utils.py @@ -77,7 +77,7 @@ def get_or_download_file(self, url: str) -> str: os.makedirs(cache_dir, exist_ok=True) if not os.path.exists(file_path): - response = requests.get(url) + response = requests.get(url, timeout=30) response.raise_for_status() with open(file_path, "wb") as f: diff --git a/python/sglang/utils.py b/python/sglang/utils.py index 7fac74d65d84..7cc322a23117 100644 --- a/python/sglang/utils.py +++ b/python/sglang/utils.py @@ -31,6 +31,56 @@ logger = logging.getLogger(__name__) +KNOWN_NON_DIFFUSERS_DIFFUSION_MODEL_PATTERNS: dict[str, str] = { + "hunyuan3d": "Hunyuan3D2Pipeline", + "flux.2-dev-nvfp4": "Flux2NvfpPipeline", +} + + +def load_diffusion_overlay_registry_from_env() -> dict[str, dict[str, Any]]: + raw_value = os.getenv("SGLANG_DIFFUSION_MODEL_OVERLAY_REGISTRY", "").strip() + if not raw_value: + return {} + + if raw_value.startswith("{"): + payload = json.loads(raw_value) + else: + with open(os.path.expanduser(raw_value), encoding="utf-8") as f: + payload = json.load(f) + + if not isinstance(payload, dict): + return {} + + normalized: dict[str, dict[str, Any]] = {} + for source_model_id, spec in payload.items(): + if isinstance(spec, str): + normalized[source_model_id] = {"overlay_repo_id": spec} + elif isinstance(spec, dict) and spec.get("overlay_repo_id"): + normalized[source_model_id] = dict(spec) + return normalized + + +def has_diffusion_overlay_registry_match( + model_path: str, registry: dict[str, dict[str, Any]] | None = None +) -> bool: + registry = ( + load_diffusion_overlay_registry_from_env() if registry is None else registry + ) + if model_path in registry: + return True + if not os.path.exists(model_path): + return False + base_name = os.path.basename(os.path.normpath(model_path)) + return any(base_name == key.rsplit("/", 1)[-1] for key in registry) + + +def is_known_non_diffusers_diffusion_model(model_path: str) -> bool: + model_path_lower = model_path.lower() + return any( + pattern in model_path_lower + for pattern in KNOWN_NON_DIFFUSERS_DIFFUSION_MODEL_PATTERNS + ) + def execute_once(func): has_run = None @@ -204,7 +254,16 @@ def encode_image_base64(image_path: Union[str, bytes]): elif isinstance(image_path, bytes): return pybase64.b64encode(image_path).decode("utf-8") else: - # image_path is PIL.WebPImagePlugin.WebPImageFile + import torch + + if isinstance(image_path, torch.Tensor): + # Convert GPU-decoded image tensor (C, H, W) uint8 to PIL Image + from PIL import Image + + tensor = image_path.cpu() if image_path.device.type != "cpu" else image_path + image_path = Image.fromarray(tensor.permute(1, 2, 0).numpy()) + + # image_path is a PIL Image image = image_path buffered = BytesIO() image.save(buffered, format="PNG") diff --git a/scripts/ci/check_workflow_job_names.py b/scripts/ci/check_workflow_job_names.py new file mode 100755 index 000000000000..75e2009ea1d5 --- /dev/null +++ b/scripts/ci/check_workflow_job_names.py @@ -0,0 +1,58 @@ +#!/usr/bin/env python3 +"""Check that required status check job names are unique across workflows. + +Duplicate job names on the same commit allow a passing job in one workflow +to satisfy a required status check meant for a different workflow, bypassing +branch protection. + +See: https://github.com/sgl-project/sglang/pull/20208 for an example where +pr-test-npu.yml's "pr-test-finish" job (which passed) caused GitHub to treat +the required "pr-test-finish" check (from pr-test.yml, which failed) as met. +""" + +import glob +import sys +from collections import defaultdict + +import yaml + +# Job names used as required status checks in branch protection. +# These MUST be unique across all workflow files. +PROTECTED_JOB_NAMES = { + "pr-test-finish", + "lint", +} + + +def main() -> int: + workflows = sorted(glob.glob(".github/workflows/*.yml")) + job_to_files: dict[str, list[str]] = defaultdict(list) + + for wf in workflows: + with open(wf) as f: + data = yaml.safe_load(f) + if not data or "jobs" not in data: + continue + for job in data["jobs"]: + if job in PROTECTED_JOB_NAMES: + job_to_files[job].append(wf) + + duplicates = {job: files for job, files in job_to_files.items() if len(files) > 1} + + if not duplicates: + return 0 + + print("ERROR: Required status check job names must be unique across workflows.") + print("Duplicates allow branch protection bypass via auto-merge.\n") + for job, files in sorted(duplicates.items()): + print(f" Job '{job}' appears in:") + for f in files: + print(f" - {f}") + print() + + print("Fix: rename the job in non-primary workflows to avoid collision.") + return 1 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/scripts/ci/cuda/cache_nvidia_wheels.sh b/scripts/ci/cuda/cache_nvidia_wheels.sh new file mode 100755 index 000000000000..2a0f8dbb9e65 --- /dev/null +++ b/scripts/ci/cuda/cache_nvidia_wheels.sh @@ -0,0 +1,24 @@ +#!/bin/bash +# Cache and pre-install nvidia wheels that torch pins. +# +# pypi.nvidia.com returns Cache-Control: no-store, so pip re-downloads +# cudnn (~707 MB) and nvshmem (~125 MB) on every CI run. This script +# caches the wheels locally and installs them so that the subsequent +# `pip install -e "python[dev]"` sees "Requirement already satisfied". +# +# Integrity: uses `unzip -t` to detect partial/corrupt downloads. +# +# Usage: source scripts/ci/cuda/cache_nvidia_wheels.sh + +NVIDIA_WHEEL_CACHE="/root/.cache/nvidia-wheels" +mkdir -p "$NVIDIA_WHEEL_CACHE" + +for url in \ + "https://pypi.nvidia.com/nvidia-cudnn-cu12/nvidia_cudnn_cu12-9.10.2.21-py3-none-manylinux_2_27_x86_64.whl" \ + "https://pypi.nvidia.com/nvidia-nvshmem-cu12/nvidia_nvshmem_cu12-3.3.20-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl"; do + whl="$NVIDIA_WHEEL_CACHE/$(basename "$url")" + [ -f "$whl" ] && unzip -tq "$whl" &>/dev/null || curl -fL -o "$whl" "$url" +done + +pip install --no-deps "$NVIDIA_WHEEL_CACHE"/nvidia_cudnn_cu12-*.whl \ + "$NVIDIA_WHEEL_CACHE"/nvidia_nvshmem_cu12-*.whl 2>/dev/null || true diff --git a/scripts/ci/cuda/ci_download_flashinfer_jit_cache.sh b/scripts/ci/cuda/ci_download_flashinfer_jit_cache.sh index bca94477c679..ab41dcef22c0 100755 --- a/scripts/ci/cuda/ci_download_flashinfer_jit_cache.sh +++ b/scripts/ci/cuda/ci_download_flashinfer_jit_cache.sh @@ -25,8 +25,6 @@ if [ "$FLASHINFER_JIT_CACHE_INSTALLED" = false ]; then FLASHINFER_CACHE_DIR="${HOME}/.cache/flashinfer-wheels" mkdir -p "${FLASHINFER_CACHE_DIR}" - find "${FLASHINFER_CACHE_DIR}" -name "flashinfer_jit_cache-*.whl" ! -name "flashinfer_jit_cache-${FLASHINFER_PYTHON_REQUIRED}*" -type f -delete 2>/dev/null || true - FLASHINFER_WHEEL_PATTERN="flashinfer_jit_cache-${FLASHINFER_PYTHON_REQUIRED}*.whl" CACHED_WHEEL=$(find "${FLASHINFER_CACHE_DIR}" -name "${FLASHINFER_WHEEL_PATTERN}" -type f 2>/dev/null | head -n 1) diff --git a/scripts/ci/cuda/ci_install_dependency.sh b/scripts/ci/cuda/ci_install_dependency.sh index d2b495fafafb..5bfbea04ffeb 100755 --- a/scripts/ci/cuda/ci_install_dependency.sh +++ b/scripts/ci/cuda/ci_install_dependency.sh @@ -24,6 +24,11 @@ set -euxo pipefail # ------------------------------------------------------------------------------ # Set up environment variables CU_VERSION="cu129" + +# Nvidia package versions we override (torch pins older versions). +# Used both as pip constraints during install and for post-install verification. +NVIDIA_CUDNN_VERSION="9.16.0.29" +NVIDIA_NVSHMEM_VERSION="3.4.5" OPTIONAL_DEPS="${1:-}" SECONDS=0 @@ -213,11 +218,12 @@ mark_step_done "Uninstall Flashinfer" # Install main package # ------------------------------------------------------------------------------ # Install the main package -EXTRAS="dev" +EXTRAS="dev,runai,tracing" if [ -n "$OPTIONAL_DEPS" ]; then - EXTRAS="dev,${OPTIONAL_DEPS}" + EXTRAS="dev,runai,tracing,${OPTIONAL_DEPS}" fi echo "Installing python extras: [${EXTRAS}]" +source "$(dirname "$0")/cache_nvidia_wheels.sh" $PIP_CMD install -e "python[${EXTRAS}]" --extra-index-url https://download.pytorch.org/whl/${CU_VERSION} $PIP_INSTALL_SUFFIX mark_step_done "Install main package" @@ -298,7 +304,7 @@ if [ "$CU_VERSION" = "cu130" ]; then else NVRTC_SPEC="nvidia-cuda-nvrtc-cu12" fi -$PIP_CMD install mooncake-transfer-engine==0.3.10 "${NVRTC_SPEC}" py-spy scipy huggingface_hub[hf_xet] pytest $PIP_INSTALL_SUFFIX +$PIP_CMD install mooncake-transfer-engine==0.3.10.post1 "${NVRTC_SPEC}" py-spy scipy huggingface_hub[hf_xet] pytest $PIP_INSTALL_SUFFIX # Install other test dependencies if [ "$IS_BLACKWELL" != "1" ]; then @@ -331,18 +337,18 @@ fi # Fix dependencies: DeepEP depends on nvshmem 3.4.5 — skip reinstall when already correct (avoids pip races / wasted work) INSTALLED_NVSHMEM=$(pip show nvidia-nvshmem-cu12 2>/dev/null | grep "^Version:" | awk '{print $2}' || echo "") -if [ "$INSTALLED_NVSHMEM" = "3.4.5" ]; then - echo "nvidia-nvshmem-cu12==3.4.5 already installed, skipping reinstall" +if [ "$INSTALLED_NVSHMEM" = "$NVIDIA_NVSHMEM_VERSION" ]; then + echo "nvidia-nvshmem-cu12==${NVIDIA_NVSHMEM_VERSION} already installed, skipping reinstall" else - $PIP_CMD install nvidia-nvshmem-cu12==3.4.5 $PIP_INSTALL_SUFFIX + $PIP_CMD install nvidia-nvshmem-cu12==${NVIDIA_NVSHMEM_VERSION} $PIP_INSTALL_SUFFIX fi # Fix dependencies: Cudnn with version less than 9.16.0.29 will cause performance regression on Conv3D kernel INSTALLED_CUDNN=$(pip show nvidia-cudnn-cu12 2>/dev/null | grep "^Version:" | awk '{print $2}' || echo "") -if [ "$INSTALLED_CUDNN" = "9.16.0.29" ]; then - echo "nvidia-cudnn-cu12==9.16.0.29 already installed, skipping reinstall" +if [ "$INSTALLED_CUDNN" = "$NVIDIA_CUDNN_VERSION" ]; then + echo "nvidia-cudnn-cu12==${NVIDIA_CUDNN_VERSION} already installed, skipping reinstall" else - $PIP_CMD install nvidia-cudnn-cu12==9.16.0.29 $PIP_INSTALL_SUFFIX + $PIP_CMD install nvidia-cudnn-cu12==${NVIDIA_CUDNN_VERSION} $PIP_INSTALL_SUFFIX fi mark_step_done "Fix other dependencies" @@ -352,6 +358,13 @@ mark_step_done "Fix other dependencies" # can delete the .pth file without reliably recreating it (pip race condition). $PIP_CMD install "nvidia-cutlass-dsl>=4.4.1" "nvidia-cutlass-dsl-libs-base>=4.4.1" --no-deps --force-reinstall $PIP_INSTALL_SUFFIX || true + +# Install human-eval +pip install "setuptools==70.0.0" +git clone https://github.com/merrymercy/human-eval.git +cd human-eval +pip install -e . --no-build-isolation + # ------------------------------------------------------------------------------ # Prepare runner # ------------------------------------------------------------------------------ diff --git a/scripts/ci/utils/diffusion/generate_diffusion_dashboard.py b/scripts/ci/utils/diffusion/generate_diffusion_dashboard.py index 183611ef044d..bb223fbe6dde 100644 --- a/scripts/ci/utils/diffusion/generate_diffusion_dashboard.py +++ b/scripts/ci/utils/diffusion/generate_diffusion_dashboard.py @@ -26,7 +26,7 @@ CI_DATA_REPO_NAME = "sglang-ci-data" CI_DATA_BRANCH = "main" HISTORY_PREFIX = "diffusion-comparisons" -MAX_HISTORY_RUNS = 7 +MAX_HISTORY_RUNS = 14 # Base URL for chart images pushed to sglang-ci-data CHARTS_RAW_BASE_URL = ( @@ -344,7 +344,7 @@ def generate_dashboard( # ---- Section 2: SGLang Performance Trend ---- if history: - lines.append("\n## SGLang Performance Trend (Last 7 Runs)\n") + lines.append(f"\n## SGLang Performance Trend (Last {len(history) + 1} Runs)\n") # Build header header = "| Date | Commit |" @@ -491,9 +491,16 @@ def _chart_label(run: dict) -> str: ax.set_xticklabels(labels, fontsize=7) ax.set_ylabel("Latency (s)") ax.set_title(f"Latency Trend -- {cid}", fontsize=11, fontweight="bold") - ax.legend(loc="upper right", fontsize=8) + ax.legend(loc="lower right", fontsize=8, framealpha=0.8) ax.grid(True, alpha=0.3) - ax.set_ylim(bottom=0) + all_vals = sg_vals + [v for v in vl_vals if v is not None] + y_min = min(all_vals) + y_max = max(all_vals) + y_range = y_max - y_min if y_max > y_min else max(y_max * 0.1, 0.1) + ax.set_ylim( + bottom=max(0, y_min - y_range * 0.3), + top=y_max + y_range * 0.3, + ) filename = f"latency_{_sanitize_filename(cid)}.png" chart_path = os.path.join(charts_dir, filename) diff --git a/scripts/ci/utils/query_job_status.py b/scripts/ci/utils/query_job_status.py index 9a5e8a128903..39028af2d041 100755 --- a/scripts/ci/utils/query_job_status.py +++ b/scripts/ci/utils/query_job_status.py @@ -1,11 +1,17 @@ #!/usr/bin/env python3 """ -Query GitHub Actions job status for specific jobs. +Query GitHub Actions job status for specific jobs or generate runner fleet reports. Usage: + # Per-job reports (original mode) python scripts/ci/utils/query_job_status.py --job "stage-c-test-large-8-gpu-amd-mi35x" python scripts/ci/utils/query_job_status.py --job "stage-c-test-large-8-gpu-amd-mi35x" --hours 48 - python scripts/ci/utils/query_job_status.py --job "AMD" --workflow pr-test-amd.yml + python scripts/ci/utils/query_job_status.py --job "stage-c-test-large-8-gpu-amd-mi35x" --workflow "pr-test-amd.yml" --input-data-file actions-job-snapshot.json --summary + + # Runner fleet report (cross-workflow runner analytics) + python scripts/ci/utils/query_job_status.py --runner-report --workflow "pr-test-amd.yml,nightly-test-amd.yml" --hours 24 + python scripts/ci/utils/query_job_status.py --runner-report --workflow "pr-test-amd.yml,nightly-test-amd.yml,pr-test-amd-rocm720.yml,nightly-test-amd-rocm720.yml" --summary + python scripts/ci/utils/query_job_status.py --workflow "pr-test-amd.yml,nightly-test-amd.yml,pr-test-amd-rocm720.yml,nightly-test-amd-rocm720.yml" --dump-data-file actions-job-snapshot.json Requirements: pip install tabulate @@ -17,7 +23,7 @@ import subprocess import sys from datetime import datetime, timedelta, timezone -from typing import Optional +from typing import Any, Optional try: from tabulate import tabulate @@ -76,6 +82,77 @@ def run_gh_command(args: list[str]) -> dict: return json.loads(result.stdout) +def is_rate_limit_error(error: str) -> bool: + """Check whether an API error was caused by GitHub rate limiting.""" + return "rate limit exceeded" in error.lower() + + +def _new_workflow_fetch_stats(workflow: str) -> dict[str, Any]: + """Create an empty metadata bucket for a workflow snapshot.""" + return { + "workflow": workflow, + "total_runs_seen": 0, + "runs_with_jobs": 0, + "skipped_runs": 0, + "skipped_runs_rate_limit": 0, + "jobs_collected": 0, + } + + +def _new_fetch_metadata(repo: str, workflows: list[str], hours: int) -> dict[str, Any]: + """Create the fetch metadata container stored alongside snapshot jobs.""" + return { + "repo": repo, + "hours": hours, + "requested_workflows": workflows, + "total_runs_seen": 0, + "runs_with_jobs": 0, + "jobs_collected": 0, + "skipped_runs": [], + "workflow_fetch_failures": [], + "workflow_stats": { + workflow: _new_workflow_fetch_stats(workflow) for workflow in workflows + }, + } + + +def _record_workflow_fetch_failure( + fetch_metadata: dict[str, Any], workflow: str, error: str +) -> None: + """Record a workflow-level failure while listing workflow runs.""" + fetch_metadata["workflow_fetch_failures"].append( + { + "workflow": workflow, + "error": error.strip(), + "reason": "rate_limit" if is_rate_limit_error(error) else "api_error", + } + ) + + +def _record_skipped_run( + fetch_metadata: dict[str, Any], workflow: str, run: dict, error: str +) -> None: + """Record a run whose jobs could not be fetched.""" + workflow_stats = fetch_metadata["workflow_stats"].setdefault( + workflow, _new_workflow_fetch_stats(workflow) + ) + workflow_stats["skipped_runs"] += 1 + if is_rate_limit_error(error): + workflow_stats["skipped_runs_rate_limit"] += 1 + + fetch_metadata["skipped_runs"].append( + { + "workflow": workflow, + "run_id": run["id"], + "created_at": run.get("created_at", ""), + "status": run.get("status", "unknown"), + "conclusion": run.get("conclusion") or "-", + "reason": "rate_limit" if is_rate_limit_error(error) else "api_error", + "error": error.strip(), + } + ) + + def parse_time(time_str: str) -> Optional[datetime]: """Parse ISO timestamp to datetime.""" if not time_str: @@ -150,66 +227,131 @@ def get_pr_number_from_run(run: dict) -> Optional[int]: return None -def query_jobs( - repo: str, +def _job_name_matches_filter(job_name: str, job_filter: str) -> bool: + """Check whether a job name matches the report filter prefix.""" + job_name_lower = job_name.lower() + filter_lower = job_filter.lower() + if not job_name_lower.startswith(filter_lower): + return False + if len(job_name_lower) > len(filter_lower): + next_char = job_name_lower[len(filter_lower)] + if next_char not in (" ", "("): + return False + return True + + +def filter_jobs( + jobs: list[dict], job_filter: str, workflow: str = None, - hours: int = 24, status_filter: str = None, ) -> list[dict]: - """Query jobs matching the filter.""" + """Filter a prefetched job list for a specific report target.""" + results = [] + for job in jobs: + if workflow and job.get("workflow") != workflow: + continue + if not _job_name_matches_filter(job.get("job_name", ""), job_filter): + continue + if status_filter and job.get("status") != status_filter: + continue + results.append(job) + return results + + +def save_snapshot(path: str, snapshot: dict[str, Any]) -> None: + """Persist a prefetched Actions snapshot to disk.""" + with open(path, "w") as f: + json.dump(snapshot, f, indent=2) + + +def load_snapshot(path: str) -> dict[str, Any]: + """Load a previously saved Actions snapshot from disk.""" + with open(path) as f: + snapshot = json.load(f) + if "jobs" not in snapshot: + raise ValueError(f"Snapshot file {path} is missing the 'jobs' field") + return snapshot - print(f"Fetching workflow runs from last {hours} hours...", file=sys.stderr) - runs = get_workflow_runs(repo, workflow, hours) - print(f"Found {len(runs)} workflow runs", file=sys.stderr) + +def fetch_all_jobs_snapshot( + repo: str, + workflows: list[str], + hours: int = 24, +) -> dict[str, Any]: + """Fetch jobs once and store enough metadata to detect incomplete data.""" + fetch_metadata = _new_fetch_metadata(repo, workflows, hours) + all_runs = [] + + for workflow in workflows: + print(f"Fetching runs for {workflow}...", file=sys.stderr) + try: + runs = get_workflow_runs(repo, workflow, hours) + except Exception as e: + error = str(e) + print( + f"Warning: Failed to list runs for workflow {workflow}: {error}", + file=sys.stderr, + ) + _record_workflow_fetch_failure(fetch_metadata, workflow, error) + continue + + print(f" Found {len(runs)} runs for {workflow}", file=sys.stderr) + fetch_metadata["workflow_stats"][workflow]["total_runs_seen"] = len(runs) + for run in runs: + run["_workflow"] = workflow + all_runs.extend(runs) + + seen_run_ids = set() + unique_runs = [] + for run in all_runs: + if run["id"] not in seen_run_ids: + seen_run_ids.add(run["id"]) + unique_runs.append(run) + + fetch_metadata["total_runs_seen"] = len(unique_runs) + print(f"Total unique workflow runs: {len(unique_runs)}", file=sys.stderr) results = [] - total_runs = len(runs) + total_runs = len(unique_runs) - for i, run in enumerate(runs): + for i, run in enumerate(unique_runs): if (i + 1) % 20 == 0: print(f"Processing run {i+1}/{total_runs}...", file=sys.stderr) + workflow_name = run.get("_workflow", "-") try: jobs = get_jobs_for_run(repo, run["id"]) except Exception as e: + error = str(e) print( - f"Warning: Failed to get jobs for run {run['id']}: {e}", file=sys.stderr + f"Warning: Failed to get jobs for run {run['id']}: {error}", + file=sys.stderr, ) + _record_skipped_run(fetch_metadata, workflow_name, run, error) continue + workflow_stats = fetch_metadata["workflow_stats"].setdefault( + workflow_name, _new_workflow_fetch_stats(workflow_name) + ) + workflow_stats["runs_with_jobs"] += 1 + fetch_metadata["runs_with_jobs"] += 1 + pr_number = get_pr_number_from_run(run) branch = run.get("head_branch", "") run_status = run.get("status", "unknown") run_conclusion = run.get("conclusion") or "-" + jobs_added = 0 for job in jobs: job_name = job.get("name", "") - - # Filter by job name - # Use prefix matching to avoid e.g. "stage-c-test-large-8-gpu-amd" - # also matching "stage-c-test-large-8-gpu-amd-mi35x" - job_name_lower = job_name.lower() - filter_lower = job_filter.lower() - if not job_name_lower.startswith(filter_lower): - continue - # If there are characters after the filter, ensure it's not a - # continuation of the base job name (e.g., "-mi35x") - if len(job_name_lower) > len(filter_lower): - next_char = job_name_lower[len(filter_lower)] - if next_char not in (" ", "("): - continue - - # Filter by status if specified - if status_filter and job.get("status") != status_filter: - continue - job_status = job.get("status", "unknown") runner_name = job.get("runner_name") or "-" + labels = job.get("labels", []) + + if len(labels) == 1 and labels[0] == "ubuntu-latest": + continue - # Detect stuck/ghost jobs: - # - Job is in_progress but no runner assigned - # - Job is in_progress but workflow run is cancelled/completed is_stuck = False if job_status == "in_progress": if runner_name == "-": @@ -229,6 +371,8 @@ def query_jobs( "started_at": job.get("started_at", ""), "completed_at": job.get("completed_at", ""), "runner_name": runner_name, + "labels": labels, + "runner_group_name": job.get("runner_group_name") or "-", "run_id": run["id"], "run_status": run_status, "run_conclusion": run_conclusion, @@ -236,10 +380,49 @@ def query_jobs( "branch": branch, "html_url": job.get("html_url", ""), "is_stuck": is_stuck, + "workflow": workflow_name, } ) + jobs_added += 1 - return results + workflow_stats["jobs_collected"] += jobs_added + + fetch_metadata["jobs_collected"] = len(results) + return { + "snapshot_version": 1, + "repo": repo, + "hours": hours, + "workflows": workflows, + "generated_at": datetime.now(timezone.utc).isoformat(), + "jobs": results, + "fetch_metadata": fetch_metadata, + } + + +def query_jobs( + repo: str, + job_filter: str, + workflow: str = None, + hours: int = 24, + status_filter: str = None, +) -> list[dict]: + """Query jobs matching the filter.""" + snapshot = fetch_all_jobs_snapshot(repo, [workflow], hours) + return filter_jobs(snapshot["jobs"], job_filter, workflow, status_filter) + + +def query_all_jobs( + repo: str, + workflows: list[str], + hours: int = 24, +) -> list[dict]: + """Query all jobs across multiple workflows for fleet-level analysis. + + Unlike query_jobs(), this does NOT filter by job name and collects + everything in a single pass -- ideal for runner-centric analytics. + Jobs on ubuntu-latest are excluded since those are utility jobs. + """ + return fetch_all_jobs_snapshot(repo, workflows, hours)["jobs"] def calculate_duration(started_at: str, completed_at: str) -> str: @@ -320,6 +503,253 @@ def calculate_queue_time( return f"{minutes}m{seconds}s" +# --------------------------------------------------------------------------- +# Runner fleet analytics functions +# --------------------------------------------------------------------------- + + +def _format_duration_seconds(seconds: Optional[float]) -> str: + """Format seconds into human-readable duration string.""" + if seconds is None or seconds < 0: + return "-" + total_seconds = int(seconds) + minutes = total_seconds // 60 + secs = total_seconds % 60 + if minutes >= 60: + hours = minutes // 60 + minutes = minutes % 60 + return f"{hours}h{minutes}m" + return f"{minutes}m{secs}s" + + +def _get_runner_label(job: dict) -> str: + """Extract the primary runner label from a job's labels list.""" + labels = job.get("labels", []) + if not labels: + return "unknown" + for label in labels: + if label.startswith("linux-mi"): + return label + return labels[0] + + +def _percentile(data: list[float], p: int) -> Optional[float]: + """Return a percentile from an already sorted or unsorted numeric list.""" + if not data: + return None + sorted_data = sorted(data) + idx = min(int(len(sorted_data) * p / 100), len(sorted_data) - 1) + return sorted_data[idx] + + +def _average(data: list[float]) -> Optional[float]: + """Return the average of a numeric list when samples exist.""" + if not data: + return None + return sum(data) / len(data) + + +def _queue_time_seconds(job: dict) -> Optional[float]: + """Extract queue time in seconds for a job if both timestamps exist.""" + created = parse_time(job.get("created_at", "")) + started = parse_time(job.get("started_at", "")) + if not (created and started): + return None + + queue_seconds = (started - created).total_seconds() + if queue_seconds < 0: + return None + return queue_seconds + + +def _build_queue_distribution(queue_times: list[float]) -> dict[str, Any]: + """Build queue time buckets and percentile stats for one sample set.""" + if not queue_times: + return {"buckets": [], "p50": None, "p90": None, "p99": None, "total": 0} + + sorted_queue_times = sorted(queue_times) + bucket_defs = [ + ("< 1 min", 0, 60), + ("1-5 min", 60, 300), + ("5-15 min", 300, 900), + ("15-30 min", 900, 1800), + ("30-60 min", 1800, 3600), + ("> 60 min", 3600, float("inf")), + ] + + total = len(sorted_queue_times) + buckets = [] + for label, lo, hi in bucket_defs: + count = sum(1 for qt in sorted_queue_times if lo <= qt < hi) + pct = count / total * 100 if total > 0 else 0 + buckets.append({"range": label, "count": count, "percentage": round(pct, 1)}) + + return { + "buckets": buckets, + "p50": _percentile(sorted_queue_times, 50), + "p90": _percentile(sorted_queue_times, 90), + "p99": _percentile(sorted_queue_times, 99), + "total": total, + } + + +def analyze_concurrency(jobs: list[dict], report_time: datetime = None) -> dict: + """Analyze concurrent runner usage per runner label. + + Uses an event-sweep algorithm: for each job that ran, create +1 event + at started_at and -1 event at completed_at, then sweep through sorted + events tracking the concurrent count. + """ + if report_time is None: + report_time = datetime.now(timezone.utc) + + label_jobs: dict[str, list[dict]] = {} + for job in jobs: + label = _get_runner_label(job) + label_jobs.setdefault(label, []).append(job) + + results = {} + for label in sorted(label_jobs): + pool_jobs = label_jobs[label] + events: list[tuple[datetime, int]] = [] + queue_times: list[float] = [] + durations: list[float] = [] + + for job in pool_jobs: + started = parse_time(job.get("started_at", "")) + completed = parse_time(job.get("completed_at", "")) + + if started and completed: + events.append((started, +1)) + events.append((completed, -1)) + durations.append((completed - started).total_seconds()) + elif started: + events.append((started, +1)) + events.append((report_time, -1)) + durations.append((report_time - started).total_seconds()) + + qt = _queue_time_seconds(job) + if qt is not None: + queue_times.append(qt) + + if not events: + results[label] = { + "peak": 0, + "avg_concurrent": 0.0, + "total_jobs": len(pool_jobs), + "avg_queue_seconds": _average(queue_times), + "p50_queue_seconds": _percentile(queue_times, 50), + "p99_queue_seconds": _percentile(queue_times, 99), + "avg_duration_seconds": _average(durations), + } + continue + + events.sort(key=lambda x: (x[0], x[1])) + concurrent = 0 + peak = 0 + time_weighted_sum = 0.0 + total_time = 0.0 + prev_time = events[0][0] + + for ts, delta in events: + if prev_time and concurrent > 0: + dt = (ts - prev_time).total_seconds() + time_weighted_sum += concurrent * dt + total_time += dt + concurrent += delta + peak = max(peak, concurrent) + prev_time = ts + + avg_concurrent = time_weighted_sum / total_time if total_time > 0 else 0 + avg_queue = _average(queue_times) + avg_duration = _average(durations) + + results[label] = { + "peak": peak, + "avg_concurrent": round(avg_concurrent, 1), + "total_jobs": len(pool_jobs), + "avg_queue_seconds": avg_queue, + "p50_queue_seconds": _percentile(queue_times, 50), + "p99_queue_seconds": _percentile(queue_times, 99), + "avg_duration_seconds": avg_duration, + } + + return results + + +def analyze_busy_periods(jobs: list[dict]) -> list[dict]: + """Analyze job activity by hour of day (UTC). + + Buckets jobs by the UTC hour they started and computes avg queue time. + Classifies each hour as Quiet / Moderate / Busy / Peak relative to the + busiest hour. + """ + hourly: dict[int, dict] = { + h: {"jobs_started": 0, "queue_times": []} for h in range(24) + } + + for job in jobs: + started = parse_time(job.get("started_at", "")) + created = parse_time(job.get("created_at", "")) + + if started: + hour = started.astimezone(timezone.utc).hour + hourly[hour]["jobs_started"] += 1 + + if created: + qt = (started - created).total_seconds() + if qt >= 0: + hourly[hour]["queue_times"].append(qt) + + max_jobs = max((v["jobs_started"] for v in hourly.values()), default=1) or 1 + + results = [] + for hour in range(24): + data = hourly[hour] + avg_queue = ( + sum(data["queue_times"]) / len(data["queue_times"]) + if data["queue_times"] + else 0 + ) + ratio = data["jobs_started"] / max_jobs + if ratio >= 0.75: + load = "Peak" + elif ratio >= 0.5: + load = "Busy" + elif ratio >= 0.25: + load = "Moderate" + else: + load = "Quiet" + + results.append( + { + "hour": hour, + "hour_label": f"{hour:02d}:00-{(hour + 1) % 24:02d}:00", + "jobs_started": data["jobs_started"], + "avg_queue_seconds": avg_queue, + "load": load, + } + ) + + return results + + +def analyze_queue_distribution(jobs: list[dict]) -> dict: + """Analyze queue time distribution per runner label.""" + queue_times_by_label: dict[str, list[float]] = {} + for job in jobs: + queue_seconds = _queue_time_seconds(job) + if queue_seconds is None: + continue + label = _get_runner_label(job) + queue_times_by_label.setdefault(label, []).append(queue_seconds) + + return { + label: _build_queue_distribution(queue_times) + for label, queue_times in sorted(queue_times_by_label.items()) + } + + def process_results( results: list[dict], repo: str, report_time: datetime = None ) -> dict: @@ -433,6 +863,104 @@ def process_results( } +def summarize_fetch_metadata( + fetch_metadata: Optional[dict[str, Any]], workflows: list[str] = None +) -> Optional[dict[str, Any]]: + """Summarize snapshot completeness for the workflows relevant to a report.""" + if not fetch_metadata: + return None + + workflow_filter = ( + set(workflows) + if workflows + else set(fetch_metadata.get("requested_workflows", [])) + ) + workflow_stats = fetch_metadata.get("workflow_stats", {}) + if not workflow_filter: + workflow_filter = set(workflow_stats) + + relevant_stats = [ + workflow_stats[workflow] + for workflow in workflow_filter + if workflow in workflow_stats + ] + relevant_skipped_runs = [ + run + for run in fetch_metadata.get("skipped_runs", []) + if run.get("workflow") in workflow_filter + ] + relevant_workflow_failures = [ + failure + for failure in fetch_metadata.get("workflow_fetch_failures", []) + if failure.get("workflow") in workflow_filter + ] + + skipped_run_rate_limit = sum( + 1 for run in relevant_skipped_runs if run.get("reason") == "rate_limit" + ) + workflow_failure_rate_limit = sum( + 1 + for failure in relevant_workflow_failures + if failure.get("reason") == "rate_limit" + ) + + return { + "known_runs": sum(stat.get("total_runs_seen", 0) for stat in relevant_stats), + "runs_with_jobs": sum(stat.get("runs_with_jobs", 0) for stat in relevant_stats), + "jobs_collected": sum(stat.get("jobs_collected", 0) for stat in relevant_stats), + "skipped_runs": relevant_skipped_runs, + "workflow_failures": relevant_workflow_failures, + "skipped_run_rate_limit": skipped_run_rate_limit, + "workflow_failure_rate_limit": workflow_failure_rate_limit, + "incomplete": bool(relevant_skipped_runs or relevant_workflow_failures), + } + + +def append_fetch_metadata_notice( + lines: list[str], + fetch_metadata: Optional[dict[str, Any]], + workflows: list[str] = None, +) -> None: + """Append a markdown notice when the report is based on incomplete data.""" + summary = summarize_fetch_metadata(fetch_metadata, workflows) + if not summary or not summary["incomplete"]: + return + + skipped_runs = summary["skipped_runs"] + workflow_failures = summary["workflow_failures"] + other_skipped = len(skipped_runs) - summary["skipped_run_rate_limit"] + other_workflow_failures = ( + len(workflow_failures) - summary["workflow_failure_rate_limit"] + ) + + lines.append( + "> **Data completeness:** Incomplete. GitHub API rate limit and/or fetch errors prevented a full dataset." + ) + if summary["known_runs"] > 0: + lines.append( + f"> Successfully fetched jobs for **{summary['runs_with_jobs']}/{summary['known_runs']}** known runs in scope. Missing runs: **{len(skipped_runs)}** (rate limit: {summary['skipped_run_rate_limit']}, other API errors: {other_skipped})." + ) + + if workflow_failures: + workflow_names = ", ".join( + f"`{failure['workflow']}`" for failure in workflow_failures + ) + lines.append( + f"> Could not list workflow runs for {workflow_names}. Missing run count is unknown for those workflows (rate limit: {summary['workflow_failure_rate_limit']}, other API errors: {other_workflow_failures})." + ) + + if skipped_runs: + skipped_ids = ", ".join(f"`{run['run_id']}`" for run in skipped_runs[:10]) + remaining = len(skipped_runs) - 10 + suffix = f", and {remaining} more" if remaining > 0 else "" + lines.append(f"> Missing run IDs: {skipped_ids}{suffix}.") + + lines.append( + "> Missing job counts inside skipped runs are unknown because GitHub did not return those run job lists." + ) + lines.append("") + + def print_table( results: list[dict], repo: str, generated_time: str, report_time: datetime = None ): @@ -621,6 +1149,8 @@ def format_markdown( hours: int, generated_time: str, report_time: datetime = None, + fetch_metadata: dict[str, Any] = None, + workflow: str = None, ) -> str: """Format results as markdown for GitHub Actions summary.""" lines = [] @@ -634,6 +1164,9 @@ def format_markdown( lines.append("") lines.append("> **Note:** All times are displayed in UTC") lines.append("") + append_fetch_metadata_notice( + lines, fetch_metadata, [workflow] if workflow else None + ) if not results: lines.append("> No jobs found matching the filter.") @@ -795,11 +1328,192 @@ def format_markdown( return "\n".join(lines) -def main(): - # Check gh CLI availability before proceeding - if not check_gh_cli_available(): - sys.exit(1) +def format_runner_report_markdown( + jobs: list[dict], + workflows: list[str], + hours: int, + generated_time: str, + report_time: datetime = None, + fetch_metadata: dict[str, Any] = None, +) -> str: + """Format runner fleet analytics as markdown for GitHub Actions summary.""" + if report_time is None: + report_time = datetime.now(timezone.utc) + + lines: list[str] = [] + # Header + lines.append("# CI Runner Fleet Report") + lines.append("") + lines.append(f"**Workflows:** {', '.join(f'`{w}`' for w in workflows)}") + lines.append(f"**Time window:** Last {hours} hours") + lines.append(f"**Generated:** {generated_time} UTC") + lines.append(f"**Total jobs analyzed:** {len(jobs)}") + lines.append("") + lines.append("> All times are in UTC. Jobs on `ubuntu-latest` are excluded.") + lines.append("") + append_fetch_metadata_notice(lines, fetch_metadata, workflows) + + if not jobs: + lines.append("> No self-hosted runner jobs found in the time window.") + return "\n".join(lines) + + # --- Fleet Overview --- + unique_labels = {_get_runner_label(j) for j in jobs} + completed_jobs = [j for j in jobs if j.get("status") == "completed"] + lines.append("## Fleet Overview") + lines.append("") + lines.append("| Metric | Value |") + lines.append("|--------|-------|") + lines.append(f"| Total runner labels seen | {len(unique_labels)} |") + lines.append(f"| Total jobs analyzed | {len(jobs)} |") + lines.append(f"| Completed jobs | {len(completed_jobs)} |") + lines.append(f"| Time window | {hours}h |") + lines.append("") + + # --- Concurrency by Runner Label --- + concurrency = analyze_concurrency(jobs, report_time) + if concurrency: + lines.append("## Concurrency by Runner Label") + lines.append("") + lines.append( + "| Runner Label | Peak Concurrent | Avg Concurrent | Total Jobs | Avg Queue | P50 Queue | P99 Queue | Avg Duration |" + ) + lines.append( + "|-------------|----------------|---------------|-----------|-----------|-----------|-----------|-------------|" + ) + for label in sorted(concurrency, key=lambda k: -concurrency[k]["peak"]): + c = concurrency[label] + lines.append( + f"| `{label}` | **{c['peak']}** | {c['avg_concurrent']} " + f"| {c['total_jobs']} " + f"| {_format_duration_seconds(c['avg_queue_seconds'])} " + f"| {_format_duration_seconds(c['p50_queue_seconds'])} " + f"| {_format_duration_seconds(c['p99_queue_seconds'])} " + f"| {_format_duration_seconds(c['avg_duration_seconds'])} |" + ) + lines.append("") + + # --- Busy Periods --- + busy_periods = analyze_busy_periods(jobs) + if busy_periods: + lines.append("## Busy Periods (UTC)") + lines.append("") + lines.append("| Hour (UTC) | Jobs Started | Avg Queue Time | Load |") + lines.append("|-----------|-------------|---------------|------|") + for bp in busy_periods: + if bp["jobs_started"] == 0: + continue + load_display = ( + f"**{bp['load']}**" if bp["load"] in ("Peak", "Busy") else bp["load"] + ) + lines.append( + f"| {bp['hour_label']} | {bp['jobs_started']} " + f"| {_format_duration_seconds(bp['avg_queue_seconds'])} " + f"| {load_display} |" + ) + lines.append("") + + peak_hours = [bp for bp in busy_periods if bp["load"] == "Peak"] + quiet_hours = [ + bp + for bp in busy_periods + if bp["load"] == "Quiet" and bp["jobs_started"] > 0 + ] + if peak_hours: + labels = ", ".join(bp["hour_label"] for bp in peak_hours) + lines.append(f"> **Peak hours:** {labels}") + lines.append("") + if quiet_hours: + labels = ", ".join(bp["hour_label"] for bp in quiet_hours) + lines.append(f"> **Quiet hours:** {labels}") + lines.append("") + + # --- Queue Time Distribution --- + queue_dist = analyze_queue_distribution(jobs) + if queue_dist: + lines.append("## Queue Time Distribution by Runner Label") + lines.append("") + for label in sorted(queue_dist, key=lambda k: -queue_dist[k]["total"]): + dist = queue_dist[label] + lines.append(f"### `{label}`") + lines.append("") + lines.append( + f"> **Samples:** {dist['total']} | **P50:** {_format_duration_seconds(dist['p50'])} | **P90:** {_format_duration_seconds(dist['p90'])} | **P99:** {_format_duration_seconds(dist['p99'])}" + ) + lines.append("") + lines.append("| Queue Time Range | Count | Percentage |") + lines.append("|-----------------|-------|------------|") + for b in dist["buckets"]: + bar = "#" * int(b["percentage"] / 3) + lines.append( + f"| {b['range']} | {b['count']} | {b['percentage']}% {bar} |" + ) + lines.append("") + + # --- Failed Jobs Detail (collapsible) --- + failed_jobs = [ + j + for j in jobs + if j.get("conclusion") == "failure" and not j.get("is_stuck", False) + ] + if failed_jobs: + lines.append("
") + lines.append( + f"Failed Jobs ({len(failed_jobs)} total) - Click to expand" + ) + lines.append("") + lines.append( + "| Job Name | Runner | Workflow | Queue | Duration | PR/Branch | Link |" + ) + lines.append( + "|----------|--------|---------|-------|----------|-----------|------|" + ) + for j in sorted(failed_jobs, key=lambda x: x["created_at"], reverse=True): + queue = calculate_queue_time( + j["created_at"], j["started_at"], j["status"], report_time + ) + dur = calculate_duration(j["started_at"], j["completed_at"]) + pr_info = ( + f"PR#{j['pr_number']}" if j.get("pr_number") else j.get("branch", "-") + ) + url = j.get("html_url", "") + wf = j.get("workflow", "-") + lines.append( + f"| `{j['job_name']}` | `{j['runner_name']}` | `{wf}` " + f"| {queue} | {dur} | {pr_info} | [View]({url}) |" + ) + lines.append("") + lines.append("
") + lines.append("") + + # --- Stuck Jobs --- + stuck_jobs = [j for j in jobs if j.get("is_stuck", False)] + if stuck_jobs: + lines.append("## Stuck/Ghost Jobs") + lines.append("") + lines.append( + "> Jobs showing `in_progress` but have no runner assigned or workflow run is cancelled" + ) + lines.append("") + lines.append( + "| Job Name | Job Status | Run Status | Runner | Workflow | Link |" + ) + lines.append("|----------|-----------|-----------|--------|---------|------|") + for j in sorted(stuck_jobs, key=lambda x: x["created_at"], reverse=True): + run_info = f"{j.get('run_status', '-')}/{j.get('run_conclusion', '-')}" + url = j.get("html_url", "") + wf = j.get("workflow", "-") + lines.append( + f"| `{j['job_name']}` | {j['status']} | {run_info} " + f"| `{j['runner_name']}` | `{wf}` | [View]({url}) |" + ) + lines.append("") + + return "\n".join(lines) + + +def main(): # Capture the time when the command is run (both datetime and formatted string) report_time = datetime.now(timezone.utc) report_generated_time = report_time.strftime("%Y-%m-%d %H:%M:%S") @@ -812,13 +1526,14 @@ def main(): ) parser.add_argument( "--job", - required=True, - help="Job name filter (e.g., 'stage-c-test-large-8-gpu-amd-mi35x')", + required=False, + default=None, + help="Job name filter (required unless --runner-report is used)", ) parser.add_argument( "--workflow", default="pr-test-amd.yml", - help="Workflow file name (default: pr-test-amd.yml)", + help="Workflow file name, or comma-separated list for --runner-report (default: pr-test-amd.yml)", ) parser.add_argument( "--hours", @@ -847,20 +1562,117 @@ def main(): type=str, help="Write output to file", ) + parser.add_argument( + "--runner-report", + action="store_true", + help="Generate runner fleet analytics report across all jobs (no --job filter needed)", + ) + parser.add_argument( + "--input-data-file", + type=str, + help="Load a prefetched Actions snapshot JSON instead of calling gh api", + ) + parser.add_argument( + "--dump-data-file", + type=str, + help="Fetch Actions data once and save it as a snapshot JSON file", + ) args = parser.parse_args() - results = query_jobs( - args.repo, - args.job, - args.workflow, - args.hours, - args.status, - ) + if args.input_data_file and args.dump_data_file: + parser.error("--input-data-file and --dump-data-file cannot be used together") + + if not args.runner_report and not args.job and not args.dump_data_file: + parser.error( + "--job is required unless --runner-report or --dump-data-file is specified" + ) + + workflows = [w.strip() for w in args.workflow.split(",") if w.strip()] + + if not args.input_data_file and not check_gh_cli_available(): + sys.exit(1) + + snapshot = None + repo = args.repo + fetch_metadata = None + + if args.input_data_file: + snapshot = load_snapshot(args.input_data_file) + repo = snapshot.get("repo", args.repo) + fetch_metadata = snapshot.get("fetch_metadata") + + if args.dump_data_file: + snapshot = fetch_all_jobs_snapshot(repo, workflows, args.hours) + save_snapshot(args.dump_data_file, snapshot) + summary = summarize_fetch_metadata(snapshot.get("fetch_metadata"), workflows) + print(f"Snapshot written to {args.dump_data_file}", file=sys.stderr) + if summary and summary["incomplete"]: + print( + "Warning: Snapshot is incomplete due to rate limit/API fetch failures.", + file=sys.stderr, + ) + if summary["known_runs"] > 0: + print( + f"Known runs fetched successfully: {summary['runs_with_jobs']}/{summary['known_runs']}", + file=sys.stderr, + ) + print( + f"Skipped runs with unknown job counts: {len(summary['skipped_runs'])}", + file=sys.stderr, + ) + return + + # --- Runner fleet report mode --- + if args.runner_report: + if snapshot is None: + snapshot = fetch_all_jobs_snapshot(repo, workflows, args.hours) + fetch_metadata = snapshot.get("fetch_metadata") + + jobs = [ + job for job in snapshot["jobs"] if job.get("workflow") in set(workflows) + ] + + md_content = format_runner_report_markdown( + jobs, + workflows, + args.hours, + report_generated_time, + report_time, + fetch_metadata, + ) + + print(md_content) + + if args.output_file: + with open(args.output_file, "w") as f: + f.write(md_content) + print(f"\nOutput written to {args.output_file}", file=sys.stderr) + + if args.summary: + summary_file = os.environ.get("GITHUB_STEP_SUMMARY") + if summary_file: + with open(summary_file, "a") as f: + f.write(md_content) + f.write("\n") + print("Summary written to GITHUB_STEP_SUMMARY", file=sys.stderr) + else: + print( + "Warning: GITHUB_STEP_SUMMARY not set, markdown printed above.", + file=sys.stderr, + ) + return + + # --- Original per-job report mode --- + if snapshot is None: + snapshot = fetch_all_jobs_snapshot(repo, [args.workflow], args.hours) + fetch_metadata = snapshot.get("fetch_metadata") + + results = filter_jobs(snapshot["jobs"], args.job, args.workflow, args.status) output_content = None if args.output == "table": - print_table(results, args.repo, report_generated_time, report_time) + print_table(results, repo, report_generated_time, report_time) elif args.output == "csv": lines = [ "job_name,status,is_stuck,conclusion,created_at,started_at,queue_time,duration,runner,run_status,run_conclusion,pr_number,branch,url" @@ -877,7 +1689,6 @@ def main(): output_content = "\n".join(lines) print(output_content) elif args.output == "json": - # Add calculated fields to JSON output for consistency json_results = [] for r in sorted(results, key=lambda x: x["created_at"], reverse=True): r_copy = r.copy() @@ -892,27 +1703,39 @@ def main(): print(output_content) elif args.output == "markdown": output_content = format_markdown( - results, args.repo, args.job, args.hours, report_generated_time, report_time + results, + repo, + args.job, + args.hours, + report_generated_time, + report_time, + fetch_metadata, + args.workflow, ) print(output_content) - # Write to file if specified if args.output_file and output_content: with open(args.output_file, "w") as f: f.write(output_content) print(f"\nOutput written to {args.output_file}", file=sys.stderr) - # Write to GITHUB_STEP_SUMMARY if requested if args.summary: md_content = format_markdown( - results, args.repo, args.job, args.hours, report_generated_time, report_time + results, + repo, + args.job, + args.hours, + report_generated_time, + report_time, + fetch_metadata, + args.workflow, ) summary_file = os.environ.get("GITHUB_STEP_SUMMARY") if summary_file: with open(summary_file, "a") as f: f.write(md_content) f.write("\n") - print(f"Summary written to GITHUB_STEP_SUMMARY", file=sys.stderr) + print("Summary written to GITHUB_STEP_SUMMARY", file=sys.stderr) else: print( "Warning: GITHUB_STEP_SUMMARY not set, printing markdown instead:", diff --git a/scripts/ci/utils/slash_command_handler.py b/scripts/ci/utils/slash_command_handler.py index 6875d88fc587..9e9d2bc3cc3a 100644 --- a/scripts/ci/utils/slash_command_handler.py +++ b/scripts/ci/utils/slash_command_handler.py @@ -45,7 +45,7 @@ def find_workflow_run_url( The workflow run URL if found, None otherwise. """ # Build expected display_title based on workflow's run-name. - # rerun-ut includes test_command: "[rerun-ut] []" + # rerun-test includes test_command: "[rerun-test] []" # Other workflows: "[stage-name] []" suffix = f" {test_command}" if test_command else "" if pr_head_sha: @@ -271,6 +271,9 @@ def handle_rerun_stage( "stage-c-test-deepep-8-gpu-h200", "multimodal-gen-test-1-gpu", "multimodal-gen-test-2-gpu", + "multimodal-gen-component-accuracy-1-gpu", + "multimodal-gen-component-accuracy-2-gpu", + "multimodal-gen-test-1-b200", ] # Valid AMD stage names that support target_stage @@ -464,7 +467,7 @@ def resolve_test_file(file_part): return None, ( f"Ambiguous filename `{file_part}` — matched {len(matches)} files:\n\n" f"{match_list}\n\n" - f"Please provide the full path, e.g. `/rerun-ut {matches[0]}`" + f"Please provide the full path, e.g. `/rerun-test {matches[0]}`" ) return matches[0][len("test/") :], None @@ -485,7 +488,9 @@ def detect_cuda_suite(file_path_from_test): content = f.read() match = re.search( - r'register_cuda_ci\([^)]*suite\s*=\s*["\']([^"\']+)["\']', content + r'^[^#\n]*register_cuda_ci\([^)]*suite\s*=\s*["\']([^"\']+)["\']', + content, + re.MULTILINE, ) if not match: return ( @@ -553,7 +558,7 @@ def _resolve_and_dispatch_ut(gh_repo, pr, test_spec, token): ) try: - workflow_name = "Rerun UT" + workflow_name = "Rerun Test" workflows = gh_repo.get_workflows() target_workflow = None for wf in workflows: @@ -610,13 +615,13 @@ def _resolve_and_dispatch_ut(gh_repo, pr, test_spec, token): "error": f"Dispatch failed: {dispatch_resp.status_code}", } - print(f"Successfully triggered rerun-ut: {test_command}") + print(f"Successfully triggered rerun-test: {test_command}") run_url = find_workflow_run_url( gh_repo, target_workflow.id, ref, - "rerun-ut", + "rerun-test", token, dispatch_time, pr_head_sha=pr_head_sha, @@ -632,16 +637,16 @@ def _resolve_and_dispatch_ut(gh_repo, pr, test_spec, token): } except Exception as e: - print(f"Error triggering rerun-ut for {test_spec}: {e}") + print(f"Error triggering rerun-test for {test_spec}: {e}") return {"spec": test_spec, "success": False, "error": str(e)} -def handle_rerun_ut(gh_repo, pr, comment, user_perms, test_specs, token): +def handle_rerun_test(gh_repo, pr, comment, user_perms, test_specs, token): """ - Handles the /rerun-ut command. Accepts a list of test specs and dispatches + Handles the /rerun-test command. Accepts a list of test specs and dispatches a workflow run for each, posting a single consolidated comment. """ - # SECURITY: For fork PRs, only allow /rerun-ut if the commenter has write+ permission. + # SECURITY: For fork PRs, only allow /rerun-test if the commenter has write+ permission. # This command checks out and executes code from the PR branch on self-hosted GPU # runners, so we must ensure the commenter is a trusted collaborator. is_fork = pr.head.repo is None or pr.head.repo.owner.login != gh_repo.owner.login @@ -649,10 +654,10 @@ def handle_rerun_ut(gh_repo, pr, comment, user_perms, test_specs, token): commenter = comment.user.login perm = gh_repo.get_collaborator_permission(commenter) if perm not in ("admin", "write"): - print(f"Permission denied: /rerun-ut on fork PR by {commenter}.") + print(f"Permission denied: /rerun-test on fork PR by {commenter}.") comment.create_reaction("confused") pr.create_issue_comment( - "āŒ `/rerun-ut` is not available for fork PRs unless the commenter " + "āŒ `/rerun-test` is not available for fork PRs unless the commenter " "has write permission on the repo.\n\n" "Please ask a maintainer to run this command, or use the normal CI flow." ) @@ -660,21 +665,21 @@ def handle_rerun_ut(gh_repo, pr, comment, user_perms, test_specs, token): print(f"Fork PR, but commenter {commenter} has write+ permission. Proceeding.") if not ( - user_perms.get("can_rerun_ut", False) + user_perms.get("can_rerun_test", False) or user_perms.get("can_rerun_stage", False) ): - print("Permission denied: neither can_rerun_ut nor can_rerun_stage is true.") + print("Permission denied: neither can_rerun_test nor can_rerun_stage is true.") return False if not test_specs: comment.create_reaction("confused") pr.create_issue_comment( - "āŒ Please specify a test: `/rerun-ut ::`\n\n" + "āŒ Please specify a test: `/rerun-test ::`\n\n" "Examples:\n" - "- `/rerun-ut test/registered/core/test_srt_endpoint.py::TestSRTEndpoint.test_simple_decode`\n" - "- `/rerun-ut registered/core/test_srt_endpoint.py::TestSRTEndpoint`\n" - "- `/rerun-ut test_srt_endpoint.py`\n" - "- `/rerun-ut test_a.py test_b.py test_c.py` (multiple tests)" + "- `/rerun-test test/registered/core/test_srt_endpoint.py::TestSRTEndpoint.test_simple_decode`\n" + "- `/rerun-test registered/core/test_srt_endpoint.py::TestSRTEndpoint`\n" + "- `/rerun-test test_srt_endpoint.py`\n" + "- `/rerun-test test_a.py test_b.py test_c.py` (multiple tests)" ) return False @@ -737,7 +742,7 @@ def main(): # PR authors can always rerun failed CI and rerun individual UTs on their own PRs, # even if they are not listed in CI_PERMISSIONS.json. # Note: /tag-run-ci-label and /rerun-stage still require CI_PERMISSIONS.json. - # Note: /rerun-ut is blocked entirely for fork PRs in handle_rerun_ut() itself. + # Note: /rerun-test is blocked entirely for fork PRs in handle_rerun_test() itself. if pr.user.login == user_login: if user_perms is None: print( @@ -750,7 +755,7 @@ def main(): f"User {user_login} is the PR author and has existing CI permissions." ) user_perms["can_rerun_failed_ci"] = True - user_perms["can_rerun_ut"] = True + user_perms["can_rerun_test"] = True if not user_perms: print(f"User {user_login} does not have any configured permissions. Exiting.") @@ -795,9 +800,9 @@ def main(): stage_name = parts[1].strip() if len(parts) > 1 else None handle_rerun_stage(repo, pr, comment, user_perms, stage_name, token) - elif first_line.startswith("/rerun-ut"): + elif first_line.startswith("/rerun-test"): test_specs = first_line.split()[1:] - handle_rerun_ut(repo, pr, comment, user_perms, test_specs or None, token) + handle_rerun_test(repo, pr, comment, user_perms, test_specs or None, token) else: print(f"Unknown or ignored command: {first_line}") diff --git a/scripts/ci_monitor/ci_failures_analysis.py b/scripts/ci_monitor/ci_failures_analysis.py index d5a4f6242940..c64a2b43b7f2 100644 --- a/scripts/ci_monitor/ci_failures_analysis.py +++ b/scripts/ci_monitor/ci_failures_analysis.py @@ -586,6 +586,10 @@ def analyze_runner_health( runner_instance_first_failure: Dict[str, Optional[Dict]] = {} runner_instance_last_failure: Dict[str, Optional[Dict]] = {} runner_instance_recovery: Dict[str, Optional[Dict]] = {} + runner_instance_all_failures_in_streak: Dict[str, List[Dict]] = defaultdict( + list + ) + runner_instance_all_failures: Dict[str, List[Dict]] = defaultdict(list) total_runs_processed = len(sorted_runs) for i, run in enumerate(sorted_runs, 1): @@ -802,6 +806,12 @@ def analyze_runner_health( runner_instance_first_failed_job[runner_instance_key] ) runner_instance_last_failure[runner_instance_key] = failure_info + runner_instance_all_failures_in_streak[runner_instance_key].append( + failure_info + ) + runner_instance_all_failures[runner_instance_key].append( + failure_info + ) if ( runner_instance_current_streak[runner_instance_key] @@ -823,6 +833,7 @@ def analyze_runner_health( runner_instance_current_streak[runner_instance_key] = 0 runner_instance_first_failure[runner_instance_key] = None + runner_instance_all_failures_in_streak[runner_instance_key] = [] runner_instance_last_failure[runner_instance_key] = None time.sleep(0.05) @@ -903,6 +914,9 @@ def analyze_runner_health( "avg_queue_time_seconds": avg_queue_time, "p90_queue_time_seconds": p90_queue_time, "queue_time_samples": len(queue_times), + "all_failures": list( + runner_instance_all_failures.get(instance_key, []) + ), } # Build runner streak data @@ -951,6 +965,9 @@ def analyze_runner_health( "last_failure_in_streak": runner_instance_last_failure.get( instance_key ), + "all_failures_in_streak": list( + runner_instance_all_failures_in_streak.get(instance_key, []) + ), "recovery_info": runner_instance_recovery.get(instance_key), } @@ -2058,8 +2075,10 @@ def generate_job_section_md( "total_jobs": stats["total_jobs"], "unique_jobs": len(stats.get("jobs_failed", {})), "avg_queue": stats.get("avg_queue_time_seconds", 0), - "first_failure": streak_data.get("first_failure_in_streak"), - "last_failure": streak_data.get("last_failure_in_streak"), + "all_failures_in_streak": streak_data.get( + "all_failures_in_streak", [] + ), + "all_failures": stats.get("all_failures", []), } ) @@ -2096,10 +2115,10 @@ def generate_job_section_md( ) summary_lines.append("") summary_lines.append( - "| Machine Name | Current Streak | Max | Fail Rate | Avg Queue | Total Jobs | Unique Jobs | First Failure | Last Failure |" + "| Machine Name | Current Streak | Max | Fail Rate | Avg Queue | Total Jobs | Failed Jobs | Unique Jobs | Jobs |" ) summary_lines.append( - "|--------------|----------------|-----|-----------|-----------|------------|-------------|---------------|--------------|" + "|--------------|----------------|-----|-----------|-----------|------------|-------------|-------------|------|" ) for runner_data in runners_with_streak[:15]: @@ -2115,17 +2134,14 @@ def generate_job_section_md( else "N/A" ) - first_failure = runner_data.get("first_failure") - first_str = ( - f"[Run #{first_failure['run_number']}]({first_failure.get('job_url', first_failure['url'])})" - if first_failure - else "N/A" - ) - - last_failure = runner_data.get("last_failure") - last_str = ( - f"[Run #{last_failure['run_number']}]({last_failure.get('job_url', last_failure['url'])})" - if last_failure + all_failures = runner_data.get("all_failures_in_streak", []) + failed_jobs_count = len(all_failures) + jobs_str = ( + " ".join( + f"[#{f.get('run_number', '?')}]({f.get('job_url', f['url'])})" + for f in all_failures + ) + if all_failures else "N/A" ) @@ -2133,12 +2149,12 @@ def generate_job_section_md( if runner_data["current_streak"] >= 3: summary_lines.append( f"| `{display_name}` | {runner_data['current_streak']} | {runner_data['max_streak']} | " - f"{runner_data['failure_rate']:.1f}% | {avg_queue_str} | {runner_data['total_jobs']} | {runner_data.get('unique_jobs', 0)} | {first_str} | {last_str} |" + f"{runner_data['failure_rate']:.1f}% | {avg_queue_str} | {runner_data['total_jobs']} | {failed_jobs_count} | {runner_data.get('unique_jobs', 0)} | {jobs_str} |" ) else: summary_lines.append( f"| `{display_name}` | {runner_data['current_streak']} | {runner_data['max_streak']} | " - f"{runner_data['failure_rate']:.1f}% | {avg_queue_str} | {runner_data['total_jobs']} | {runner_data.get('unique_jobs', 0)} | {first_str} | {last_str} |" + f"{runner_data['failure_rate']:.1f}% | {avg_queue_str} | {runner_data['total_jobs']} | {failed_jobs_count} | {runner_data.get('unique_jobs', 0)} | {jobs_str} |" ) summary_lines.append("") @@ -2150,10 +2166,10 @@ def generate_job_section_md( ) summary_lines.append("") summary_lines.append( - "| Machine Name | Fail Rate | Avg Queue | Total Jobs | Unique Jobs |" + "| Machine Name | Fail Rate | Avg Queue | Total Jobs | Failed Jobs | Unique Jobs | Jobs |" ) summary_lines.append( - "|--------------|-----------|-----------|------------|-------------|" + "|--------------|-----------|-----------|------------|-------------|-------------|------|" ) for runner_data in runners_high_fail_rate[:15]: @@ -2169,10 +2185,21 @@ def generate_job_section_md( else "N/A" ) + all_failures = runner_data.get("all_failures", []) + failed_jobs_count = len(all_failures) + jobs_str = ( + " ".join( + f"[#{f.get('run_number', '?')}]({f.get('job_url', f['url'])})" + for f in all_failures + ) + if all_failures + else "N/A" + ) + summary_lines.append( f"| `{display_name}` | {runner_data['failure_rate']:.1f}% | " f"{avg_queue_str} | {runner_data['total_jobs']} | " - f"{runner_data.get('unique_jobs', 0)} |" + f"{failed_jobs_count} | {runner_data.get('unique_jobs', 0)} | {jobs_str} |" ) summary_lines.append("") @@ -2512,7 +2539,9 @@ def main(): ) # Choosing nvidia pr test and nightly for runner health analysis - runner_runs = pr_test_nvidia_general_runs + nightly_nvidia_general_runs + # Use scheduled runs (already limited to 12 PR + 6 nightly) to avoid + # pulling months of history from the unfiltered general fetch. + runner_runs = pr_test_nvidia_scheduled_runs + nightly_nvidia_scheduled_runs if not runner_runs and not pr_test_nvidia_scheduled_runs: print("No workflow runs found") diff --git a/scripts/playground/bench_speculative.py b/scripts/playground/bench_speculative.py index 806699f7121c..5373df5169e8 100644 --- a/scripts/playground/bench_speculative.py +++ b/scripts/playground/bench_speculative.py @@ -119,7 +119,7 @@ def send_one_batch(base_url, num_prompts, batch_size, processor, is_multimodal): acc_length = results["accept_length"] or 1.0 avg_output_token = results["total_output_tokens"] / results["completed"] - server_info = requests.get(base_url + "/get_server_info").json() + server_info = requests.get(base_url + "/server_info").json() # We use 20% percentile instead of median on purpose step_time = np.percentile( server_info["internal_states"][0]["step_time_dict"][str(batch_size)], 20 diff --git a/sgl-kernel/CMakeLists.txt b/sgl-kernel/CMakeLists.txt index 743c29104b51..bbacf6dc4021 100644 --- a/sgl-kernel/CMakeLists.txt +++ b/sgl-kernel/CMakeLists.txt @@ -260,13 +260,11 @@ endif() set(SOURCES "csrc/allreduce/custom_all_reduce.cu" "csrc/allreduce/mscclpp_allreduce.cu" - "csrc/attention/cascade.cu" "csrc/attention/cutlass_mla_kernel.cu" "csrc/attention/merge_attn_states.cu" "csrc/attention/vertical_slash_index.cu" "csrc/common_extension.cc" "csrc/elementwise/activation.cu" - "csrc/elementwise/cast.cu" "csrc/elementwise/concat_mla.cu" "csrc/elementwise/copy.cu" "csrc/elementwise/fused_add_rms_norm_kernel.cu" diff --git a/sgl-kernel/benchmark/bench_amd_deterministic_allreduce.py b/sgl-kernel/benchmark/bench_amd_deterministic_allreduce.py index 647543089b5e..234d09342774 100644 --- a/sgl-kernel/benchmark/bench_amd_deterministic_allreduce.py +++ b/sgl-kernel/benchmark/bench_amd_deterministic_allreduce.py @@ -29,20 +29,18 @@ sys.path.insert(0, python_dir) # Try to import custom all-reduce if available +from sglang.srt.environ import envs + try: import sglang.srt.distributed.device_communicators.custom_all_reduce_ops as custom_ar_ops from sglang.srt.distributed.device_communicators.custom_all_reduce import ( CustomAllreduce, ) - from sglang.srt.distributed.device_communicators.custom_all_reduce_utils import ( - is_weak_contiguous, - ) CUSTOM_AR_AVAILABLE = custom_ar_ops.IS_CUSTOM_AR_AVAILABLE except (ImportError, AttributeError): CUSTOM_AR_AVAILABLE = False CustomAllreduce = None - is_weak_contiguous = None # Note: sglang's optimized all-reduce requires full runtime initialization # and won't work in standalone benchmarks, so we skip it @@ -110,6 +108,7 @@ def reduce_scatter_then_all_gather(tensor, rank, world_size, custom_ar=None): def worker(world_size, rank, port, results_queue): + envs.SGLANG_USE_1STAGE_ALLREDUCE.set("1") device = torch.device(f"cuda:{rank}") torch.cuda.set_device(device) @@ -240,7 +239,7 @@ def worker(world_size, rank, port, results_queue): results_deterministic_kernel = [] latencies_deterministic_kernel = [] deterministic_kernel_available = False - if custom_ar is not None and hasattr(custom_ar, "deterministic_all_reduce"): + if custom_ar is not None: # Check if input size fits in buffer input_size_bytes = base_input.numel() * base_input.element_size() if input_size_bytes > custom_ar.max_size: @@ -259,9 +258,7 @@ def worker(world_size, rank, port, results_queue): # Measure latency torch.cuda.synchronize() start = time.perf_counter() - result_kernel = custom_ar.deterministic_all_reduce( - inp_kernel, registered=False - ) + result_kernel = custom_ar.custom_all_reduce(inp_kernel) torch.cuda.synchronize() end = time.perf_counter() latencies_deterministic_kernel.append(end - start) diff --git a/sgl-kernel/benchmark/bench_fp4_gemm.py b/sgl-kernel/benchmark/bench_fp4_gemm.py index f8f0bd666a21..0f1023af8fd6 100755 --- a/sgl-kernel/benchmark/bench_fp4_gemm.py +++ b/sgl-kernel/benchmark/bench_fp4_gemm.py @@ -1,13 +1,20 @@ import argparse import csv import os +from functools import partial +from typing import List, Tuple import torch import triton from flashinfer import mm_fp4 +from flashinfer.testing import bench_gpu_time from sglang.jit_kernel.nvfp4 import cutlass_scaled_fp4_mm, scaled_fp4_quant -from sglang.srt.utils import get_device_capability, is_sm100_supported +from sglang.srt.utils import ( + get_device_capability, + is_sm100_supported, + is_sm120_supported, +) from sglang.utils import is_in_ci IS_CI = is_in_ci() @@ -15,30 +22,102 @@ FLOAT4_E2M1_MAX = 6.0 FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max +DEEPSEEK_R1_MODEL = "deepseek-ai/DeepSeek-R1-0528-FP4" -def get_weight_shapes(args): - models_tps = args.tp_sizes +# Weight shapes are in the format: ([K, N], TP_SPLIT_DIM) +# TP split dim 0 means split K by tp size; dim 1 means split N by tp size. +WEIGHT_SHAPES = { + "meta-llama/Llama-3.1-8B-Instruct": [ + ([4096, 6144], 1), + ([4096, 4096], 0), + ([4096, 28672], 1), + ([14336, 4096], 0), + ], + "meta-llama/Llama-3.3-70B-Instruct": [ + ([8192, 10240], 1), + ([8192, 8192], 0), + ([8192, 57344], 1), + ([28672, 8192], 0), + ], + "mistralai/Mistral-Large-Instruct-2407": [ + ([12288, 14336], 1), + ([12288, 12288], 0), + ([12288, 57344], 1), + ([28672, 12288], 0), + ], + "Qwen/Qwen2.5-7B-Instruct": [ + ([3584, 4608], 1), + ([3584, 3584], 0), + ([3584, 37888], 1), + ([18944, 3584], 0), + ], + "Qwen/Qwen2.5-32B-Instruct": [ + ([5120, 7168], 1), + ([5120, 5120], 0), + ([5120, 55296], 1), + ([27648, 5120], 0), + ], + "Qwen/Qwen2.5-72B-Instruct": [ + ([8192, 10240], 1), + ([8192, 8192], 0), + ([8192, 59136], 1), + ([29568, 8192], 0), + ], + "Qwen/Qwen3.5-27B": [ + ([5120, 8192], 1), + ([6144, 5120], 0), + ([5120, 34816], 1), + ([17408, 5120], 0), + ], + "deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct": [ + ([2048, 3072], 1), + ([2048, 4096], 1), + ([2048, 2048], 0), + ([2048, 576], 0), + ([2048, 21888], 1), + ([10944, 2048], 0), + ([2048, 2816], 1), + ([1408, 2048], 0), + ], +} - if models_tps == [4]: - return [[1024, 3584], [7168, 256], [7168, 2304], [9216, 3584]] +DEEPSEEK_R1_WEIGHT_SHAPES = { + 4: [[1024, 3584], [7168, 256], [7168, 2304], [9216, 3584]], + 8: [[512, 3584], [7168, 128], [7168, 1152], [4608, 3584]], +} - if models_tps == [8]: - return [[512, 3584], [7168, 128], [7168, 1152], [4608, 3584]] - return [ - [1024, 3584], - [7168, 256], - [7168, 2304], - [9216, 3584], - [512, 3584], - [7168, 128], - [7168, 1152], - [4608, 3584], - ] + +def get_weight_shapes(args) -> List[Tuple[int, int, str]]: + shapes: List[Tuple[int, int, str]] = [] + for model in args.models: + if model == DEEPSEEK_R1_MODEL: + for tp_size in args.tp_sizes: + if tp_size in DEEPSEEK_R1_WEIGHT_SHAPES: + selected = DEEPSEEK_R1_WEIGHT_SHAPES[tp_size] + else: + selected = ( + DEEPSEEK_R1_WEIGHT_SHAPES[4] + DEEPSEEK_R1_WEIGHT_SHAPES[8] + ) + for n, packed_k in selected: + shapes.append((n, packed_k, model)) + continue + + if model not in WEIGHT_SHAPES: + raise ValueError(f"Unsupported model: {model}") + for tp_size in args.tp_sizes: + for k_n, tp_split_dim in WEIGHT_SHAPES[model]: + k, n = k_n + if tp_split_dim == 0: + k = k // tp_size + else: + n = n // tp_size + packed_k = k // 2 + shapes.append((n, packed_k, model)) + return shapes -# CI environment uses simplified parameters if IS_CI: - batch_sizes = [1, 8] # Simplified for CI + batch_sizes = [1, 8] else: batch_sizes = [ 1, @@ -60,29 +139,54 @@ def get_weight_shapes(args): ] +def _run_mm_fp4(a_fp4, b_fp4_T, a_sf, b_sf_T, alpha, dtype, res_fi, backend): + return mm_fp4(a_fp4, b_fp4_T, a_sf, b_sf_T, alpha, dtype, res_fi, backend=backend) + + @triton.testing.perf_report( triton.testing.Benchmark( x_names=["batch_size"], x_vals=batch_sizes, - # x_vals = [64], x_log=False, line_arg="provider", - line_vals=["sglang_cutlass", "cutlass", "cudnn", "trtllm", "auto"], - line_names=[ - "sglang cutlass fp4", - "flashinfer cutlass fp4", - "cudnn fp4", - "trtllm fp4", - "auto fp4 (cudnn/cutlass)", - ], - styles=[ - ("red", "solid"), - ("orange", "solid"), - ("blue", "solid"), - ("green", "solid"), - ("purple", "solid"), - ], - ylabel="latency (ms)", + line_vals=( + ["sglang_cutlass", "cutlass", "cudnn", "trtllm", "auto"] + if is_sm100_supported() + else ["sglang_cutlass", "cutlass", "cudnn", "auto"] + ), + line_names=( + [ + "sglang cutlass fp4", + "flashinfer cutlass fp4", + "cudnn fp4", + "trtllm fp4", + "auto fp4 (cudnn/cutlass)", + ] + if is_sm100_supported() + else [ + "sglang cutlass fp4", + "flashinfer cutlass fp4", + "cudnn fp4", + "auto fp4", + ] + ), + styles=( + [ + ("red", "solid"), + ("orange", "solid"), + ("blue", "solid"), + ("green", "solid"), + ("purple", "solid"), + ] + if is_sm100_supported() + else [ + ("red", "solid"), + ("orange", "solid"), + ("blue", "solid"), + ("purple", "solid"), + ] + ), + ylabel="bandwidth (GB/s)", plot_name="fp4_gemm_benchmark", args={}, ) @@ -99,87 +203,93 @@ def benchmark(batch_size, provider, N, K, dtype, correctness, csv_file): b_global_scale = ( (FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.amax(b_dtype.flatten(), dim=-1) ).to(torch.float32) - alpha = 1.0 / (a_global_scale * b_global_scale) a_fp4, a_scale_interleaved = scaled_fp4_quant(a_dtype, a_global_scale) - # print("a_fp4", a_fp4) b_fp4, b_scale_interleaved = scaled_fp4_quant(b_dtype, b_global_scale) + b_fp4_T = b_fp4.T + b_sf_T = b_scale_interleaved.T res_fi = torch.empty((M, N), dtype=dtype, device="cuda") - quantiles = [0.5, 0.2, 0.8] if provider == "sglang_cutlass": - ms, min_ms, max_ms = triton.testing.do_bench_cudagraph( - lambda: cutlass_scaled_fp4_mm( - a_fp4, b_fp4, a_scale_interleaved, b_scale_interleaved, alpha, dtype - ), - quantiles=quantiles, - ) - if provider == "cutlass": - ms, min_ms, max_ms = triton.testing.do_bench_cudagraph( - lambda: mm_fp4( + times_ms = bench_gpu_time( + fn=cutlass_scaled_fp4_mm, + input_args=( a_fp4, - b_fp4.T, + b_fp4, a_scale_interleaved, - b_scale_interleaved.T, + b_scale_interleaved, alpha, dtype, - res_fi, - backend="cutlass", ), - quantiles=quantiles, + use_cuda_graph=True, ) - if provider == "cudnn": - ms, min_ms, max_ms = triton.testing.do_bench_cudagraph( - lambda: mm_fp4( + elif provider == "cutlass": + times_ms = bench_gpu_time( + fn=partial(_run_mm_fp4, backend="cutlass"), + input_args=( a_fp4, - b_fp4.T, + b_fp4_T, a_scale_interleaved, - b_scale_interleaved.T, + b_sf_T, alpha, dtype, res_fi, - backend="cudnn", ), - quantiles=quantiles, + use_cuda_graph=True, ) - if provider == "trtllm": - a_scale_interleaved = a_scale_interleaved.to(torch.uint8) - b_scale_interleaved = b_scale_interleaved.to(torch.uint8) - ms, min_ms, max_ms = triton.testing.do_bench_cudagraph( - lambda: mm_fp4( + elif provider == "cudnn": + times_ms = bench_gpu_time( + fn=partial(_run_mm_fp4, backend="cudnn"), + input_args=( a_fp4, - b_fp4.T, + b_fp4_T, a_scale_interleaved, - b_scale_interleaved.T, + b_sf_T, alpha, dtype, res_fi, - backend="trtllm", ), - quantiles=quantiles, + use_cuda_graph=True, ) - if provider == "auto": - ms, min_ms, max_ms = triton.testing.do_bench_cudagraph( - lambda: mm_fp4( + elif provider == "trtllm": + a_sf_u8 = a_scale_interleaved.to(torch.uint8) + b_sf_u8_T = b_sf_T.to(torch.uint8) + times_ms = bench_gpu_time( + fn=partial(_run_mm_fp4, backend="trtllm"), + input_args=(a_fp4, b_fp4_T, a_sf_u8, b_sf_u8_T, alpha, dtype, res_fi), + use_cuda_graph=True, + ) + elif provider == "auto": + times_ms = bench_gpu_time( + fn=partial(_run_mm_fp4, backend="auto"), + input_args=( a_fp4, - b_fp4.T, + b_fp4_T, a_scale_interleaved, - b_scale_interleaved.T, + b_sf_T, alpha, dtype, res_fi, ), - quantiles=quantiles, + use_cuda_graph=True, ) + + ms = torch.tensor(times_ms).median().item() + + # A: MƗpacked_k bytes (fp4 packed), B: NƗpacked_k bytes, C: MƗNƗelement_size bytes + element_size = torch.finfo(dtype).bits // 8 + total_bytes = M * packed_k + N * packed_k + M * N * element_size + bandwidth_gbs = total_bytes / (ms * 1e-3) / 1e9 + if correctness: res_cutlass = cutlass_scaled_fp4_mm( a_fp4, b_fp4, a_scale_interleaved, b_scale_interleaved, alpha, dtype ) mm_fp4( a_fp4, - b_fp4.T, + b_fp4_T, a_scale_interleaved, - b_scale_interleaved.T, + b_sf_T, alpha, dtype, res_fi, @@ -190,9 +300,9 @@ def benchmark(batch_size, provider, N, K, dtype, correctness, csv_file): ), "cudnn fp4 doesn't match cutlass fp4" mm_fp4( a_fp4, - b_fp4.T, + b_fp4_T, a_scale_interleaved, - b_scale_interleaved.T, + b_sf_T, alpha, dtype, res_fi, @@ -205,13 +315,20 @@ def benchmark(batch_size, provider, N, K, dtype, correctness, csv_file): if csv_file: with open(csv_file, "a", newline="") as f: writer = csv.writer(f) - writer.writerow([provider, M, N, K, ms]) + writer.writerow([provider, M, N, K, ms, bandwidth_gbs]) - return ms, min_ms, max_ms + return bandwidth_gbs if __name__ == "__main__": parser = argparse.ArgumentParser() + parser.add_argument( + "--models", + nargs="+", + type=str, + default=[DEEPSEEK_R1_MODEL], + help="List of models to benchmark. Supported: Llama 8B/70B, Qwen, Mistral, DeepSeek.", + ) parser.add_argument( "--tp-sizes", nargs="+", @@ -223,7 +340,7 @@ def benchmark(batch_size, provider, N, K, dtype, correctness, csv_file): "--dtype", type=torch.dtype, default=torch.bfloat16, - help="Data type", + help="Output data type", ) parser.add_argument( "--correctness", @@ -238,34 +355,29 @@ def benchmark(batch_size, provider, N, K, dtype, correctness, csv_file): ) args = parser.parse_args() - # Simplify for CI environment if IS_CI: - args.tp_sizes = [args.tp_sizes[0]] # Use only first TP size + args.tp_sizes = [args.tp_sizes[0]] if args.csv: with open(args.csv, "w", newline="") as f: writer = csv.writer(f) - writer.writerow(["provider", "m", "n", "k", "time_ms"]) + writer.writerow(["provider", "m", "n", "k", "time_ms", "bandwidth_gbs"]) - # FP4 operations require Blackwell SM100 support major, minor = get_device_capability() - if not is_sm100_supported(): + if not (is_sm100_supported() or is_sm120_supported()): print("Skipping FP4 GEMM benchmark") if major is not None: - print( - f"FP4 operations require SM100 (Blackwell), but found sm{major}{minor}" - ) + print(f"FP4 operations require sm100+, but found sm{major}{minor}") else: print("Could not determine device capability") else: NKs = get_weight_shapes(args) - # Limit iterations in CI if IS_CI: - NKs = NKs[:2] # Only test first 2 shapes in CI + NKs = NKs[:2] - for N, K in NKs: - print(f"DeepSeek-R1-0528-FP4 N={N} K={K}: ") + for N, K, model_name in NKs: + print(f"{model_name} N={N} packed_k={K}: ") benchmark.run( print_data=True, N=N, diff --git a/sgl-kernel/benchmark/bench_nvfp4_scaled_gemm.py b/sgl-kernel/benchmark/bench_nvfp4_scaled_gemm.py deleted file mode 100644 index eeb5842edec2..000000000000 --- a/sgl-kernel/benchmark/bench_nvfp4_scaled_gemm.py +++ /dev/null @@ -1,192 +0,0 @@ -import argparse -import copy -import itertools -import os - -import torch -import triton - -from sglang.jit_kernel.nvfp4 import cutlass_scaled_fp4_mm, scaled_fp4_quant -from sglang.srt.utils import get_device_capability - -# CI environment detection -IS_CI = ( - os.getenv("CI", "false").lower() == "true" - or os.getenv("GITHUB_ACTIONS", "false").lower() == "true" -) - -FLOAT4_E2M1_MAX = 6.0 -FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max - -# Weight Shapes are in the format -# ([K, N], TP_SPLIT_DIM) -# Example: -# A shape of ([14336, 4096], 0) indicates the following GEMM shape, -# - TP1 : K = 14336, N = 4096 -# - TP2 : K = 7168, N = 4096 -# A shape of ([4096, 6144], 1) indicates the following GEMM shape, -# - TP1 : K = 4096, N = 6144 -# - TP4 : K = 4096, N = 1536 - -# TP1 shapes -WEIGHT_SHAPES = { - "meta-llama/Llama-3.1-8B-Instruct": [ - ([4096, 6144], 1), - ([4096, 4096], 0), - ([4096, 28672], 1), - ([14336, 4096], 0), - ], - "meta-llama/Llama-3.3-70B-Instruct": [ - ([8192, 10240], 1), - ([8192, 8192], 0), - ([8192, 57344], 1), - ([28672, 8192], 0), - ], - "mistralai/Mistral-Large-Instruct-2407": [ - ([12288, 14336], 1), - ([12288, 12288], 0), - ([12288, 57344], 1), - ([28672, 12288], 0), - ], - "Qwen/Qwen2.5-7B-Instruct": [ - ([3584, 4608], 1), - ([3584, 3584], 0), - ([3584, 37888], 1), - ([18944, 3584], 0), - ], - "Qwen/Qwen2.5-32B-Instruct": [ - ([5120, 7168], 1), - ([5120, 5120], 0), - ([5120, 55296], 1), - ([27648, 5120], 0), - ], - "Qwen/Qwen2.5-72B-Instruct": [ - ([8192, 10240], 1), - ([8192, 8192], 0), - ([8192, 59136], 1), - ([29568, 8192], 0), - ], - "deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct": [ - ([2048, 3072], 1), - ([2048, 4096], 1), - ([2048, 2048], 0), - ([2048, 576], 0), - ([2048, 21888], 1), - ([10944, 2048], 0), - ([2048, 2816], 1), - ([1408, 2048], 0), - ], -} - - -@triton.testing.perf_report( - triton.testing.Benchmark( - x_names=["batch_size"], - x_vals=[1, 16, 64, 128, 256, 512, 1024, 2048], - x_log=False, - line_arg="provider", - line_vals=[ - "sglang-fp4-fp16", - "sglang-fp4-bf16", - ], - line_names=[ - "sglang-fp4-fp16", - "sglang-fp4-bf16", - ], - styles=[("green", "-"), ("blue", "-")], - ylabel="TFLOPS", - plot_name="fp4 block scaled matmul", - args={}, - ) -) -def benchmark(batch_size, provider, N, K): - # M, N, K = batch_size, 4096, 8192 - run_step = 100 - dtype = torch.float16 if "fp16" in provider else torch.bfloat16 - M = batch_size - a = torch.randn((M, K), dtype=dtype, device="cuda") - b = torch.randn((N, K), dtype=dtype, device="cuda") - a_global_scale = ( - (FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.amax(a.flatten(), dim=-1) - ).to(torch.float32) - b_global_scale = ( - (FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.amax(b.flatten(), dim=-1) - ).to(torch.float32) - alpha = 1.0 / (a_global_scale * b_global_scale) - a_fp4, a_scale_interleaved = scaled_fp4_quant(a, a_global_scale) - b_fp4, b_scale_interleaved = scaled_fp4_quant(b, b_global_scale) - - start_event = torch.cuda.Event(enable_timing=True) - end_event = torch.cuda.Event(enable_timing=True) - - # Bridging the gap between CPU and GPU - for _ in range(25): - c = a @ b.t() - # Warmup - for _ in range(5): - cutlass_scaled_fp4_mm( - a_fp4, b_fp4, a_scale_interleaved, b_scale_interleaved, alpha, dtype - ) - start_event.record() - for _ in range(run_step): - cutlass_scaled_fp4_mm( - a_fp4, b_fp4, a_scale_interleaved, b_scale_interleaved, alpha, dtype - ) - end_event.record() - end_event.synchronize() - torch.cuda.synchronize() - ms = start_event.elapsed_time(end_event) / run_step - - tflops = lambda ms: (2 * M * N * K) * 1e-9 / ms - return tflops(ms) - - -def prepare_shapes(args): - KN_model_names = [] - models_tps = list(itertools.product(args.models, args.tp_sizes)) - for model, tp_size in models_tps: - assert model in WEIGHT_SHAPES - for KN, tp_split_dim in copy.deepcopy(WEIGHT_SHAPES[model]): - KN[tp_split_dim] = KN[tp_split_dim] // tp_size - KN.append(model) - KN_model_names.append(KN) - return KN_model_names - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument( - "--models", - nargs="+", - type=str, - default=["meta-llama/Llama-3.1-8B-Instruct"], - help="List of models to benchmark", - ) - parser.add_argument( - "--tp-sizes", - nargs="+", - type=int, - default=[1], - help="List of tensor parallel sizes", - ) - args = parser.parse_args() - - # Check architecture compatibility - FP4 operations require sm100a/sm103a - major, minor = get_device_capability() - if major is None or major < 10: # Requires compute capability 10.0+ (sm100a/sm103a) - print("Skipping NVIDIA FP4 scaled GEMM benchmark") - if major is not None: - print(f"FP4 operations require sm100a/sm103a, but found sm{major}{minor}") - else: - print("Could not determine device capability") - else: - KN_model_names = prepare_shapes(args) - - # Limit iterations in CI - if IS_CI: - KN_model_names = KN_model_names[:2] # Only test first 2 shapes in CI - - for K, N, model_name in KN_model_names: - print(f"{model_name} N={N} K={K}: ") - benchmark.run(print_data=True, N=N, K=K) - print("Benchmark finished!") diff --git a/sgl-kernel/cmake/flashmla.cmake b/sgl-kernel/cmake/flashmla.cmake index b1546b151020..d52aadf3f082 100644 --- a/sgl-kernel/cmake/flashmla.cmake +++ b/sgl-kernel/cmake/flashmla.cmake @@ -4,7 +4,7 @@ include(FetchContent) FetchContent_Declare( repo-flashmla GIT_REPOSITORY https://github.com/sgl-project/FlashMLA - GIT_TAG be055fb7df0090fde45f08e9cb5b8b4c0272da73 + GIT_TAG 9804b12079e4c873514d3457aa588d3ccf40da28 GIT_SHALLOW OFF ) FetchContent_Populate(repo-flashmla) @@ -34,8 +34,9 @@ if(${CUDA_VERSION} VERSION_GREATER_EQUAL "13.0") # Patch FlashMLA sources for SM103a support. # These patches are only needed (and only valid) with CUDA 13+. - # Patch flashmla_utils.h: widen IS_SM100 to cover the full SM100 family - set(FLASHMLA_UTILS_FILE "${repo-flashmla_SOURCE_DIR}/csrc/flashmla_utils.h") + # Patch utils.h: widen IS_SM100 to cover the full SM100 family. + # Newer FlashMLA versions use csrc/utils.h. + set(FLASHMLA_UTILS_FILE "${repo-flashmla_SOURCE_DIR}/csrc/utils.h") file(READ "${FLASHMLA_UTILS_FILE}" FLASHMLA_UTILS_CONTENT) string(REPLACE "#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ == 1000) @@ -44,7 +45,7 @@ if(${CUDA_VERSION} VERSION_GREATER_EQUAL "13.0") #define IS_SM100 1" FLASHMLA_UTILS_CONTENT "${FLASHMLA_UTILS_CONTENT}") file(WRITE "${FLASHMLA_UTILS_FILE}" "${FLASHMLA_UTILS_CONTENT}") - message(STATUS "Patched flashmla_utils.h for SM103a support") + message(STATUS "Patched utils.h for SM103a support") # Patch cutlass/arch/config.h: add SM103 architecture defines. # The new block is inserted right before the existing "// SM101 and SM101a" @@ -87,16 +88,46 @@ endif() set(FlashMLA_SOURCES "csrc/flashmla_extension.cc" + + # Compatibility shim for sgl-kernel torch.ops API. ${repo-flashmla_SOURCE_DIR}/csrc/python_api.cpp - ${repo-flashmla_SOURCE_DIR}/csrc/smxx/get_mla_metadata.cu - ${repo-flashmla_SOURCE_DIR}/csrc/smxx/mla_combine.cu - ${repo-flashmla_SOURCE_DIR}/csrc/sm90/decode/dense/splitkv_mla.cu - ${repo-flashmla_SOURCE_DIR}/csrc/sm90/decode/sparse_fp8/splitkv_mla.cu + + # Decode metadata/combine kernels. + ${repo-flashmla_SOURCE_DIR}/csrc/smxx/decode/get_decoding_sched_meta/get_decoding_sched_meta.cu + ${repo-flashmla_SOURCE_DIR}/csrc/smxx/decode/combine/combine.cu + + # sm90 dense decode. + ${repo-flashmla_SOURCE_DIR}/csrc/sm90/decode/dense/instantiations/fp16.cu + ${repo-flashmla_SOURCE_DIR}/csrc/sm90/decode/dense/instantiations/bf16.cu + + # sm90 sparse decode. + ${repo-flashmla_SOURCE_DIR}/csrc/sm90/decode/sparse_fp8/instantiations/model1_persistent_h64.cu + ${repo-flashmla_SOURCE_DIR}/csrc/sm90/decode/sparse_fp8/instantiations/model1_persistent_h128.cu + ${repo-flashmla_SOURCE_DIR}/csrc/sm90/decode/sparse_fp8/instantiations/v32_persistent_h64.cu + ${repo-flashmla_SOURCE_DIR}/csrc/sm90/decode/sparse_fp8/instantiations/v32_persistent_h128.cu + + # sm90 sparse prefill. ${repo-flashmla_SOURCE_DIR}/csrc/sm90/prefill/sparse/fwd.cu - ${repo-flashmla_SOURCE_DIR}/csrc/sm100/decode/sparse_fp8/splitkv_mla.cu + ${repo-flashmla_SOURCE_DIR}/csrc/sm90/prefill/sparse/instantiations/phase1_k512.cu + ${repo-flashmla_SOURCE_DIR}/csrc/sm90/prefill/sparse/instantiations/phase1_k512_topklen.cu + ${repo-flashmla_SOURCE_DIR}/csrc/sm90/prefill/sparse/instantiations/phase1_k576.cu + ${repo-flashmla_SOURCE_DIR}/csrc/sm90/prefill/sparse/instantiations/phase1_k576_topklen.cu + + # sm100 dense prefill/bwd. ${repo-flashmla_SOURCE_DIR}/csrc/sm100/prefill/dense/fmha_cutlass_fwd_sm100.cu ${repo-flashmla_SOURCE_DIR}/csrc/sm100/prefill/dense/fmha_cutlass_bwd_sm100.cu - ${repo-flashmla_SOURCE_DIR}/csrc/sm100/prefill/sparse/fwd.cu + + # sm100 sparse prefill. + ${repo-flashmla_SOURCE_DIR}/csrc/sm100/prefill/sparse/fwd/head64/instantiations/phase1_k512.cu + ${repo-flashmla_SOURCE_DIR}/csrc/sm100/prefill/sparse/fwd/head64/instantiations/phase1_k576.cu + ${repo-flashmla_SOURCE_DIR}/csrc/sm100/prefill/sparse/fwd/head128/instantiations/phase1_k512.cu + ${repo-flashmla_SOURCE_DIR}/csrc/sm100/prefill/sparse/fwd/head128/instantiations/phase1_k576.cu + ${repo-flashmla_SOURCE_DIR}/csrc/sm100/prefill/sparse/fwd_for_small_topk/head128/instantiations/phase1_prefill_k512.cu + + # sm100 sparse decode. + ${repo-flashmla_SOURCE_DIR}/csrc/sm100/decode/head64/instantiations/v32.cu + ${repo-flashmla_SOURCE_DIR}/csrc/sm100/decode/head64/instantiations/model1.cu + ${repo-flashmla_SOURCE_DIR}/csrc/sm100/prefill/sparse/fwd_for_small_topk/head128/instantiations/phase1_decode_k512.cu ${repo-flashmla_SOURCE_DIR}/csrc/extension/sm90/dense_fp8/dense_fp8_python_api.cpp ${repo-flashmla_SOURCE_DIR}/csrc/extension/sm90/dense_fp8/flash_fwd_mla_fp8_sm90.cu @@ -104,9 +135,14 @@ set(FlashMLA_SOURCES ) Python_add_library(flashmla_ops MODULE USE_SABI ${SKBUILD_SABI_VERSION} WITH_SOABI ${FlashMLA_SOURCES}) -target_compile_options(flashmla_ops PRIVATE $<$:${FLASHMLA_CUDA_FLAGS}>) +target_compile_options(flashmla_ops PRIVATE + $<$:-std=c++20> + $<$:-std=c++20> + $<$:${FLASHMLA_CUDA_FLAGS}> +) target_include_directories(flashmla_ops PRIVATE ${repo-flashmla_SOURCE_DIR}/csrc + ${repo-flashmla_SOURCE_DIR}/csrc/kerutils/include ${repo-flashmla_SOURCE_DIR}/csrc/sm90 ${repo-flashmla_SOURCE_DIR}/csrc/extension/sm90/dense_fp8/ ${repo-flashmla_SOURCE_DIR}/csrc/cutlass/include diff --git a/sgl-kernel/csrc/attention/cascade.cu b/sgl-kernel/csrc/attention/cascade.cu deleted file mode 100644 index 9d49360ddee4..000000000000 --- a/sgl-kernel/csrc/attention/cascade.cu +++ /dev/null @@ -1,55 +0,0 @@ -// Adapted from -// https://github.com/flashinfer-ai/flashinfer/blob/55576c626421b5ee7e7ebe74afd26465c8ae863f/csrc/cascade.cu - -#include -#include - -#include - -#include "pytorch_extension_utils.h" - -using namespace flashinfer; - -void merge_state( - at::Tensor v_a, at::Tensor s_a, at::Tensor v_b, at::Tensor s_b, at::Tensor v_merged, at::Tensor s_merged) { - CHECK_INPUT(v_a); - CHECK_INPUT(s_a); - CHECK_INPUT(v_b); - CHECK_INPUT(s_b); - auto device = v_a.device(); - CHECK_EQ(s_a.device(), device); - CHECK_EQ(v_b.device(), device); - CHECK_EQ(s_b.device(), device); - CHECK_DIM(3, v_a); - CHECK_DIM(2, s_a); - CHECK_DIM(3, v_b); - CHECK_DIM(2, s_b); - CHECK_SHAPE(v_a, v_b); - CHECK_SHAPE(s_a, s_b); - CHECK_EQ(v_a.size(0), s_a.size(0)); - CHECK_EQ(v_a.size(1), s_b.size(1)); - unsigned int seq_len = v_a.size(0); - unsigned int num_heads = v_a.size(1); - unsigned int head_dim = v_a.size(2); - - const c10::cuda::OptionalCUDAGuard device_guard(v_a.device()); - auto stream = at::cuda::getCurrentCUDAStream(); - - bool success = DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(v_a.scalar_type(), c_type, [&] { - cudaError_t status = MergeState( - static_cast(v_a.data_ptr()), - static_cast(s_a.data_ptr()), - static_cast(v_b.data_ptr()), - static_cast(s_b.data_ptr()), - static_cast(v_merged.data_ptr()), - static_cast(s_merged.data_ptr()), - seq_len, - num_heads, - head_dim, - stream); - TORCH_CHECK(status == cudaSuccess, "MergeState kernel launch failed: ", cudaGetErrorString(status)); - return true; - }); - - TORCH_CHECK(success, "MergeState kernel launch failed: unsupported data type"); -} diff --git a/sgl-kernel/csrc/common_extension.cc b/sgl-kernel/csrc/common_extension.cc index cdce0064b2f0..b7c01a08327e 100644 --- a/sgl-kernel/csrc/common_extension.cc +++ b/sgl-kernel/csrc/common_extension.cc @@ -50,8 +50,6 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { /* * From csrc/attention */ - m.def("merge_state(Tensor v_a, Tensor s_a, Tensor v_b, Tensor s_b, Tensor! v_merged, Tensor! s_merged) -> ()"); - m.impl("merge_state", torch::kCUDA, &merge_state); m.def("merge_state_v2(Tensor v_a, Tensor s_a, Tensor v_b, Tensor s_b, Tensor! v_merged, Tensor! s_merged) -> ()"); m.impl("merge_state_v2", torch::kCUDA, &merge_state_v2); m.def( @@ -90,11 +88,6 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { " Tensor cos_sin_cache, bool is_neox) -> ()"); m.impl("rotary_embedding", torch::kCUDA, &rotary_embedding); - m.def( - "downcast_fp8(Tensor k, Tensor v, Tensor k_out, Tensor v_out, Tensor k_scale, Tensor v_scale, Tensor loc, " - "int mult, int offset) -> ()"); - m.impl("downcast_fp8", torch::kCUDA, &downcast_fp8); - m.def("copy_to_gpu_no_ce(Tensor input, Tensor! output) -> ()"); m.impl("copy_to_gpu_no_ce", torch::kCUDA, ©_to_gpu_no_ce); m.def("concat_mla_k(Tensor! k, Tensor k_nope, Tensor k_rope) -> ()"); @@ -364,9 +357,6 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { m.def("top_p_renorm_probs(Tensor probs, Tensor! renorm_probs, Tensor? maybe_top_p_arr, float top_p_val) -> ()"); m.impl("top_p_renorm_probs", torch::kCUDA, &top_p_renorm_probs); - m.def("top_k_mask_logits(Tensor logits, Tensor mask_logits, Tensor? maybe_top_k_arr, int top_k_val) -> ()"); - m.impl("top_k_mask_logits", torch::kCUDA, &top_k_mask_logits); - /* * From Sparse Flash Attention */ diff --git a/sgl-kernel/csrc/common_extension_musa.cc b/sgl-kernel/csrc/common_extension_musa.cc index 33bc639a1baa..00a83f5b53a9 100644 --- a/sgl-kernel/csrc/common_extension_musa.cc +++ b/sgl-kernel/csrc/common_extension_musa.cc @@ -43,9 +43,6 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) { "top_k_top_p_sampling_from_probs(Tensor probs, Tensor output, Tensor? maybe_indices, Tensor? maybe_top_k_arr, " "float top_k_val, Tensor? maybe_top_p_arr, float top_p_val, bool deterministic, Generator? gen) -> ()"); m.impl("top_k_top_p_sampling_from_probs", torch::kMUSA, &top_k_top_p_sampling_from_probs); - - m.def("top_k_mask_logits(Tensor logits, Tensor mask_logits, Tensor? maybe_top_k_arr, int top_k_val) -> ()"); - m.impl("top_k_mask_logits", torch::kMUSA, &top_k_mask_logits); } REGISTER_EXTENSION(common_ops) diff --git a/sgl-kernel/csrc/cpu/common.h b/sgl-kernel/csrc/cpu/common.h index 31c6295a011a..48f25e21c286 100644 --- a/sgl-kernel/csrc/cpu/common.h +++ b/sgl-kernel/csrc/cpu/common.h @@ -45,7 +45,7 @@ namespace { } \ }() -// dispatch: bfloat16, float16, int8_t, fp8_e4m3 +// dispatch: bfloat16, float16, int8_t, fp8_e4m3, uint8_t(mxfp4/int4) #define CPU_DISPATCH_PACKED_TYPES(TYPE, ...) \ [&] { \ switch (TYPE) { \ @@ -65,6 +65,10 @@ namespace { using packed_t = at::Float8_e4m3fn; \ return __VA_ARGS__(); \ } \ + case at::ScalarType::Byte: { \ + using packed_t = uint8_t; \ + return __VA_ARGS__(); \ + } \ default: \ TORCH_CHECK(false, "Unsupported floating data type.\n"); \ } \ diff --git a/sgl-kernel/csrc/cpu/gemm.cpp b/sgl-kernel/csrc/cpu/gemm.cpp index e2fdc8951f23..13d329be84f9 100644 --- a/sgl-kernel/csrc/cpu/gemm.cpp +++ b/sgl-kernel/csrc/cpu/gemm.cpp @@ -65,6 +65,43 @@ inline void pack_vnni(int8_t* __restrict__ packed, const int8_t* __restr s8s8_compensation(packed, K); } +// uint8_t: mxfp4 or int4 +// pack to vnni2 format as they are computed with bfloat16 +// +// from [N, K'/2, 2] to [K'/2, N, 2], view 2x int4 as unit8: +// from [N, K ] to [K, N ] where K = K'/2 +// +template <> +inline void pack_vnni(uint8_t* __restrict__ packed, const uint8_t* __restrict__ weight, int N, int K) { + constexpr int BLOCK_N = block_size_n(); + + uint8_t unpacked[2 * BLOCK_N]; + + // 32-way pack (align with BLOCK_N), faster for avx512 unpacking + // + // for a range of (64): + // {0, 1, 2, ..., 63} + // + // original format: + // { 1|0, 3|2, ..., 63|62} + // + // packed format: + // {32|0, 31|1, ..., 63|31} + // + for (int k = 0; k < K; ++k) { + // unpack first + for (int n = 0; n < N; ++n) { + uint8_t value = weight[n * K + k]; + unpacked[n * 2 + 0] = value & 0xF; // lower 4 bits + unpacked[n * 2 + 1] = value >> 4; // higher 4 bits + } + // re-pack to 32-way + for (int n = 0; n < N; ++n) { + packed[k * N + n] = (unpacked[n + BLOCK_N] << 4) | unpacked[n]; + } + } +} + template inline void copy_stub(scalar_t* __restrict__ out, const float* __restrict__ input, int64_t size) { using bVec = at::vec::Vectorized; @@ -600,9 +637,12 @@ at::Tensor convert_weight_packed(at::Tensor& weight) { const int64_t OC = ndim == 3 ? weight.size(1) : weight.size(0); const int64_t IC = ndim == 3 ? weight.size(2) : weight.size(1); + // mxfp4 or int4 are packed with uint8 + const int64_t actual_IC = st == at::kByte ? IC * 2 : IC; + // we handle 2 TILE_N at a time. TORCH_CHECK(OC % TILE_N == 0, "invalid weight out features ", OC); - TORCH_CHECK(IC % TILE_K == 0, "invalid weight input features ", IC); + TORCH_CHECK(actual_IC % TILE_K == 0, "invalid weight input features ", actual_IC); constexpr int64_t BLOCK_N = block_size_n(); const int64_t NB = div_up(OC, BLOCK_N); @@ -611,13 +651,14 @@ at::Tensor convert_weight_packed(at::Tensor& weight) { auto packed_weight = at::empty({}, weight.options()); const int64_t stride = OC * IC; + // Note: for `kByte` (uint8), it represents either `mxfp4` or `int4`. TORCH_CHECK( - st == at::kBFloat16 || st == at::kHalf || st == at::kChar || st == at::kFloat8_e4m3fn, - "expect weight to be bfloat16, float16, int8 or fp8_e4m3."); + st == at::kBFloat16 || st == at::kHalf || st == at::kChar || st == at::kFloat8_e4m3fn || st == at::kByte, + "expect weight to be bfloat16, float16, int8, fp8_e4m3 or uint8(mxfp4 or int4)."); CPU_DISPATCH_PACKED_TYPES(st, [&] { // adjust most inner dimension size - const int packed_row_size = get_row_size(IC); + const int packed_row_size = get_row_size(actual_IC); auto sizes = weight.sizes().vec(); sizes[ndim - 1] = packed_row_size; packed_weight.resize_(sizes); @@ -646,6 +687,41 @@ at::Tensor convert_weight_packed(at::Tensor& weight) { return packed_weight; } +at::Tensor convert_scale_packed(at::Tensor& scale) { + CHECK_INPUT(scale); + + const int64_t ndim = scale.ndimension(); + TORCH_CHECK(ndim == 2 || ndim == 3, "expect scale to be 2d or 3d, got ", ndim, "d tensor."); + const auto st = scale.scalar_type(); + const int64_t E = ndim == 3 ? scale.size(0) : 1; + const int64_t N = ndim == 3 ? scale.size(1) : scale.size(0); + // number of groups, e.g. K/32 + const int64_t G = ndim == 3 ? scale.size(2) : scale.size(1); + + constexpr int64_t BLOCK_N = block_size_n(); + TORCH_CHECK(N % BLOCK_N == 0, "invalid weight out features ", N); + const int64_t NB = N / BLOCK_N; + + auto packed_scale = at::empty_like(scale); + TORCH_CHECK(st == at::kByte, "expect scale to be uint8."); + + const uint8_t* s_data = scale.data_ptr(); + uint8_t* packed_data = packed_scale.data_ptr(); + + // parallel on src {E, NB, BLOCK_N, G}, dst {E, NB, G, BLOCK_N} + at::parallel_for(0, E * NB * BLOCK_N * G, 0, [&](int64_t begin, int64_t end) { + int64_t e{0}, nb{0}, n{0}, g{0}; + data_index_init(begin, e, E, nb, NB, n, BLOCK_N, g, G); + + for (int64_t i = begin; i < end; ++i) { + packed_data[e * N * G + nb * G * BLOCK_N + g * BLOCK_N + n] = s_data[i]; + // move to the next index + data_index_step(e, E, nb, NB, n, BLOCK_N, g, G); + } + }); + return packed_scale; +} + // mat1 : [M, K] // mat2 : [N, K] ([K, N] if use_fma_gemm) // bias : [N] diff --git a/sgl-kernel/csrc/cpu/gemm.h b/sgl-kernel/csrc/cpu/gemm.h index e11b224fe193..fc2b199bfea6 100644 --- a/sgl-kernel/csrc/cpu/gemm.h +++ b/sgl-kernel/csrc/cpu/gemm.h @@ -33,6 +33,11 @@ inline bool can_use_brgemm(int M) { return M > 4; } +template <> +inline bool can_use_brgemm(int M) { + return M > 4; +} + template <> inline bool can_use_brgemm(int M) { return M > 4; @@ -52,6 +57,12 @@ inline int64_t get_row_size(int64_t K) { return K + sizeof(int32_t); } +// uint8: mxfp4 or int4 +template <> +inline int64_t get_row_size(int64_t K) { + return K >> 1; +} + inline int64_t get_row_size(int64_t K, bool use_int8_w8a8) { return use_int8_w8a8 ? K + sizeof(int32_t) : K; } @@ -287,3 +298,22 @@ void tinygemm_kernel( int64_t ldc_s, bool store_out, bool use_brgemm); + +// mxfp4 +template +void tinygemm_kernel( + const scalar_t* __restrict__ A, + const uint8_t* __restrict__ B, + scalar_t* __restrict__ C, + scalar_t* __restrict__ Btmp, + float* __restrict__ Ctmp, + const uint8_t* __restrict__ scale, + int64_t M, + int64_t N, + int64_t K, + int64_t lda, + int64_t ldb, + int64_t ldc, + bool brg, + int64_t block_size_K, + bool do_unpack = true); diff --git a/sgl-kernel/csrc/cpu/gemm_fp8.cpp b/sgl-kernel/csrc/cpu/gemm_fp8.cpp index 15bd44434e0e..245fd3a075d7 100644 --- a/sgl-kernel/csrc/cpu/gemm_fp8.cpp +++ b/sgl-kernel/csrc/cpu/gemm_fp8.cpp @@ -65,15 +65,15 @@ inline void copy_mul_stub(scalar_t* __restrict__ out, const float* __restrict__ inline void unpack_B( at::BFloat16* __restrict__ Btmp, const at::Float8_e4m3fn* __restrict__ packed_B, - int N, - int K, - int ldb, - int ldb_tmp, + int64_t N, + int64_t K, + int64_t ldb, + int64_t ldb_tmp, float scale) { #if defined(CPU_CAPABILITY_AVX512) // [K/2, N, 2] - const int K2 = K >> 1; - const int ldb2 = ldb; // ldb * 2 >> 1; + const int64_t K2 = K >> 1; + const int64_t ldb2 = ldb; // ldb * 2 >> 1; const uint16_t* b_ptr = reinterpret_cast(packed_B); const __m512 vexp = _mm512_castsi512_ps(_mm512_set1_epi32(kFP8_BIAS)); const __m512 vd = _mm512_mul_ps(_mm512_set1_ps(scale), vexp); @@ -85,7 +85,7 @@ inline void unpack_B( constexpr int PREFETCH_SIZE_K = 64; #pragma GCC unroll 4 - for (int k = 0; k < K2; ++k) { + for (int64_t k = 0; k < K2; ++k) { __m512i b8 = _mm512_loadu_si512(b_ptr + k * ldb2); if constexpr (PREFETCH_SIZE_K > 0) { _mm_prefetch(b_ptr + (k + PREFETCH_SIZE_K) * ldb2, _MM_HINT_T0); @@ -154,18 +154,67 @@ inline void unpack_B( #endif } -template +// mxfp4 +inline void unpack_B( + at::BFloat16* __restrict__ Btmp, + const uint8_t* __restrict__ packed_B, + int64_t N, + int64_t K, + int64_t ldb, + int64_t ldb_tmp, + const uint8_t* __restrict__ scale) { +#if defined(CPU_CAPABILITY_AVX512) + // [K/2, N, 2] + const int64_t K2 = K >> 1; + const int64_t ldb2 = ldb; // ldb * 2 >> 1; + const uint8_t* b_ptr = reinterpret_cast(packed_B); // 2 * 4 bit = 8 bit + + constexpr int BLOCK_N = block_size_n(); + static_assert(BLOCK_N == 32); + + // prefetch distance + constexpr int PREFETCH_SIZE_K = 64; + + // exponent bias 127 + const __m512i off = _mm512_set1_epi16(0x7F); + + // load 32 bytes only once for each block + __m256i s8 = _mm256_loadu_si256(reinterpret_cast(scale)); + __m512i s16 = _mm512_slli_epi16(_mm512_sub_epi16(_mm512_cvtepu8_epi16(s8), off), 0x7); + + // holds Nx2(64) scales, interleaved as 2 belongs to K dimension + // e.g. vs0: { s0, s0, s1, s1, ..., s15, s15} + // vs1: {s16, s16, s17, s17, ..., s31, s31} + auto [vscale0, vscale1] = transpose_2x32_16bit(s16, s16); + +#pragma GCC unroll 4 + for (int64_t k = 0; k < K2; ++k) { + __m256i b4 = _mm256_loadu_si256(reinterpret_cast(b_ptr + k * ldb2)); + if constexpr (PREFETCH_SIZE_K > 0) { + _mm_prefetch(b_ptr + (k + PREFETCH_SIZE_K) * ldb2, _MM_HINT_T0); + } + auto [vb0, vb1] = CVT_MXFP4_TO_BF16(b4, vscale0, vscale1); + + _mm512_storeu_si512(Btmp + k * ldb_tmp * 2 + 0, (__m512i)vb0); + _mm512_storeu_si512(Btmp + k * ldb_tmp * 2 + 32, (__m512i)vb1); + } +#else + TORCH_CHECK(false, "unpack_B: scalar path not implemented!"); +#endif +} + +template struct tinygemm_kernel_nn { static inline void apply( const scalar_t* __restrict__ A, const packed_t* __restrict__ B, scalar_t* __restrict__ C, const float* __restrict__ bias, - const float* __restrict__ scale, - int K, - int lda, - int ldb, - int ldc, + const param_t* __restrict__ scale, + int64_t K, + int64_t lda, + int64_t ldb, + int64_t ldc, int64_t block_size_K) { TORCH_CHECK(false, "tinygemm_kernel_nn: scalar path not implemented!"); } @@ -187,22 +236,22 @@ struct tinygemm_kernel_nn2 { }; #if defined(CPU_CAPABILITY_AVX512) template -struct tinygemm_kernel_nn { +struct tinygemm_kernel_nn { static inline void apply( const at::BFloat16* __restrict__ A, const at::Float8_e4m3fn* __restrict__ B, at::BFloat16* __restrict__ C, const float* __restrict__ bias, const float* __restrict__ scale, - int K, - int lda, - int ldb, - int ldc, + int64_t K, + int64_t lda, + int64_t ldb, + int64_t ldc, int64_t block_size_K) { constexpr int ROWS = BLOCK_M; constexpr int COLS = BLOCK_N / 16; - const int KB = div_up(K, BLOCK_K); + const int64_t KB = div_up(K, (int64_t)BLOCK_K); // prefetch distance constexpr int PREFETCH_SIZE_K = 64; @@ -228,8 +277,8 @@ struct tinygemm_kernel_nn{}(loadc); - const int lda2 = lda >> 1; - const int ldb2 = ldb; // ldb * 2 >> 1; + const int64_t lda2 = lda >> 1; + const int64_t ldb2 = ldb; // ldb * 2 >> 1; const float* a_ptr = reinterpret_cast(A); const uint16_t* b_ptr = reinterpret_cast(B); @@ -256,10 +305,10 @@ struct tinygemm_kernel_nn> 1; - for (int kb = 0; kb < KB; ++kb) { - int kb_start = kb * BLOCK_K2; - int kb_end = std::min(K >> 1, kb_start + BLOCK_K2); + constexpr int64_t BLOCK_K2 = BLOCK_K >> 1; + for (int64_t kb = 0; kb < KB; ++kb) { + int64_t kb_start = kb * BLOCK_K2; + int64_t kb_end = std::min(K >> 1, kb_start + BLOCK_K2); // 1. load scale vector vscale = _mm512_set1_ps(scale[kb]); vscale = _mm512_mul_ps(vscale, vexp); @@ -359,10 +408,110 @@ struct tinygemm_kernel_nn2 { Unroll{}(storec); } }; + +template +struct tinygemm_kernel_nn { + static inline void apply( + const at::BFloat16* __restrict__ A, + const uint8_t* __restrict__ B, + at::BFloat16* __restrict__ C, + const float* __restrict__ bias, + const uint8_t* __restrict__ scale, + int K, + int lda, + int ldb, + int ldc, + int64_t block_size_K) { + // mxfp4 supports only group size of 32 + // expect weight packed in 32-way, vnni2 format Nx2(64) + assert(block_size_K == 32); + assert(BLOCK_N == 32); + + constexpr int ROWS = BLOCK_M; + constexpr int COLS = BLOCK_N / 16; + + // prefetch distance + constexpr int PREFETCH_SIZE_K = 64; + constexpr int PREFETCH_SIZE_KB = 1; + + __m512bh va; + __m512bh vb[COLS]; + __m512 vc[ROWS * COLS]; + + // holds Nx2(64) scales, interleaved as 2 belongs to K dimension + // e.g. vs0: { s0, s0, s1, s1, ..., s15, s15} + // vs1: {s16, s16, s17, s17, ..., s31, s31} + __m512i vscale[COLS]; + + // exponent bias 127 + const __m512i off = _mm512_set1_epi16(0x7F); + + auto loadc = [&](auto i) { + constexpr int col = i % COLS; + if constexpr (has_bias) { + vc[i] = _mm512_loadu_ps(bias + col * 16); + } else { + vc[i] = _mm512_setzero_ps(); + } + }; + Unroll{}(loadc); + + const int64_t K2 = K >> 1; + const int64_t lda2 = lda >> 1; + const int64_t ldb2 = ldb; // ldb * 2 >> 1; + const float* a_ptr = reinterpret_cast(A); + const uint8_t* b_ptr = reinterpret_cast(B); + + auto compute = [&](auto i, int k) { + constexpr int row = i / COLS; + constexpr int col = i % COLS; + + if constexpr (col == 0) { + va = (__m512bh)(_mm512_set1_ps(a_ptr[row * lda2 + k])); + if constexpr (PREFETCH_SIZE_K > 0) { + _mm_prefetch(a_ptr + row * lda2 + k + PREFETCH_SIZE_K, _MM_HINT_T0); + } + } + if constexpr (row == 0) { + // load 32 * 2 (64) int4 at a time + if constexpr (col % 2 == 0) { + __m256i b4 = _mm256_loadu_si256(reinterpret_cast(b_ptr + k * ldb2 + col * 16)); + if constexpr (PREFETCH_SIZE_K > 0) { + _mm_prefetch(b_ptr + (k + PREFETCH_SIZE_K) * ldb2 + col * 16, _MM_HINT_T0); + } + std::tie(vb[col + 0], vb[col + 1]) = CVT_MXFP4_TO_BF16(b4, vscale[col + 0], vscale[col + 1]); + } + } + vc[i] = _mm512_dpbf16_ps(vc[i], va, vb[col]); + }; + + for (int64_t k = 0; k < K2; ++k) { + // update scales every 16x2 K + if ((k & 15) == 0) { + __m256i s8 = _mm256_loadu_si256(reinterpret_cast(scale + (k >> 4) * 32)); + __m512i s16 = _mm512_slli_epi16(_mm512_sub_epi16(_mm512_cvtepu8_epi16(s8), off), 0x7); + std::tie(vscale[0], vscale[1]) = transpose_2x32_16bit(s16, s16); + } + Unroll{}(compute, k); + } + + auto storec = [&](auto i) { + constexpr int row = i / COLS; + constexpr int col = i % COLS; + // for COLS = 2,4 use 512bit store + if constexpr (col % 2 == 0) { + _mm512_storeu_si512( + reinterpret_cast<__m512i*>((C + row * ldc + col * 16)), + (__m512i)(_mm512_cvtne2ps_pbh(vc[row * COLS + col + 1], vc[row * COLS + col]))); + } + }; + Unroll{}(storec); + } +}; #endif #define LAUNCH_TINYGEMM_KERNEL_NN(MB_SIZE, NB_SIZE) \ - tinygemm_kernel_nn::apply( \ + tinygemm_kernel_nn::apply( \ A + mb_start * lda, \ B + nb_start * 2, \ C + mb_start * ldc + nb_start, \ @@ -378,7 +527,7 @@ struct tinygemm_kernel_nn2 { tinygemm_kernel_nn2::apply( \ A + mb_start * lda, B + nb_start * 2, C + mb_start * ldc + nb_start, scale, K, lda, ldb, ldc); -template +template struct brgemm { static inline void apply( const scalar_t* __restrict__ A, @@ -387,7 +536,7 @@ struct brgemm { scalar_t* __restrict__ Btmp, float* __restrict__ Ctmp, const float* __restrict__ bias, - const float* __restrict__ scale, + const param_t* __restrict__ scale, int M, int N, int K, @@ -402,7 +551,7 @@ template struct brgemm2 {}; template -struct brgemm { +struct brgemm { static inline void apply( const at::BFloat16* __restrict__ A, const at::Float8_e4m3fn* __restrict__ B, @@ -481,14 +630,56 @@ struct brgemm2 { } }; -template +template +struct brgemm { + static inline void apply( + const at::BFloat16* __restrict__ A, + const uint8_t* __restrict__ B, + at::BFloat16* __restrict__ C, + at::BFloat16* __restrict__ Btmp, + float* __restrict__ Ctmp, + const float* __restrict__ bias, + const uint8_t* __restrict__ scale, + int M, + int N, + int K, + int lda, + int ldb, + int ldc, + bool do_unpack = true) { + constexpr int BLOCK_N = block_size_n(); + + // [K, BLOCK_N] -> [K / 2, BLOCK_N * 2] + const int ldb_tmp = BLOCK_N; + + if (do_unpack) { + // group size 32 for mxfp4 + for (int k = 0; k < K; k += 32) { + unpack_B(Btmp + k * ldb_tmp, B + k * (ldb >> 1), N, 32, ldb, ldb_tmp, scale + (k >> 5) * BLOCK_N); + } + } + + at::native::cpublas::brgemm(M, N, K, lda, ldb_tmp, BLOCK_N, /* add_C */ false, A, Btmp, Ctmp); + + // copy from Ctmp to C + for (int m = 0; m < M; ++m) { + if constexpr (has_bias) { + copy_add_stub(C + m * ldc, Ctmp + m * BLOCK_N, bias, N); + } else { + copy_stub(C + m * ldc, Ctmp + m * BLOCK_N, N); + } + } + } +}; + +template void tinygemm_kernel( const scalar_t* __restrict__ A, - const at::Float8_e4m3fn* __restrict__ B, + const packed_t* __restrict__ B, scalar_t* __restrict__ C, scalar_t* __restrict__ Btmp, float* __restrict__ Ctmp, - const float* __restrict__ scale, + const param_t* __restrict__ scale, const float* __restrict__ bias, int64_t M, int64_t N, @@ -500,7 +691,7 @@ void tinygemm_kernel( int64_t block_size_K, bool do_unpack = true) { if (brg) { - brgemm::apply( + brgemm::apply( A, B, C, Btmp, Ctmp, bias, scale, M, N, K, lda, ldb, ldc, do_unpack); return; } @@ -536,6 +727,7 @@ void tinygemm_kernel( } } } + template void tinygemm_kernel2( const scalar_t* __restrict__ A, @@ -633,12 +825,19 @@ void tinygemm_kernel2( } } } -template -void fp8_scaled_mm_kernel_impl( + +// NB: fp8/fp4 scaled mm kernel implementation +// +// scalar_t packed_t param_t +// FP8 BF16 FP8 FP32 +// MXFP4 BF16 U8 U8 +// +template +void fp_scaled_mm_kernel_impl( scalar_t* __restrict__ out, const scalar_t* __restrict__ mat1, - const at::Float8_e4m3fn* __restrict__ mat2, - const float* __restrict__ scales2, + const packed_t* __restrict__ mat2, + const param_t* __restrict__ scales2, const float* __restrict__ bias, scalar_t* __restrict__ buffer, int64_t M, @@ -648,16 +847,17 @@ void fp8_scaled_mm_kernel_impl( int64_t out_strideM, int64_t block_size_N, int64_t block_size_K, - int64_t buffer_size_per_thread) { + int64_t buffer_size_per_thread, + const func_t& scale_offset_per_block) { constexpr int64_t BLOCK_M = block_size_m(); constexpr int64_t BLOCK_N = block_size_n(); const int64_t MB = div_up(M, BLOCK_M); const int64_t NB = div_up(N, BLOCK_N); - const int64_t scale_size_K = div_up(K, block_size_K); - const int64_t blocks_n_per_group = block_size_N / BLOCK_N; + const bool use_brgemm = can_use_brgemm(M); - const bool use_brgemm = can_use_brgemm(M); + // use K/2 for mxfp4 and K for fp8 + const int64_t packed_K = get_row_size(K); // parallel on [MB, NB] AT_DISPATCH_BOOL(bias != nullptr, has_bias, [&] { @@ -666,8 +866,8 @@ void fp8_scaled_mm_kernel_impl( scalar_t* __restrict__ Btmp = buffer + tid * buffer_size_per_thread; float* __restrict__ Ctmp = (float*)((void*)(Btmp + MAX_CACHE_BLOCK_SIZE * BLOCK_N * K)); - loop_2d(mb0, mb1, nb0, nb1, BLOCK_N * K, [&](int64_t mb, int64_t nb, int64_t nb_offset) { - const float* scale_ptr = scales2 + (nb / blocks_n_per_group) * scale_size_K; + loop_2d(mb0, mb1, nb0, nb1, BLOCK_N * K, [&](int64_t mb, int64_t nb, int64_t nb_offset) { + const param_t* scale_ptr = scales2 + scale_offset_per_block(nb); int64_t mb_start = mb * BLOCK_M; int64_t mb_size = std::min(M - mb_start, BLOCK_M); @@ -677,9 +877,9 @@ void fp8_scaled_mm_kernel_impl( // only do unpacking for the first row bool do_unpack = (mb == mb0); - tinygemm_kernel( + tinygemm_kernel( /* A */ mat1 + mb_start * mat1_strideM, - /* B */ mat2 + nb_start * K, // nb * BLOCK_N * K + /* B */ mat2 + nb_start * packed_K, // nb * BLOCK_N * K /* C */ out + mb_start * out_strideM + nb_start, /* Btmp */ Btmp + nb_offset * BLOCK_N * K, /* Ctmp */ Ctmp, @@ -723,9 +923,10 @@ void tinygemm_kernel( bool brg, int64_t block_size_K, bool do_unpack) { - tinygemm_kernel( + tinygemm_kernel( A, B, C, Btmp, Ctmp, scale, nullptr, M, N, K, lda, ldb, ldc, brg, block_size_K, do_unpack); } + template void tinygemm_kernel( const scalar_t* __restrict__ A, @@ -743,24 +944,51 @@ void tinygemm_kernel( bool brg) { tinygemm_kernel2(A, B, C, Btmp, Ctmp, scale, M, N, K, lda, ldb, ldc, brg); } -#define INSTANTIATE_TINYGEMM_TEMPLATE(TYPE) \ - template void tinygemm_kernel( \ - const TYPE* __restrict__ A, \ - const at::Float8_e4m3fn* __restrict__ B, \ - TYPE* __restrict__ C, \ - TYPE* __restrict__ Btmp, \ - float* __restrict__ Ctmp, \ - const float* __restrict__ scale, \ - int64_t M, \ - int64_t N, \ - int64_t K, \ - int64_t lda, \ - int64_t ldb, \ - int64_t ldc, \ - bool brg, \ - int64_t block_size_K, \ + +template +void tinygemm_kernel( + const scalar_t* __restrict__ A, + const uint8_t* __restrict__ B, + scalar_t* __restrict__ C, + scalar_t* __restrict__ Btmp, + float* __restrict__ Ctmp, + const uint8_t* __restrict__ scale, + int64_t M, + int64_t N, + int64_t K, + int64_t lda, + int64_t ldb, + int64_t ldc, + bool brg, + int64_t block_size_K, + bool do_unpack) { + tinygemm_kernel( + A, B, C, Btmp, Ctmp, scale, nullptr, M, N, K, lda, ldb, ldc, brg, block_size_K, do_unpack); +} + +#define INSTANTIATE_TINYGEMM_TEMPLATE(TYPE_A, TYPE_B, TYPE_S) \ + template void tinygemm_kernel( \ + const TYPE_A* __restrict__ A, \ + const TYPE_B* __restrict__ B, \ + TYPE_A* __restrict__ C, \ + TYPE_A* __restrict__ Btmp, \ + float* __restrict__ Ctmp, \ + const TYPE_S* __restrict__ scale, \ + int64_t M, \ + int64_t N, \ + int64_t K, \ + int64_t lda, \ + int64_t ldb, \ + int64_t ldc, \ + bool brg, \ + int64_t block_size_K, \ bool do_unpack) +INSTANTIATE_TINYGEMM_TEMPLATE(at::BFloat16, at::Float8_e4m3fn, float); +INSTANTIATE_TINYGEMM_TEMPLATE(at::Half, at::Float8_e4m3fn, float); +INSTANTIATE_TINYGEMM_TEMPLATE(at::BFloat16, uint8_t, uint8_t); +INSTANTIATE_TINYGEMM_TEMPLATE(at::Half, uint8_t, uint8_t); + #define INSTANTIATE_TINYGEMM_TEMPLATE2(TYPE) \ template void tinygemm_kernel( \ const TYPE* __restrict__ A, \ @@ -777,10 +1005,28 @@ void tinygemm_kernel( int64_t ldc, \ bool brg) -INSTANTIATE_TINYGEMM_TEMPLATE(at::BFloat16); -INSTANTIATE_TINYGEMM_TEMPLATE(at::Half); INSTANTIATE_TINYGEMM_TEMPLATE2(at::BFloat16); +inline const float* get_bias_data(const std::optional& bias, int64_t N) { + if (bias.has_value()) { + const auto& bias_ref = bias.value(); + CHECK_EQ(bias_ref.size(0), N); + return bias_ref.data_ptr(); + } + return nullptr; +} + +// FP8 and MXFP4 WoQ uses the same pattern: +// Btmp : [T, BLOCK_N * K] +// Ctmp : [T, BLOCK_M * BLOCK_N] +inline at::Tensor alloc_thread_buffer(const at::TensorOptions& options, int64_t K) { + constexpr int64_t BLOCK_M = block_size_m(); + constexpr int64_t BLOCK_N = block_size_n(); + int num_threads = at::get_num_threads(); + int64_t size_per_thread = MAX_CACHE_BLOCK_SIZE * BLOCK_N * K + BLOCK_M * BLOCK_N * 2; + return at::empty({num_threads, size_per_thread}, options); +} + at::Tensor fp8_scaled_mm_cpu( at::Tensor& mat1, at::Tensor& mat2, @@ -807,11 +1053,9 @@ at::Tensor fp8_scaled_mm_cpu( CHECK_DIM(2, mat2); TORCH_CHECK(block_size.size() == 2, "fp8_scaled_mm_cpu: expect block_size.size() to be 2."); - int64_t block_size_N = block_size[0]; int64_t block_size_K = block_size[1]; - constexpr int64_t BLOCK_M = block_size_m(); constexpr int64_t BLOCK_N = block_size_n(); TORCH_CHECK(block_size_N % BLOCK_N == 0, "fp8_scaled_mm_cpu: expect block_size_N to be multiples of BLOCK_N"); TORCH_CHECK(block_size_K == BLOCK_K, "fp8_scaled_mm_cpu: expect block_size_K equals to BLOCK_K"); @@ -825,39 +1069,90 @@ at::Tensor fp8_scaled_mm_cpu( TORCH_CHECK(scales2.scalar_type() == at::kFloat, "fp8_scaled_mm_cpu: expect scales to be float32."); auto out = at::empty({M, N}, mat1.options().dtype(out_dtype)); - // strides - int64_t mat1_strideM = mat1.stride(0); - int64_t out_strideM = out.stride(0); - - const bool has_bias = bias.has_value(); - const float* bias_data = nullptr; - if (has_bias) { - CHECK_EQ(bias.value().size(0), N); - bias_data = bias.value().data_ptr(); - } - - // Btmp : [T, BLOCK_N * K] - // Ctmp : [T, BLOCK_M * BLOCK_N] - int num_threads = at::get_num_threads(); - int64_t size_per_thread = MAX_CACHE_BLOCK_SIZE * BLOCK_N * K + BLOCK_M * BLOCK_N * 2; - auto buffer = at::empty({num_threads, size_per_thread}, mat1.options()); + auto buffer = alloc_thread_buffer(mat1.options(), K); AT_DISPATCH_REDUCED_FLOATING_TYPES(out_dtype, "fp8_scaled_mm_kernel_impl", [&] { - fp8_scaled_mm_kernel_impl( + // used for lambda computing scale offset for each block + // fp8 block gemm sale shape: [N/128, K/128] + // for each block: [1, K/128] + const int64_t scale_size_K = div_up(K, block_size_K); + const int64_t blocks_n_per_group = block_size_N / BLOCK_N; + + fp_scaled_mm_kernel_impl( out.data_ptr(), mat1.data_ptr(), packed_w.data_ptr(), scales2.data_ptr(), - bias_data, + get_bias_data(bias, N), buffer.data_ptr(), M, N, K, - mat1_strideM, - out_strideM, + mat1.stride(0), + out.stride(0), block_size_N, block_size_K, - size_per_thread); + buffer.size(-1), + [&](int64_t nb) { return (nb / blocks_n_per_group) * scale_size_K; }); + }); + + return out; +} + +// mat1 : [M, K] bfloat16 +// mat2 : [N, K / 2] uint8, actual layout: [N / BLOCK_N, K / 2, BLOCK_N, 2] +// scales2: [N, K / G], actual layout: [N / BLOCK_N, K / G, BLOCK_N] +at::Tensor mxfp4_scaled_mm_cpu( + at::Tensor& mat1, at::Tensor& mat2, at::Tensor& scales2, const std::optional& bias, bool is_vnni) { + RECORD_FUNCTION("sgl-kernel::mxfp4_scaled_mm_cpu", std::vector({mat1, mat2, scales2, bias})); + + auto packed_w = is_vnni ? mat2 : convert_weight_packed(mat2); + + CHECK_INPUT(mat1); + CHECK_INPUT(mat2); + CHECK_INPUT(scales2); + + int64_t M = mat1.size(0); + int64_t N = mat2.size(0); + int64_t K = mat2.size(1) * 2; + + // mxfp4 supports only group size of 32 (2^5) + constexpr int64_t group_size = 32; + constexpr int64_t BLOCK_N = block_size_n(); + + CHECK_EQ(mat1.size(1), K); + CHECK_EQ(scales2.numel(), N * K >> 5); + + const auto st = mat1.scalar_type(); + TORCH_CHECK(st == at::kBFloat16 || st == at::kHalf, "mxfp4_scaled_mm_cpu: expect A to be bfloat16 or half."); + TORCH_CHECK(mat2.scalar_type() == at::kByte, "mxfp4_scaled_mm_cpu: expect mat2 to be uint8."); + TORCH_CHECK(scales2.scalar_type() == at::kByte, "mxfp4_scaled_mm_cpu: expect scales to be uint8."); + auto out = at::empty({M, N}, mat1.options()); + + auto buffer = alloc_thread_buffer(mat1.options(), K); + + AT_DISPATCH_REDUCED_FLOATING_TYPES(st, "mxfp4_scaled_mm_kernel_impl", [&] { + // used for lambda computing scale offset for each block + // mxfp4 block gemm sale shape: [N/BLOCK_N, K/32, BLOCK_N] + // for each block: [K/32, BLOCK_N] + const int64_t s_strideN = (K >> 5) * BLOCK_N; + + fp_scaled_mm_kernel_impl( + out.data_ptr(), + mat1.data_ptr(), + packed_w.data_ptr(), + scales2.data_ptr(), + get_bias_data(bias, N), + buffer.data_ptr(), + M, + N, + K, + mat1.stride(0), + out.stride(0), + /* block_size_N */ 1, + /* block_size_K */ group_size, + buffer.size(-1), + [&](int64_t nb) { return nb * s_strideN; }); }); return out; diff --git a/sgl-kernel/csrc/cpu/rope.cpp b/sgl-kernel/csrc/cpu/rope.cpp index 7efc816b46d2..6646d17b69e3 100644 --- a/sgl-kernel/csrc/cpu/rope.cpp +++ b/sgl-kernel/csrc/cpu/rope.cpp @@ -169,6 +169,214 @@ void rotary_embedding_neox_4D_kernel_impl( } } +template +void apply_rotary_pos_emb_kernel_impl( + scalar_t* __restrict__ query, + scalar_t* __restrict__ key, + float* __restrict__ cos, + float* __restrict__ sin, + int64_t query_stride_s, + int64_t key_stride_s, + int64_t num_heads, + int64_t num_kv_heads, + int64_t head_size, + int64_t num_tokens) { + using bVec = at::vec::Vectorized; + using fVec = at::vec::Vectorized; + constexpr int64_t bVecSize = bVec::size(); + constexpr int64_t fVecSize = fVec::size(); + + int64_t embed_dim = head_size / 2; + bool flag = (embed_dim % bVecSize == 0); + int64_t loop_upper = flag ? embed_dim : embed_dim - bVecSize; + + auto compute_loop = [&](int64_t token_head, float* cos_ptr, float* sin_ptr, scalar_t* qk) { + int64_t j = 0; + for (; j < loop_upper; j += bVecSize) { + int64_t rot_offset = j; + int64_t x_index = rot_offset; + int64_t y_index = embed_dim + rot_offset; + + int64_t out_x = token_head + x_index; + int64_t out_y = token_head + y_index; + + fVec _cos_x_0 = fVec::loadu(cos_ptr + x_index); + fVec _sin_x_0 = fVec::loadu(sin_ptr + x_index); + fVec _cos_x_1 = fVec::loadu(cos_ptr + x_index + fVecSize); + fVec _sin_x_1 = fVec::loadu(sin_ptr + x_index + fVecSize); + + fVec _cos_y_0 = fVec::loadu(cos_ptr + y_index); + fVec _sin_y_0 = fVec::loadu(sin_ptr + y_index); + fVec _cos_y_1 = fVec::loadu(cos_ptr + y_index + fVecSize); + fVec _sin_y_1 = fVec::loadu(sin_ptr + y_index + fVecSize); + + bVec _q_x = bVec::loadu(qk + out_x); + bVec _q_y = bVec::loadu(qk + out_y); + fVec _q_x_0, _q_x_1; + std::tie(_q_x_0, _q_x_1) = at::vec::convert_to_float(_q_x); + fVec _q_y_0, _q_y_1; + std::tie(_q_y_0, _q_y_1) = at::vec::convert_to_float(_q_y); + + auto out1_0 = _q_x_0 * _cos_x_0 - _q_y_0 * _sin_x_0; + auto out1_1 = _q_x_1 * _cos_x_1 - _q_y_1 * _sin_x_1; + auto out1 = convert_from_float_ext(out1_0, out1_1); + out1.store(qk + out_x); + + auto out2_0 = _q_y_0 * _cos_y_0 + _q_x_0 * _sin_y_0; + auto out2_1 = _q_y_1 * _cos_y_1 + _q_x_1 * _sin_y_1; + auto out2 = convert_from_float_ext(out2_0, out2_1); + out2.store(qk + out_y); + } + if (!flag) { + for (; j < embed_dim; ++j) { + int64_t x_index = j; + int64_t y_index = embed_dim + j; + + int64_t out_x = token_head + x_index; + int64_t out_y = token_head + y_index; + + float _cos_x = cos_ptr[x_index]; + float _sin_x = sin_ptr[x_index]; + float _cos_y = cos_ptr[y_index]; + float _sin_y = sin_ptr[y_index]; + + float _q_x = qk[out_x]; + float _q_y = qk[out_y]; + + qk[out_x] = _q_x * _cos_x - _q_y * _sin_x; + qk[out_y] = _q_y * _cos_y + _q_x * _sin_y; + } + } + }; + + at::parallel_for(0, num_tokens, 0, [&](int64_t begin, int64_t end) { + int64_t token_idx = {0}; + data_index_init(begin, token_idx, num_tokens); + for (int i = begin; i < end; ++i) { + float* cos_ptr = cos + token_idx * head_size; + float* sin_ptr = sin + token_idx * head_size; + + for (int64_t i = 0; i < num_heads; ++i) { + int64_t head_idx = i; + int64_t token_head = token_idx * query_stride_s + head_idx * head_size; + compute_loop(token_head, cos_ptr, sin_ptr, query); + } + + for (int64_t i = 0; i < num_kv_heads; ++i) { + int64_t head_idx = i; + int64_t token_head = token_idx * key_stride_s + head_idx * head_size; + compute_loop(token_head, cos_ptr, sin_ptr, key); + } + data_index_step(token_idx, num_tokens); + } + }); +} + +template +void apply_rotary_pos_emb_kernel_impl( + scalar_t* __restrict__ query, + scalar_t* __restrict__ key, + scalar_t* __restrict__ cos, + scalar_t* __restrict__ sin, + int64_t query_stride_s, + int64_t key_stride_s, + int64_t num_heads, + int64_t num_kv_heads, + int64_t head_size, + int64_t num_tokens) { + using bVec = at::vec::Vectorized; + using fVec = at::vec::Vectorized; + constexpr int64_t bVecSize = bVec::size(); + + int64_t embed_dim = head_size / 2; + bool flag = (embed_dim % bVecSize == 0); + int64_t loop_upper = flag ? embed_dim : embed_dim - bVecSize; + + auto compute_loop = [&](int64_t token_head, scalar_t* cos_ptr, scalar_t* sin_ptr, scalar_t* qk) { + int64_t j = 0; + for (; j < loop_upper; j += bVecSize) { + int64_t rot_offset = j; + int64_t x_index = rot_offset; + int64_t y_index = embed_dim + rot_offset; + + int64_t out_x = token_head + x_index; + int64_t out_y = token_head + y_index; + + bVec _cos_x = bVec::loadu(cos_ptr + x_index); + bVec _sin_x = bVec::loadu(sin_ptr + x_index); + bVec _cos_y = bVec::loadu(cos_ptr + y_index); + bVec _sin_y = bVec::loadu(sin_ptr + y_index); + fVec _cos_x_0, _cos_x_1; + std::tie(_cos_x_0, _cos_x_1) = at::vec::convert_to_float(_cos_x); + fVec _sin_x_0, _sin_x_1; + std::tie(_sin_x_0, _sin_x_1) = at::vec::convert_to_float(_sin_x); + fVec _cos_y_0, _cos_y_1; + std::tie(_cos_y_0, _cos_y_1) = at::vec::convert_to_float(_cos_y); + fVec _sin_y_0, _sin_y_1; + std::tie(_sin_y_0, _sin_y_1) = at::vec::convert_to_float(_sin_y); + + bVec _q_x = bVec::loadu(qk + out_x); + bVec _q_y = bVec::loadu(qk + out_y); + fVec _q_x_0, _q_x_1; + std::tie(_q_x_0, _q_x_1) = at::vec::convert_to_float(_q_x); + fVec _q_y_0, _q_y_1; + std::tie(_q_y_0, _q_y_1) = at::vec::convert_to_float(_q_y); + + auto out1_0 = _q_x_0 * _cos_x_0 - _q_y_0 * _sin_x_0; + auto out1_1 = _q_x_1 * _cos_x_1 - _q_y_1 * _sin_x_1; + auto out1 = convert_from_float_ext(out1_0, out1_1); + out1.store(qk + out_x); + + auto out2_0 = _q_y_0 * _cos_y_0 + _q_x_0 * _sin_y_0; + auto out2_1 = _q_y_1 * _cos_y_1 + _q_x_1 * _sin_y_1; + auto out2 = convert_from_float_ext(out2_0, out2_1); + out2.store(qk + out_y); + } + if (!flag) { + for (; j < embed_dim; ++j) { + int64_t x_index = j; + int64_t y_index = embed_dim + j; + + int64_t out_x = token_head + x_index; + int64_t out_y = token_head + y_index; + + float _cos_x = cos_ptr[x_index]; + float _sin_x = sin_ptr[x_index]; + float _cos_y = cos_ptr[y_index]; + float _sin_y = sin_ptr[y_index]; + + float _q_x = qk[out_x]; + float _q_y = qk[out_y]; + + qk[out_x] = _q_x * _cos_x - _q_y * _sin_x; + qk[out_y] = _q_y * _cos_y + _q_x * _sin_y; + } + } + }; + + at::parallel_for(0, num_tokens, 0, [&](int64_t begin, int64_t end) { + int64_t token_idx = {0}; + data_index_init(begin, token_idx, num_tokens); + for (int i = begin; i < end; ++i) { + scalar_t* cos_ptr = cos + token_idx * head_size; + scalar_t* sin_ptr = sin + token_idx * head_size; + + for (int64_t i = 0; i < num_heads; ++i) { + int64_t head_idx = i; + int64_t token_head = token_idx * query_stride_s + head_idx * head_size; + compute_loop(token_head, cos_ptr, sin_ptr, query); + } + + for (int64_t i = 0; i < num_kv_heads; ++i) { + int64_t head_idx = i; + int64_t token_head = token_idx * key_stride_s + head_idx * head_size; + compute_loop(token_head, cos_ptr, sin_ptr, key); + } + data_index_step(token_idx, num_tokens); + } + }); +} + template inline scalar_t* get_cache_ptr( int64_t j, @@ -561,6 +769,68 @@ std::tuple rotary_embedding_cpu( return std::make_tuple(query_out, key_out); } +// query: [num_tokens, num_heads, head_size] +// key: [num_tokens, num_heads, head_size] +// cos: [num_tokens, head_size] +// sin: [num_tokens, head_size] +std::tuple +apply_rotary_pos_emb_cpu(at::Tensor& query, at::Tensor& key, at::Tensor& cos, at::Tensor& sin) { + RECORD_FUNCTION("sgl-kernel::apply_rotary_pos_emb_cpu", std::vector({query, key})); + CHECK_LAST_DIM_CONTIGUOUS_INPUT(query); + CHECK_LAST_DIM_CONTIGUOUS_INPUT(key); + CHECK_INPUT(cos); + CHECK_INPUT(sin); + CHECK_DIM(3, query); + CHECK_DIM(3, key); + CHECK_DIM(2, cos); + CHECK_DIM(2, sin); + const auto input_dtype = query.scalar_type(); + int64_t num_tokens = query.size(0); + CHECK_EQ(num_tokens, key.size(0)); + CHECK_EQ(num_tokens, cos.size(0)); + CHECK_EQ(num_tokens, sin.size(0)); + int64_t num_heads = query.size(1); + CHECK_EQ(num_heads, key.size(1)); + int64_t head_size = query.size(2); + CHECK_EQ(head_size, key.size(2)); + CHECK_EQ(head_size, cos.size(1)); + CHECK_EQ(head_size, sin.size(1)); + int64_t q_stride_s = query.stride(0); + int64_t k_stride_s = key.stride(0); + TORCH_CHECK(input_dtype == key.scalar_type(), "query and key must have the same data type"); + AT_DISPATCH_REDUCED_FLOATING_TYPES(query.scalar_type(), "apply_rotary_pos_emb_cpu", [&] { + if (cos.scalar_type() == at::kFloat && sin.scalar_type() == at::kFloat) { + apply_rotary_pos_emb_kernel_impl( + query.data_ptr(), + key.data_ptr(), + cos.data_ptr(), + sin.data_ptr(), + q_stride_s, + k_stride_s, + num_heads, + num_heads, + head_size, + num_tokens); + } else if (cos.scalar_type() == input_dtype && sin.scalar_type() == input_dtype) { + apply_rotary_pos_emb_kernel_impl( + query.data_ptr(), + key.data_ptr(), + cos.data_ptr(), + sin.data_ptr(), + q_stride_s, + k_stride_s, + num_heads, + num_heads, + head_size, + num_tokens); + } else { + TORCH_CHECK( + false, "cos and sin must have the same data type, and must be either float or the same type as query/key"); + } + }); + return std::make_tuple(query, key); +} + // positions: [num_tokens] (text only) or [3, num_tokens] (T/H/W positions with multimodal inputs) // query: [num_tokens, num_heads * head_size] // key: [num_tokens, num_kv_heads * head_size] diff --git a/sgl-kernel/csrc/cpu/topk.cpp b/sgl-kernel/csrc/cpu/topk.cpp index 0471661e58a7..100e87a7c9ce 100644 --- a/sgl-kernel/csrc/cpu/topk.cpp +++ b/sgl-kernel/csrc/cpu/topk.cpp @@ -227,11 +227,9 @@ void topk_softmax_kernel_impl( queue[e] = {scores[e], e}; } - std::partial_sort( - queue.begin(), - queue.begin() + num_experts_per_group, - queue.end(), - [](const elem_t& x, const elem_t& y) -> bool { return x.first > y.first; }); + std::partial_sort(queue.begin(), queue.begin() + topk, queue.end(), [](const elem_t& x, const elem_t& y) -> bool { + return x.first > y.first; + }); for (int64_t j = 0; j < topk; ++j) { topk_weights[i * topk + j] = queue[j].first; diff --git a/sgl-kernel/csrc/cpu/torch_extension_cpu.cpp b/sgl-kernel/csrc/cpu/torch_extension_cpu.cpp index 88f4228a50fc..31a6d95fd7f0 100644 --- a/sgl-kernel/csrc/cpu/torch_extension_cpu.cpp +++ b/sgl-kernel/csrc/cpu/torch_extension_cpu.cpp @@ -143,6 +143,9 @@ std::tuple chunk_gated_delta_rule_cpu( // weight prepack at::Tensor convert_weight_packed(at::Tensor& weight); +// scale prepack for mxfp4 +at::Tensor convert_scale_packed(at::Tensor& scale); + // quant std::tuple per_token_quant_int8_cpu(at::Tensor& A); @@ -178,6 +181,10 @@ at::Tensor fp8_scaled_mm_cpu( at::ScalarType out_dtype, bool is_vnni); +// mxfp4 gemm +at::Tensor mxfp4_scaled_mm_cpu( + at::Tensor& mat1, at::Tensor& mat2, at::Tensor& scales2, const std::optional& bias, bool is_vnni); + // quant + igemm at::Tensor int8_scaled_mm_with_quant( at::Tensor& mat1, @@ -318,6 +325,8 @@ std::tuple rotary_embedding_cpu( int64_t head_size, at::Tensor& cos_sin_cache, bool is_neox); +std::tuple +apply_rotary_pos_emb_cpu(at::Tensor& query, at::Tensor& key, at::Tensor& cos, at::Tensor& sin); // mrope std::tuple multimodal_rotary_embedding_cpu( @@ -463,6 +472,10 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { m.def("convert_weight_packed(Tensor weight) -> Tensor"); m.impl("convert_weight_packed", torch::kCPU, &convert_weight_packed); + // scale prepack for mxfp4 + m.def("convert_scale_packed(Tensor scale) -> Tensor"); + m.impl("convert_scale_packed", torch::kCPU, &convert_scale_packed); + // quant m.def("per_token_quant_int8_cpu(Tensor A) -> (Tensor, Tensor)"); m.impl("per_token_quant_int8_cpu", torch::kCPU, &per_token_quant_int8_cpu); @@ -488,6 +501,10 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { "out_dtype, bool is_vnni) -> Tensor"); m.impl("fp8_scaled_mm_cpu", torch::kCPU, &fp8_scaled_mm_cpu); + // mxfp4 gemm + m.def("mxfp4_scaled_mm_cpu(Tensor mat1, Tensor mat2, Tensor scales2, Tensor? bias, bool is_vnni) -> Tensor"); + m.impl("mxfp4_scaled_mm_cpu", torch::kCPU, &mxfp4_scaled_mm_cpu); + // quant + igemm m.def( "int8_scaled_mm_with_quant(Tensor mat1, Tensor mat2, Tensor scales2, Tensor? bias, ScalarType out_dtype, bool " @@ -572,6 +589,9 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { "rotary_embedding_cpu(Tensor positions, Tensor query, Tensor key, int head_size, Tensor cos_sin_cache, " "bool is_neox) -> (Tensor, Tensor)"); m.impl("rotary_embedding_cpu", torch::kCPU, &rotary_embedding_cpu); + m.def("apply_rotary_pos_emb_cpu(Tensor query, Tensor key, Tensor cos, Tensor sin) -> (Tensor, Tensor)"); + m.impl("apply_rotary_pos_emb_cpu", torch::kCPU, &apply_rotary_pos_emb_cpu); + // multimodal rope m.def( "multimodal_rotary_embedding_cpu(Tensor positions, Tensor query, Tensor key, int head_size, Tensor " diff --git a/sgl-kernel/csrc/cpu/vec.h b/sgl-kernel/csrc/cpu/vec.h index 107022ffd237..a37bc6ba2467 100644 --- a/sgl-kernel/csrc/cpu/vec.h +++ b/sgl-kernel/csrc/cpu/vec.h @@ -145,6 +145,49 @@ inline __attribute__((always_inline)) __m512bh CVT_FP8_TO_BF16_EXT(__m256i a) { // bias for conversion of fp8 to bf16 1/256 in float32 #define kFP8_BIAS 0x3b800000 +// remove warning: ignoring attributes on template argument ā€˜__m512bh’ [-Wignored-attributes] +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wignored-attributes" + +#define MXFP4_VALUES \ + -6.0f, -4.0f, -3.0f, -2.0f, -1.5f, -1.0f, -0.5f, -0.0f, 6.0f, 4.0f, 3.0f, 2.0f, 1.5f, 1.0f, 0.5f, 0.0f + +// convert 64 mxfp4 to 2x bf16 vectors, expect input 32-way packing +inline std::tuple<__m512bh, __m512bh> cvt_mxfp4_e2m1_bf16_intrinsic_lut(__m256i a, __m512i s0, __m512i s1) { + // LUT + const __m512 values = _mm512_set_ps(MXFP4_VALUES); + const __m512i lut = (__m512i)(_mm512_cvtne2ps_pbh(values, values)); + + const __m512i abs_mask = _mm512_set1_epi16(0x7FFF); + const __m512i zero = _mm512_setzero_si512(); + + // expand values to 16-bit integers + __m512i x0 = _mm512_cvtepu8_epi16(a); + __m512i x1 = _mm512_srli_epi32(x0, 4); + + // LUT to convert mxfp4 values to bf16 + x0 = _mm512_permutexvar_epi16(x0, lut); + x1 = _mm512_permutexvar_epi16(x1, lut); + + // check for zeros + __mmask32 mask0 = _mm512_cmp_epi16_mask(_mm512_and_si512(x0, abs_mask), zero, _MM_CMPINT_EQ); + __mmask32 mask1 = _mm512_cmp_epi16_mask(_mm512_and_si512(x1, abs_mask), zero, _MM_CMPINT_EQ); + + // emulate bf16 mul with scale factor + x0 = _mm512_add_epi16(x0, s0); + x1 = _mm512_add_epi16(x1, s1); + + // blend with zero + x0 = _mm512_mask_blend_epi16(mask0, x0, zero); + x1 = _mm512_mask_blend_epi16(mask1, x1, zero); + + return std::make_tuple(__m512bh(x0), __m512bh(x1)); +} + +#define CVT_MXFP4_TO_BF16(a, s0, s1) cvt_mxfp4_e2m1_bf16_intrinsic_lut(a, s0, s1) + +#pragma GCC diagnostic pop + #endif // vector to scalar reduction diff --git a/sgl-kernel/csrc/elementwise/cast.cu b/sgl-kernel/csrc/elementwise/cast.cu deleted file mode 100644 index a6a3b31a1d1a..000000000000 --- a/sgl-kernel/csrc/elementwise/cast.cu +++ /dev/null @@ -1,172 +0,0 @@ -#include - -#include "utils.h" - -template -struct ConvertToFP8 { - static __device__ __nv_fp8_storage_t convert_to_fp8(T value) { - return 0; - } -}; - -template <> -struct ConvertToFP8<__nv_bfloat16> { - static __device__ __nv_fp8_storage_t convert_to_fp8(__nv_bfloat16 value) { - return __nv_cvt_bfloat16raw_to_fp8(value, __NV_SATFINITE, __NV_E4M3); - } -}; - -template <> -struct ConvertToFP8 { - static __device__ __nv_fp8_storage_t convert_to_fp8(half value) { - return __nv_cvt_halfraw_to_fp8(value, __NV_SATFINITE, __NV_E4M3); - } -}; - -template -struct ConvertFromFloat { - static __device__ T convert_from_float(float value) { - return 0; - } -}; - -template <> -struct ConvertFromFloat<__nv_bfloat16> { - static __device__ __nv_bfloat16 convert_from_float(float value) { - return __float2bfloat16(value); - } -}; - -template <> -struct ConvertFromFloat { - static __device__ half convert_from_float(float value) { - return __float2half(value); - } -}; - -template -__global__ void fused_downcast_kernel( - const T* cache_k, - const T* cache_v, - const float* k_scale, - const float* v_scale, - __nv_fp8_storage_t* output_k, - __nv_fp8_storage_t* output_v, - const int input_sl, - const int head, - const int dim, - const T max_fp8, - const T min_fp8, - const int64_t mult, - const int64_t offset, - const int64_t* loc) { - // TODO: change name - int token_idx = blockIdx.x; - int thread_idx = threadIdx.x; - int total_threads = blockDim.x; - - T k_scale_val = ConvertFromFloat::convert_from_float(k_scale[0]); - T v_scale_val = ConvertFromFloat::convert_from_float(v_scale[0]); - - T k_scale_inv = static_cast(1.f) / k_scale_val; - T v_scale_inv = static_cast(1.f) / v_scale_val; - - auto clamp = [&](T val) { return val > max_fp8 ? max_fp8 : (min_fp8 > val ? min_fp8 : val); }; - - if (token_idx < input_sl) { - int out_seq_idx = loc[token_idx]; - -#pragma unroll - for (int i = thread_idx; i < head * dim; i += total_threads) { - int in_idx = token_idx * head * dim + i; - int out_idx = (out_seq_idx * mult + offset) * head * dim + i; - - T k_val = cache_k[in_idx] * k_scale_inv; - k_val = clamp(k_val); - output_k[out_idx] = ConvertToFP8::convert_to_fp8(k_val); - - T v_val = cache_v[in_idx] * v_scale_inv; - v_val = clamp(v_val); - output_v[out_idx] = ConvertToFP8::convert_to_fp8(v_val); - } - } -} - -template -void downcast_fp8_impl( - at::Tensor& k, - at::Tensor& v, - at::Tensor& k_out, - at::Tensor& v_out, - at::Tensor& k_scale, - at::Tensor& v_scale, - at::Tensor& loc, - int64_t mult, - int64_t offset, - cudaStream_t stream) { - CHECK_INPUT(k); - CHECK_INPUT(v); - CHECK_INPUT(k_out); - CHECK_INPUT(v_out); - CHECK_INPUT(k_scale); - CHECK_INPUT(v_scale); - CHECK_INPUT(loc); - - int64_t input_sl = k.size(0); - int64_t head = k.size(1); - int64_t dim = k.size(2); - - dim3 grid(input_sl * head); - int vec_size = 8; - dim3 block(std::min(int(dim) / vec_size, 1024)); - - const T max_fp8 = static_cast(FP8_E4M3_MAX); - const T min_fp8 = static_cast(-FP8_E4M3_MAX); - - fused_downcast_kernel<<>>( - static_cast(k.data_ptr()), - static_cast(v.data_ptr()), - static_cast(k_scale.data_ptr()), - static_cast(v_scale.data_ptr()), - static_cast<__nv_fp8_storage_t*>(k_out.data_ptr()), - static_cast<__nv_fp8_storage_t*>(v_out.data_ptr()), - input_sl, - head, - dim, - max_fp8, - min_fp8, - mult, - offset, - static_cast(loc.data_ptr())); - - cudaError_t status = cudaGetLastError(); - TORCH_CHECK(status == cudaSuccess, "Kernel launch failed: " + std::string(cudaGetErrorString(status))); -} - -void downcast_fp8( - at::Tensor& k, - at::Tensor& v, - at::Tensor& k_out, - at::Tensor& v_out, - at::Tensor& k_scale, - at::Tensor& v_scale, - at::Tensor& loc, - int64_t mult, - int64_t offset) { - CHECK_INPUT(k); - CHECK_INPUT(v); - CHECK_INPUT(k_out); - CHECK_INPUT(v_out); - - cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - switch (k.scalar_type()) { - case at::ScalarType::BFloat16: - downcast_fp8_impl<__nv_bfloat16>(k, v, k_out, v_out, k_scale, v_scale, loc, mult, offset, stream); - break; - case at::ScalarType::Half: - downcast_fp8_impl<__half>(k, v, k_out, v_out, k_scale, v_scale, loc, mult, offset, stream); - break; - default: - TORCH_CHECK(false, "Unsupported input type for downcast_fp8. Expected bfloat16 or float16."); - } -} diff --git a/sgl-kernel/csrc/gemm/marlin/marlin_template.h b/sgl-kernel/csrc/gemm/marlin/marlin_template.h index 01eb338782c4..19f5d5477c4e 100644 --- a/sgl-kernel/csrc/gemm/marlin/marlin_template.h +++ b/sgl-kernel/csrc/gemm/marlin/marlin_template.h @@ -487,11 +487,11 @@ __global__ void Marlin( constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta; // Scale sizes/strides without act_order - int s_gl_stride = prob_n / 8; - constexpr int s_sh_stride = 16 * thread_n_blocks / 8; - constexpr int s_tb_groups = !has_act_order && group_blocks != -1 && group_blocks < thread_k_blocks - ? thread_k_blocks / group_blocks / (w_type == sglang::kFE2M1f ? 2 : 1) - : 1; + // FP4 (kFE2M1f) uses FP8 scales (1 byte/element), others use FP16 (2 bytes) + int s_gl_stride = prob_n / (w_type == sglang::kFE2M1f ? 16 : 8); + constexpr int s_sh_stride = 16 * thread_n_blocks / (w_type == sglang::kFE2M1f ? 16 : 8); + constexpr int s_tb_groups = + !has_act_order && group_blocks != -1 && group_blocks < thread_k_blocks ? thread_k_blocks / group_blocks : 1; constexpr int s_sh_stage = s_tb_groups * s_sh_stride; int s_gl_rd_delta = s_gl_stride; @@ -543,8 +543,7 @@ __global__ void Marlin( if constexpr (group_blocks == -1) { s_gl_rd = s_sh_stride * slice_col + threadIdx.x; } else { - s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) / (w_type == sglang::kFE2M1f ? 2 : 1) + - s_sh_stride * slice_col + threadIdx.x; + s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) + s_sh_stride * slice_col + threadIdx.x; } } auto s_sh_wr = threadIdx.x; @@ -566,15 +565,7 @@ __global__ void Marlin( // we scale a `half2` tile in column-major layout in the former and in // row-major in the latter case. int s_sh_rd; - if constexpr (group_blocks != -1 && w_type == sglang::kFE2M1f) { - auto warp_id = threadIdx.x / 32; - int n_warps = thread_n_blocks / 4; - int warp_row = warp_id / n_warps; - - s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + (threadIdx.x % 32) / 4; - s_sh_rd = s_sh_rd * 2 + warp_row % 2; - - } else if constexpr (group_blocks != -1) + if constexpr (group_blocks != -1) s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + (threadIdx.x % 32) / 4; else if constexpr (group_blocks == -1 && (m_block_size_8 || (has_zp && !dequant_skip_flop))) s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + (threadIdx.x % 32) / 8; @@ -879,7 +870,7 @@ __global__ void Marlin( cur_k += k_iter_size * (k % b_sh_wr_iters); int k_blocks = cur_k / 16; - int cur_group_id = k_blocks / (group_blocks * (w_type == sglang::kFE2M1f ? 2 : 1)); + int cur_group_id = k_blocks / group_blocks; int4* sh_s_stage = sh_s + s_sh_stage * pipe; diff --git a/sgl-kernel/include/sgl_kernel_ops.h b/sgl-kernel/include/sgl_kernel_ops.h index 8bb8f4684999..77068fb8d621 100644 --- a/sgl-kernel/include/sgl_kernel_ops.h +++ b/sgl-kernel/include/sgl_kernel_ops.h @@ -103,8 +103,6 @@ void mscclpp_allreduce(fptr_t _context, torch::Tensor& inp, torch::Tensor& out, /* * From csrc/attention */ -void merge_state( - at::Tensor v_a, at::Tensor s_a, at::Tensor v_b, at::Tensor s_b, at::Tensor v_merged, at::Tensor s_merged); void merge_state_v2( at::Tensor v_a, at::Tensor s_a, at::Tensor v_b, at::Tensor s_b, at::Tensor v_merged, at::Tensor s_merged); void cutlass_mla_decode( @@ -143,17 +141,6 @@ void rotary_embedding( torch::Tensor& cos_sin_cache, bool is_neox); -void downcast_fp8( - at::Tensor& k, - at::Tensor& v, - at::Tensor& k_out, - at::Tensor& v_out, - at::Tensor& k_scale, - at::Tensor& v_scale, - at::Tensor& loc, - int64_t mult, - int64_t offset); - void copy_to_gpu_no_ce(const at::Tensor& input, at::Tensor& output); void concat_mla_k(torch::Tensor k, torch::Tensor k_nope, torch::Tensor k_rope); void concat_mla_absorb_q(at::Tensor a, at::Tensor b, at::Tensor out); @@ -604,9 +591,6 @@ void top_k_renorm_probs( void top_p_renorm_probs( at::Tensor probs, at::Tensor renorm_probs, std::optional maybe_top_p_arr, double top_p_val); -void top_k_mask_logits( - at::Tensor logits, at::Tensor mask_logits, std::optional maybe_top_k_arr, int64_t top_k_val); - namespace flash { /* * From fa2 sparse diff --git a/sgl-kernel/pyproject.toml b/sgl-kernel/pyproject.toml index 9affd67d0603..0ae708604025 100644 --- a/sgl-kernel/pyproject.toml +++ b/sgl-kernel/pyproject.toml @@ -8,7 +8,7 @@ build-backend = "scikit_build_core.build" [project] name = "sglang-kernel" -version = "0.4.0" +version = "0.4.1" authors = [ { name="SGLang Kernel Team", email="sglang@lmsys.org" }, ] diff --git a/sgl-kernel/pyproject_cpu.toml b/sgl-kernel/pyproject_cpu.toml index 05b618d1a67f..a5588a467b57 100644 --- a/sgl-kernel/pyproject_cpu.toml +++ b/sgl-kernel/pyproject_cpu.toml @@ -8,7 +8,7 @@ build-backend = "scikit_build_core.build" [project] name = "sglang-kernel-cpu" -version = "0.4.0" +version = "0.4.1" description = "Kernel Library for SGLang" readme = "README.md" requires-python = ">=3.10" diff --git a/sgl-kernel/pyproject_musa.toml b/sgl-kernel/pyproject_musa.toml index 669927be3dc7..5692b48db8c7 100644 --- a/sgl-kernel/pyproject_musa.toml +++ b/sgl-kernel/pyproject_musa.toml @@ -3,14 +3,14 @@ requires = [ "setuptools>=75.0", "scikit-build-core>=0.10", "torch", - "torchada>=0.1.14", + "torchada>=0.1.45", "wheel", ] build-backend = "setuptools.build_meta" [project] name = "sglang-kernel" -version = "0.4.0" +version = "0.4.1" description = "Kernel Library for SGLang" readme = "README.md" requires-python = ">=3.10" diff --git a/sgl-kernel/pyproject_rocm.toml b/sgl-kernel/pyproject_rocm.toml index cc117979874d..f615f7be78d9 100644 --- a/sgl-kernel/pyproject_rocm.toml +++ b/sgl-kernel/pyproject_rocm.toml @@ -9,7 +9,7 @@ build-backend = "setuptools.build_meta" [project] name = "sglang-kernel" -version = "0.4.0" +version = "0.4.1" description = "Kernel Library for SGLang" readme = "README.md" requires-python = ">=3.10" diff --git a/sgl-kernel/python/sgl_kernel/__init__.py b/sgl-kernel/python/sgl_kernel/__init__.py index b5dbca95e021..ed08e3f0faee 100644 --- a/sgl-kernel/python/sgl_kernel/__init__.py +++ b/sgl-kernel/python/sgl_kernel/__init__.py @@ -14,7 +14,6 @@ from sgl_kernel.attention import ( cutlass_mla_decode, cutlass_mla_get_workspace_size, - merge_state, merge_state_v2, ) from sgl_kernel.cutlass_moe import cutlass_w4a8_moe_mm, get_cutlass_w4a8_moe_mm_data @@ -22,7 +21,6 @@ concat_mla_absorb_q, concat_mla_k, copy_to_gpu_no_ce, - downcast_fp8, fused_add_rmsnorm, gelu_and_mul, gelu_tanh_and_mul, @@ -92,7 +90,6 @@ ggml_mul_mat_vec_a8, ) from sgl_kernel.sampling import ( - top_k_mask_logits, top_k_renorm_prob, top_p_renorm_prob, ) @@ -128,7 +125,6 @@ "copy_to_gpu_no_ce", "cutlass_mla_decode", "cutlass_mla_get_workspace_size", - "downcast_fp8", "dsv3_fused_a_gemm", "dsv3_router_gemm", "es_fp8_blockwise_scaled_grouped_mm", @@ -151,7 +147,6 @@ "gptq_shuffle", "int8_scaled_mm", "kimi_k2_moe_fused_gate", - "merge_state", "merge_state_v2", "moe_align_block_size", "moe_fused_gate", @@ -170,7 +165,6 @@ "sgl_per_token_quant_fp8", "shuffle_rows", "silu_and_mul", - "top_k_mask_logits", "top_k_renorm_prob", "top_p_renorm_prob", "topk_sigmoid", diff --git a/sgl-kernel/python/sgl_kernel/_fa4_interface.py b/sgl-kernel/python/sgl_kernel/_fa4_interface.py deleted file mode 100644 index 1b6ab5305960..000000000000 --- a/sgl-kernel/python/sgl_kernel/_fa4_interface.py +++ /dev/null @@ -1,940 +0,0 @@ -# Adapted from https://github.com/Dao-AILab/flash-attention/blob/5d4c9537a1e0f1adcc3e4c3e11ae46fe94a18b11/flash_attn/cute/interface.py - -# Copyright (c) 2025, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -# [2025-10-14] Version in Cute-DSL, for Hopper and Blackwell. You'd need to install nvidia-cutlass-dsl==4.2.1. - - -import copy -import gc -import logging -import math -import os -from functools import lru_cache -from typing import Callable, Optional, Tuple - -logger = logging.getLogger(__name__) - - -import cuda.bindings.driver as cuda -import cutlass -import cutlass.cute as cute -import torch -from cutlass.cute.runtime import from_dlpack -from flash_attn_origin.cute import utils -from flash_attn_origin.cute.block_sparsity import ( - BlockSparseTensorsTorch, - get_block_sparse_expected_shapes, - normalize_block_sparse_tensors, - to_cute_block_sparse_tensors, -) -from flash_attn_origin.cute.flash_fwd import FlashAttentionForwardSm90 -from flash_attn_origin.cute.flash_fwd_combine import FlashAttentionForwardCombine -from flash_attn_origin.cute.flash_fwd_sm100 import FlashAttentionForwardSm100 - - -@lru_cache(maxsize=None) -def _get_device_capability(): - """Cached device capability check.""" - return torch.cuda.get_device_capability()[0] - - -def maybe_contiguous(x): - return x.contiguous() if x is not None and x.stride(-1) != 1 else x - - -def _validate_tensor(t, name, expected_shape, expected_dtype, expected_device): - assert ( - t.shape == expected_shape - ), f"{name} shape {t.shape} != expected {expected_shape}" - assert ( - t.dtype == expected_dtype - ), f"{name} dtype {t.dtype} != expected {expected_dtype}" - assert ( - t.device == expected_device - ), f"{name} device {t.device} != expected {expected_device}" - assert t.is_cuda, f"{name} must be on CUDA" - - -def to_cute_tensor(t, assumed_align=16, leading_dim=-1, fully_dynamic=False): - """Convert torch tensor to cute tensor for TVM FFI. leading_dim=-1 defaults to t.ndim-1.""" - tensor = from_dlpack(t.detach(), assumed_align=assumed_align, enable_tvm_ffi=True) - if fully_dynamic: - return tensor.mark_layout_dynamic() - if leading_dim == -1: - leading_dim = t.ndim - 1 - return tensor.mark_layout_dynamic(leading_dim=leading_dim) - - -torch2cute_dtype_map = { - torch.float16: cutlass.Float16, - torch.bfloat16: cutlass.BFloat16, - torch.float32: cutlass.Float32, -} - - -def num_splits_heuristic(total_mblocks, num_SMs, num_n_blocks, max_splits): - # If num_n_blocks is too small, use 1 split. For example, we never split for hdim = 128 and seqlen_k = 512. - if num_n_blocks <= 4: - return 1 - - # NOTE: We should revisit this heuristic after persistence is supported for split KV. - # Sometimes, it's ideal to over-schedule splits for better efficiency. - return min(num_SMs // total_mblocks, max_splits, num_n_blocks) - - -def _flash_attn_fwd( - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - cu_seqlens_q: Optional[torch.Tensor] = None, - cu_seqlens_k: Optional[torch.Tensor] = None, - seqused_q: Optional[torch.Tensor] = None, - seqused_k: Optional[torch.Tensor] = None, - max_seqlen_q: Optional[int] = None, - max_seqlen_k: Optional[int] = None, - page_table: Optional[torch.Tensor] = None, - softmax_scale: Optional[float] = None, - causal: bool = False, - softcap: Optional[float] = None, - window_size_left: Optional[int] = None, - window_size_right: Optional[int] = None, - learnable_sink: Optional[torch.Tensor] = None, - # m_block_size: int = 128, - # n_block_size: int = 64, - # num_threads: int = 128, - m_block_size: int = 128, - n_block_size: int = 128, - num_threads: int = 384, - num_splits: int = 1, - pack_gqa: Optional[bool] = None, - _compute_capability: Optional[int] = None, - score_mod: Optional[Callable] = None, - mask_mod: Optional[Callable] = None, - block_sparse_tensors: Optional[BlockSparseTensorsTorch] = None, - return_lse: bool = False, - out: Optional[torch.Tensor] = None, - lse: Optional[torch.Tensor] = None, - aux_tensors: Optional[list[torch.Tensor]] = None, -) -> Tuple[torch.Tensor, torch.Tensor]: - """Forward pass for FlashAttention. - - Args: - ... - score_mod: A callable that takes the attention scores and applies a modification. - mask_mod: A callable that takes token position information and selectively masks - block_sparse_tensors: A tuple of tensors used for block sparsity. - return_lse: Whether to return the log softmax of the attention scores. If set to True will always calculate - out: Optional pre-allocated output tensor. If None, will be allocated internally. - lse: Optional pre-allocated log-sum-exp tensor. If None, will be allocated when needed. - aux_tensors: Some score_mods will want to read from global aux_tensors. This is how we thread them through to the inner kernel. - """ - q, k, v = [maybe_contiguous(t) for t in (q, k, v)] - num_head, head_dim = q.shape[-2:] - if cu_seqlens_q is None: - batch_size, seqlen_q = q.shape[:2] - total_q = batch_size * seqlen_q - else: - batch_size = cu_seqlens_q.shape[0] - 1 - seqlen_q = None - total_q = q.shape[0] - if page_table is not None: - assert cu_seqlens_k is None, "page_table is not supported with cu_seqlens_k" - assert page_table.dtype == torch.int32, "page_table must be int32" - assert ( - page_table.stride(-1) == 1 - ), "page_table must be contiguous in the last dimension" - max_num_pages_per_seq = page_table.shape[1] - assert page_table.shape == (batch_size, max_num_pages_per_seq) - num_pages, page_size = k.shape[:2] - seqlen_k = num_pages * page_size - else: - num_pages, page_size = None, None - seqlen_k = k.shape[-3] - num_head_kv = k.shape[-2] - head_dim_v = v.shape[-1] - if cu_seqlens_k is None: - if page_table is None: - assert k.shape == (batch_size, seqlen_k, num_head_kv, head_dim) - assert v.shape == (batch_size, seqlen_k, num_head_kv, head_dim_v) - else: - assert k.shape == (num_pages, page_size, num_head_kv, head_dim) - assert v.shape == (num_pages, page_size, num_head_kv, head_dim_v) - else: - assert k.shape == (seqlen_k, num_head_kv, head_dim) - assert v.shape == (seqlen_k, num_head_kv, head_dim_v) - assert cu_seqlens_k.shape == ( - batch_size + 1, - ), "cu_seqlens_k must have shape (batch_size + 1,)" - - if cu_seqlens_q is not None: - assert cu_seqlens_q.shape == ( - batch_size + 1, - ), "cu_seqlens_q must have shape (batch_size + 1,)" - assert seqused_q is None or seqused_q.shape == ( - batch_size, - ), "seqused_q must have shape (batch_size,)" - assert seqused_k is None or seqused_k.shape == ( - batch_size, - ), "seqused_k must have shape (batch_size,)" - assert q.dtype in [ - torch.float16, - torch.bfloat16, - ], "inputs must be float16 or bfloat16" - assert q.dtype == k.dtype == v.dtype, "inputs must have the same dtype" - for t in [cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k]: - if t is not None: - assert ( - t.dtype == torch.int32 - ), "cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k must be int32" - assert ( - t.stride(0) == 1 - ), "cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k must be contiguous" - if learnable_sink is not None: - assert learnable_sink.shape == (num_head,) - assert learnable_sink.dtype == torch.bfloat16, "learnable_sink must be bfloat16" - - assert all( - t is None or t.is_cuda - for t in ( - q, - k, - v, - cu_seqlens_q, - cu_seqlens_k, - seqused_q, - seqused_k, - page_table, - learnable_sink, - ) - ), "inputs must be on CUDA device" - assert num_head % num_head_kv == 0, "num_head must be divisible by num_head_kv" - assert head_dim <= 256, "head_dim must be less than or equal to 256" - alignment = 16 // q.element_size() - assert head_dim % alignment == 0, f"head_dim must be divisible by {alignment}" - assert head_dim_v % alignment == 0, f"head_dim_v must be divisible by {alignment}" - if softmax_scale is None: - softmax_scale = 1.0 / math.sqrt(head_dim) - if softcap == 0.0: - softcap = None - qhead_per_kvhead = num_head // num_head_kv - if pack_gqa is None: - pack_gqa = qhead_per_kvhead > 1 - - out_torch_dtype = q.dtype - device = q.device - q_batch_seqlen_shape = ( - (batch_size, seqlen_q) if cu_seqlens_q is None else (total_q,) - ) - lse_shape = ( - (batch_size, num_head, seqlen_q) - if cu_seqlens_q is None - else (num_head, total_q) - ) - requires_grad = q.requires_grad or k.requires_grad or v.requires_grad - - if out is None: - out = torch.empty( - *q_batch_seqlen_shape, - num_head, - head_dim_v, - dtype=out_torch_dtype, - device=device, - ) - else: - _validate_tensor( - out, - "out", - (*q_batch_seqlen_shape, num_head, head_dim_v), - out_torch_dtype, - device, - ) - - if lse is None: - lse = ( - torch.empty(lse_shape, dtype=torch.float32, device=device) - if requires_grad or return_lse - else None - ) - elif lse is not None: - _validate_tensor(lse, "lse", lse_shape, torch.float32, device) - - dtype = torch2cute_dtype_map[q.dtype] - compute_capability = ( - _get_device_capability() if _compute_capability is None else _compute_capability - ) - - assert compute_capability in [ - 9, - 10, - 11, - ], "Unsupported compute capability. Supported: 9.x, 10.x, 11.x" - - use_block_sparsity = block_sparse_tensors is not None - - if mask_mod is None: - if causal: - window_size_right = 0 - local = window_size_left is not None or window_size_right is not None - if window_size_left is not None or window_size_right is not None: - if window_size_left is None and window_size_right == 0: - causal, local = True, False - window_size_right = None - else: - causal, local = False, True - else: - causal, local = False, False - - current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) - - if compute_capability == 9: # TODO: tune block size according to hdim. - if ( - head_dim == head_dim_v == 128 - and not causal - and not local - and not use_block_sparsity - ): - n_block_size = 192 - - if compute_capability in [10, 11]: - if pack_gqa and (128 % qhead_per_kvhead != 0): - pack_gqa = False - # TODO: fix GQA + SplitKV + non-varlen - if pack_gqa and num_splits != 1 and cu_seqlens_q is None: - pack_gqa = False - - if max_seqlen_q is None: - max_seqlen_q = seqlen_q if cu_seqlens_q is None else total_q - if max_seqlen_k is None: - max_seqlen_k = seqlen_k - seqlen_q_packgqa = max_seqlen_q * qhead_per_kvhead - if compute_capability == 10: - q_stage = 2 if seqlen_q_packgqa > m_block_size else 1 - else: - q_stage = 1 - - if num_splits < 1: - m_block_size_effective = q_stage * m_block_size - seqlen_k_loaded = ( - max_seqlen_k - if not local - else max( - 0, - min( - max_seqlen_k, - window_size_right + window_size_left + 1 + m_block_size, - ), - ) - ) - num_n_blocks = (seqlen_k_loaded + n_block_size - 1) // n_block_size - num_m_blocks = ( - seqlen_q_packgqa + m_block_size_effective - 1 - ) // m_block_size_effective - total_mblocks = batch_size * num_head_kv * num_m_blocks - num_splits = num_splits_heuristic( - total_mblocks, - torch.cuda.get_device_properties(device).multi_processor_count, - num_n_blocks, - 128, - ) - - is_split_kv = num_splits > 1 - if is_split_kv: - out_partial = torch.empty( - num_splits, - *q_batch_seqlen_shape, - num_head, - head_dim_v, - dtype=torch.float32, - device=device, - ) - lse_partial = torch.empty( - num_splits, *lse_shape, dtype=torch.float32, device=device - ) - - # hash score and mask mods for compile cache - score_mod_hash = utils.hash_callable(score_mod) if score_mod is not None else False - mask_mod_hash = utils.hash_callable(mask_mod) if mask_mod is not None else False - - if softcap is not None: - assert score_mod is None, "softcap and score_mod cannot be used together" - score_mod = utils.create_softcap_scoremod(softcap) - - is_varlen = ( - cu_seqlens_q is not None - or cu_seqlens_k is not None - or seqused_q is not None - or seqused_k is not None - ) - - if mask_mod is not None: - if is_varlen: - raise NotImplementedError( - "mask_mod with aux_tensors is not yet supported for varlen sequences. This will be fixed in a future PR." - ) - - if use_block_sparsity: - if is_varlen: - raise NotImplementedError( - "Block sparsity is not yet supported for varlen sequences. This will be fixed in a future PR." - ) - # NB: pack_gqa requires block sparse head dim == 1 (broadcasted) - if pack_gqa and block_sparse_tensors.mask_block_cnt.shape[1] != 1: - pack_gqa = False - if is_split_kv: - raise NotImplementedError( - "Block sparsity is not yet supported with SplitKV. TODO: partition sparse block lists per split." - ) - - compile_key = ( - dtype, - head_dim, - head_dim_v, - qhead_per_kvhead, - causal, - score_mod_hash, - mask_mod_hash, - use_block_sparsity, - len(aux_tensors) if aux_tensors is not None else 0, - lse is None, - cu_seqlens_q is None, - cu_seqlens_k is None, - seqused_q is None, - seqused_k is None, - page_table is not None, - window_size_left is not None, - window_size_right is not None, - learnable_sink is not None, - m_block_size, - n_block_size, - q_stage, - num_threads, - is_split_kv, - pack_gqa, - compute_capability, - page_size not in [None, 128], # paged KV non-TMA - ) - if compile_key not in _flash_attn_fwd.compile_cache: - ( - cu_seqlens_q_tensor, - cu_seqlens_k_tensor, - seqused_q_tensor, - seqused_k_tensor, - learnable_sink_tensor, - ) = [ - to_cute_tensor(t, assumed_align=4, leading_dim=0) if t is not None else None - for t in (cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k, learnable_sink) - ] - page_table_tensor = ( - to_cute_tensor(page_table, assumed_align=4, leading_dim=1) - if page_table is not None - else None - ) - q_tensor, k_tensor, v_tensor, o_tensor = [ - to_cute_tensor(t) - for t in (q, k, v, out if not is_split_kv else out_partial) - ] - if is_split_kv: - lse_tensor = to_cute_tensor(lse_partial, assumed_align=4) - elif lse is not None: - lse_tensor = to_cute_tensor(lse, assumed_align=4) - else: - lse_tensor = None - - sparse_tensors = None - if block_sparse_tensors is not None: - if seqlen_q is None: - raise ValueError( - "Block sparsity requires fixed-length sequences (seqlen_q must be known)." - ) - expected_count_shape, expected_index_shape = ( - get_block_sparse_expected_shapes( - batch_size, - num_head, - seqlen_q, - seqlen_k, - m_block_size, - n_block_size, - q_stage, - ) - ) - compile_time_normalized = normalize_block_sparse_tensors( - block_sparse_tensors, - expected_count_shape=expected_count_shape, - expected_index_shape=expected_index_shape, - ) - sparse_tensors = to_cute_block_sparse_tensors(compile_time_normalized) - - cute_aux_tensors = None - if aux_tensors is not None: - cute_aux_tensors = [ - to_cute_tensor(buf, assumed_align=None, fully_dynamic=True) - for buf in aux_tensors - ] - - if compute_capability == 9: - assert page_table is None, "paged KV not supported on SM 9.0" - assert not is_split_kv, "SplitKV not supported on SM 9.0" - # fa_fwd = FlashAttentionForwardSm80( - fa_fwd = FlashAttentionForwardSm90( - dtype, - head_dim, - head_dim_v, - qhead_per_kvhead, - is_causal=causal, - is_local=local, - pack_gqa=pack_gqa, - tile_m=m_block_size, - tile_n=n_block_size, - # num_stages=1, - num_stages=2, - num_threads=num_threads, - Q_in_regs=False, - intra_wg_overlap=True, - mma_pv_is_rs=True, - mask_mod=mask_mod, - score_mod=score_mod, - has_aux_tensors=aux_tensors is not None, - ) - elif compute_capability in [10, 11]: - fa_fwd = FlashAttentionForwardSm100( - head_dim, - head_dim_v, - qhead_per_kvhead=qhead_per_kvhead, - is_causal=causal, - is_local=local, - is_split_kv=is_split_kv, - pack_gqa=pack_gqa, - m_block_size=m_block_size, - n_block_size=n_block_size, - q_stage=q_stage, - is_persistent=not causal - and not local - and cu_seqlens_q is None - and seqused_q is None - and not is_split_kv, - score_mod=score_mod, - mask_mod=mask_mod, - has_aux_tensors=aux_tensors is not None, - paged_kv_non_tma=page_size not in [None, 128], - is_varlen_q=cu_seqlens_q is not None or seqused_q is not None, - ) - else: - raise ValueError( - f"Unsupported compute capability: {compute_capability}. Supported: 9.x, 10.x, 11.x" - ) - # TODO: check @can_implement - _flash_attn_fwd.compile_cache[compile_key] = cute.compile( - fa_fwd, - q_tensor, - k_tensor, - v_tensor, - o_tensor, - lse_tensor, - softmax_scale, - current_stream, - cu_seqlens_q_tensor, - cu_seqlens_k_tensor, - seqused_q_tensor, - seqused_k_tensor, - page_table_tensor, - window_size_left, - window_size_right, - learnable_sink_tensor, - sparse_tensors, - cute_aux_tensors, - options="--enable-tvm-ffi", - ) - - # Expand block sparse tensors to match actual head count (may be broadcast from 1) - normalized_block_sparse_tensors = None - if block_sparse_tensors is not None: - expected_count_shape, expected_index_shape = get_block_sparse_expected_shapes( - batch_size, - num_head, - seqlen_q, - seqlen_k, - m_block_size, - n_block_size, - q_stage, - ) - normalized_block_sparse_tensors = normalize_block_sparse_tensors( - block_sparse_tensors, - expected_count_shape=expected_count_shape, - expected_index_shape=expected_index_shape, - ) - _flash_attn_fwd.compile_cache[compile_key]( - q, - k, - v, - out if not is_split_kv else out_partial, - lse_partial if is_split_kv else lse, - softmax_scale, - current_stream, - cu_seqlens_q, - cu_seqlens_k, - seqused_q, - seqused_k, - page_table, - window_size_left, - window_size_right, - learnable_sink, - normalized_block_sparse_tensors, - aux_tensors, - ) - if is_split_kv: - _flash_attn_fwd_combine( - out_partial, - lse_partial.transpose(-1, -2), - out, - lse.transpose(-1, -2) if lse is not None else None, - cu_seqlens_q, - seqused_q, - ) - return out, lse - - -_flash_attn_fwd.compile_cache = {} - - -def _flash_attn_fwd_combine( - out_partial: torch.Tensor, - lse_partial: torch.Tensor, - out: torch.Tensor, - lse: Optional[torch.Tensor] = None, - cu_seqlens: Optional[torch.Tensor] = None, - seqused: Optional[torch.Tensor] = None, - num_splits_dynamic_ptr: Optional[torch.Tensor] = None, - semaphore_to_reset: Optional[torch.Tensor] = None, -) -> None: - """Forward combine kernel for split attention computation. - - Combines partial outputs and log-sum-exp values from multiple splits - of attention computation into final outputs. - - Args: - out_partial: Partial outputs tensor (num_splits, batch, seqlen, nheads, headdim) or - (num_splits, total_q, nheads, headdim) if there's cu_seqlens - lse_partial: Partial LSE tensor (num_splits, batch, seqlen, nheads) or - (num_splits, total_q, nheads) if there's cu_seqlens - out: Output tensor (batch, seqlen, nheads, headdim) or (total_q, nheads, headdim) if there's cu_seqlens - lse: Output LSE tensor (batch, seqlen, nheads) or (total_q, nheads) if there's cu_seqlens. - cu_seqlens: Cumulative sequence lengths for variable length sequences - seqused: Used sequence lengths for each batch - num_splits_dynamic_ptr: Dynamic number of splits per batch - semaphore_to_reset: Semaphore for synchronization - k_block_size: Block size for head dimension - - Returns: - None - """ - # Input validation - assert out_partial.dim() in [4, 5], "out_partial must have 4 or 5 dimensions" - assert lse_partial.dim() in [3, 4], "lse_partial must have 3 or 4 dimensions" - assert out_partial.dtype in [ - torch.float16, - torch.bfloat16, - torch.float32, - ], "out_partial must be fp16, bf16, or fp32" - assert lse_partial.dtype == torch.float32, "lse_partial must be fp32" - assert out_partial.is_cuda and lse_partial.is_cuda, "tensors must be on CUDA device" - assert ( - out_partial.stride(-1) == 1 - ), "out_partial must be contiguous in the last dimension" - assert ( - lse_partial.stride(-2) == 1 - ), "lse_partial must be contiguous in the seqlen dimension" - assert lse_partial.shape == out_partial.shape[:-1] - - # Determine if this is variable length based on dimensions - is_varlen = out_partial.dim() == 4 - - # Validate output tensor shapes and types - assert out.shape == out_partial.shape[1:], "out shape mismatch" - if lse is not None: - assert lse.shape == lse_partial.shape[1:], "lse shape mismatch" - assert lse.dtype == torch.float32, "lse must be fp32" - - # Validate optional tensors - for t, name in [ - (cu_seqlens, "cu_seqlens"), - (seqused, "seqused"), - (num_splits_dynamic_ptr, "num_splits_dynamic_ptr"), - ]: - if t is not None: - assert t.dtype == torch.int32, f"{name} must be int32" - assert t.is_cuda, f"{name} must be on CUDA device" - assert t.is_contiguous(), f"{name} must be contiguous" - - head_dim = out_partial.shape[-1] - num_splits = out_partial.shape[0] - assert num_splits <= 256 - # If hdim is 96 or 192, it's faster to round them to 128 or 256 respectively - # so that kBlockM is smaller and we have more parallelism. - k_block_size = 64 if head_dim <= 64 else 128 - # We want kBlockM to be as small as possible to maximize parallelism. - # E.g., if hdim is 64, we want kBlockM to be 16 so that we can use 256 threads, each reading 4 elements (floats). - m_block_size = ( - 8 if k_block_size % 128 == 0 else (16 if k_block_size % 64 == 0 else 32) - ) - log_max_splits = max(math.ceil(math.log2(num_splits)), 4) - if m_block_size == 8: - # If kBlockM == 8 then the minimum number of splits is 32. - # TODO: we can deal w this by using 128 threads instead - log_max_splits = max(log_max_splits, 5) - - current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) - - # Create combine kernel configuration - dtype = torch2cute_dtype_map[out.dtype] - dtype_partial = torch2cute_dtype_map[out_partial.dtype] - - compile_key = ( - dtype, - dtype_partial, - head_dim, - m_block_size, - k_block_size, - log_max_splits, - cu_seqlens is not None, - seqused is not None, - lse is not None, - ) - - if compile_key not in _flash_attn_fwd_combine.compile_cache: - out_partial_tensor = to_cute_tensor( - out_partial, leading_dim=4 if not is_varlen else 3 - ) - lse_partial_tensor = to_cute_tensor( - lse_partial, assumed_align=4, leading_dim=lse_partial.ndim - 2 - ) - out_tensor = to_cute_tensor(out, leading_dim=3 if not is_varlen else 2) - lse_tensor = ( - to_cute_tensor(lse, assumed_align=4, leading_dim=lse.ndim - 2) - if lse is not None - else None - ) - - optional_tensors = [ - to_cute_tensor(t, assumed_align=4, leading_dim=0) if t is not None else None - for t in (cu_seqlens, seqused, num_splits_dynamic_ptr, semaphore_to_reset) - ] - ( - cu_seqlens_tensor, - seqused_tensor, - num_splits_dynamic_tensor, - semaphore_tensor, - ) = optional_tensors - fa_combine = FlashAttentionForwardCombine( - dtype=dtype, - dtype_partial=dtype_partial, - head_dim=head_dim, - m_block_size=m_block_size, - k_block_size=k_block_size, - log_max_splits=log_max_splits, - ) - - # Check if implementation is supported - if not fa_combine.can_implement( - dtype, - dtype_partial, - head_dim, - m_block_size, - k_block_size, - log_max_splits, - num_threads=256, - ): - raise RuntimeError( - "FlashAttention combine kernel cannot be implemented with given parameters" - ) - - _flash_attn_fwd_combine.compile_cache[compile_key] = cute.compile( - fa_combine, - out_partial_tensor, - lse_partial_tensor, - out_tensor, - lse_tensor, - cu_seqlens_tensor, - seqused_tensor, - num_splits_dynamic_tensor, - semaphore_tensor, - current_stream, - options="--enable-tvm-ffi", - ) - _flash_attn_fwd_combine.compile_cache[compile_key]( - out_partial, - lse_partial, - out, - lse, - cu_seqlens, - seqused, - num_splits_dynamic_ptr, - semaphore_to_reset, - current_stream, - ) - - -_flash_attn_fwd_combine.compile_cache = {} - - -def warmup_flash_attn(f): - """ - Decorator for flash_attn_varlen_func: - - On first call, run several warmup passes with different flag combinations: - * return_softmax_lse in {False, True} - * global noncausal (window_size=(None,None)) - * causal (window_size=(None,0)) - * local sliding window (window_size=(64,64)) - * optionally pack_gqa=True if qheads > kvheads and allowed - - No score_mod / softcap (not supported for varlen yet) - - Executes sequentially to minimize peak GPU mem - - Does not modify user tensors (clones) - """ - disable_warmup = os.getenv("SGLANG_DISABLE_FA4_WARMUP", "").lower() in ( - "1", - "true", - "yes", - "on", - ) - if disable_warmup: - return f - - done = False - - def _clone_args(args, kwargs): - """Clone tensor arguments to avoid sharing storage; deepcopy for others.""" - - def maybe_clone(x): - if isinstance(x, torch.Tensor): - return x.detach().clone() # detach to avoid autograd edges - return copy.deepcopy(x) - - return tuple(maybe_clone(a) for a in args), { - k: maybe_clone(v) for k, v in kwargs.items() - } - - def _infer_heads(args, kwargs): - """Infer q and kv head counts from arguments.""" - # Expect signature: (q, k, v, cu_seqlens_q, cu_seqlens_k, ...) - q = args[0] if len(args) > 0 else kwargs.get("q") - k = args[1] if len(args) > 1 else kwargs.get("k") - try: - qh = int(q.shape[-2]) - kvh = int(k.shape[-2]) - return qh, kvh - except Exception: - return None, None - - def _run_warmups(args, kwargs): - """Run warmup calls sequentially and release memory after each.""" - base_args, base_kwargs = _clone_args(args, kwargs) - - qh, kvh = _infer_heads(base_args, base_kwargs) - can_pack_gqa = ( - qh is not None and kvh is not None and qh % kvh == 0 and qh // kvh > 1 - ) - has_page_table = ( - "page_table" in base_kwargs and base_kwargs["page_table"] is not None - ) - - # Window presets covering global, causal, and local - window_presets = [ - (None, None), # global noncausal - (None, 0), # causal - (64, 64), # local sliding window - ] - - lse_flags = [False, True] - - # Base combo list - combos = [] - for ws in window_presets: - for return_lse_flag in lse_flags: - combos.append(dict(window_size=ws, return_softmax_lse=return_lse_flag)) - - # Optionally add a pack_gqa=True variant (FA4 may disable it internally for some varlen shapes/SMs) - if can_pack_gqa: - for ws in window_presets: - combos.append( - dict(window_size=ws, return_softmax_lse=False, pack_gqa=True) - ) - - # If page_table is present, warm one combo with it (page_table in compile key for SM100) - if has_page_table: - combos.append(dict(window_size=(None, None), return_softmax_lse=False)) - - # Run sequentially - for combo in combos: - wa, wk = _clone_args(base_args, base_kwargs) - # Keep user-provided softcap/score_mod OUT (varlen+score_mod unsupported) - wk.pop("score_mod", None) - if "softcap" in wk and wk["softcap"]: - wk["softcap"] = 0.0 - # Apply combo - wk.update(combo) - with torch.cuda.stream(torch.cuda.current_stream()): - try: - f(*wa, **wk) - except Exception as e: - # Some combos can be invalid for specific head dims / arch. Ignore and continue. - logger.debug("Warmup combo skipped: %s", e) - del wa, wk - torch.cuda.empty_cache() - gc.collect() - - def wrapper(*args, **kwargs): - nonlocal done - if not done: - logger.info( - "Running FA4 warmup (global/causal/local, LSE on/off, optional GQA pack)..." - ) - _run_warmups(args, kwargs) - done = True - return f(*args, **kwargs) - - return wrapper - - -@warmup_flash_attn -def flash_attn_varlen_func( - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - cu_seqlens_q: Optional[torch.Tensor] = None, - cu_seqlens_k: Optional[torch.Tensor] = None, - seqused_q: Optional[torch.Tensor] = None, - seqused_k: Optional[torch.Tensor] = None, - page_table: Optional[torch.Tensor] = None, - softmax_scale: Optional[float] = None, - causal: bool = False, - window_size: Tuple[Optional[int], Optional[int]] = (None, None), - learnable_sink: Optional[torch.Tensor] = None, - softcap: float = 0.0, - num_splits: int = 1, - pack_gqa: Optional[bool] = None, - return_softmax_lse: Optional[bool] = False, - score_mod: Optional[Callable] = None, - aux_tensors: Optional[list] = None, -) -> Tuple[torch.Tensor, torch.Tensor]: - out, lse = _flash_attn_fwd( - q, - k, - v, - cu_seqlens_q, - cu_seqlens_k, - seqused_q, - seqused_k, - page_table=page_table, - softmax_scale=softmax_scale, - causal=causal, - window_size_left=window_size[0], - window_size_right=window_size[1], - learnable_sink=learnable_sink, - softcap=softcap, - num_splits=num_splits, - pack_gqa=pack_gqa, - return_lse=return_softmax_lse, - score_mod=score_mod, - aux_tensors=aux_tensors, - ) - - return (out, lse) if return_softmax_lse else out diff --git a/sgl-kernel/python/sgl_kernel/attention.py b/sgl-kernel/python/sgl_kernel/attention.py index 44dd6111ada2..faf23a4f0396 100644 --- a/sgl-kernel/python/sgl_kernel/attention.py +++ b/sgl-kernel/python/sgl_kernel/attention.py @@ -3,25 +3,6 @@ import torch -def merge_state( - v_a: torch.Tensor, - s_a: torch.Tensor, - v_b: torch.Tensor, - s_b: torch.Tensor, - v_merged: Optional[torch.Tensor] = None, - s_merged: Optional[torch.Tensor] = None, -) -> Tuple[torch.Tensor, torch.Tensor]: - s_a = s_a.to(torch.float32) - s_b = s_b.to(torch.float32) - # Avoid creating new tensors if they are already provided - if v_merged is None: - v_merged = torch.empty_like(v_a) - if s_merged is None: - s_merged = torch.empty_like(s_a) - torch.ops.sgl_kernel.merge_state.default(v_a, s_a, v_b, s_b, v_merged, s_merged) - return v_merged, s_merged - - def merge_state_v2( v_a: torch.Tensor, s_a: torch.Tensor, diff --git a/sgl-kernel/python/sgl_kernel/elementwise.py b/sgl-kernel/python/sgl_kernel/elementwise.py index 1ed1ae474a79..62a3f646c4db 100644 --- a/sgl-kernel/python/sgl_kernel/elementwise.py +++ b/sgl-kernel/python/sgl_kernel/elementwise.py @@ -344,22 +344,6 @@ def rotary_embedding( ) -def downcast_fp8( - k: torch.Tensor, - v: torch.Tensor, - k_out: torch.Tensor, - v_out: torch.Tensor, - k_scale: torch.Tensor, - v_scale: torch.Tensor, - loc: torch.Tensor, - mult: int = 1, - offset: int = 0, -) -> None: - torch.ops.sgl_kernel.downcast_fp8( - k, v, k_out, v_out, k_scale, v_scale, loc, mult, offset - ) - - def copy_to_gpu_no_ce(input: torch.Tensor, output: torch.Tensor): torch.ops.sgl_kernel.copy_to_gpu_no_ce(input, output) diff --git a/sgl-kernel/python/sgl_kernel/flash_mla.py b/sgl-kernel/python/sgl_kernel/flash_mla.py index 144ddc31a705..3b4643cded62 100644 --- a/sgl-kernel/python/sgl_kernel/flash_mla.py +++ b/sgl-kernel/python/sgl_kernel/flash_mla.py @@ -35,6 +35,9 @@ def get_mla_metadata( tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), dtype torch.int32. num_splits: (batch_size + 1), dtype torch.int32. """ + if _flashmla_import_error is not None: + raise _IMPORT_ERROR from _flashmla_import_error + if is_fp8_kvcache and topk is None: return torch.ops.sgl_kernel.get_mla_decoding_metadata_dense_fp8.default( cache_seqlens, @@ -86,6 +89,9 @@ def flash_mla_with_kvcache( out: (batch_size, seq_len_q, num_heads_q, head_dim_v). softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32. """ + if _flashmla_import_error is not None: + raise _IMPORT_ERROR from _flashmla_import_error + if softmax_scale is None: softmax_scale = q.shape[-1] ** (-0.5) if indices is not None: @@ -149,6 +155,9 @@ def flash_mla_sparse_fwd( - max_logits: [s_q, h_q], float - lse: [s_q, h_q], float, 2-based log-sum-exp """ + if _flashmla_import_error is not None: + raise _IMPORT_ERROR from _flashmla_import_error + results = torch.ops.sgl_kernel.sparse_prefill_fwd.default( q, kv, indices, sm_scale, d_v ) diff --git a/sgl-kernel/python/sgl_kernel/sampling.py b/sgl-kernel/python/sgl_kernel/sampling.py index ccf98cb6ba10..f72033f52708 100644 --- a/sgl-kernel/python/sgl_kernel/sampling.py +++ b/sgl-kernel/python/sgl_kernel/sampling.py @@ -113,76 +113,3 @@ def top_p_renorm_probs( top_p_renorm_prob = top_p_renorm_probs - - -def _top_k_mask_logits_internal( - logits: torch.Tensor, - maybe_top_k_arr: Optional[torch.Tensor], - top_k_val: int, -) -> torch.Tensor: - logits = logits.float() - maybe_top_k_arr = maybe_top_k_arr.int() if maybe_top_k_arr is not None else None - mask_logits = torch.empty_like(logits) - torch.ops.sgl_kernel.top_k_mask_logits.default( - logits, mask_logits, maybe_top_k_arr, top_k_val - ) - return mask_logits - - -def top_k_mask_logits( - logits: torch.Tensor, - top_k: Union[torch.Tensor, int], -) -> torch.Tensor: - r"""Adapt from https://github.com/flashinfer-ai/flashinfer/flashinfer/sampling.py - Fused GPU kernel for masking logits by top-k thresholding. - - Parameters - ---------- - logits: torch.Tensor - Logits before softmax, shape ``(batch_size, num_classes)``. - top_k: Union[torch.Tensor, int] - Either a scalar or a tensor of shape ``(batch_size,)``, representing the top-k threshold for for - for masking logits, should be in ``(0, num_classes)``. - If a scalar, the same threshold is used for all requests. - If a tensor, each request has its own threshold. - We keep the top-k logits, set the rest to negative infinity. - - Returns - ------- - masked_logits: torch.Tensor - Masked logits, shape ``(batch_size, num_classes)``. - - Examples - -------- - - >>> import torch - >>> import flashinfer - >>> torch.manual_seed(42) - >>> batch_size = 4 - >>> vocab_size = 5 - >>> top_k = 3 - >>> logits = torch.randn(batch_size, vocab_size).to(0) - >>> logits - tensor([[ 1.9269, 1.4873, 0.9007, -2.1055, -0.7581], - [ 1.0783, 0.8008, 1.6806, 0.3559, -0.6866], - [-0.4934, 0.2415, -0.2316, 0.0418, -0.2516], - [ 0.8599, -0.3097, -0.3957, 0.8034, -0.6216]], device='cuda:0') - >>> masked_logits = flashinfer.sampling.top_k_mask_logits(logits, top_k) - >>> masked_logits - tensor([[ 1.9269, 1.4873, 0.9007, -inf, -inf], - [ 1.0783, 0.8008, 1.6806, -inf, -inf], - [ -inf, 0.2415, -0.2316, 0.0418, -inf], - [ 0.8599, -0.3097, -inf, 0.8034, -inf]], device='cuda:0') - - Note - ---- - The combination of ``top_k_mask_logits`` and ``softmax`` should be equivalent to ``top_k_renorm_probs``. - - See Also - -------- - top_k_renorm_probs - """ - if logits.device.type == "musa" or not _has_flashinfer: - return _top_k_mask_logits_internal(logits, *_to_tensor_scalar_tuple(top_k)) - else: - return _flashinfer_sampling.top_k_mask_logits(logits, top_k) diff --git a/sgl-kernel/python/sgl_kernel/version.py b/sgl-kernel/python/sgl_kernel/version.py index 6a9beea82f65..3d26edf777ed 100644 --- a/sgl-kernel/python/sgl_kernel/version.py +++ b/sgl-kernel/python/sgl_kernel/version.py @@ -1 +1 @@ -__version__ = "0.4.0" +__version__ = "0.4.1" diff --git a/sgl-kernel/tests/test_amd_deterministic_custom_allreduce.py b/sgl-kernel/tests/test_amd_deterministic_custom_allreduce.py index 7e3e82c9bfb7..aa71259251a7 100644 --- a/sgl-kernel/tests/test_amd_deterministic_custom_allreduce.py +++ b/sgl-kernel/tests/test_amd_deterministic_custom_allreduce.py @@ -22,6 +22,8 @@ import torch import torch.distributed as dist +from sglang.srt.environ import envs + def get_open_port(): with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: @@ -30,6 +32,7 @@ def get_open_port(): def worker(world_size, rank, port): + envs.SGLANG_USE_1STAGE_ALLREDUCE.set("1") device = torch.device(f"cuda:{rank}") torch.cuda.set_device(device) @@ -60,12 +63,6 @@ def worker(world_size, rank, port): print("āœ— Custom AR not available or disabled") dist.destroy_process_group() return - - if not hasattr(custom_ar, "deterministic_all_reduce"): - if rank == 0: - print("āœ— Deterministic kernel not available") - dist.destroy_process_group() - return except Exception as e: if rank == 0: print(f"āœ— Failed to initialize deterministic kernel: {e}") @@ -115,18 +112,7 @@ def worker(world_size, rank, port): # Clone the same input inp = base_input.clone() - # Use deterministic kernel - # Check if input fits in buffer, use registered mode if too large - input_size_bytes = inp.numel() * inp.element_size() - use_registered = input_size_bytes > custom_ar.max_size - - if use_registered: - # For large inputs, register buffer first - custom_ar.register_buffer(inp) - result = custom_ar.deterministic_all_reduce(inp, registered=True) - else: - # For smaller inputs, use unregistered mode (copies to internal buffer) - result = custom_ar.deterministic_all_reduce(inp, registered=False) + result = custom_ar.custom_all_reduce(inp) torch.cuda.synchronize() # Store checksum @@ -179,22 +165,7 @@ def worker(world_size, rank, port): # Flatten for all-reduce: (bs * hidden_dim,) batch_flat = batch.view(-1) - # Use deterministic kernel - # Check if input fits in buffer, use registered mode if too large - input_size_bytes = batch_flat.numel() * batch_flat.element_size() - use_registered = input_size_bytes > custom_ar.max_size - - if use_registered: - # For large inputs, register buffer first - custom_ar.register_buffer(batch_flat) - result_flat = custom_ar.deterministic_all_reduce( - batch_flat, registered=True - ) - else: - # For smaller inputs, use unregistered mode - result_flat = custom_ar.deterministic_all_reduce( - batch_flat, registered=False - ) + result_flat = custom_ar.custom_all_reduce(batch_flat) torch.cuda.synchronize() # Reshape back to (bs, hidden_dim) diff --git a/sgl-kernel/tests/test_hadamard.py b/sgl-kernel/tests/test_hadamard.py deleted file mode 100644 index a0eea45b2eea..000000000000 --- a/sgl-kernel/tests/test_hadamard.py +++ /dev/null @@ -1,86 +0,0 @@ -import math -import sys - -import pytest -import torch -import torch.nn.functional as F -from einops import rearrange, repeat -from scipy.linalg import hadamard - -try: - from sgl_kernel import hadamard_transform -except Exception: - pytest.skip( - "sgl-kernel hadamard interface was removed (migrated to jit_kernel)", - allow_module_level=True, - ) - - -def hadamard_transform_ref(x, scale=1.0): - """ - x: (..., dim) - out: (..., dim) - """ - if hadamard is None: - raise ImportError("Please install scipy") - x_shape = x.shape - dim = x.shape[-1] - x = x.reshape(-1, dim) - log_dim = math.ceil(math.log2(dim)) - dim_padded = 2**log_dim - if dim != dim_padded: - x = F.pad(x, (0, dim_padded - dim)) - out = F.linear( - x, - torch.tensor(hadamard(dim_padded, dtype=float), dtype=x.dtype, device=x.device), - ) - out = out * scale - return out[..., :dim].reshape(*x_shape) - - -@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16]) -@pytest.mark.parametrize( - "dim", - [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 137, 1024, 2048, 4096, 8192, 16384, 32768], -) -def test_fast_hadamard_transform(dim, dtype): - device = "cuda" - - if dtype == torch.float32: - rtol, atol = 3e-4, 3e-3 - elif dtype == torch.bfloat16: - rtol, atol = 1e-2, 5e-2 - else: # float16 - rtol, atol = 3e-3, 5e-3 - - torch.random.manual_seed(0) - batch_size = 15 - - x = torch.randn(batch_size, dim, device=device, dtype=dtype) - x_ref = x.detach().clone().to(torch.float32) - x_pt = x.detach().clone() - - scale = 1 / math.sqrt(dim) - - out = hadamard_transform(x, scale=scale) - out_ref = hadamard_transform_ref(x_ref, scale=scale) - out_pt = hadamard_transform_ref(x_pt, scale=scale) - - torch.testing.assert_close( - out_pt.float(), - out_ref, - rtol=rtol, - atol=atol, - msg="Reference implementations mismatch", - ) - torch.testing.assert_close( - out.float(), - out_ref, - rtol=rtol, - atol=atol, - msg="fast_hadamard_transform output mismatch", - ) - - -if __name__ == "__main__": - sys.exit(pytest.main([__file__])) diff --git a/sgl-kernel/tests/test_merge_state.py b/sgl-kernel/tests/test_merge_state.py deleted file mode 100644 index 3aedb0f94bc6..000000000000 --- a/sgl-kernel/tests/test_merge_state.py +++ /dev/null @@ -1,143 +0,0 @@ -# Adapted from https://github.com/flashinfer-ai/flashinfer/blob/55576c626421b5ee7e7ebe74afd26465c8ae863f/flashinfer/triton/kernels/cascade.py - -import sys -from typing import List - -import pytest -import torch -import triton -import triton.language as tl -from sgl_kernel import merge_state - - -def check_input(x: torch.Tensor): - assert x.is_cuda, f"{str(x)} must be a CUDA Tensor" - assert x.is_contiguous(), f"{str(x)} must be contiguous" - - -def check_dim(d, x: torch.Tensor): - assert x.dim() == d, f"{str(x)} must be a {d}D tensor" - - -def check_shape(a: torch.Tensor, b: torch.Tensor): - assert a.dim() == b.dim(), "tensors should have same dim" - for i in range(a.dim()): - assert a.size(i) == b.size( - i - ), f"tensors shape mismatch, {a.size()} and {b.size()}" - - -def check_device(tensors: List[torch.Tensor]): - device = tensors[0].device - for t in tensors: - assert ( - t.device == device - ), f"All tensors should be on the same device, but got {device} and {t.device}" - - -@triton.jit -def state_merge(o, m, d, other_o, other_m, other_d): - m_max = tl.maximum(m, other_m) - d = d * tl.exp2(m - m_max) + other_d * tl.exp2(other_m - m_max) - o = o * tl.exp2(m - m_max) + other_o * tl.exp2(other_m - m_max) - return o, m_max, d - - -@triton.jit -def state_normalize(o, m, d): - o = o / d - return o, m, d - - -@triton.jit -def state_get_lse(o, m, d): - return m + tl.log2(d) - - -@triton.jit -def merge_state_kernel( - v_a_ptr, - s_a_ptr, - v_b_ptr, - s_b_ptr, - v_merged_ptr, - s_merged_ptr, - num_heads, - head_dim, - bdx: tl.constexpr, - bdy: tl.constexpr, -): - pos = tl.program_id(axis=0) - for tx in tl.range(bdx): - for head_idx in tl.range(bdy): - s_a_val = tl.load(s_a_ptr + pos * num_heads + head_idx) - s_b_val = tl.load(s_b_ptr + pos * num_heads + head_idx) - - offsets = (pos * num_heads + head_idx) * head_dim + tx - v_a = tl.load(v_a_ptr + offsets) - v_b = tl.load(v_b_ptr + offsets) - - v_merged, s_max, d = state_merge( - o=v_a, m=s_a_val, d=1, other_o=v_b, other_m=s_b_val, other_d=1 - ) - v_merged, s_max, d = state_normalize(v_merged, s_max, d) - v_merged_offset = (pos * num_heads + head_idx) * head_dim + tx - tl.store(v_merged_ptr + v_merged_offset, v_merged) - - if s_merged_ptr: - tl.store( - s_merged_ptr + pos * num_heads + head_idx, - tl.log2(d) + s_max, - ) - - -def merge_state_triton( - v_a: torch.Tensor, s_a: torch.Tensor, v_b: torch.Tensor, s_b: torch.Tensor -): - check_input(v_a) - check_input(s_a) - check_input(v_b) - check_input(s_b) - check_device([v_a, s_a, v_b, s_b]) - check_dim(3, v_a) - check_dim(2, s_a) - check_dim(3, v_b) - check_dim(2, s_b) - check_shape(v_a, v_b) - check_shape(s_a, s_b) - assert v_a.size(0) == s_a.size(0) - assert v_a.size(1) == s_b.size(1) - s_a = s_a.to(torch.float32) - s_b = s_b.to(torch.float32) - seq_len = v_a.size(0) - num_heads = v_a.size(1) - head_dim = v_a.size(2) - v_merged = torch.empty_like(v_a).to(s_a.device) - s_merged = torch.empty((seq_len, num_heads)).to(s_a.device) - bdx = head_dim - bdy = num_heads - - merge_state_kernel[lambda meta: (seq_len,)]( - v_a, s_a, v_b, s_b, v_merged, s_merged, num_heads, head_dim, bdx=bdx, bdy=bdy - ) - - return v_merged, s_merged - - -@pytest.mark.parametrize("seq_len", [2048]) -@pytest.mark.parametrize("num_heads", [32]) -@pytest.mark.parametrize("head_dim", [128]) -def test_merge_state(seq_len, num_heads, head_dim): - va = torch.randn(seq_len, num_heads, head_dim).half().to("cuda:0") - sa = torch.randn(seq_len, num_heads, dtype=torch.float32).to("cuda:0") - vb = torch.randn(seq_len, num_heads, head_dim).half().to("cuda:0") - sb = torch.randn(seq_len, num_heads, dtype=torch.float32).to("cuda:0") - v_merged, s_merged = merge_state_triton(va, sa, vb, sb) - v_merged_std, s_merged_std = merge_state(va, sa, vb, sb) - - assert torch.allclose(v_merged, v_merged_std, atol=1e-2) - assert torch.allclose(s_merged, s_merged_std, atol=1e-2) - - -if __name__ == "__main__": - sys.exit(pytest.main([__file__])) diff --git a/sgl-kernel/tests/test_merge_state_v2.py b/sgl-kernel/tests/test_merge_state_v2.py index 4bbf1704da60..7b285d10ff79 100644 --- a/sgl-kernel/tests/test_merge_state_v2.py +++ b/sgl-kernel/tests/test_merge_state_v2.py @@ -5,7 +5,7 @@ import torch import triton import triton.language as tl -from sgl_kernel import merge_state, merge_state_v2 +from sgl_kernel import merge_state_v2 @triton.jit @@ -146,11 +146,9 @@ def generate_markdown_table(): global all_case_info table_header = ( "| tokens | heads | headsize | dtype " - "| device | torch | triton | v1 | v2 | speedup(vs triton) | speedup(vs v1)|" - ) - table_separator = ( - "| --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- |" + "| device | torch | triton | v2 | speedup(vs triton) |" ) + table_separator = "| --- | --- | --- | --- | --- | --- | --- | --- | --- |" def shortly_dtype(dtype: torch.dtype) -> str: return str(dtype).removeprefix("torch.") @@ -169,21 +167,17 @@ def shortly_device(device: str) -> str: device, time_torch, time_triton, - time_v1, time_v2, ) = info dtype = shortly_dtype(dtype) device = shortly_device(device) improved_triton = time_triton / time_v2 - improved_v1 = time_v1 / time_v2 print( f"| {num_tokens} | {num_heads} | {head_size} " f"| {dtype} | {device} | {time_torch:.4f}ms " f"| {time_triton:.4f}ms " - f"| {time_v1:.4f}ms " f"| {time_v2:.4f}ms " - f"| {improved_triton:.4f}x " - f"| {improved_v1:.4f}x |" + f"| {improved_triton:.4f}x |" ) @@ -259,11 +253,6 @@ def perf_kernel_fn( prefix_lse_ = prefix_lse suffix_lse_ = suffix_lse - if fn_type == "cuda_v1": - # merge_state v1 kernel not support float32 - if output_dtype not in (torch.half, torch.bfloat16): - return 0, output_fn, output_lse_fn - total_time = 0 start = torch.cuda.Event(enable_timing=True) end = torch.cuda.Event(enable_timing=True) @@ -316,29 +305,21 @@ def perf_kernel_fn( fn_type="triton", ) - # 2. Run the merge_state V1 kernel - output_v1 = output.clone() - output_lse_v1 = output_lse.clone() - time_v1, output_v1, output_lse_v1 = perf_kernel_fn( - output_v1, output_lse_v1, merge_state, fn_type="cuda_v1" - ) - - # 3. Run the merge_state V2 kernel + # 2. Run the merge_state V2 kernel output_v2 = output.clone() output_lse_v2 = output_lse.clone() time_v2, output_v2, output_lse_v2 = perf_kernel_fn( output_v2, output_lse_v2, merge_state_v2, fn_type="cuda_v2" ) - # 4. Performance compare + # 3. Performance compare improved = time_triton / time_v2 print(f" Torch time: {time_torch:.6f}ms") print(f" Triton time: {time_triton:.6f}ms") - print(f"CUDA v1 time: {time_v1:.6f}ms") print(f"CUDA v2 time: {time_v2:.6f}ms, Performance: {improved:.5f}x") print("-" * 100) - # 5. Correctness compare + # 4. Correctness compare # Liger Kernel: Efficient Triton Kernels for LLM Training # https://arxiv.org/pdf/2410.10989, 3.3 Correctness # use rtol = 1e-2 for bfloat16. @@ -387,7 +368,6 @@ def diff(a: torch.Tensor, b: torch.Tensor): device, time_torch, time_triton, - time_v1, time_v2, ) ) diff --git a/sgl-model-gateway/README.md b/sgl-model-gateway/README.md index 4c4f92da0256..046cf352a14e 100644 --- a/sgl-model-gateway/README.md +++ b/sgl-model-gateway/README.md @@ -407,7 +407,7 @@ Use upstream SGLang binaries to start dedicated worker processes. ### Worker Lifecycle & Job Queue - `JobQueue` handles asynchronous add/remove operations to avoid blocking clients. -- `WorkerManager` inspects worker metadata (`/get_server_info`, `/get_model_info`), tracks load, and exposes `flush_cache` and `get_loads`. +- `WorkerManager` inspects worker metadata (`/server_info`, `/get_model_info`), tracks load, and exposes `flush_cache` and `get_loads`. - Per-worker circuit breakers and health probes keep the registry healthy; load monitor feeds metrics to cache-aware and power-of-two policies. ### Administrative & Worker APIs diff --git a/sgl-model-gateway/bindings/python/src/sglang_router/mini_lb.py b/sgl-model-gateway/bindings/python/src/sglang_router/mini_lb.py index f5fd0b5a7323..11ef2f4c6b64 100644 --- a/sgl-model-gateway/bindings/python/src/sglang_router/mini_lb.py +++ b/sgl-model-gateway/bindings/python/src/sglang_router/mini_lb.py @@ -269,6 +269,8 @@ async def flush_cache(): return Response(status_code=200) +# TODO: Remove `/get_server_info` alias after one release-cycle deprecation window. +@app.get("/server_info") @app.get("/get_server_info") async def get_server_info(): prefill_infos = [] @@ -277,10 +279,10 @@ async def get_server_info(): async with aiohttp.ClientSession() as session: for server in lb.prefill_urls: - server_info = await session.get(f"{server}/get_server_info") + server_info = await session.get(f"{server}/server_info") prefill_infos.append(await server_info.json()) for server in lb.decode_urls: - server_info = await session.get(f"{server}/get_server_info") + server_info = await session.get(f"{server}/server_info") info_json = await server_info.json() decode_infos.append(info_json) # Extract internal_states from decode servers diff --git a/sgl-model-gateway/e2e_test/infra/simple_eval_common.py b/sgl-model-gateway/e2e_test/infra/simple_eval_common.py index 92e72937d9e9..7be4358172c7 100644 --- a/sgl-model-gateway/e2e_test/infra/simple_eval_common.py +++ b/sgl-model-gateway/e2e_test/infra/simple_eval_common.py @@ -457,7 +457,7 @@ def download_dataset(path: str, url: str) -> None: """Download a dataset from URL to path.""" logger.info("Downloading dataset from %s", url) try: - response = requests.get(url, stream=True) + response = requests.get(url, stream=True, timeout=30) response.raise_for_status() total_size = int(response.headers.get("content-length", 0)) diff --git a/sgl-model-gateway/e2e_test/infra/simple_eval_mmlu.py b/sgl-model-gateway/e2e_test/infra/simple_eval_mmlu.py index 1083e56ca60c..a83ed1d2eaaf 100644 --- a/sgl-model-gateway/e2e_test/infra/simple_eval_mmlu.py +++ b/sgl-model-gateway/e2e_test/infra/simple_eval_mmlu.py @@ -93,7 +93,10 @@ class MMLUEval(Eval): """MMLU benchmark evaluation.""" def __init__(self, filename: str, num_examples: int | None, num_threads: int): - df = pandas.read_csv(filename) + if "://" in filename: + df = pandas.read_csv(filename, storage_options={"timeout": 30}) + else: + df = pandas.read_csv(filename) examples = [row.to_dict() for _, row in df.iterrows()] if num_examples: examples = random.Random(0).sample(examples, num_examples) diff --git a/sgl-model-gateway/src/routers/http/pd_router.rs b/sgl-model-gateway/src/routers/http/pd_router.rs index 9bcf06c9d265..a939b6427a95 100644 --- a/sgl-model-gateway/src/routers/http/pd_router.rs +++ b/sgl-model-gateway/src/routers/http/pd_router.rs @@ -1223,7 +1223,7 @@ impl RouterTrait for PDRouter { async fn get_server_info(&self, _req: Request) -> Response { // Get info from the first decode server to match sglang's server info format // Note: We use decode workers for server info to match expected format - self.proxy_to_first_prefill_worker("get_server_info", None) + self.proxy_to_first_prefill_worker("server_info", None) .await } diff --git a/sgl-model-gateway/src/routers/http/router.rs b/sgl-model-gateway/src/routers/http/router.rs index 0fbf2e422abb..b02f6638dbfb 100644 --- a/sgl-model-gateway/src/routers/http/router.rs +++ b/sgl-model-gateway/src/routers/http/router.rs @@ -724,7 +724,7 @@ impl RouterTrait for Router { } async fn get_server_info(&self, req: Request) -> Response { - self.proxy_get_request(req, "get_server_info").await + self.proxy_get_request(req, "server_info").await } async fn get_models(&self, req: Request) -> Response { diff --git a/sgl-model-gateway/src/server.rs b/sgl-model-gateway/src/server.rs index 4cea623de27f..db23d0aad39e 100644 --- a/sgl-model-gateway/src/server.rs +++ b/sgl-model-gateway/src/server.rs @@ -610,6 +610,8 @@ pub fn build_app( .route("/engine_metrics", get(engine_metrics)) .route("/v1/models", get(v1_models)) .route("/get_model_info", get(get_model_info)) + .route("/server_info", get(get_server_info)) + // TODO: Remove `/get_server_info` alias after one release-cycle deprecation window. .route("/get_server_info", get(get_server_info)); // Build admin routes with control plane auth if configured, otherwise use simple API key auth diff --git a/sgl-model-gateway/tests/api/api_endpoints_test.rs b/sgl-model-gateway/tests/api/api_endpoints_test.rs index 7a0d36676581..6e6ff125e9b2 100644 --- a/sgl-model-gateway/tests/api/api_endpoints_test.rs +++ b/sgl-model-gateway/tests/api/api_endpoints_test.rs @@ -314,7 +314,7 @@ mod model_info_tests { let req = Request::builder() .method("GET") - .uri("/get_server_info") + .uri("/server_info") .body(Body::empty()) .unwrap(); @@ -445,7 +445,7 @@ mod model_info_tests { let req = Request::builder() .method("GET") - .uri("/get_server_info") + .uri("/server_info") .body(Body::empty()) .unwrap(); let resp = app.clone().oneshot(req).await.unwrap(); diff --git a/sgl-model-gateway/tests/common/mock_worker.rs b/sgl-model-gateway/tests/common/mock_worker.rs index 23d6bb6f5d32..166fc9d8314b 100755 --- a/sgl-model-gateway/tests/common/mock_worker.rs +++ b/sgl-model-gateway/tests/common/mock_worker.rs @@ -82,7 +82,7 @@ impl MockWorker { let app = Router::new() .route("/health", get(health_handler)) .route("/health_generate", get(health_generate_handler)) - .route("/get_server_info", get(server_info_handler)) + .route("/server_info", get(server_info_handler)) .route("/get_model_info", get(model_info_handler)) .route("/generate", post(generate_handler)) .route("/v1/chat/completions", post(chat_completions_handler)) diff --git a/sgl-model-gateway/tests/common/tls_mock_worker.rs b/sgl-model-gateway/tests/common/tls_mock_worker.rs index 270866aa1130..36a6c542d877 100644 --- a/sgl-model-gateway/tests/common/tls_mock_worker.rs +++ b/sgl-model-gateway/tests/common/tls_mock_worker.rs @@ -101,7 +101,7 @@ impl TlsMockWorker { let app = Router::new() .route("/health", get(health_handler)) .route("/health_generate", get(health_generate_handler)) - .route("/get_server_info", get(server_info_handler)) + .route("/server_info", get(server_info_handler)) .route("/generate", post(generate_handler)) .route("/v1/chat/completions", post(chat_completions_handler)) .with_state(config); diff --git a/sgl-model-gateway/tests/routing/test_pd_routing.rs b/sgl-model-gateway/tests/routing/test_pd_routing.rs index 1853474723df..b6c69576f1c5 100644 --- a/sgl-model-gateway/tests/routing/test_pd_routing.rs +++ b/sgl-model-gateway/tests/routing/test_pd_routing.rs @@ -765,7 +765,7 @@ mod pd_routing_unit_tests { let implemented_endpoints = vec![ ("/health", "GET", true), ("/health_generate", "GET", true), // Note: Python uses POST, we use GET - ("/get_server_info", "GET", true), + ("/server_info", "GET", true), ("/v1/models", "GET", true), ("/get_model_info", "GET", true), ("/generate", "POST", true), diff --git a/test/manual/ep/test_moe_deepep_eval_accuracy_large.py b/test/manual/ep/test_moe_deepep_eval_accuracy_large.py index e79e87ed4770..4781bb9ae8b0 100644 --- a/test/manual/ep/test_moe_deepep_eval_accuracy_large.py +++ b/test/manual/ep/test_moe_deepep_eval_accuracy_large.py @@ -7,7 +7,6 @@ from types import SimpleNamespace from sglang.srt.utils import kill_process_tree -from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k from sglang.test.run_eval import run_eval from sglang.test.test_utils import ( DEFAULT_DEEPEP_MODEL_NAME_FOR_TEST, @@ -44,18 +43,19 @@ def tearDownClass(cls): def test_gsm8k(self): args = SimpleNamespace( + base_url=self.base_url, + model=self.model, + eval_name="gsm8k", + api="completion", + max_tokens=512, + num_examples=200, + num_threads=64, num_shots=8, - data_path=None, - num_questions=200, - parallel=64, - max_new_tokens=512, - host="http://127.0.0.1", - port=int(self.base_url.split(":")[-1]), ) - metrics = run_eval_few_shot_gsm8k(args) + metrics = run_eval(args) print(f"Eval accuracy of GSM8K: {metrics=}") - self.assertGreater(metrics["accuracy"], 0.93) + self.assertGreater(metrics["score"], 0.93) def test_mmlu(self): args = SimpleNamespace( diff --git a/test/manual/ep/test_mooncake_expert_backup.py b/test/manual/ep/test_mooncake_expert_backup.py index fc3089d8875f..c6cec9cbd242 100644 --- a/test/manual/ep/test_mooncake_expert_backup.py +++ b/test/manual/ep/test_mooncake_expert_backup.py @@ -5,7 +5,7 @@ import requests from sglang.srt.utils import kill_process_tree -from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k +from sglang.test.run_eval import run_eval from sglang.test.server_fixtures.disaggregation_fixture import get_rdma_devices_args from sglang.test.test_utils import ( DEFAULT_MODEL_NAME_FOR_TEST_MLA, @@ -24,6 +24,7 @@ class TestBackup(CustomTestCase): def setUpClass(cls): cls.model = DEFAULT_MODEL_NAME_FOR_TEST_MLA cls.base_port = 20000 + cls.base_url = f"http://127.0.0.1:{cls.base_port}" cls.num_processes = 2 # TODO (stage 100): in the future, implement a specified multiprocess launcher cls.processes = [ @@ -124,18 +125,18 @@ def tearDownClass(cls): def test_gsm8k(self): args = SimpleNamespace( - num_shots=5, - data_path=None, - num_questions=200, - max_new_tokens=512, - parallel=128, - host="http://127.0.0.1", - port=self.base_port, + base_url=self.base_url, + model=self.model, + eval_name="gsm8k", + api="completion", + max_tokens=512, + num_examples=200, + num_threads=128, ) - metrics = run_eval_few_shot_gsm8k(args) + metrics = run_eval(args) print(metrics) - self.assertGreater(metrics["accuracy"], 0.60) + self.assertGreater(metrics["score"], 0.60) if __name__ == "__main__": diff --git a/test/manual/ep/test_nixl_ep.py b/test/manual/ep/test_nixl_ep.py index 3a4be3b8c89a..9be2ff037b47 100644 --- a/test/manual/ep/test_nixl_ep.py +++ b/test/manual/ep/test_nixl_ep.py @@ -4,7 +4,7 @@ from types import SimpleNamespace from sglang.srt.utils import kill_process_tree -from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k +from sglang.test.run_eval import run_eval from sglang.test.server_fixtures.disaggregation_fixture import get_rdma_devices_args from sglang.test.test_utils import ( DEFAULT_MODEL_NAME_FOR_TEST_MLA, @@ -71,21 +71,21 @@ def tearDownClass(cls): def _run_gsm8k(self): args = SimpleNamespace( - num_shots=5, - data_path=None, - num_questions=200, - max_new_tokens=512, - parallel=128, - host="http://127.0.0.1", - port=int(self.base_url.split(":")[-1]), + base_url=self.base_url, + model=self.model, + eval_name="gsm8k", + api="completion", + max_tokens=512, + num_examples=200, + num_threads=128, ) - metrics = run_eval_few_shot_gsm8k(args) + metrics = run_eval(args) print(metrics) return metrics def test_gsm8k(self): metrics = self._run_gsm8k() - self.assertGreater(metrics["accuracy"], 0.60) + self.assertGreater(metrics["score"], 0.60) class TestNixlEPTP(_EPTestBase): @@ -108,7 +108,7 @@ class TestNixlMoeMooncakeElasticEP(_EPTestBase): def test_gsm8k_fault_1(self): os.system(f"pkill -f {self.pkill_process_1}") metrics = self._run_gsm8k() - self.assertGreater(metrics["accuracy"], 0.60) + self.assertGreater(metrics["score"], 0.60) if __name__ == "__main__": diff --git a/test/manual/hicache/test_pp_with_hicache.py b/test/manual/hicache/test_pp_with_hicache.py index 761c0fe46e69..9c14d173b9b6 100644 --- a/test/manual/hicache/test_pp_with_hicache.py +++ b/test/manual/hicache/test_pp_with_hicache.py @@ -13,7 +13,7 @@ import requests from sglang.srt.utils import kill_process_tree -from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k +from sglang.test.run_eval import run_eval from sglang.test.test_utils import ( DEFAULT_MODEL_NAME_FOR_TEST, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, @@ -189,24 +189,24 @@ def flush_cache(self): def test_eval_accuracy(self): args = SimpleNamespace( - num_shots=5, - data_path=None, - num_questions=40, - max_new_tokens=256, - parallel=24, - host=f"http://{self.base_host}", - port=int(self.base_port), + base_url=self.base_url, + model=self.model, + eval_name="gsm8k", + api="completion", + max_tokens=512, + num_examples=40, + num_threads=24, ) - metrics_initial = run_eval_few_shot_gsm8k(args) - self.assertGreater(metrics_initial["accuracy"], 0.6) + metrics_initial = run_eval(args) + self.assertGreater(metrics_initial["score"], 0.6) self.flush_cache() - metrics_cached = run_eval_few_shot_gsm8k(args) - self.assertGreater(metrics_cached["accuracy"], 0.6) + metrics_cached = run_eval(args) + self.assertGreater(metrics_cached["score"], 0.6) - accuracy_diff = abs(metrics_initial["accuracy"] - metrics_cached["accuracy"]) + accuracy_diff = abs(metrics_initial["score"] - metrics_cached["score"]) self.assertLess(accuracy_diff, 0.05) diff --git a/test/manual/kv_transfer/test_mooncake_transfer_engine_init.py b/test/manual/kv_transfer/test_mooncake_transfer_engine_init.py new file mode 100755 index 000000000000..33c92328adf2 --- /dev/null +++ b/test/manual/kv_transfer/test_mooncake_transfer_engine_init.py @@ -0,0 +1,430 @@ +#!/usr/bin/env python3 +""" +Test script for validating Mooncake transfer-engine gating and initialization. +Tests the Mooncake-related branches in the current model-runner flow. + +This test verifies: +1. MooncakeTransferEngine initialization conditions +2. Different server argument combinations that trigger mooncake TE +3. Mooncake transfer engine initialization with hostname, gpu_id, and ib_device + +Usage: + # Run from project root on 2 GPUs + CUDA_VISIBLE_DEVICES=0,1 python test/manual/kv_transfer/test_mooncake_transfer_engine_init.py +""" + +import argparse +import multiprocessing +import os +import sys +import time +from dataclasses import dataclass +from types import SimpleNamespace +from typing import Optional +from unittest.mock import patch + + +@dataclass +class ServerArgs: + """Mock ServerArgs for testing.""" + + disaggregation_mode: str = "null" + disaggregation_transfer_backend: str = "mooncake" + enable_hierarchical_cache: bool = False + hicache_storage_backend: str = "mooncake" + encoder_only: bool = False + language_only: bool = False + encoder_transfer_backend: str = "mooncake" + enable_elastic_expert_backup: bool = False + elastic_ep_backend: Optional[str] = None + disaggregation_ib_device: Optional[str] = None + mooncake_ib_device: Optional[str] = None + + +def test_mooncake_te_condition(server_args: ServerArgs) -> bool: + """ + Test the condition logic for using MooncakeTransferEngine. + """ + from sglang.srt.model_executor.model_runner import ModelRunner + + dummy_runner = SimpleNamespace(server_args=server_args, gpu_id=0) + init_called = False + + def _fake_init_mooncake_transfer_engine(*, hostname, gpu_id, ib_device): + nonlocal init_called + init_called = True + return SimpleNamespace( + hostname=hostname, + gpu_id=gpu_id, + ib_device=ib_device, + ) + + with patch( + "sglang.srt.distributed.device_communicators.mooncake_transfer_engine.init_mooncake_transfer_engine", + side_effect=_fake_init_mooncake_transfer_engine, + ), patch( + "sglang.srt.model_executor.model_runner.get_local_ip_auto", + return_value="127.0.0.1", + ): + ModelRunner.init_shared_mooncake_transfer_engine(dummy_runner) + + return init_called + + +def run_mooncake_init( + rank: int, + world_size: int, + master_port: int, + args: argparse.Namespace, + server_args: ServerArgs, +): + """Worker function for testing mooncake transfer engine initialization.""" + os.environ["CUDA_VISIBLE_DEVICES"] = args.cuda_visible_devices + os.environ["MASTER_ADDR"] = "127.0.0.1" + os.environ["MASTER_PORT"] = str(master_port) + os.environ["RANK"] = str(rank) + os.environ["WORLD_SIZE"] = str(world_size) + os.environ["LOCAL_RANK"] = str(rank) + + # Import before try block to avoid NameError in finally + import torch + import torch.distributed as dist + + dist_initialized = False + + try: + # Initialize distributed environment + print(f"[Rank {rank}] Initializing distributed environment...") + dist.init_process_group( + backend="nccl", + world_size=world_size, + rank=rank, + init_method=f"tcp://127.0.0.1:{master_port}", + device_id=rank, + ) + dist_initialized = True + + # Set device + torch.cuda.set_device(rank) + + # Sync to ensure all ranks are ready + dist.barrier() + print(f"[Rank {rank}] Distributed initialization complete.") + + # Test the condition logic + use_mooncake_te = test_mooncake_te_condition(server_args) + print(f"[Rank {rank}] use_mooncake_te = {use_mooncake_te}") + + if use_mooncake_te: + print(f"[Rank {rank}] Attempting to initialize MooncakeTransferEngine...") + + from sglang.srt.distributed.device_communicators.mooncake_transfer_engine import ( + init_mooncake_transfer_engine, + ) + from sglang.srt.utils import get_local_ip_auto + + ib_device = ( + server_args.disaggregation_ib_device or server_args.mooncake_ib_device + ) + + print(f"[Rank {rank}] IB device: {ib_device}") + + # Always actually initialize mooncake + engine = init_mooncake_transfer_engine( + hostname=get_local_ip_auto(), + gpu_id=rank, + ib_device=ib_device, + ) + print(f"[Rank {rank}] Session ID: {engine.get_session_id()}") + print(f"[Rank {rank}] MooncakeTransferEngine initialized successfully!") + + dist.barrier() + + print(f"[Rank {rank}] Test completed successfully!") + sys.exit(0) + + except ImportError as e: + print(f"[Rank {rank}] Mooncake not available (ImportError): {e}") + sys.exit(1) + + except Exception as e: + print(f"[Rank {rank}] Test failed with error: {e}") + import traceback + + traceback.print_exc() + sys.exit(1) + + finally: + # Cleanup + if dist_initialized and dist.is_initialized(): + dist.destroy_process_group() + print(f"[Rank {rank}] Process group destroyed.") + + +def run_test(args: argparse.Namespace, server_args: ServerArgs) -> bool: + """Run the mooncake transfer engine test.""" + # Set CUDA visible devices + cuda_devices = args.cuda_visible_devices.split(",") + world_size = len(cuda_devices) + + if world_size < 2: + print("ERROR: This test requires at least 2 GPUs.") + print( + "Usage: CUDA_VISIBLE_DEVICES=0,1 python test/manual/kv_transfer/test_mooncake_transfer_engine_init.py" + ) + sys.exit(1) + + # Check GPU availability + import torch + + if not torch.cuda.is_available(): + print("ERROR: CUDA is not available") + sys.exit(1) + + available_gpus = torch.cuda.device_count() + if world_size > available_gpus: + print(f"ERROR: Requested {world_size} GPUs but only {available_gpus} available") + sys.exit(1) + + print(f"Testing with {world_size} GPUs: {cuda_devices}") + print() + + # Print server args configuration + print("ServerArgs configuration:") + for key, value in vars(server_args).items(): + print(f" {key}: {value}") + print() + + # Check if mooncake should be used + use_mooncake_te = test_mooncake_te_condition(server_args) + print(f"use_mooncake_te = {use_mooncake_te}") + print() + + # Find a free port + import socket + + with socket.socket() as s: + s.bind(("", 0)) + master_port = s.getsockname()[1] + + print(f"Using master port: {master_port}") + + # Spawn worker processes + ctx = multiprocessing.get_context("spawn") + processes = [] + + for rank in range(world_size): + p = ctx.Process( + target=run_mooncake_init, + args=(rank, world_size, master_port, args, server_args), + ) + p.start() + processes.append(p) + + # Wait for all processes to complete + success = True + for i, p in enumerate(processes): + p.join(timeout=60) + if p.exitcode != 0: + print(f"Process {i} failed with exit code: {p.exitcode}") + success = False + + # Cleanup any remaining processes + for p in processes: + if p.is_alive(): + print(f"Process {p.pid} is still alive, terminating...") + p.terminate() + p.join(timeout=5) + + return success + + +def test_condition_logic(): + """Test the condition logic for different server argument combinations.""" + print("=" * 60) + print("Testing condition logic for use_mooncake_te") + print("=" * 60) + print() + + original_hicache_reuse = os.environ.get("SGLANG_HICACHE_MOONCAKE_REUSE_TE") + passed = 0 + failed = 0 + + try: + test_cases = [ + # (name, env_value, server_args, expected_result) + ( + "PD disaggregation with mooncake", + None, + ServerArgs( + disaggregation_mode="prefill", + disaggregation_transfer_backend="mooncake", + ), + True, + ), + ( + "PD disaggregation without mooncake", + None, + ServerArgs( + disaggregation_mode="prefill", + disaggregation_transfer_backend="other", + ), + False, + ), + ( + "No disaggregation", + None, + ServerArgs(), + False, + ), + ( + "HiCache with mooncake (env=False)", + "0", + ServerArgs( + enable_hierarchical_cache=True, + hicache_storage_backend="mooncake", + ), + False, + ), + ( + "HiCache with mooncake (env=True)", + "1", + ServerArgs( + enable_hierarchical_cache=True, + hicache_storage_backend="mooncake", + ), + True, + ), + ( + "Encoder only with mooncake", + None, + ServerArgs(encoder_only=True, encoder_transfer_backend="mooncake"), + True, + ), + ( + "Language only with mooncake", + None, + ServerArgs(language_only=True, encoder_transfer_backend="mooncake"), + True, + ), + ( + "Elastic expert backup with backend", + None, + ServerArgs( + enable_elastic_expert_backup=True, + elastic_ep_backend="mooncake", + ), + True, + ), + ( + "Elastic expert backup without backend", + None, + ServerArgs(enable_elastic_expert_backup=True, elastic_ep_backend=None), + False, + ), + ] + + for name, env_value, server_args, expected in test_cases: + if env_value is None: + os.environ.pop("SGLANG_HICACHE_MOONCAKE_REUSE_TE", None) + else: + os.environ["SGLANG_HICACHE_MOONCAKE_REUSE_TE"] = env_value + + result = test_mooncake_te_condition(server_args) + status = "PASS" if result == expected else "FAIL" + + if result == expected: + passed += 1 + else: + failed += 1 + + print(f"{status}: {name}") + print(f" Expected: {expected}, Got: {result}") + print() + finally: + if original_hicache_reuse is None: + os.environ.pop("SGLANG_HICACHE_MOONCAKE_REUSE_TE", None) + else: + os.environ["SGLANG_HICACHE_MOONCAKE_REUSE_TE"] = original_hicache_reuse + + print(f"Condition logic tests: {passed} passed, {failed} failed") + print() + + return failed == 0 + + +def main(): + parser = argparse.ArgumentParser( + description="Validate Mooncake transfer-engine gating and initialization" + ) + parser.add_argument( + "--cuda-visible-devices", + type=str, + default="0,1", + help="CUDA visible devices (default: 0,1)", + ) + parser.add_argument( + "--test-case", + type=str, + choices=[ + "pd_disaggregation", + "hicache", + "encoder_only", + "language_only", + "elastic_ep", + ], + default="pd_disaggregation", + help="Test case to run", + ) + + args = parser.parse_args() + + print("=" * 60) + print("Mooncake Transfer Engine Init Test") + print("=" * 60) + print() + + # First run condition logic tests + condition_passed = test_condition_logic() + + if not condition_passed: + print("Condition logic tests failed, skipping distributed test.") + sys.exit(1) + + # Configure server args based on test case + server_args = ServerArgs() + + if args.test_case == "pd_disaggregation": + server_args.disaggregation_mode = "prefill" + server_args.disaggregation_transfer_backend = "mooncake" + elif args.test_case == "hicache": + server_args.enable_hierarchical_cache = True + server_args.hicache_storage_backend = "mooncake" + os.environ["SGLANG_HICACHE_MOONCAKE_REUSE_TE"] = "1" + elif args.test_case == "encoder_only": + server_args.encoder_only = True + server_args.encoder_transfer_backend = "mooncake" + elif args.test_case == "language_only": + server_args.language_only = True + server_args.encoder_transfer_backend = "mooncake" + elif args.test_case == "elastic_ep": + server_args.enable_elastic_expert_backup = True + server_args.elastic_ep_backend = "mooncake" + + start_time = time.time() + success = run_test(args, server_args) + elapsed_time = time.time() - start_time + + print() + print("=" * 60) + if success: + print(f"TEST PASSED (elapsed: {elapsed_time:.2f}s)") + else: + print(f"TEST FAILED (elapsed: {elapsed_time:.2f}s)") + print("=" * 60) + + sys.exit(0 if success else 1) + + +if __name__ == "__main__": + main() diff --git a/test/manual/lora/test_lora_qwen3_vl.py b/test/manual/lora/test_lora_qwen3_vl.py deleted file mode 100644 index cef3649919a4..000000000000 --- a/test/manual/lora/test_lora_qwen3_vl.py +++ /dev/null @@ -1,233 +0,0 @@ -import random -import unittest -from typing import Sequence - -from sglang.srt.models.qwen3_vl import Qwen3VLForConditionalGeneration -from sglang.srt.models.qwen3_vl_moe import Qwen3VLMoeForConditionalGeneration -from sglang.test.lora_utils import ( - TORCH_DTYPES, - LoRAAdaptor, - LoRAModelCase, - ensure_reproducibility, -) -from sglang.test.runners import HFRunner, SRTRunner -from sglang.test.test_utils import CustomTestCase, calculate_rouge_l - - -class TestLoRAQwen3VLGating(CustomTestCase): - """Unit tests for should_apply_lora gating on Qwen3‑VL dense and MoE variants.""" - - def _assert_pattern( - self, pattern, positives: Sequence[str], negatives: Sequence[str] - ): - for name in positives: - self.assertTrue(bool(pattern.match(name)), f"Expected to match: {name}") - for name in negatives: - self.assertFalse(bool(pattern.match(name)), f"Should not match: {name}") - - def test_qwen3_vl_should_apply_lora_regex(self): - positives = ( - "model.layers.0.self_attn.qkv_proj", - "model.layers.1.self_attn.o_proj", - "model.layers.2.mlp.gate_up_proj", - "model.layers.3.mlp.down_proj", - ) - negatives = ( - "visual.blocks.0.attn.qkv_proj", - "model.layers.x.self_attn.qkv_proj", - "model.layers.0.attn.qkv_proj", - "model.layers.0.mlp.not_proj", - "model.layers.0.self_attn.q_proj", - ) - self._assert_pattern( - Qwen3VLForConditionalGeneration._lora_pattern, positives, negatives - ) - - def test_qwen3_vl_moe_should_apply_lora_regex(self): - positives = ( - "model.layers.0.self_attn.qkv_proj", - "model.layers.5.self_attn.o_proj", - ) - negatives = ( - "model.layers.0.mlp.gate_up_proj", - "model.layers.0.mlp.down_proj", - "visual.blocks.0.attn.qkv_proj", - "model.layers.x.self_attn.qkv_proj", - "model.layers.0.attn.qkv_proj", - ) - self._assert_pattern( - Qwen3VLMoeForConditionalGeneration._lora_pattern_moe, positives, negatives - ) - - -TEST_MULTIPLE_BATCH_PROMPTS = [ - """ - ### Instruction: - Tell me about llamas and alpacas - ### Response: - Llamas are large, long-necked animals with a woolly coat. They have two toes on each foot instead of three like other camelids (camels, dromedaries). Llamas live in the Andean mountains of South America where they graze on grasses and shrubs. Alpaca is another name for domesticated llama. The word "alpaca" comes from an Incan language meaning "golden fleece." Alpacas look very similar to llamas but are smaller than their wild relatives. Both species were used by ancient people as pack animals and for meat. Today both llamas and alpacas are raised primarily for their fiber which can be spun into yarn or knitted into clothing. - ### Question 2: - What do you know about llamas? - ### Answer: - """, - """ - ### Instruction: - Write a poem about the transformers Python library. - Mention the word "large language models" in that poem. - ### Response: - The Transformers are large language models, - They're used to make predictions on text. - """, - "AI is a field of computer science focused on", - "Computer science is the study of", - "Write a short story.", - "What are the main components of a computer?", -] - - -LORA_MODEL_VARIANTS = [ - ( - "Qwen3-VL", - LoRAModelCase( - base="Qwen/Qwen3-VL-4B-Instruct", - adaptors=[ - LoRAAdaptor( - name="mryufei/Qwen3-VL-4B-Instruct-trl-sft", - prefill_tolerance=3e-1, - ), - ], - max_loras_per_batch=1, - ), - ), - # TODO: Move 30B MoE to 2 GPU runner - # ( - # "Qwen3-VL-MoE", - # LoRAModelCase( - # base="Qwen/Qwen3-VL-30B-A3B-Instruct", - # adaptors=[ - # LoRAAdaptor( - # name="sosoai/qwen3_vl_30b_lora", - # prefill_tolerance=3e-1, - # ), - # ], - # max_loras_per_batch=1, - # ), - # ), -] - -LORA_MAX_NEW_TOKENS = 32 - - -def _run_lora_multiple_batch_on_model_cases( - model_cases: Sequence[LoRAModelCase], *, max_new_tokens: int, variant_label: str -): - for model_case in model_cases: - for torch_dtype in TORCH_DTYPES: - backend = "csgmv" - base_path = model_case.base - lora_adapter_paths = [adaptor.name for adaptor in model_case.adaptors] - - batches = [ - ( - [ - random.choice(TEST_MULTIPLE_BATCH_PROMPTS), - random.choice(TEST_MULTIPLE_BATCH_PROMPTS), - random.choice(TEST_MULTIPLE_BATCH_PROMPTS), - ], - [None, lora_adapter_paths[0], None], - ), - ( - [ - random.choice(TEST_MULTIPLE_BATCH_PROMPTS), - random.choice(TEST_MULTIPLE_BATCH_PROMPTS), - random.choice(TEST_MULTIPLE_BATCH_PROMPTS), - ], - [lora_adapter_paths[0], None, None], - ), - ( - [ - random.choice(TEST_MULTIPLE_BATCH_PROMPTS), - random.choice(TEST_MULTIPLE_BATCH_PROMPTS), - random.choice(TEST_MULTIPLE_BATCH_PROMPTS), - ], - [None, None, None], - ), - ] - - print( - f"\n=== {variant_label} LoRA parity on '{base_path}', backend={backend}, dtype={torch_dtype} ===" - ) - - ensure_reproducibility() - srt_runner = SRTRunner( - base_path, - torch_dtype=torch_dtype, - model_type="generation", - lora_paths=lora_adapter_paths, - max_loras_per_batch=model_case.max_loras_per_batch, - lora_backend=backend, - sleep_on_idle=True, - attention_backend="torch_native", - disable_radix_cache=True, - ) - - ensure_reproducibility() - hf_runner = HFRunner( - base_path, - torch_dtype=torch_dtype, - model_type="generation", - patch_model_do_sample_false=True, - ) - - with srt_runner, hf_runner: - for i, (prompts, lora_paths) in enumerate(batches): - print( - f"\n--- Running Batch {i + 1} --- prompts: {prompts}, lora_paths: {lora_paths}" - ) - - srt_outputs = srt_runner.batch_forward( - prompts, - max_new_tokens=max_new_tokens, - lora_paths=lora_paths, - ) - - hf_outputs = hf_runner.forward( - prompts, - max_new_tokens=max_new_tokens, - lora_paths=lora_paths, - ) - - print("SRT outputs:", [s for s in srt_outputs.output_strs]) - print("HF outputs:", [s for s in hf_outputs.output_strs]) - - for srt_out, hf_out in zip( - srt_outputs.output_strs, hf_outputs.output_strs - ): - srt_str = srt_out.strip() - hf_str = hf_out.strip() - rouge_tol = model_case.rouge_l_tolerance - rouge_score = calculate_rouge_l([srt_str], [hf_str])[0] - if rouge_score < rouge_tol: - raise AssertionError( - f"ROUGE-L score {rouge_score} below tolerance {rouge_tol} " - f"for base '{base_path}', adaptor '{lora_paths}', backend '{backend}', prompt: '{prompts}...'" - ) - - print(f"--- Batch {i + 1} Comparison Passed --- ") - - -class TestLoRAQwen3VLIntegration(CustomTestCase): - """Parity integration tests for Qwen3‑VL dense and MoE LoRA adapters.""" - - def test_ci_lora_models(self): - for label, model_case in LORA_MODEL_VARIANTS: - with self.subTest(variant=label): - _run_lora_multiple_batch_on_model_cases( - [model_case], - max_new_tokens=LORA_MAX_NEW_TOKENS, - variant_label=label, - ) - - -if __name__ == "__main__": - unittest.main() diff --git a/test/manual/models/test_falcon_h1_models.py b/test/manual/models/test_falcon_h1_models.py index 1706cc8594dd..4630e1cdb69a 100644 --- a/test/manual/models/test_falcon_h1_models.py +++ b/test/manual/models/test_falcon_h1_models.py @@ -1,7 +1,7 @@ from types import SimpleNamespace from sglang.srt.utils import kill_process_tree -from sglang.test.few_shot_gsm8k import run_eval +from sglang.test.run_eval import run_eval from sglang.test.test_utils import ( DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, DEFAULT_URL_FOR_TEST, @@ -31,17 +31,17 @@ def tearDownClass(cls): def test_gsm8k(self): args = SimpleNamespace( - num_shots=5, - data_path=None, - num_questions=200, - max_new_tokens=512, - parallel=128, - host="http://127.0.0.1", - port=int(self.base_url.split(":")[-1]), + base_url=self.base_url, + model=self.model, + eval_name="gsm8k", + api="completion", + max_tokens=512, + num_examples=200, + num_threads=128, ) metrics = run_eval(args) print(f"{metrics=}") - self.assertGreater(metrics["accuracy"], 0.74) + self.assertGreater(metrics["score"], 0.74) class TestFalconH1TP4(CustomTestCase): @@ -65,17 +65,17 @@ def tearDownClass(cls): def test_gsm8k(self): args = SimpleNamespace( - num_shots=5, - data_path=None, - num_questions=200, - max_new_tokens=512, - parallel=128, - host="http://127.0.0.1", - port=int(self.base_url.split(":")[-1]), + base_url=self.base_url, + model=self.model, + eval_name="gsm8k", + api="completion", + max_tokens=512, + num_examples=200, + num_threads=128, ) metrics = run_eval(args) print(f"{metrics=}") - self.assertGreater(metrics["accuracy"], 0.74) + self.assertGreater(metrics["score"], 0.74) class TestFalconH1NoGatedRMS(CustomTestCase): @@ -99,17 +99,17 @@ def tearDownClass(cls): def test_gsm8k(self): args = SimpleNamespace( - num_shots=5, - data_path=None, - num_questions=200, - max_new_tokens=512, - parallel=128, - host="http://127.0.0.1", - port=int(self.base_url.split(":")[-1]), + base_url=self.base_url, + model=self.model, + eval_name="gsm8k", + api="completion", + max_tokens=512, + num_examples=200, + num_threads=128, ) metrics = run_eval(args) print(f"{metrics=}") - self.assertGreater(metrics["accuracy"], 0.74) + self.assertGreater(metrics["score"], 0.74) class TestFalconH1NoGatedTP4(CustomTestCase): @@ -133,14 +133,14 @@ def tearDownClass(cls): def test_gsm8k(self): args = SimpleNamespace( - num_shots=5, - data_path=None, - num_questions=200, - max_new_tokens=512, - parallel=128, - host="http://127.0.0.1", - port=int(self.base_url.split(":")[-1]), + base_url=self.base_url, + model=self.model, + eval_name="gsm8k", + api="completion", + max_tokens=512, + num_examples=200, + num_threads=128, ) metrics = run_eval(args) print(f"{metrics=}") - self.assertGreater(metrics["accuracy"], 0.74) + self.assertGreater(metrics["score"], 0.74) diff --git a/test/manual/models/test_grok_models.py b/test/manual/models/test_grok_models.py index 625fa1a65bfe..9a3b0e516fa0 100644 --- a/test/manual/models/test_grok_models.py +++ b/test/manual/models/test_grok_models.py @@ -2,7 +2,7 @@ from types import SimpleNamespace from sglang.srt.utils import kill_process_tree -from sglang.test.few_shot_gsm8k import run_eval +from sglang.test.run_eval import run_eval from sglang.test.test_utils import ( DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, DEFAULT_URL_FOR_TEST, @@ -34,13 +34,13 @@ def tearDownClass(cls): def test_gsm8k(self): args = SimpleNamespace( - num_shots=5, - data_path=None, - num_questions=64, - max_new_tokens=256, - parallel=128, - host="http://127.0.0.1", - port=int(self.base_url.split(":")[-1]), + base_url=self.base_url, + model=self.model, + eval_name="gsm8k", + api="completion", + max_tokens=512, + num_examples=64, + num_threads=128, ) metrics = run_eval(args) print(f"{metrics=}") diff --git a/test/manual/models/test_kimi_k2_models.py b/test/manual/models/test_kimi_k2_models.py index 6a2fbed71082..6e83ef50c88e 100644 --- a/test/manual/models/test_kimi_k2_models.py +++ b/test/manual/models/test_kimi_k2_models.py @@ -4,7 +4,7 @@ import requests from sglang.srt.utils import kill_process_tree -from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k +from sglang.test.run_eval import run_eval from sglang.test.test_utils import ( DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, DEFAULT_URL_FOR_TEST, @@ -48,22 +48,22 @@ def test_a_gsm8k( requests.get(self.base_url + "/flush_cache") args = SimpleNamespace( - num_shots=5, - data_path=None, - num_questions=200, - max_new_tokens=512, - parallel=128, - host="http://127.0.0.1", - port=int(self.base_url.split(":")[-1]), + base_url=self.base_url, + model=self.model, + eval_name="gsm8k", + api="completion", + max_tokens=512, + num_examples=200, + num_threads=128, ) - metrics = run_eval_few_shot_gsm8k(args) + metrics = run_eval(args) print(f"{metrics=}") if is_in_ci(): write_github_step_summary( - f"### test_gsm8k (Kimi-K2-Thinking)\n" f'{metrics["accuracy"]=:.3f}\n' + f"### test_gsm8k (Kimi-K2-Thinking)\n" f'{metrics["score"]=:.3f}\n' ) - self.assertGreater(metrics["accuracy"], 0.95) + self.assertGreater(metrics["score"], 0.95) if __name__ == "__main__": diff --git a/test/manual/models/test_llama4_models.py b/test/manual/models/test_llama4_models.py index cb0c57604ebe..70c4210fe912 100644 --- a/test/manual/models/test_llama4_models.py +++ b/test/manual/models/test_llama4_models.py @@ -2,7 +2,7 @@ from types import SimpleNamespace from sglang.srt.utils import kill_process_tree -from sglang.test.few_shot_gsm8k import run_eval +from sglang.test.run_eval import run_eval from sglang.test.test_utils import ( DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, DEFAULT_URL_FOR_TEST, @@ -44,17 +44,16 @@ def test_gsm8k(self): ], ) args = SimpleNamespace( - num_shots=5, - data_path=None, - num_questions=200, - max_new_tokens=512, - parallel=128, - host="http://127.0.0.1", - port=int(self.base_url.split(":")[-1]), + base_url=self.base_url, + eval_name="gsm8k", + api="completion", + max_tokens=512, + num_examples=200, + num_threads=128, ) metrics = run_eval(args) print(f"{metrics=}") - self.assertGreaterEqual(metrics["accuracy"], model.accuracy) + self.assertGreaterEqual(metrics["score"], model.accuracy) except Exception as e: print(f"Error testing {model.model}: {e}") self.fail(f"Test failed for {model.model}: {e}") diff --git a/test/manual/models/test_mistral_large3_basic.py b/test/manual/models/test_mistral_large3_basic.py index 3b173f1abd7c..2eac4f79b09d 100644 --- a/test/manual/models/test_mistral_large3_basic.py +++ b/test/manual/models/test_mistral_large3_basic.py @@ -3,7 +3,7 @@ from types import SimpleNamespace from sglang.srt.utils import kill_process_tree -from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k +from sglang.test.run_eval import run_eval from sglang.test.send_one import BenchArgs, send_one_prompt from sglang.test.test_utils import ( DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, @@ -53,22 +53,23 @@ def test_a_gsm8k( self, ): # Append an "a" to make this test run first (alphabetically) to warm up the server args = SimpleNamespace( + base_url=self.base_url, + model=self.model, + eval_name="gsm8k", + api="completion", + max_tokens=512, + num_examples=1400, + num_threads=1400, num_shots=8, - data_path=None, - num_questions=1400, - parallel=1400, - max_new_tokens=512, - host="http://127.0.0.1", - port=int(self.base_url.split(":")[-1]), ) - metrics = run_eval_few_shot_gsm8k(args) + metrics = run_eval(args) print(f"{metrics=}") if is_in_ci(): write_github_step_summary( - f"### test_gsm8k (mistral-large-3)\n" f'{metrics["accuracy"]=:.3f}\n' + f"### test_gsm8k (mistral-large-3)\n" f'{metrics["score"]=:.3f}\n' ) - self.assertGreater(metrics["accuracy"], 0.90) + self.assertGreater(metrics["score"], 0.90) def test_bs_1_speed(self): args = BenchArgs(port=int(self.base_url.split(":")[-1]), max_new_tokens=2048) diff --git a/test/manual/models/test_mtp_models.py b/test/manual/models/test_mtp_models.py index 49b53c1e4573..c5f3fc5cd3ad 100644 --- a/test/manual/models/test_mtp_models.py +++ b/test/manual/models/test_mtp_models.py @@ -2,7 +2,7 @@ from types import SimpleNamespace from sglang.srt.utils import kill_process_tree -from sglang.test.few_shot_gsm8k import run_eval +from sglang.test.run_eval import run_eval from sglang.test.test_utils import ( DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, DEFAULT_URL_FOR_TEST, @@ -41,17 +41,17 @@ def tearDownClass(cls): def test_gsm8k(self): args = SimpleNamespace( - num_shots=5, - data_path=None, - num_questions=200, - max_new_tokens=512, - parallel=128, - host="http://127.0.0.1", - port=int(self.base_url.split(":")[-1]), + base_url=self.base_url, + model=self.model, + eval_name="gsm8k", + api="completion", + max_tokens=512, + num_examples=200, + num_threads=128, ) metrics = run_eval(args) print(f"{metrics=}") - self.assertGreater(metrics["accuracy"], 0.7) + self.assertGreater(metrics["score"], 0.7) if __name__ == "__main__": diff --git a/test/manual/models/test_unsloth_models.py b/test/manual/models/test_unsloth_models.py index 24660ea34fc6..9f71ff1634e0 100644 --- a/test/manual/models/test_unsloth_models.py +++ b/test/manual/models/test_unsloth_models.py @@ -2,7 +2,7 @@ from types import SimpleNamespace from sglang.srt.utils import kill_process_tree -from sglang.test.few_shot_gsm8k import run_eval +from sglang.test.run_eval import run_eval from sglang.test.test_utils import ( DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, DEFAULT_URL_FOR_TEST, @@ -29,17 +29,17 @@ def tearDownClass(cls): def test_gsm8k(self): args = SimpleNamespace( - num_shots=5, - data_path=None, - num_questions=200, - max_new_tokens=512, - parallel=128, - host="http://127.0.0.1", - port=int(self.base_url.split(":")[-1]), + base_url=self.base_url, + model=self.model, + eval_name="gsm8k", + api="completion", + max_tokens=512, + num_examples=200, + num_threads=128, ) metrics = run_eval(args) print(f"{metrics=}") - self.assertGreater(metrics["accuracy"], 0.78) + self.assertGreater(metrics["score"], 0.78) class TestUnslothPhi4Bnb4bit(CustomTestCase): @@ -63,17 +63,17 @@ def tearDownClass(cls): def test_gsm8k(self): args = SimpleNamespace( - num_shots=5, - data_path=None, - num_questions=200, - max_new_tokens=512, - parallel=128, - host="http://127.0.0.1", - port=int(self.base_url.split(":")[-1]), + base_url=self.base_url, + model=self.model, + eval_name="gsm8k", + api="completion", + max_tokens=512, + num_examples=200, + num_threads=128, ) metrics = run_eval(args) print(f"{metrics=}") - self.assertGreater(metrics["accuracy"], 0.75) + self.assertGreater(metrics["score"], 0.75) class TestUnslothPhi4UnslothBnb4bit(CustomTestCase): @@ -97,17 +97,17 @@ def tearDownClass(cls): def test_gsm8k(self): args = SimpleNamespace( - num_shots=5, - data_path=None, - num_questions=200, - max_new_tokens=512, - parallel=128, - host="http://127.0.0.1", - port=int(self.base_url.split(":")[-1]), + base_url=self.base_url, + model=self.model, + eval_name="gsm8k", + api="completion", + max_tokens=512, + num_examples=200, + num_threads=128, ) metrics = run_eval(args) print(f"{metrics=}") - self.assertGreater(metrics["accuracy"], 0.75) + self.assertGreater(metrics["score"], 0.75) class TestUnslothPhi4MiniInstruct(CustomTestCase): @@ -128,17 +128,17 @@ def tearDownClass(cls): def test_gsm8k(self): args = SimpleNamespace( - num_shots=5, - data_path=None, - num_questions=200, - max_new_tokens=512, - parallel=128, - host="http://127.0.0.1", - port=int(self.base_url.split(":")[-1]), + base_url=self.base_url, + model=self.model, + eval_name="gsm8k", + api="completion", + max_tokens=512, + num_examples=200, + num_threads=128, ) metrics = run_eval(args) print(f"{metrics=}") - self.assertGreater(metrics["accuracy"], 0.65) + self.assertGreater(metrics["score"], 0.65) class TestUnslothPhi4MiniBnb4bit(CustomTestCase): @@ -162,17 +162,17 @@ def tearDownClass(cls): def test_gsm8k(self): args = SimpleNamespace( - num_shots=5, - data_path=None, - num_questions=200, - max_new_tokens=512, - parallel=128, - host="http://127.0.0.1", - port=int(self.base_url.split(":")[-1]), + base_url=self.base_url, + model=self.model, + eval_name="gsm8k", + api="completion", + max_tokens=512, + num_examples=200, + num_threads=128, ) metrics = run_eval(args) print(f"{metrics=}") - self.assertGreater(metrics["accuracy"], 0.6) + self.assertGreater(metrics["score"], 0.6) class TestUnslothPhi4MiniUnslothBnb4bit(CustomTestCase): @@ -196,17 +196,17 @@ def tearDownClass(cls): def test_gsm8k(self): args = SimpleNamespace( - num_shots=5, - data_path=None, - num_questions=200, - max_new_tokens=512, - parallel=128, - host="http://127.0.0.1", - port=int(self.base_url.split(":")[-1]), + base_url=self.base_url, + model=self.model, + eval_name="gsm8k", + api="completion", + max_tokens=512, + num_examples=200, + num_threads=128, ) metrics = run_eval(args) print(f"{metrics=}") - self.assertGreater(metrics["accuracy"], 0.6) + self.assertGreater(metrics["score"], 0.6) if __name__ == "__main__": diff --git a/test/manual/piecewise_cudagraph/test_disaggregation_piecewise_cuda_graph.py b/test/manual/piecewise_cudagraph/test_disaggregation_piecewise_cuda_graph.py index 086830a90851..9f2e7bfa056b 100644 --- a/test/manual/piecewise_cudagraph/test_disaggregation_piecewise_cuda_graph.py +++ b/test/manual/piecewise_cudagraph/test_disaggregation_piecewise_cuda_graph.py @@ -1,7 +1,7 @@ import unittest from types import SimpleNamespace -from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k +from sglang.test.run_eval import run_eval from sglang.test.server_fixtures.disaggregation_fixture import ( PDDisaggregationServerBase, ) @@ -70,18 +70,18 @@ def start_decode(cls): def test_gsm8k_accuracy(self): """Verify that piecewise cuda graph works correctly in prefill server""" args = SimpleNamespace( - num_shots=5, - data_path=None, - num_questions=200, - max_new_tokens=512, - parallel=128, - host=f"http://{self.base_host}", - port=int(self.lb_port), + base_url=self.base_url, + model=self.model, + eval_name="gsm8k", + api="completion", + max_tokens=512, + num_examples=200, + num_threads=128, ) - metrics = run_eval_few_shot_gsm8k(args) - print(f"GSM8K accuracy with piecewise cuda graph: {metrics['accuracy']:.3f}") + metrics = run_eval(args) + print(f"GSM8K accuracy with piecewise cuda graph: {metrics['score']:.3f}") - self.assertGreater(metrics["accuracy"], 0.62) + self.assertGreater(metrics["score"], 0.62) if __name__ == "__main__": diff --git a/test/manual/test_async_mm_data_processor.py b/test/manual/test_async_mm_data_processor.py deleted file mode 100644 index 0edc2f5ccc8e..000000000000 --- a/test/manual/test_async_mm_data_processor.py +++ /dev/null @@ -1,365 +0,0 @@ -""" -Unit tests for AsyncMMDataProcessor. - -Covers: - - Async and sync processing paths - - Concurrency limiting via semaphore - - Per-call timeout behavior (async and sync) - - Argument passthrough (images, audios, text/ids, request_obj, kwargs) - - Error propagation and shutdown behavior -""" - -import asyncio -import logging -import sys -import threading -import time -from unittest.mock import Mock - -import pytest - -from sglang.srt.managers.async_mm_data_processor import AsyncMMDataProcessor - - -class TestAsyncMMDataProcessor: - """Test suite for AsyncMMDataProcessor.""" - - @pytest.fixture - def async_processor(self): - """Create a processor exposing an async process_mm_data_async.""" - - class AsyncProc: - async def process_mm_data_async( - self, - *, - image_data=None, - audio_data=None, - input_text=None, - request_obj=None, - **kwargs, - ): - # Allow tests to simulate latency via kwargs - delay = kwargs.get("delay_s", 0.0) - if delay: - await asyncio.sleep(delay) - return { - "path": "async", - "images": image_data, - "audios": audio_data, - "text": input_text, - "request": request_obj, - "kwargs": kwargs, - } - - return AsyncProc() - - @pytest.fixture - def sync_processor(self): - """Provide a processor exposing a sync process_mm_data.""" - - class SyncProc: - def process_mm_data( - self, - *, - image_data=None, - audio_data=None, - input_text=None, - request_obj=None, - **kwargs, - ): - delay = kwargs.get("delay_s", 0.0) - if delay: - # Simulate CPU/blocking work - time.sleep(delay) - return { - "path": "sync", - "images": image_data, - "audios": audio_data, - "text": input_text, - "request": request_obj, - "kwargs": kwargs, - } - - return SyncProc() - - @pytest.mark.asyncio - async def test_async_path_basic(self, async_processor): - """Async processor should be awaited directly.""" - proc = AsyncMMDataProcessor(async_processor) - out = await proc.process( - image_data=["img1.png"], - audio_data=["a.wav"], - input_text_or_ids="hello", - request_obj={"rid": 1}, - mode="fast", - ) - assert out["path"] == "async" - assert out["images"] == ["img1.png"] - assert out["audios"] == ["a.wav"] - assert out["text"] == "hello" - assert out["request"] == {"rid": 1} - assert out["kwargs"]["mode"] == "fast" - - @pytest.mark.asyncio - async def test_sync_fallback_basic(self, sync_processor): - """Sync processor should run in fallback executor.""" - proc = AsyncMMDataProcessor(sync_processor) - out = await proc.process( - image_data=[b"\x00\x01"], - audio_data=None, - input_text_or_ids=[1, 2, 3], - request_obj="req-obj", - role="user", - ) - assert out["path"] == "sync" - assert out["images"] == [b"\x00\x01"] - assert out["audios"] is None - assert out["text"] == [1, 2, 3] - assert out["request"] == "req-obj" - assert out["kwargs"]["role"] == "user" - - @pytest.mark.asyncio - async def test_timeout_async(self, async_processor): - """Timeout should raise asyncio.TimeoutError for async path.""" - proc = AsyncMMDataProcessor(async_processor, timeout_s=0.01) - with pytest.raises(asyncio.TimeoutError): - await proc.process( - input_text_or_ids="slow", - request_obj=None, - delay_s=0.05, # longer than timeout - ) - - @pytest.mark.asyncio - async def test_timeout_sync(self, sync_processor): - """Timeout should raise asyncio.TimeoutError for sync fallback path.""" - proc = AsyncMMDataProcessor(sync_processor, timeout_s=0.01) - with pytest.raises(asyncio.TimeoutError): - await proc.process( - input_text_or_ids="slow", - request_obj=None, - delay_s=0.05, # longer than timeout - ) - - @pytest.mark.asyncio - async def test_semaphore_release_after_timeout(self, sync_processor): - """ - If a call times out, the semaphore should be released so a subsequent call can proceed. - Use >=2 fallback workers so the timed-out thread doesn't block the next call. - """ - proc = AsyncMMDataProcessor( - sync_processor, - max_concurrent_calls=2, - timeout_s=0.01, - ) - - # First call will time out - with pytest.raises(asyncio.TimeoutError): - await proc.process( - input_text_or_ids="slow1", request_obj=None, delay_s=0.05 - ) - - # Second call should be able to acquire the semaphore and complete - out = await proc.process(input_text_or_ids="ok", request_obj=None, delay_s=0.0) - assert out["text"] == "ok" - - @pytest.mark.asyncio - async def test_concurrency_limit_async(self): - """Ensure max_concurrent_calls caps concurrency for async path.""" - current = 0 - max_seen = 0 - - class AsyncProc: - async def process_mm_data_async(self, **kwargs): - nonlocal current, max_seen - current += 1 - max_seen = max(max_seen, current) - try: - await asyncio.sleep(0.02) - return {"ok": True} - finally: - current -= 1 - - proc = AsyncMMDataProcessor(AsyncProc(), max_concurrent_calls=2) - - tasks = [ - proc.process(input_text_or_ids=f"t{i}", request_obj=None) for i in range(6) - ] - await asyncio.gather(*tasks) - - assert max_seen <= 2 - - @pytest.mark.asyncio - async def test_concurrency_limit_sync(self): - """Ensure max_concurrent_calls caps concurrency for sync fallback path.""" - current = 0 - max_seen = 0 - lock = threading.Lock() - - class SyncProc: - def process_mm_data(self, **kwargs): - nonlocal current, max_seen - with lock: - current += 1 - max_seen = max(max_seen, current) - try: - time.sleep(0.02) - return {"ok": True} - finally: - with lock: - current -= 1 - - proc = AsyncMMDataProcessor(SyncProc(), max_concurrent_calls=3) - - tasks = [ - proc.process(input_text_or_ids=f"s{i}", request_obj=None) for i in range(9) - ] - await asyncio.gather(*tasks) - - assert max_seen <= 3 - - @pytest.mark.asyncio - async def test_error_from_async_processor(self): - """Exceptions raised by the async processor should propagate.""" - - class BadAsync: - async def process_mm_data_async(self, **_): - await asyncio.sleep(0) - raise ValueError("async boom") - - proc = AsyncMMDataProcessor(BadAsync()) - with pytest.raises(ValueError, match="async boom"): - await proc.process(input_text_or_ids="x", request_obj=None) - - @pytest.mark.asyncio - async def test_error_from_sync_processor(self): - """Exceptions raised by the sync processor should propagate.""" - - class BadSync: - def process_mm_data(self, **_): - raise RuntimeError("sync boom") - - proc = AsyncMMDataProcessor(BadSync()) - with pytest.raises(RuntimeError, match="sync boom"): - await proc.process(input_text_or_ids="x", request_obj=None) - - @pytest.mark.asyncio - async def test_missing_both_methods_raises(self): - """Processor missing both methods should raise at call time.""" - - class Empty: - pass - - proc = AsyncMMDataProcessor(Empty()) - with pytest.raises( - RuntimeError, match="neither 'process_mm_data_async' nor 'process_mm_data'" - ): - await proc.process(input_text_or_ids="x", request_obj=None) - - @pytest.mark.asyncio - async def test_async_attribute_not_coroutine_uses_sync_fallback(self): - """ - If `process_mm_data_async` exists but isn't a coroutine function, - wrapper should treat it as sync and use `process_mm_data`. - """ - - class WeirdProc: - # Not a coroutine function: - def process_mm_data_async(self, **_): - return {"path": "would-be-async"} - - def process_mm_data(self, **_): - return {"path": "sync"} - - proc = AsyncMMDataProcessor(WeirdProc()) - out = await proc.process(input_text_or_ids="x", request_obj=None) - assert out["path"] == "sync" - - @pytest.mark.asyncio - async def test_kwargs_and_request_passthrough_async(self, async_processor): - """Extra kwargs and request_obj should be forwarded on async path.""" - proc = AsyncMMDataProcessor(async_processor) - out = await proc.process( - image_data=["i1", "i2"], - audio_data=["a1"], - input_text_or_ids="hello world", - request_obj={"uid": 42}, - return_meta=True, - delay_s=0.0, - ) - assert out["images"] == ["i1", "i2"] - assert out["audios"] == ["a1"] - assert out["text"] == "hello world" - assert out["request"] == {"uid": 42} - assert out["kwargs"]["return_meta"] is True - - @pytest.mark.asyncio - async def test_kwargs_and_request_passthrough_sync(self, sync_processor): - """Extra kwargs and request_obj should be forwarded on sync path.""" - proc = AsyncMMDataProcessor(sync_processor) - out = await proc.process( - image_data=None, - audio_data=[], - input_text_or_ids=[101, 102], - request_obj=("r", 7), - lang="en", - ) - assert out["images"] is None - assert out["audios"] == [] - assert out["text"] == [101, 102] - assert out["request"] == ("r", 7) - assert out["kwargs"]["lang"] == "en" - - def test_shutdown_on_sync_executor(self, sync_processor): - """Explicit shutdown should close fallback executor for sync path.""" - proc = AsyncMMDataProcessor(sync_processor) - # Swap real executor for a mock to assert shutdown behavior - proc.fallback_exec = Mock() - proc.shutdown() - proc.fallback_exec.shutdown.assert_called_once_with(wait=False) - - def test_del_calls_shutdown(self, sync_processor, caplog): - """__del__ should best-effort shutdown without raising.""" - caplog.set_level(logging.DEBUG) - proc = AsyncMMDataProcessor(sync_processor) - proc.fallback_exec = Mock() - # Simulate object destruction - proc.__del__() - proc.fallback_exec.shutdown.assert_called_once_with(wait=False) - - @pytest.mark.asyncio - async def test_concurrent_mixed_requests(self, async_processor): - """Mix different payloads and ensure all complete with valid outputs.""" - proc = AsyncMMDataProcessor(async_processor, max_concurrent_calls=4) - - tasks = [ - proc.process(input_text_or_ids="t1", request_obj=1), - proc.process(image_data=["i.png"], input_text_or_ids=[9, 8], request_obj=2), - proc.process( - audio_data=["v.wav"], input_text_or_ids="speech", request_obj=3 - ), - proc.process( - image_data=[], audio_data=[], input_text_or_ids=None, request_obj=4 - ), - ] - outs = await asyncio.gather(*tasks) - assert len(outs) == 4 - for out in outs: - assert "path" in out - assert out["path"] == "async" - - @pytest.mark.asyncio - async def test_many_requests_values_match_inputs(self, sync_processor): - """For sync path, ensure each response corresponds to its specific input.""" - proc = AsyncMMDataProcessor(sync_processor, max_concurrent_calls=8) - texts = [f"msg-{i}" for i in range(10)] - tasks = [ - proc.process(input_text_or_ids=t, request_obj=i) - for i, t in enumerate(texts) - ] - outs = await asyncio.gather(*tasks) - got = [o["text"] for o in outs] - assert got == texts - - -if __name__ == "__main__": - sys.exit(pytest.main([__file__])) diff --git a/test/manual/test_cross_node_scheduler_info_sync.py b/test/manual/test_cross_node_scheduler_info_sync.py new file mode 100755 index 000000000000..f6e5a835fb73 --- /dev/null +++ b/test/manual/test_cross_node_scheduler_info_sync.py @@ -0,0 +1,204 @@ +#!/usr/bin/env python3 +""" +Test cross-node scheduler_infos synchronization for remote weight loading. + +Simulates multi-node setups on a single machine using different GPU subsets. +Validates that scheduler_infos are correctly synced across nodes via Gloo. + +IMPORTANT: For multi-node tests, start both nodes within a few seconds of each +other to avoid port binding conflicts (they share the same network namespace). + +Test cases: + - tp4_nodes2: TP=4 across 2 nodes, validates basic cross-node sync + - dp2_single_node: DP=2 with dp_attention on single node + - dp2_tp2_nodes2: DP=2, TP=4 across 2 nodes with dp_attention + +Usage (multi-node): + Terminal 1: python test_cross_node_scheduler_info_sync.py --test-case tp4_nodes2 --node-rank 0 + Terminal 2: python test_cross_node_scheduler_info_sync.py --test-case tp4_nodes2 --node-rank 1 + Terminal 3: python test_cross_node_scheduler_info_sync.py --test-case tp4_nodes2 --test-only + +Usage (single-node): + Terminal 1: python test_cross_node_scheduler_info_sync.py --test-case dp2_single_node --node-rank 0 + Terminal 2: python test_cross_node_scheduler_info_sync.py --test-case dp2_single_node --test-only + +Requirements: 4 GPUs on single machine +""" + +import argparse +import socket +import subprocess +import sys +import time +from dataclasses import dataclass +from typing import List + +import requests + +from sglang.test.test_utils import ( + DEFAULT_SMALL_MOE_MODEL_NAME_FOR_TEST_CHAT, +) + + +@dataclass +class TestCase: + name: str + tp_size: int + dp_size: int + nnodes: int + gpus_per_node: int + expected_ranks: int + extra_args: List[str] + + +TEST_CASES = { + "tp4_nodes2": TestCase( + name="tp4_nodes2", + tp_size=4, + dp_size=1, + nnodes=2, + gpus_per_node=2, + expected_ranks=4, + extra_args=[], + ), + "dp2_single_node": TestCase( + name="dp2_single_node", + tp_size=2, + dp_size=2, + nnodes=1, + gpus_per_node=2, + expected_ranks=2, + extra_args=["--enable-dp-attention", "--dp", "2", "--attention-backend", "fa3"], + ), + "dp2_tp2_nodes2": TestCase( + name="dp2_tp2_nodes2", + tp_size=4, + dp_size=2, + nnodes=2, + gpus_per_node=2, + expected_ranks=4, + extra_args=["--enable-dp-attention", "--dp", "2", "--attention-backend", "fa3"], + ), +} + +TEST_CASE_MODELS = { + "tp4_nodes2": DEFAULT_SMALL_MOE_MODEL_NAME_FOR_TEST_CHAT, + "dp2_single_node": DEFAULT_SMALL_MOE_MODEL_NAME_FOR_TEST_CHAT, + "dp2_tp2_nodes2": DEFAULT_SMALL_MOE_MODEL_NAME_FOR_TEST_CHAT, +} + + +def get_local_ip() -> str: + s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + try: + s.connect(("8.8.8.8", 80)) + return s.getsockname()[0] + except Exception: + return "127.0.0.1" + finally: + s.close() + + +def launch_node( + test_case: TestCase, node_rank: int, model_path: str, dist_init_addr: str +): + cmd = [ + sys.executable, + "-m", + "sglang.launch_server", + "--model-path", + model_path, + "--tp", + str(test_case.tp_size), + "--port", + str(30000 + node_rank * 100), + "--host", + "0.0.0.0", + "--remote-instance-weight-loader-start-seed-via-transfer-engine", + ] + if test_case.nnodes > 1: + cmd.extend( + [ + "--nnodes", + str(test_case.nnodes), + "--node-rank", + str(node_rank), + "--dist-init-addr", + dist_init_addr, + "--base-gpu-id", + str(node_rank * test_case.gpus_per_node), + ] + ) + cmd.extend(test_case.extra_args) + print(f"[Node {node_rank}] {' '.join(cmd)}") + subprocess.run(cmd) + + +def test_api(test_case: TestCase) -> bool: + base_url = "http://127.0.0.1:30000" + print(f"Testing {test_case.name}: expecting {test_case.expected_ranks} ranks") + + for _ in range(60): + try: + if requests.get(f"{base_url}/health", timeout=2).status_code == 200: + break + except Exception: + pass + time.sleep(2) + else: + print("ERROR: Server not ready") + return False + + all_passed = True + for rank in range(test_case.expected_ranks): + try: + resp = requests.get( + f"{base_url}/get_remote_instance_transfer_engine_info", + params={"rank": rank}, + timeout=5, + ) + status = "āœ“" if resp.status_code == 200 else "āœ—" + print(f"{status} Rank {rank}: {resp.status_code}") + if resp.status_code != 200: + all_passed = False + except Exception as e: + print(f"āœ— Rank {rank}: {e}") + all_passed = False + + print("PASSED" if all_passed else "FAILED") + return all_passed + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--test-case", type=str, choices=list(TEST_CASES.keys()), required=True + ) + parser.add_argument("--node-rank", type=int, choices=[0, 1]) + parser.add_argument("--model-path", type=str, default=None) + parser.add_argument("--dist-init-addr", type=str, default=None) + parser.add_argument("--test-only", action="store_true") + args = parser.parse_args() + + test_case = TEST_CASES[args.test_case] + model_path = args.model_path or TEST_CASE_MODELS.get( + args.test_case, DEFAULT_SMALL_MOE_MODEL_NAME_FOR_TEST_CHAT + ) + + if args.test_only: + sys.exit(0 if test_api(test_case) else 1) + + if test_case.nnodes == 1: + launch_node(test_case, 0, model_path, "") + return + + if args.node_rank is None: + print(f"Usage: --node-rank 0 or 1, then --test-only in another terminal") + sys.exit(0) + + dist_init_addr = args.dist_init_addr or f"{get_local_ip()}:20000" + launch_node(test_case, args.node_rank, model_path, dist_init_addr) + + +if __name__ == "__main__": + main() diff --git a/test/manual/test_mla_tp.py b/test/manual/test_mla_tp.py index e957cf2de89f..5684e7b502d1 100644 --- a/test/manual/test_mla_tp.py +++ b/test/manual/test_mla_tp.py @@ -4,7 +4,7 @@ import torch from sglang.srt.utils import kill_process_tree -from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k +from sglang.test.run_eval import run_eval from sglang.test.test_utils import ( DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, DEFAULT_URL_FOR_TEST, @@ -36,30 +36,30 @@ def tearDownClass(cls): def test_gsm8k(self): args = SimpleNamespace( - num_shots=5, - data_path=None, - num_questions=200, - max_new_tokens=512, - parallel=128, - host="http://127.0.0.1", - port=int(self.base_url.split(":")[-1]), + base_url=self.base_url, + model=self.model, + eval_name="gsm8k", + api="completion", + max_tokens=512, + num_examples=200, + num_threads=128, ) - metrics = run_eval_few_shot_gsm8k(args) - self.assertGreater(metrics["accuracy"], 0.62) + metrics = run_eval(args) + self.assertGreater(metrics["score"], 0.62) def test_gsm8k_bs1(self): # test torch compile accuracy for bs=1 args = SimpleNamespace( - num_shots=5, - data_path=None, - num_questions=10, - max_new_tokens=512, - parallel=1, - host="http://127.0.0.1", - port=int(self.base_url.split(":")[-1]), + base_url=self.base_url, + model=self.model, + eval_name="gsm8k", + api="completion", + max_tokens=512, + num_examples=10, + num_threads=1, ) - metrics = run_eval_few_shot_gsm8k(args) - self.assertGreater(metrics["accuracy"], 0.62) + metrics = run_eval(args) + self.assertGreater(metrics["score"], 0.62) if __name__ == "__main__": diff --git a/test/manual/test_torch_flex_attention_backend.py b/test/manual/test_torch_flex_attention_backend.py index 832ac14c49f2..891471bae35a 100644 --- a/test/manual/test_torch_flex_attention_backend.py +++ b/test/manual/test_torch_flex_attention_backend.py @@ -7,7 +7,7 @@ from types import SimpleNamespace from sglang.srt.utils import kill_process_tree -from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k +from sglang.test.run_eval import run_eval from sglang.test.test_utils import ( DEFAULT_MODEL_NAME_FOR_TEST, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, @@ -30,17 +30,17 @@ def test_gsm8k(self): try: args = SimpleNamespace( + base_url=base_url, + eval_name="gsm8k", + api="completion", + max_tokens=512, + num_examples=100, + num_threads=10, num_shots=8, - data_path=None, - num_questions=100, - parallel=10, - max_new_tokens=512, - host="http://127.0.0.1", - port=int(base_url.split(":")[-1]), ) - metrics = run_eval_few_shot_gsm8k(args) + metrics = run_eval(args) print(f"{metrics=}") - self.assertGreater(metrics["accuracy"], 0.62) + self.assertGreater(metrics["score"], 0.62) finally: kill_process_tree(process.pid) diff --git a/test/manual/test_tracing.py b/test/manual/test_tracing.py deleted file mode 100644 index 808809278624..000000000000 --- a/test/manual/test_tracing.py +++ /dev/null @@ -1,314 +0,0 @@ -import multiprocessing as mp -import os -import subprocess -import time -import unittest -from dataclasses import dataclass -from typing import Optional, Union - -import requests -import zmq - -from sglang import Engine -from sglang.srt.observability.trace import * -from sglang.srt.observability.trace import get_cur_time_ns, set_global_trace_level -from sglang.srt.utils import kill_process_tree -from sglang.srt.utils.network import get_zmq_socket -from sglang.test.test_utils import ( - DEFAULT_SMALL_MODEL_NAME_FOR_TEST, - DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, - DEFAULT_URL_FOR_TEST, - CustomTestCase, - popen_launch_server, -) - - -@dataclass -class Req: - rid: int - req_context: Optional[Union[TraceReqContext]] = None - - -class TestTrace(CustomTestCase): - def __launch_otel_jaeger(self): - cmd = [ - "docker", - "compose", - "-f", - "../../examples/monitoring/tracing_compose.yaml", - "up", - "-d", - ] - proc = subprocess.run(cmd) - - if proc.returncode != 0: - print("launch opentelemetry collector and jaeger docker err") - return False - return True - - def __stop_otel_jaeger(self): - cmd = [ - "docker", - "compose", - "-f", - "../../examples/monitoring/tracing_compose.yaml", - "down", - ] - proc = subprocess.run(cmd) - - if proc.returncode != 0: - print("stop opentelemetry collector and jaeger docker err") - return False - return True - - def __clear_trace_file(self): - try: - os.remove("/tmp/otel_trace.json") - except: - pass - - def __test_trace_enable(self, trace_level, expect_export_data): - self.__clear_trace_file() - assert self.__launch_otel_jaeger() - self.addCleanup(self.__stop_otel_jaeger) - - process = popen_launch_server( - DEFAULT_SMALL_MODEL_NAME_FOR_TEST, - DEFAULT_URL_FOR_TEST, - timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, - other_args=[ - "--enable-trace", - "--otlp-traces-endpoint", - "0.0.0.0:4317", - ], - ) - - try: - response = requests.get(f"{DEFAULT_URL_FOR_TEST}/health_generate") - self.assertEqual(response.status_code, 200) - - # set trace level - response = requests.get( - f"{DEFAULT_URL_FOR_TEST}/set_trace_level?level={trace_level}" - ) - self.assertEqual(response.status_code, 200) - - # Make some requests to generate trace data - response = requests.post( - f"{DEFAULT_URL_FOR_TEST}/generate", - json={ - "text": "The capital of France is", - "sampling_params": { - "temperature": 0, - "max_new_tokens": 32, - }, - "stream": True, - }, - stream=True, - ) - for _ in response.iter_lines(decode_unicode=False): - pass - - # sleep for a few seconds to wait for opentelemetry collector to asynchronously export data to file. - time.sleep(10) - - # check trace file - assert os.path.isfile("/tmp/otel_trace.json"), "trace file not exist" - if expect_export_data: - assert ( - os.path.getsize("/tmp/otel_trace.json") > 0 - ), "trace file is empty" - else: - assert ( - os.path.getsize("/tmp/otel_trace.json") == 0 - ), "trace file is not empty" - - finally: - kill_process_tree(process.pid) - - def test_trace_enable_level_1(self): - self.__test_trace_enable("1", True) - - def test_trace_enable_level_2(self): - self.__test_trace_enable("2", True) - - def test_trace_enable_level_3(self): - self.__test_trace_enable("3", True) - - def test_trace_enable_level_0(self): - self.__test_trace_enable("0", False) - - def test_trace_engine_enable(self): - self.__clear_trace_file() - assert self.__launch_otel_jaeger() - self.addCleanup(self.__stop_otel_jaeger) - - prompt = "Today is a sunny day and I like" - model_path = DEFAULT_SMALL_MODEL_NAME_FOR_TEST - - sampling_params = {"temperature": 0, "max_new_tokens": 8} - - engine = Engine( - model_path=model_path, - random_seed=42, - enable_trace=True, - otlp_traces_endpoint="localhost:4317", - ) - - try: - engine.generate(prompt, sampling_params) - - # sleep for a few seconds to wait for opentelemetry collector to asynchronously export data to file. - time.sleep(10) - - # check trace file - assert os.path.isfile("/tmp/otel_trace.json"), "trace file not exist" - assert os.path.getsize("/tmp/otel_trace.json") > 0, "trace file is empty" - finally: - engine.shutdown() - - def test_trace_engine_encode(self): - self.__clear_trace_file() - assert self.__launch_otel_jaeger() - self.addCleanup(self.__stop_otel_jaeger) - - prompt = "Today is a sunny day and I like" - model_path = "Qwen/Qwen2-7B" - - engine = Engine( - model_path=model_path, - random_seed=42, - enable_trace=True, - otlp_traces_endpoint="localhost:4317", - is_embedding=True, - ) - - try: - engine.encode(prompt) - - # sleep for a few seconds to wait for opentelemetry collector to asynchronously export data to file. - time.sleep(10) - - # check trace file - assert os.path.isfile("/tmp/otel_trace.json"), "trace file not exist" - assert os.path.getsize("/tmp/otel_trace.json") > 0, "trace file is empty" - finally: - engine.shutdown() - - def test_slice_trace_simple(self): - self.__clear_trace_file() - assert self.__launch_otel_jaeger() - self.addCleanup(self.__stop_otel_jaeger) - try: - process_tracing_init("0.0.0.0:4317", "test") - trace_set_thread_info("Test") - set_global_trace_level(3) - req_context = TraceReqContext(0) - req_context.trace_req_start() - req_context.trace_slice_start("test slice", level=1) - time.sleep(1) - req_context.trace_slice_end("test slice", level=1) - req_context.trace_req_finish() - - # sleep for a few seconds to wait for opentelemetry collector to asynchronously export data to file. - time.sleep(10) - # check trace file - assert os.path.isfile("/tmp/otel_trace.json"), "trace file not exist" - assert os.path.getsize("/tmp/otel_trace.json") > 0, "trace file is empty" - finally: - pass - - def test_slice_trace_complex(self): - self.__clear_trace_file() - assert self.__launch_otel_jaeger() - self.addCleanup(self.__stop_otel_jaeger) - try: - process_tracing_init("0.0.0.0:4317", "test") - trace_set_thread_info("Test") - set_global_trace_level(3) - req_context = TraceReqContext(0) - req_context.trace_req_start() - t1 = get_cur_time_ns() - time.sleep(1) - req_context.trace_event("event test", 1) - t2 = get_cur_time_ns() - time.sleep(1) - t3 = get_cur_time_ns() - slice1 = TraceSliceContext("slice A", t1, t2) - slice2 = TraceSliceContext("slice B", t2, t3) - req_context.trace_slice(slice1) - req_context.trace_slice(slice2, thread_finish_flag=True) - req_context.trace_req_finish() - - # sleep for a few seconds to wait for opentelemetry collector to asynchronously export data to file. - time.sleep(10) - # check trace file - assert os.path.isfile("/tmp/otel_trace.json"), "trace file not exist" - assert os.path.getsize("/tmp/otel_trace.json") > 0, "trace file is empty" - finally: - pass - - def test_trace_context_propagete(self): - def __process_work(): - process_tracing_init("0.0.0.0:4317", "test") - trace_set_thread_info("Sub Process") - - context = zmq.Context(2) - recv_from_main = get_zmq_socket( - context, zmq.PULL, "ipc:///tmp/zmq_test.ipc", True - ) - - try: - req = recv_from_main.recv_pyobj() - req.req_context.rebuild_thread_context() - req.req_context.trace_slice_start("work", level=1) - time.sleep(1) - req.req_context.trace_slice_end( - "work", level=1, thread_finish_flag=True - ) - finally: - recv_from_main.close() - context.term() - - self.__clear_trace_file() - assert self.__launch_otel_jaeger() - self.addCleanup(self.__stop_otel_jaeger) - - context = zmq.Context(2) - send_to_subproc = get_zmq_socket( - context, zmq.PUSH, "ipc:///tmp/zmq_test.ipc", False - ) - try: - process_tracing_init("0.0.0.0:4317", "test") - trace_set_thread_info("Main Process") - - subproc = mp.Process(target=__process_work) - subproc.start() - - # sleep for a few second to ensure subprocess init - time.sleep(1) - - req = Req(rid=0) - req.req_context = TraceReqContext(0) - req.req_context.trace_req_start() - req.req_context.trace_slice_start("dispatch", level=1) - time.sleep(1) - send_to_subproc.send_pyobj(req) - req.req_context.trace_slice_end("dispatch", level=1) - - subproc.join() - req.req_context.trace_req_finish() - - # sleep for a few seconds to wait for opentelemetry collector to asynchronously export data to file. - time.sleep(10) - # check trace file - assert os.path.isfile("/tmp/otel_trace.json"), "trace file not exist" - assert os.path.getsize("/tmp/otel_trace.json") > 0, "trace file is empty" - - finally: - send_to_subproc.close() - context.term() - - -if __name__ == "__main__": - unittest.main() diff --git a/test/manual/test_whisper_cuda_graph.py b/test/manual/test_whisper_cuda_graph.py new file mode 100644 index 000000000000..72d6da16b068 --- /dev/null +++ b/test/manual/test_whisper_cuda_graph.py @@ -0,0 +1,161 @@ +""" +Test Whisper model with CUDA graph support. + +This test verifies that: +1. Whisper model works correctly with CUDA graph enabled (default) +2. Cross-attention KV cache is properly managed through RadixAttention +3. Output is consistent between CUDA graph and non-CUDA-graph modes + +Usage: + python test_whisper_cuda_graph.py + +Requires: + - A GPU with sufficient memory + - openai-whisper model (e.g., openai/whisper-large-v3) + - An audio file or URL for testing +""" + +import io +import unittest + +import requests + +from sglang.srt.utils import kill_process_tree +from sglang.test.test_utils import ( + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, + CustomTestCase, + popen_launch_server, +) + +WHISPER_MODEL = "openai/whisper-large-v3" +TEST_AUDIO_URL = "https://huggingface.co/datasets/Narsil/asr_dummy/resolve/main/1.flac" +TEST_AUDIO_LOCAL = "/tmp/test_whisper_audio.flac" + + +def get_audio_bytes(): + """Get audio bytes, downloading if necessary.""" + import os + + if os.path.exists(TEST_AUDIO_LOCAL): + with open(TEST_AUDIO_LOCAL, "rb") as f: + return f.read() + resp = requests.get(TEST_AUDIO_URL, timeout=30) + resp.raise_for_status() + with open(TEST_AUDIO_LOCAL, "wb") as f: + f.write(resp.content) + return resp.content + + +class TestWhisperCudaGraph(CustomTestCase): + """Test Whisper with CUDA graph enabled (default behavior).""" + + @classmethod + def setUpClass(cls): + cls.model = WHISPER_MODEL + cls.base_url = DEFAULT_URL_FOR_TEST + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=[ + "--served-model-name", + "whisper", + ], + ) + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.process.pid) + + def _transcribe(self, language="en"): + """Send a transcription request via OpenAI-compatible audio endpoint.""" + audio_bytes = get_audio_bytes() + response = requests.post( + self.base_url + "/v1/audio/transcriptions", + files={"file": ("audio.ogg", io.BytesIO(audio_bytes), "audio/ogg")}, + data={ + "model": "whisper", + "language": language, + }, + ) + self.assertEqual(response.status_code, 200, response.text) + return response.json() + + def test_basic_transcription(self): + """Test that basic transcription works with CUDA graph.""" + result = self._transcribe() + self.assertIn("text", result) + text = result["text"] + self.assertTrue(len(text) > 0, "Transcription should not be empty") + print(f"Transcription: {text}") + + def test_multiple_sequential_requests(self): + """Test multiple sequential requests to verify CUDA graph replay consistency.""" + results = [] + for i in range(3): + result = self._transcribe() + self.assertIn("text", result) + results.append(result["text"]) + print(f"Request {i+1}: {result['text'][:80]}...") + + # All transcriptions of the same audio should be identical + for i in range(1, len(results)): + self.assertEqual( + results[0], + results[i], + f"Transcription {i+1} differs from first transcription", + ) + + def test_transcription_quality(self): + """Test that transcription quality is reasonable (contains expected words).""" + result = self._transcribe() + text = result["text"].lower() + # The test audio is a LibriSpeech sample about stew for dinner + self.assertIn("stew", text, f"Expected 'stew' in transcription: {text}") + self.assertIn("dinner", text, f"Expected 'dinner' in transcription: {text}") + print(f"Quality check passed: {result['text'][:80]}...") + + +class TestWhisperNoCudaGraph(CustomTestCase): + """Test Whisper with CUDA graph explicitly disabled for comparison.""" + + @classmethod + def setUpClass(cls): + cls.model = WHISPER_MODEL + cls.base_url = DEFAULT_URL_FOR_TEST + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=[ + "--served-model-name", + "whisper", + "--disable-cuda-graph", + ], + ) + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.process.pid) + + def test_basic_transcription_no_cuda_graph(self): + """Test that transcription works without CUDA graph (baseline).""" + audio_bytes = get_audio_bytes() + response = requests.post( + self.base_url + "/v1/audio/transcriptions", + files={"file": ("audio.ogg", io.BytesIO(audio_bytes), "audio/ogg")}, + data={ + "model": "whisper", + "language": "en", + }, + ) + self.assertEqual(response.status_code, 200, response.text) + result = response.json() + self.assertIn("text", result) + self.assertTrue(len(result["text"]) > 0) + print(f"No CUDA graph transcription: {result['text'][:80]}...") + + +if __name__ == "__main__": + unittest.main(verbosity=3) diff --git a/test/registered/4-gpu-models/test_deepseek_v3_cutedsl_4gpu.py b/test/registered/4-gpu-models/test_deepseek_v3_cutedsl_4gpu.py index 7babc15bd078..6f89bdfaacc3 100644 --- a/test/registered/4-gpu-models/test_deepseek_v3_cutedsl_4gpu.py +++ b/test/registered/4-gpu-models/test_deepseek_v3_cutedsl_4gpu.py @@ -4,7 +4,7 @@ from sglang.srt.utils import kill_process_tree from sglang.test.ci.ci_register import register_cuda_ci -from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k +from sglang.test.run_eval import run_eval from sglang.test.test_utils import ( DEFAULT_DEEPSEEK_NVFP4_MODEL_FOR_TEST, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, @@ -72,18 +72,18 @@ def tearDownClass(cls): def test_gsm8k(self): args = SimpleNamespace( - num_shots=5, - data_path=None, - num_questions=512, - parallel=512, - max_new_tokens=512, - host="http://127.0.0.1", - port=int(self.base_url.split(":")[-1]), + base_url=self.base_url, + model=self.model, + eval_name="gsm8k", + api="completion", + max_tokens=512, + num_examples=512, + num_threads=512, ) - metrics = run_eval_few_shot_gsm8k(args) + metrics = run_eval(args) print(f"Eval accuracy of GSM8K: {metrics=}") - self.assertGreater(metrics["accuracy"], 0.92) + self.assertGreater(metrics["score"], 0.92) class TestDummyWithSBO(CustomTestCase): @@ -148,15 +148,16 @@ def tearDownClass(cls): def test_gsm8k(self): args = SimpleNamespace( + base_url=self.base_url, + model=self.model, + eval_name="gsm8k", + api="completion", + max_tokens=512, + num_examples=512, + num_threads=512, num_shots=0, - data_path=None, - num_questions=512, - parallel=512, - max_new_tokens=16, - host="http://127.0.0.1", - port=int(self.base_url.split(":")[-1]), ) - metrics = run_eval_few_shot_gsm8k(args) + metrics = run_eval(args) print(f"Eval accuracy of GSM8K: {metrics=}") diff --git a/test/registered/4-gpu-models/test_qwen35_models.py b/test/registered/4-gpu-models/test_qwen35_models.py index f088c8242915..562d201f21e6 100644 --- a/test/registered/4-gpu-models/test_qwen35_models.py +++ b/test/registered/4-gpu-models/test_qwen35_models.py @@ -149,7 +149,7 @@ def test_gsm8k(self): print(f"{metrics=}") self.assertGreaterEqual(metrics["score"], ACC_THRESHOLDS[self.model]["gsm8k"]) - server_info = requests.get(self.base_url + "/get_server_info") + server_info = requests.get(self.base_url + "/server_info") avg_spec_accept_length = server_info.json()["internal_states"][0][ "avg_spec_accept_length" ] @@ -226,7 +226,7 @@ def test_gsm8k(self): print(f"{metrics=}") self.assertGreaterEqual(metrics["score"], ACC_THRESHOLDS[self.model]["gsm8k"]) - server_info = requests.get(self.base_url + "/get_server_info") + server_info = requests.get(self.base_url + "/server_info") avg_spec_accept_length = server_info.json()["internal_states"][0][ "avg_spec_accept_length" ] diff --git a/test/registered/8-gpu-models/test_deepseek_v32_basic.py b/test/registered/8-gpu-models/test_deepseek_v32_basic.py index b1c3c3cf96b9..c3c8430de69f 100644 --- a/test/registered/8-gpu-models/test_deepseek_v32_basic.py +++ b/test/registered/8-gpu-models/test_deepseek_v32_basic.py @@ -3,7 +3,7 @@ from sglang.srt.utils import kill_process_tree from sglang.test.ci.ci_register import register_cuda_ci -from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k +from sglang.test.run_eval import run_eval from sglang.test.send_one import BenchArgs, send_one_prompt from sglang.test.test_utils import ( DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, @@ -49,22 +49,23 @@ def test_a_gsm8k( self, ): # Append an "a" to make this test run first (alphabetically) to warm up the server args = SimpleNamespace( + base_url=self.base_url, + model=self.model, + eval_name="gsm8k", + api="completion", + max_tokens=512, + num_examples=1400, + num_threads=1400, num_shots=20, - data_path=None, - num_questions=1400, - parallel=1400, - max_new_tokens=512, - host="http://127.0.0.1", - port=int(self.base_url.split(":")[-1]), ) - metrics = run_eval_few_shot_gsm8k(args) + metrics = run_eval(args) print(f"{metrics=}") if is_in_ci(): write_github_step_summary( - f"### test_gsm8k (deepseek-v32)\n" f'{metrics["accuracy"]=:.3f}\n' + f"### test_gsm8k (deepseek-v32)\n" f'{metrics["score"]=:.3f}\n' ) - self.assertGreater(metrics["accuracy"], 0.935) + self.assertGreater(metrics["score"], 0.935) def test_bs_1_speed(self): args = BenchArgs(port=int(self.base_url.split(":")[-1]), max_new_tokens=2048) @@ -106,22 +107,23 @@ def test_a_gsm8k( self, ): # Append an "a" to make this test run first (alphabetically) to warm up the server args = SimpleNamespace( + base_url=self.base_url, + model=self.model, + eval_name="gsm8k", + api="completion", + max_tokens=512, + num_examples=1400, + num_threads=1400, num_shots=20, - data_path=None, - num_questions=1400, - parallel=1400, - max_new_tokens=512, - host="http://127.0.0.1", - port=int(self.base_url.split(":")[-1]), ) - metrics = run_eval_few_shot_gsm8k(args) + metrics = run_eval(args) print(f"{metrics=}") if is_in_ci(): write_github_step_summary( - f"### test_gsm8k (deepseek-v32)\n" f'{metrics["accuracy"]=:.3f}\n' + f"### test_gsm8k (deepseek-v32)\n" f'{metrics["score"]=:.3f}\n' ) - self.assertGreater(metrics["accuracy"], 0.935) + self.assertGreater(metrics["score"], 0.935) def test_bs_1_speed(self): args = BenchArgs(port=int(self.base_url.split(":")[-1]), max_new_tokens=2048) diff --git a/test/registered/8-gpu-models/test_deepseek_v32_mtp.py b/test/registered/8-gpu-models/test_deepseek_v32_mtp.py index 75498322f264..ef1e34b41a32 100644 --- a/test/registered/8-gpu-models/test_deepseek_v32_mtp.py +++ b/test/registered/8-gpu-models/test_deepseek_v32_mtp.py @@ -6,7 +6,7 @@ from sglang.srt.environ import envs from sglang.srt.utils import kill_process_tree from sglang.test.ci.ci_register import register_cuda_ci -from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k +from sglang.test.run_eval import run_eval from sglang.test.send_one import BenchArgs, send_one_prompt from sglang.test.test_utils import ( DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, @@ -64,18 +64,19 @@ def test_a_gsm8k( requests.get(self.base_url + "/flush_cache") args = SimpleNamespace( + base_url=self.base_url, + model=self.model, + eval_name="gsm8k", + api="completion", + max_tokens=512, + num_examples=500, + num_threads=500, num_shots=20, - data_path=None, - num_questions=500, - parallel=500, - max_new_tokens=512, - host="http://127.0.0.1", - port=int(self.base_url.split(":")[-1]), ) - metrics = run_eval_few_shot_gsm8k(args) + metrics = run_eval(args) print(f"{metrics=}") - server_info = requests.get(self.base_url + "/get_server_info") + server_info = requests.get(self.base_url + "/server_info") avg_spec_accept_length = server_info.json()["internal_states"][0][ "avg_spec_accept_length" ] @@ -84,10 +85,10 @@ def test_a_gsm8k( if is_in_ci(): write_github_step_summary( f"### test_gsm8k (deepseek-v32 mtp)\n" - f'{metrics["accuracy"]=:.3f}\n' + f'{metrics["score"]=:.3f}\n' f"{avg_spec_accept_length=:.2f}\n" ) - self.assertGreater(metrics["accuracy"], 0.94) + self.assertGreater(metrics["score"], 0.94) self.assertGreater(avg_spec_accept_length, 2.7) def test_bs_1_speed(self): @@ -150,18 +151,19 @@ def test_a_gsm8k( requests.get(self.base_url + "/flush_cache") args = SimpleNamespace( + base_url=self.base_url, + model=self.model, + eval_name="gsm8k", + api="completion", + max_tokens=512, + num_examples=500, + num_threads=500, num_shots=20, - data_path=None, - num_questions=500, - parallel=500, - max_new_tokens=512, - host="http://127.0.0.1", - port=int(self.base_url.split(":")[-1]), ) - metrics = run_eval_few_shot_gsm8k(args) + metrics = run_eval(args) print(f"{metrics=}") - server_info = requests.get(self.base_url + "/get_server_info") + server_info = requests.get(self.base_url + "/server_info") avg_spec_accept_length = server_info.json()["internal_states"][0][ "avg_spec_accept_length" ] @@ -170,10 +172,10 @@ def test_a_gsm8k( if is_in_ci(): write_github_step_summary( f"### test_gsm8k (deepseek-v32 mtp)\n" - f'{metrics["accuracy"]=:.3f}\n' + f'{metrics["score"]=:.3f}\n' f"{avg_spec_accept_length=:.2f}\n" ) - self.assertGreater(metrics["accuracy"], 0.94) + self.assertGreater(metrics["score"], 0.94) self.assertGreater(avg_spec_accept_length, 2.7) def test_bs_1_speed(self): @@ -232,18 +234,19 @@ def test_a_gsm8k( requests.get(self.base_url + "/flush_cache") args = SimpleNamespace( + base_url=self.base_url, + model=self.model, + eval_name="gsm8k", + api="completion", + max_tokens=512, + num_examples=500, + num_threads=500, num_shots=20, - data_path=None, - num_questions=500, - parallel=500, - max_new_tokens=512, - host="http://127.0.0.1", - port=int(self.base_url.split(":")[-1]), ) - metrics = run_eval_few_shot_gsm8k(args) + metrics = run_eval(args) print(f"{metrics=}") - server_info = requests.get(self.base_url + "/get_server_info") + server_info = requests.get(self.base_url + "/server_info") avg_spec_accept_length = server_info.json()["internal_states"][0][ "avg_spec_accept_length" ] @@ -252,10 +255,10 @@ def test_a_gsm8k( if is_in_ci(): write_github_step_summary( f"### test_gsm8k (deepseek-v32 mtp)\n" - f'{metrics["accuracy"]=:.3f}\n' + f'{metrics["score"]=:.3f}\n' f"{avg_spec_accept_length=:.2f}\n" ) - self.assertGreater(metrics["accuracy"], 0.94) + self.assertGreater(metrics["score"], 0.94) self.assertGreater(avg_spec_accept_length, 2.7) def test_bs_1_speed(self): @@ -315,18 +318,19 @@ def test_a_gsm8k( requests.get(self.base_url + "/flush_cache") args = SimpleNamespace( + base_url=self.base_url, + model=self.model, + eval_name="gsm8k", + api="completion", + max_tokens=512, + num_examples=500, + num_threads=500, num_shots=20, - data_path=None, - num_questions=500, - parallel=500, - max_new_tokens=512, - host="http://127.0.0.1", - port=int(self.base_url.split(":")[-1]), ) - metrics = run_eval_few_shot_gsm8k(args) + metrics = run_eval(args) print(f"{metrics=}") - server_info = requests.get(self.base_url + "/get_server_info") + server_info = requests.get(self.base_url + "/server_info") avg_spec_accept_length = server_info.json()["internal_states"][0][ "avg_spec_accept_length" ] @@ -335,10 +339,10 @@ def test_a_gsm8k( if is_in_ci(): write_github_step_summary( f"### test_gsm8k (deepseek-v32 mtp)\n" - f'{metrics["accuracy"]=:.3f}\n' + f'{metrics["score"]=:.3f}\n' f"{avg_spec_accept_length=:.2f}\n" ) - self.assertGreater(metrics["accuracy"], 0.94) + self.assertGreater(metrics["score"], 0.94) self.assertGreater(avg_spec_accept_length, 2.7) def test_bs_1_speed(self): diff --git a/test/registered/8-gpu-models/test_deepseek_v3_basic.py b/test/registered/8-gpu-models/test_deepseek_v3_basic.py index 08a683f52d6e..acbde5475051 100644 --- a/test/registered/8-gpu-models/test_deepseek_v3_basic.py +++ b/test/registered/8-gpu-models/test_deepseek_v3_basic.py @@ -3,7 +3,7 @@ from sglang.srt.utils import kill_process_tree from sglang.test.ci.ci_register import register_cuda_ci -from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k +from sglang.test.run_eval import run_eval from sglang.test.send_one import BenchArgs, send_one_prompt from sglang.test.test_utils import ( DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, @@ -47,22 +47,23 @@ def test_a_gsm8k( self, ): # Append an "a" to make this test run first (alphabetically) to warm up the server args = SimpleNamespace( + base_url=self.base_url, + model=self.model, + eval_name="gsm8k", + api="completion", + max_tokens=512, + num_examples=1400, + num_threads=1400, num_shots=8, - data_path=None, - num_questions=1400, - parallel=1400, - max_new_tokens=512, - host="http://127.0.0.1", - port=int(self.base_url.split(":")[-1]), ) - metrics = run_eval_few_shot_gsm8k(args) + metrics = run_eval(args) print(f"{metrics=}") if is_in_ci(): write_github_step_summary( - f"### test_gsm8k (deepseek-v3)\n" f'{metrics["accuracy"]=:.3f}\n' + f"### test_gsm8k (deepseek-v3)\n" f'{metrics["score"]=:.3f}\n' ) - self.assertGreater(metrics["accuracy"], 0.935) + self.assertGreater(metrics["score"], 0.935) def test_bs_1_speed(self): args = BenchArgs(port=int(self.base_url.split(":")[-1]), max_new_tokens=2048) diff --git a/test/registered/8-gpu-models/test_deepseek_v3_mtp.py b/test/registered/8-gpu-models/test_deepseek_v3_mtp.py index 31e99ab0819c..48e0847ac2f5 100644 --- a/test/registered/8-gpu-models/test_deepseek_v3_mtp.py +++ b/test/registered/8-gpu-models/test_deepseek_v3_mtp.py @@ -5,7 +5,7 @@ from sglang.srt.utils import kill_process_tree from sglang.test.ci.ci_register import register_cuda_ci -from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k +from sglang.test.run_eval import run_eval from sglang.test.send_one import BenchArgs, send_one_prompt from sglang.test.test_utils import ( DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, @@ -61,18 +61,18 @@ def test_a_gsm8k( requests.get(self.base_url + "/flush_cache") args = SimpleNamespace( - num_shots=5, - data_path=None, - num_questions=200, - max_new_tokens=512, - parallel=128, - host="http://127.0.0.1", - port=int(self.base_url.split(":")[-1]), + base_url=self.base_url, + model=self.model, + eval_name="gsm8k", + api="completion", + max_tokens=512, + num_examples=200, + num_threads=128, ) - metrics = run_eval_few_shot_gsm8k(args) + metrics = run_eval(args) print(f"{metrics=}") - server_info = requests.get(self.base_url + "/get_server_info") + server_info = requests.get(self.base_url + "/server_info") avg_spec_accept_length = server_info.json()["internal_states"][0][ "avg_spec_accept_length" ] @@ -81,10 +81,10 @@ def test_a_gsm8k( if is_in_ci(): write_github_step_summary( f"### test_gsm8k (deepseek-v3 mtp)\n" - f'{metrics["accuracy"]=:.3f}\n' + f'{metrics["score"]=:.3f}\n' f"{avg_spec_accept_length=:.2f}\n" ) - self.assertGreater(metrics["accuracy"], 0.935) + self.assertGreater(metrics["score"], 0.935) self.assertGreater(avg_spec_accept_length, 2.8) def test_bs_1_speed(self): diff --git a/test/registered/8-gpu-models/test_kimi_k25.py b/test/registered/8-gpu-models/test_kimi_k25.py index 5f903f1020d3..a160e8211822 100644 --- a/test/registered/8-gpu-models/test_kimi_k25.py +++ b/test/registered/8-gpu-models/test_kimi_k25.py @@ -10,19 +10,13 @@ register_cuda_ci(est_time=3600, suite="nightly-8-gpu-common", nightly=True) KIMI_K25_MODEL_PATH = "moonshotai/Kimi-K2.5" -EAGLE3_DRAFT_MODEL_PATH = "AQ-MedAI/Kimi-K25-eagle3" class TestKimiK25(unittest.TestCase): """Unified test class for Kimi-K2.5 performance and accuracy. - Two variants: - - basic: TP=8 + tool/reasoning parsers - - eagle3: TP=8 + EAGLE3 speculative decoding with draft model - - Each variant runs BOTH: - - Performance test (using NightlyBenchmarkRunner) - - Accuracy test (using run_eval with gsm8k) + Runs TP=8 with tool/reasoning parsers. + Runs BOTH performance test and accuracy test (gsm8k). """ def test_kimi_k25(self): @@ -31,13 +25,6 @@ def test_kimi_k25(self): "--trust-remote-code", "--tool-call-parser=kimi_k2", "--reasoning-parser=kimi_k2", - ] - eagle3_args = [ - "--speculative-algorithm=EAGLE3", - f"--speculative-draft-model-path={EAGLE3_DRAFT_MODEL_PATH}", - "--speculative-num-steps=3", - "--speculative-eagle-topk=1", - "--speculative-num-draft-tokens=4", "--model-loader-extra-config", '{"enable_multithread_load": true, "num_threads": 64}', ] @@ -57,14 +44,8 @@ def test_kimi_k25(self): ModelLaunchSettings( KIMI_K25_MODEL_PATH, tp_size=8, - extra_args=base_args + eagle3_args, - variant="TP8+MTP", - ), - ModelLaunchSettings( - KIMI_K25_MODEL_PATH, - tp_size=8, - extra_args=base_args + dp_attn_args + eagle3_args, - variant="TP8+DP8+MTP", + extra_args=base_args + dp_attn_args, + variant="TP8+DP8", ), ] diff --git a/test/registered/8-gpu-models/test_mimo_models.py b/test/registered/8-gpu-models/test_mimo_models.py index f9dc8165d223..a32fd3ac5472 100644 --- a/test/registered/8-gpu-models/test_mimo_models.py +++ b/test/registered/8-gpu-models/test_mimo_models.py @@ -11,7 +11,7 @@ class TestMiMoV2Flash(GSM8KMixin, SpecDecodingMixin, DefaultServerBase): gsm8k_accuracy_thres = 0.75 gsm8k_num_questions = 1319 - gsm8k_parallel = 1319 + gsm8k_num_threads = 1319 model = "XiaomiMiMo/MiMo-V2-Flash" other_args = [ diff --git a/test/registered/8-gpu-models/test_ring_2_5_1t.py b/test/registered/8-gpu-models/test_ring_2_5_1t.py index aa211b790a26..c7a1395222d8 100644 --- a/test/registered/8-gpu-models/test_ring_2_5_1t.py +++ b/test/registered/8-gpu-models/test_ring_2_5_1t.py @@ -23,6 +23,8 @@ def test_ring_2_5_1t(self): "--trust-remote-code", "--model-loader-extra-config", '{"enable_multithread_load": true, "num_threads": 64}', + "--watchdog-timeout", + "1800", ] variants = [ diff --git a/test/registered/amd/accuracy/mi30x/test_deepseek_v32_mtp_eval_amd.py b/test/registered/amd/accuracy/mi30x/test_deepseek_v32_mtp_eval_amd.py index 6676ab612085..1ffa71a1f272 100644 --- a/test/registered/amd/accuracy/mi30x/test_deepseek_v32_mtp_eval_amd.py +++ b/test/registered/amd/accuracy/mi30x/test_deepseek_v32_mtp_eval_amd.py @@ -106,7 +106,7 @@ def test_a_gsm8k(self): metrics = run_eval_few_shot_gsm8k(args) print(f"{metrics=}") - server_info = requests.get(self.base_url + "/get_server_info") + server_info = requests.get(self.base_url + "/server_info") avg_spec_accept_length = server_info.json()["internal_states"][0][ "avg_spec_accept_length" ] diff --git a/test/registered/amd/accuracy/mi35x/test_deepseek_v32_mtp_eval_mi35x.py b/test/registered/amd/accuracy/mi35x/test_deepseek_v32_mtp_eval_mi35x.py index 09a012043416..dad040a302d7 100644 --- a/test/registered/amd/accuracy/mi35x/test_deepseek_v32_mtp_eval_mi35x.py +++ b/test/registered/amd/accuracy/mi35x/test_deepseek_v32_mtp_eval_mi35x.py @@ -108,7 +108,7 @@ def test_a_gsm8k(self): metrics = run_eval_few_shot_gsm8k(args) print(f"{metrics=}") - server_info = requests.get(self.base_url + "/get_server_info") + server_info = requests.get(self.base_url + "/server_info") avg_spec_accept_length = server_info.json()["internal_states"][0][ "avg_spec_accept_length" ] diff --git a/test/registered/amd/accuracy/mi35x/test_glm47_fp8_eval_mi35x.py b/test/registered/amd/accuracy/mi35x/test_glm47_fp8_eval_mi35x.py new file mode 100644 index 000000000000..8ce31b900b34 --- /dev/null +++ b/test/registered/amd/accuracy/mi35x/test_glm47_fp8_eval_mi35x.py @@ -0,0 +1,59 @@ +"""MI35x GLM-4.7-FP8 GSM8K Accuracy Evaluation Test (8-GPU) + +Tests GLM-4.7-FP8 accuracy using GSM8K benchmark on MI35x. + +Registry: nightly-amd-8-gpu-mi35x-glm47-fp8 suite +""" + +import os + +# Set HF cache for MI35x +os.environ.setdefault("HF_HOME", "/data2/models/huggingface") +os.environ.setdefault("HF_HUB_CACHE", "/data2/models/huggingface/hub") + +import unittest + +from sglang.test.accuracy_test_runner import AccuracyTestParams +from sglang.test.ci.ci_register import register_amd_ci +from sglang.test.run_combined_tests import run_combined_tests +from sglang.test.test_utils import ModelLaunchSettings + +# Register for AMD CI - MI35x GLM-4.7-FP8 accuracy test (~30 min) +register_amd_ci( + est_time=1800, + suite="nightly-amd-8-gpu-mi35x-glm47-fp8", + nightly=True, +) + +GLM_4_7_FP8_MODEL_PATH = "zai-org/GLM-4.7-FP8" + + +class TestGLM47FP8EvalMI35x(unittest.TestCase): + """GLM-4.7-FP8 GSM8K Accuracy Evaluation Test for MI35x.""" + + def test_glm_47_fp8(self): + """Run accuracy test for GLM-4.7-FP8.""" + base_args = [ + "--trust-remote-code", + "--tool-call-parser=glm47", + "--reasoning-parser=glm45", + ] + + variants = [ + ModelLaunchSettings( + GLM_4_7_FP8_MODEL_PATH, + tp_size=8, + extra_args=base_args, + variant="TP8", + ), + ] + + run_combined_tests( + models=variants, + test_name="GLM-4.7-FP8", + accuracy_params=AccuracyTestParams(dataset="gsm8k", baseline_accuracy=0.92), + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/registered/amd/perf/mi30x/test_minimax_m25_perf_amd.py b/test/registered/amd/perf/mi30x/test_minimax_m25_perf_amd.py new file mode 100644 index 000000000000..ace3c8cef649 --- /dev/null +++ b/test/registered/amd/perf/mi30x/test_minimax_m25_perf_amd.py @@ -0,0 +1,140 @@ +"""Nightly performance benchmark for MiniMax-M2.5 on MI325/MI300X (8-GPU). + +This test benchmarks MiniMax-M2.5 with TP=8 + EP=8 configuration. + +The model path can be configured via MINIMAX_M25_MODEL_PATH environment variable. + +Registry: nightly-perf-8-gpu-minimax-m25 suite + +Example usage: + python -m pytest test_minimax_m25_perf_amd.py -v +""" + +import os +import unittest +from typing import List + +from sglang.test.ci.ci_register import register_amd_ci +from sglang.test.nightly_bench_utils import BenchmarkResult +from sglang.test.nightly_utils import NightlyBenchmarkRunner +from sglang.test.test_utils import DEFAULT_URL_FOR_TEST, _parse_int_list_env + +register_amd_ci(est_time=5400, suite="nightly-perf-8-gpu-minimax-m25", nightly=True) + + +def generate_simple_markdown_report(results: List[BenchmarkResult]) -> str: + """Generate a simplified markdown report without traces and cost columns. + + Skips the first result if it's a warmup run (duplicate batch_size). + """ + model_header = results[0].model_path + if results[0].run_name and results[0].run_name != "default": + model_header += f" ({results[0].run_name})" + + gpu_config = os.getenv("GPU_CONFIG", "MI325") + if gpu_config: + model_header += f" [{gpu_config}]" + + summary = f"### {model_header}\n" + summary += "| batch size | input len | latency (s) | input throughput (tok/s) | output throughput (tok/s) | ITL (ms) |\n" + summary += "| ---------- | --------- | ----------- | ------------------------ | ------------------------- | -------- |\n" + + report_results = ( + results[1:] + if len(results) > 1 and results[0].batch_size == results[1].batch_size + else results + ) + + for result in report_results: + itl = 1 / (result.output_throughput / result.batch_size) * 1000 + summary += f"| {result.batch_size} | {result.input_len} | {result.latency:.2f} | {result.input_throughput:.2f} | {result.output_throughput:.2f} | {itl:.2f} |\n" + + return summary + + +MINIMAX_M25_MODEL_PATH = os.environ.get( + "MINIMAX_M25_MODEL_PATH", "MiniMaxAI/MiniMax-M2.5" +) +PROFILE_DIR = "performance_profiles_minimax_m25" + + +class TestNightlyMiniMaxM25Performance(unittest.TestCase): + """Nightly performance benchmark for MiniMax-M2.5 on MI325/MI300X. + + Tests MiniMax-M2.5 with TP=8 + EP=8 configuration. + """ + + @classmethod + def setUpClass(cls): + cls.base_url = DEFAULT_URL_FOR_TEST + cls.batch_sizes = [1, 8, 16, 64] + cls.input_lens = tuple(_parse_int_list_env("NIGHTLY_INPUT_LENS", "4096")) + cls.output_lens = tuple(_parse_int_list_env("NIGHTLY_OUTPUT_LENS", "512")) + + cls.model_config = { + "name": "minimax-m25-tp8-ep8", + "model_path": MINIMAX_M25_MODEL_PATH, + "other_args": [ + "--trust-remote-code", + "--tp", + "8", + "--ep-size", + "8", + "--attention-backend", + "aiter", + "--mem-fraction-static", + "0.85", + "--model-loader-extra-config", + '{"enable_multithread_load": true}', + "--watchdog-timeout", + "1200", + ], + "env_vars": { + "SGLANG_USE_AITER": "1", + }, + } + + cls.runner = NightlyBenchmarkRunner(PROFILE_DIR, cls.__name__, cls.base_url) + cls.runner.setup_profile_directory() + cls.runner.full_report = f"## {cls.__name__}\n" + + def test_bench_minimax_m25(self): + """Run benchmark for MiniMax-M2.5.""" + old_env = {} + for key, value in self.model_config.get("env_vars", {}).items(): + old_env[key] = os.environ.get(key) + os.environ[key] = value + print(f"Setting env: {key}={value}") + + try: + result_tuple = self.runner.run_benchmark_for_model( + model_path=self.model_config["model_path"], + batch_sizes=self.batch_sizes, + input_lens=self.input_lens, + output_lens=self.output_lens, + other_args=self.model_config["other_args"], + variant=self.model_config["name"], + extra_bench_args=["--trust-remote-code"], + enable_profile=False, + timeout=5400, + ) + results = result_tuple[0] + success = result_tuple[1] + + if results: + self.runner.full_report += ( + generate_simple_markdown_report(results) + "\n" + ) + + self.assertTrue(success, "Benchmark failed for MiniMax-M2.5") + finally: + for key, value in old_env.items(): + if value is None: + os.environ.pop(key, None) + else: + os.environ[key] = value + self.runner.write_final_report() + + +if __name__ == "__main__": + unittest.main() diff --git a/test/registered/amd/perf/mi35x/test_minimax_m25_perf_mi35x.py b/test/registered/amd/perf/mi35x/test_minimax_m25_perf_mi35x.py new file mode 100644 index 000000000000..963a7d956e40 --- /dev/null +++ b/test/registered/amd/perf/mi35x/test_minimax_m25_perf_mi35x.py @@ -0,0 +1,146 @@ +"""MI35x Nightly performance benchmark for MiniMax-M2.5 (8-GPU). + +This test benchmarks MiniMax-M2.5 with TP=8 + EP=8 configuration on MI35x. + +The model path can be configured via MINIMAX_M25_MODEL_PATH environment variable. + +Registry: nightly-perf-8-gpu-mi35x-minimax-m25 suite + +Example usage: + python -m pytest test_minimax_m25_perf_mi35x.py -v +""" + +import os + +os.environ.setdefault("HF_HOME", "/data2/models/huggingface") +os.environ.setdefault("HF_HUB_CACHE", "/data2/models/huggingface/hub") + +import unittest +from typing import List + +from sglang.test.ci.ci_register import register_amd_ci +from sglang.test.nightly_bench_utils import BenchmarkResult +from sglang.test.nightly_utils import NightlyBenchmarkRunner +from sglang.test.test_utils import DEFAULT_URL_FOR_TEST, _parse_int_list_env + +register_amd_ci( + est_time=5400, suite="nightly-perf-8-gpu-mi35x-minimax-m25", nightly=True +) + + +def generate_simple_markdown_report(results: List[BenchmarkResult]) -> str: + """Generate a simplified markdown report without traces and cost columns. + + Skips the first result if it's a warmup run (duplicate batch_size). + """ + model_header = results[0].model_path + if results[0].run_name and results[0].run_name != "default": + model_header += f" ({results[0].run_name})" + + gpu_config = os.getenv("GPU_CONFIG", "MI35x") + if gpu_config: + model_header += f" [{gpu_config}]" + + summary = f"### {model_header}\n" + summary += "| batch size | input len | latency (s) | input throughput (tok/s) | output throughput (tok/s) | ITL (ms) |\n" + summary += "| ---------- | --------- | ----------- | ------------------------ | ------------------------- | -------- |\n" + + report_results = ( + results[1:] + if len(results) > 1 and results[0].batch_size == results[1].batch_size + else results + ) + + for result in report_results: + itl = 1 / (result.output_throughput / result.batch_size) * 1000 + summary += f"| {result.batch_size} | {result.input_len} | {result.latency:.2f} | {result.input_throughput:.2f} | {result.output_throughput:.2f} | {itl:.2f} |\n" + + return summary + + +MINIMAX_M25_MODEL_PATH = os.environ.get( + "MINIMAX_M25_MODEL_PATH", "MiniMaxAI/MiniMax-M2.5" +) +PROFILE_DIR = "performance_profiles_minimax_m25_mi35x" + + +class TestNightlyMiniMaxM25PerformanceMI35x(unittest.TestCase): + """MI35x Nightly performance benchmark for MiniMax-M2.5. + + Tests MiniMax-M2.5 with TP=8 + EP=8 configuration. + """ + + @classmethod + def setUpClass(cls): + cls.base_url = DEFAULT_URL_FOR_TEST + cls.batch_sizes = [1, 8, 16, 64] + cls.input_lens = tuple(_parse_int_list_env("NIGHTLY_INPUT_LENS", "4096")) + cls.output_lens = tuple(_parse_int_list_env("NIGHTLY_OUTPUT_LENS", "512")) + + cls.model_config = { + "name": "minimax-m25-tp8-ep8", + "model_path": MINIMAX_M25_MODEL_PATH, + "other_args": [ + "--trust-remote-code", + "--tp", + "8", + "--ep-size", + "8", + "--attention-backend", + "aiter", + "--mem-fraction-static", + "0.85", + "--model-loader-extra-config", + '{"enable_multithread_load": true}', + "--watchdog-timeout", + "1200", + ], + "env_vars": { + "SGLANG_USE_AITER": "1", + }, + } + + cls.runner = NightlyBenchmarkRunner(PROFILE_DIR, cls.__name__, cls.base_url) + cls.runner.setup_profile_directory() + cls.runner.full_report = f"## {cls.__name__}\n" + + def test_bench_minimax_m25(self): + """Run benchmark for MiniMax-M2.5.""" + old_env = {} + for key, value in self.model_config.get("env_vars", {}).items(): + old_env[key] = os.environ.get(key) + os.environ[key] = value + print(f"Setting env: {key}={value}") + + try: + result_tuple = self.runner.run_benchmark_for_model( + model_path=self.model_config["model_path"], + batch_sizes=self.batch_sizes, + input_lens=self.input_lens, + output_lens=self.output_lens, + other_args=self.model_config["other_args"], + variant=self.model_config["name"], + extra_bench_args=["--trust-remote-code"], + enable_profile=False, + timeout=5400, + ) + results = result_tuple[0] + success = result_tuple[1] + + if results: + self.runner.full_report += ( + generate_simple_markdown_report(results) + "\n" + ) + + self.assertTrue(success, "Benchmark failed for MiniMax-M2.5 on MI35x") + finally: + for key, value in old_env.items(): + if value is None: + os.environ.pop(key, None) + else: + os.environ[key] = value + self.runner.write_final_report() + + +if __name__ == "__main__": + unittest.main() diff --git a/test/registered/amd/test_deepseek_r1_mxfp4_8gpu.py b/test/registered/amd/test_deepseek_r1_mxfp4_8gpu.py index 28249e706b75..04d4f6efb7a7 100644 --- a/test/registered/amd/test_deepseek_r1_mxfp4_8gpu.py +++ b/test/registered/amd/test_deepseek_r1_mxfp4_8gpu.py @@ -135,7 +135,7 @@ def test_a_gsm8k( metrics = run_eval_few_shot_gsm8k(args) print(f"{metrics=}") - server_info = requests.get(self.base_url + "/get_server_info") + server_info = requests.get(self.base_url + "/server_info") avg_spec_accept_length = server_info.json()["internal_states"][0][ "avg_spec_accept_length" ] diff --git a/test/registered/amd/test_deepseek_v32_mtp.py b/test/registered/amd/test_deepseek_v32_mtp.py index 87e4e6923b38..69587bdf6e05 100644 --- a/test/registered/amd/test_deepseek_v32_mtp.py +++ b/test/registered/amd/test_deepseek_v32_mtp.py @@ -87,7 +87,7 @@ def test_a_gsm8k( metrics = run_eval_few_shot_gsm8k(args) print(f"{metrics=}") - server_info = requests.get(self.base_url + "/get_server_info") + server_info = requests.get(self.base_url + "/server_info") avg_spec_accept_length = server_info.json()["internal_states"][0][ "avg_spec_accept_length" ] @@ -179,7 +179,7 @@ def test_a_gsm8k( metrics = run_eval_few_shot_gsm8k(args) print(f"{metrics=}") - server_info = requests.get(self.base_url + "/get_server_info") + server_info = requests.get(self.base_url + "/server_info") avg_spec_accept_length = server_info.json()["internal_states"][0][ "avg_spec_accept_length" ] diff --git a/test/registered/amd/test_deepseek_v3_mtp.py b/test/registered/amd/test_deepseek_v3_mtp.py index 29190947414b..0a9f94090e93 100644 --- a/test/registered/amd/test_deepseek_v3_mtp.py +++ b/test/registered/amd/test_deepseek_v3_mtp.py @@ -72,7 +72,7 @@ def test_a_gsm8k( metrics = run_eval_few_shot_gsm8k(args) print(f"{metrics=}") - server_info = requests.get(self.base_url + "/get_server_info") + server_info = requests.get(self.base_url + "/server_info") avg_spec_accept_length = server_info.json()["internal_states"][0][ "avg_spec_accept_length" ] diff --git a/test/registered/amd/test_deepseek_v3_mtp_kv_fp8.py b/test/registered/amd/test_deepseek_v3_mtp_kv_fp8.py index a62eadf7a587..949b743485e6 100644 --- a/test/registered/amd/test_deepseek_v3_mtp_kv_fp8.py +++ b/test/registered/amd/test_deepseek_v3_mtp_kv_fp8.py @@ -76,7 +76,7 @@ def test_a_gsm8k( metrics = run_eval_few_shot_gsm8k(args) print(f"{metrics=}") - server_info = requests.get(self.base_url + "/get_server_info") + server_info = requests.get(self.base_url + "/server_info") avg_spec_accept_length = server_info.json()["internal_states"][0][ "avg_spec_accept_length" ] diff --git a/test/registered/amd/test_moriep_small.py b/test/registered/amd/test_moriep_small.py index 3eca8ce279e6..76ccb42d63f6 100644 --- a/test/registered/amd/test_moriep_small.py +++ b/test/registered/amd/test_moriep_small.py @@ -145,7 +145,7 @@ def test_gsm8k( print(f"{metrics=}") self.assertGreaterEqual(metrics["accuracy"], 0.92) - server_info = requests.get(self.base_url + "/get_server_info") + server_info = requests.get(self.base_url + "/server_info") avg_spec_accept_length = server_info.json()["internal_states"][0][ "avg_spec_accept_length" ] @@ -397,7 +397,7 @@ def test_gsm8k( print(f"{metrics=}") self.assertGreaterEqual(metrics["accuracy"], 0.92) - server_info = requests.get(self.base_url + "/get_server_info") + server_info = requests.get(self.base_url + "/server_info") avg_spec_accept_length = server_info.json()["internal_states"][0][ "avg_spec_accept_length" ] @@ -457,7 +457,7 @@ def test_gsm8k( print(f"{metrics=}") self.assertGreaterEqual(metrics["accuracy"], 0.92) - server_info = requests.get(self.base_url + "/get_server_info") + server_info = requests.get(self.base_url + "/server_info") avg_spec_accept_length = server_info.json()["internal_states"][0][ "avg_spec_accept_length" ] diff --git a/test/registered/amd/test_qwen3_coder_next_8gpu.py b/test/registered/amd/test_qwen3_coder_next_8gpu.py index a8631af6e8c2..b4181273d0d8 100644 --- a/test/registered/amd/test_qwen3_coder_next_8gpu.py +++ b/test/registered/amd/test_qwen3_coder_next_8gpu.py @@ -146,7 +146,7 @@ def test_a_gsm8k(self): metrics = run_eval_few_shot_gsm8k(args) print(f"{metrics=}") - server_info = requests.get(self.base_url + "/get_server_info") + server_info = requests.get(self.base_url + "/server_info") avg_spec_accept_length = server_info.json()["internal_states"][0][ "avg_spec_accept_length" ] diff --git a/test/srt/ascend/test_ascend_hicache_mha.py b/test/registered/ascend/basic_function/HiCache/test_npu_hicache_mha.py similarity index 91% rename from test/srt/ascend/test_ascend_hicache_mha.py rename to test/registered/ascend/basic_function/HiCache/test_npu_hicache_mha.py index 521537e05af7..829a5fee4fbe 100644 --- a/test/srt/ascend/test_ascend_hicache_mha.py +++ b/test/registered/ascend/basic_function/HiCache/test_npu_hicache_mha.py @@ -3,6 +3,7 @@ from urllib.parse import urlparse from sglang.srt.utils import kill_process_tree +from sglang.test.ci.ci_register import register_npu_ci from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k from sglang.test.test_utils import ( DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, @@ -11,6 +12,9 @@ popen_launch_server, ) +register_npu_ci(est_time=400, suite="stage-b-test-1-npu-a2", nightly=False) +register_npu_ci(est_time=400, suite="nightly-1-npu-a3", nightly=True) + TEST_MODEL_MATRIX = { "/root/.cache/modelscope/hub/models/Qwen/Qwen2.5-7B-Instruct": { "accuracy": 0.85, diff --git a/test/srt/ascend/test_ascend_hicache_mla.py b/test/registered/ascend/basic_function/HiCache/test_npu_hicache_mla.py similarity index 91% rename from test/srt/ascend/test_ascend_hicache_mla.py rename to test/registered/ascend/basic_function/HiCache/test_npu_hicache_mla.py index 4bb355d3746b..140d590ddaaa 100644 --- a/test/srt/ascend/test_ascend_hicache_mla.py +++ b/test/registered/ascend/basic_function/HiCache/test_npu_hicache_mla.py @@ -3,6 +3,7 @@ from urllib.parse import urlparse from sglang.srt.utils import kill_process_tree +from sglang.test.ci.ci_register import register_npu_ci from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k from sglang.test.test_utils import ( DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, @@ -11,6 +12,9 @@ popen_launch_server, ) +register_npu_ci(est_time=400, suite="stage-b-test-4-npu-a3", nightly=False) +register_npu_ci(est_time=400, suite="nightly-4-npu-a3", nightly=True) + TEST_MODEL_MATRIX = { "/root/.cache/modelscope/hub/models/vllm-ascend/DeepSeek-V2-Lite-W8A8": { "accuracy": 0.34, diff --git a/test/srt/ascend/test_ascend_sampling_backend.py b/test/registered/ascend/basic_function/backends/test_npu_sampling_backend.py similarity index 92% rename from test/srt/ascend/test_ascend_sampling_backend.py rename to test/registered/ascend/basic_function/backends/test_npu_sampling_backend.py index f0eee21a360a..7da4595f2c2c 100644 --- a/test/srt/ascend/test_ascend_sampling_backend.py +++ b/test/registered/ascend/basic_function/backends/test_npu_sampling_backend.py @@ -4,6 +4,7 @@ import requests from sglang.srt.utils import kill_process_tree +from sglang.test.ci.ci_register import register_npu_ci from sglang.test.run_eval import run_eval from sglang.test.test_utils import ( DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, @@ -12,6 +13,9 @@ popen_launch_server, ) +register_npu_ci(est_time=400, suite="stage-b-test-1-npu-a2", nightly=False) +register_npu_ci(est_time=400, suite="nightly-1-npu-a3", nightly=True) + class TestAscendSamplingBackend(CustomTestCase): @classmethod diff --git a/test/srt/ascend/test_llada2_mini_ascend.py b/test/registered/ascend/basic_function/dllm/test_npu_llada2_mini.py similarity index 92% rename from test/srt/ascend/test_llada2_mini_ascend.py rename to test/registered/ascend/basic_function/dllm/test_npu_llada2_mini.py index a1fdafcc54ea..a3c3d137c151 100644 --- a/test/srt/ascend/test_llada2_mini_ascend.py +++ b/test/registered/ascend/basic_function/dllm/test_npu_llada2_mini.py @@ -3,6 +3,7 @@ from types import SimpleNamespace from sglang.srt.utils import kill_process_tree +from sglang.test.ci.ci_register import register_npu_ci from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k from sglang.test.send_one import BenchArgs, send_one_prompt from sglang.test.test_utils import ( @@ -14,6 +15,9 @@ write_github_step_summary, ) +register_npu_ci(est_time=400, suite="stage-b-test-4-npu-a3", nightly=False) +register_npu_ci(est_time=400, suite="nightly-1-npu-a3", nightly=True) + class TestLLaDA2Mini(CustomTestCase): @classmethod diff --git a/test/srt/ascend/test_ascend_compile_graph_tp1_bf16.py b/test/registered/ascend/basic_function/optimization_debug/test_npu_compile_graph_tp1_bf16.py similarity index 92% rename from test/srt/ascend/test_ascend_compile_graph_tp1_bf16.py rename to test/registered/ascend/basic_function/optimization_debug/test_npu_compile_graph_tp1_bf16.py index e3a51499b7c4..2a94826d6715 100644 --- a/test/srt/ascend/test_ascend_compile_graph_tp1_bf16.py +++ b/test/registered/ascend/basic_function/optimization_debug/test_npu_compile_graph_tp1_bf16.py @@ -4,6 +4,7 @@ from urllib.parse import urlparse from sglang.srt.utils import kill_process_tree +from sglang.test.ci.ci_register import register_npu_ci from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k from sglang.test.test_utils import ( DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, @@ -12,6 +13,9 @@ popen_launch_server, ) +register_npu_ci(est_time=400, suite="stage-b-test-1-npu-a2", nightly=False) +register_npu_ci(est_time=400, suite="nightly-1-npu-a3", nightly=True) + TEST_MODEL_MATRIX = { "/root/.cache/modelscope/hub/models/Qwen/Qwen2.5-7B-Instruct": { "accuracy": 0.84, diff --git a/test/srt/ascend/test_ascend_graph_tp1_bf16.py b/test/registered/ascend/basic_function/optimization_debug/test_npu_graph_tp1_bf16.py similarity index 91% rename from test/srt/ascend/test_ascend_graph_tp1_bf16.py rename to test/registered/ascend/basic_function/optimization_debug/test_npu_graph_tp1_bf16.py index 4f8d4b4aa87a..916e3965d987 100644 --- a/test/srt/ascend/test_ascend_graph_tp1_bf16.py +++ b/test/registered/ascend/basic_function/optimization_debug/test_npu_graph_tp1_bf16.py @@ -3,6 +3,7 @@ from urllib.parse import urlparse from sglang.srt.utils import kill_process_tree +from sglang.test.ci.ci_register import register_npu_ci from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k from sglang.test.test_utils import ( DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, @@ -11,6 +12,9 @@ popen_launch_server, ) +register_npu_ci(est_time=400, suite="stage-b-test-1-npu-a2", nightly=False) +register_npu_ci(est_time=400, suite="nightly-1-npu-a3", nightly=True) + TEST_MODEL_MATRIX = { "/root/.cache/modelscope/hub/models/Qwen/Qwen2.5-7B-Instruct": { "accuracy": 0.85, diff --git a/test/srt/ascend/test_ascend_graph_tp2_bf16.py b/test/registered/ascend/basic_function/optimization_debug/test_npu_graph_tp2_bf16.py similarity index 91% rename from test/srt/ascend/test_ascend_graph_tp2_bf16.py rename to test/registered/ascend/basic_function/optimization_debug/test_npu_graph_tp2_bf16.py index d4bf902a20a5..37bb7fd22a71 100644 --- a/test/srt/ascend/test_ascend_graph_tp2_bf16.py +++ b/test/registered/ascend/basic_function/optimization_debug/test_npu_graph_tp2_bf16.py @@ -3,6 +3,7 @@ from urllib.parse import urlparse from sglang.srt.utils import kill_process_tree +from sglang.test.ci.ci_register import register_npu_ci from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k from sglang.test.test_utils import ( DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, @@ -11,6 +12,9 @@ popen_launch_server, ) +register_npu_ci(est_time=400, suite="stage-b-test-2-npu-a2", nightly=False) +register_npu_ci(est_time=400, suite="nightly-2-npu-a3", nightly=True) + TEST_MODEL_MATRIX = { "/root/.cache/modelscope/hub/models/Qwen/Qwen2.5-7B-Instruct": { "accuracy": 0.85, diff --git a/test/srt/ascend/test_ascend_piecewise_graph_prefill.py b/test/registered/ascend/basic_function/optimization_debug/test_npu_piecewise_graph_prefill.py similarity index 92% rename from test/srt/ascend/test_ascend_piecewise_graph_prefill.py rename to test/registered/ascend/basic_function/optimization_debug/test_npu_piecewise_graph_prefill.py index 939486582891..6db3aeb87a44 100644 --- a/test/srt/ascend/test_ascend_piecewise_graph_prefill.py +++ b/test/registered/ascend/basic_function/optimization_debug/test_npu_piecewise_graph_prefill.py @@ -2,6 +2,7 @@ from urllib.parse import urlparse from sglang.srt.utils import kill_process_tree +from sglang.test.ci.ci_register import register_npu_ci from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k from sglang.test.test_utils import ( DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, @@ -12,6 +13,9 @@ run_bench_one_batch, ) +register_npu_ci(est_time=400, suite="stage-b-test-1-npu-a2", nightly=False) +register_npu_ci(est_time=400, suite="nightly-1-npu-a3", nightly=True) + MODEL = "/root/.cache/modelscope/hub/models/Qwen/Qwen2.5-7B-Instruct" GSM8K_EXP_ACCURACY = 0.84 EXP_PREFILL_LATENCY = 0.045 diff --git a/test/srt/ascend/test_ascend_deepep.py b/test/registered/ascend/basic_function/parallel_strategy/expert_parallelism/test_npu_deepep.py similarity index 93% rename from test/srt/ascend/test_ascend_deepep.py rename to test/registered/ascend/basic_function/parallel_strategy/expert_parallelism/test_npu_deepep.py index 19330b862751..209cdcf956b9 100644 --- a/test/srt/ascend/test_ascend_deepep.py +++ b/test/registered/ascend/basic_function/parallel_strategy/expert_parallelism/test_npu_deepep.py @@ -4,6 +4,7 @@ from urllib.parse import urlparse from sglang.srt.utils import kill_process_tree +from sglang.test.ci.ci_register import register_npu_ci from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k from sglang.test.test_utils import ( DEFAULT_URL_FOR_TEST, @@ -11,6 +12,9 @@ popen_launch_server, ) +register_npu_ci(est_time=400, suite="stage-b-test-16-npu-a3", nightly=False) +register_npu_ci(est_time=400, suite="nightly-16-npu-a3", nightly=True) + TEST_MODEL_MATRIX = { "/root/.cache/modelscope/hub/models/vllm-ascend/DeepSeek-R1-0528-W8A8": { "accuracy": 0.95, diff --git a/test/registered/ascend/basic_function/parameter/test_npu_warmups.py b/test/registered/ascend/basic_function/parameter/test_npu_warmups.py index 7b1df16af9db..da678037402b 100644 --- a/test/registered/ascend/basic_function/parameter/test_npu_warmups.py +++ b/test/registered/ascend/basic_function/parameter/test_npu_warmups.py @@ -64,7 +64,7 @@ def tearDownClass(cls): def test_warmups_with_voice_chat(self): # Call the get_server_info API to verify that the warmups parameter configuration takes effect. - response = requests.get(f"{DEFAULT_URL_FOR_TEST}/get_server_info") + response = requests.get(f"{DEFAULT_URL_FOR_TEST}/server_info") self.assertEqual(response.status_code, 200) self.assertEqual("voice_chat", response.json().get("warmups")) diff --git a/test/srt/ascend/test_ascend_autoround_dense.py b/test/registered/ascend/basic_function/quant/test_npu_autoround_dense.py similarity index 86% rename from test/srt/ascend/test_ascend_autoround_dense.py rename to test/registered/ascend/basic_function/quant/test_npu_autoround_dense.py index 475311b6cadb..87575681b85a 100644 --- a/test/srt/ascend/test_ascend_autoround_dense.py +++ b/test/registered/ascend/basic_function/quant/test_npu_autoround_dense.py @@ -4,6 +4,8 @@ from urllib.parse import urlparse from sglang.srt.utils import kill_process_tree +from sglang.test.ascend.test_ascend_utils import QWEN3_8B_INT4_AUTOROUND_WEIGHTS_PATH +from sglang.test.ci.ci_register import register_npu_ci from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k from sglang.test.test_utils import ( DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, @@ -12,10 +14,13 @@ popen_launch_server, ) +register_npu_ci(est_time=400, suite="stage-b-test-1-npu-a2", nightly=False) +register_npu_ci(est_time=400, suite="nightly-1-npu-a3", nightly=True) + logger = logging.getLogger(__name__) TEST_MODEL_MATRIX = { - "/root/.cache/modelscope/hub/models/Intel/Qwen3-8B-int4-AutoRound": { + QWEN3_8B_INT4_AUTOROUND_WEIGHTS_PATH: { "accuracy": 0.85, }, } diff --git a/test/srt/ascend/test_ascend_autoround_moe.py b/test/registered/ascend/basic_function/quant/test_npu_autoround_moe.py similarity index 85% rename from test/srt/ascend/test_ascend_autoround_moe.py rename to test/registered/ascend/basic_function/quant/test_npu_autoround_moe.py index b0b8f6960a07..1864ec6ee646 100644 --- a/test/srt/ascend/test_ascend_autoround_moe.py +++ b/test/registered/ascend/basic_function/quant/test_npu_autoround_moe.py @@ -4,6 +4,10 @@ from urllib.parse import urlparse from sglang.srt.utils import kill_process_tree +from sglang.test.ascend.test_ascend_utils import ( + QWEN3_30B_A3B_INSTRUCT_2507_INT4_AUTOROUND_WEIGHTS_PATH, +) +from sglang.test.ci.ci_register import register_npu_ci from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k from sglang.test.test_utils import ( DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, @@ -12,10 +16,13 @@ popen_launch_server, ) +register_npu_ci(est_time=400, suite="stage-b-test-1-npu-a2", nightly=False) +register_npu_ci(est_time=400, suite="nightly-1-npu-a3", nightly=True) + logger = logging.getLogger(__name__) TEST_MODEL_MATRIX = { - "/root/.cache/modelscope/hub/models/Intel/Qwen3-30B-A3B-Instruct-2507-int4-AutoRound": { + QWEN3_30B_A3B_INSTRUCT_2507_INT4_AUTOROUND_WEIGHTS_PATH: { "accuracy": 0.85, }, } diff --git a/test/srt/ascend/test_ascend_gptq_moe.py b/test/registered/ascend/basic_function/quant/test_npu_gptq_moe.py similarity index 86% rename from test/srt/ascend/test_ascend_gptq_moe.py rename to test/registered/ascend/basic_function/quant/test_npu_gptq_moe.py index 22b9543795ff..686f5daa1690 100644 --- a/test/srt/ascend/test_ascend_gptq_moe.py +++ b/test/registered/ascend/basic_function/quant/test_npu_gptq_moe.py @@ -4,6 +4,10 @@ from urllib.parse import urlparse from sglang.srt.utils import kill_process_tree +from sglang.test.ascend.test_ascend_utils import ( + QWEN3_30B_A3B_GPTQ_2507_INT4_WEIGHTS_PATH, +) +from sglang.test.ci.ci_register import register_npu_ci from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k from sglang.test.test_utils import ( DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, @@ -12,10 +16,13 @@ popen_launch_server, ) +register_npu_ci(est_time=400, suite="stage-b-test-1-npu-a2", nightly=False) +register_npu_ci(est_time=400, suite="nightly-1-npu-a3", nightly=True) + logger = logging.getLogger(__name__) TEST_MODEL_MATRIX = { - "/root/.cache/modelscope/hub/models/Qwen/Qwen3-30B-A3B-GPTQ-Int4": { + QWEN3_30B_A3B_GPTQ_2507_INT4_WEIGHTS_PATH: { "accuracy": 0.85, }, } diff --git a/test/srt/ascend/test_ascend_w4a4_quantization.py b/test/registered/ascend/basic_function/quant/test_npu_w4a4_quantization.py similarity index 93% rename from test/srt/ascend/test_ascend_w4a4_quantization.py rename to test/registered/ascend/basic_function/quant/test_npu_w4a4_quantization.py index 22d3f0615181..e395ec4c8b8e 100644 --- a/test/srt/ascend/test_ascend_w4a4_quantization.py +++ b/test/registered/ascend/basic_function/quant/test_npu_w4a4_quantization.py @@ -12,6 +12,7 @@ import requests from sglang.srt.utils import kill_process_tree +from sglang.test.ci.ci_register import register_npu_ci from sglang.test.few_shot_gsm8k import run_eval from sglang.test.test_utils import ( DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, @@ -21,6 +22,9 @@ popen_launch_server, ) +register_npu_ci(est_time=400, suite="stage-b-test-4-npu-a3", nightly=False) +register_npu_ci(est_time=400, suite="nightly-4-npu-a3", nightly=True) + if "ASCEND_RT_VISIBLE_DEVICES" not in os.environ: os.environ["ASCEND_RT_VISIBLE_DEVICES"] = "0,1,2,3" DEFAULT_PORT_FOR_SRT_TEST_RUNNER = ( diff --git a/test/srt/ascend/test_ascend_w8a8_quantization.py b/test/registered/ascend/basic_function/quant/test_npu_w8a8_quantization.py similarity index 93% rename from test/srt/ascend/test_ascend_w8a8_quantization.py rename to test/registered/ascend/basic_function/quant/test_npu_w8a8_quantization.py index e0b3545701c6..96bea7efb1af 100644 --- a/test/srt/ascend/test_ascend_w8a8_quantization.py +++ b/test/registered/ascend/basic_function/quant/test_npu_w8a8_quantization.py @@ -12,6 +12,7 @@ import requests from sglang.srt.utils import kill_process_tree +from sglang.test.ci.ci_register import register_npu_ci from sglang.test.few_shot_gsm8k import run_eval from sglang.test.test_utils import ( DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, @@ -21,6 +22,9 @@ popen_launch_server, ) +register_npu_ci(est_time=400, suite="stage-b-test-1-npu-a2", nightly=False) +register_npu_ci(est_time=400, suite="nightly-1-npu-a3", nightly=True) + if "ASCEND_RT_VISIBLE_DEVICES" not in os.environ: os.environ["ASCEND_RT_VISIBLE_DEVICES"] = "0,1" DEFAULT_PORT_FOR_SRT_TEST_RUNNER = ( diff --git a/test/srt/ascend/test_ascend_mla_fia_w8a8int8.py b/test/registered/ascend/basic_function/runtime_opts/test_npu_mla_fia_w8a8int8.py similarity index 92% rename from test/srt/ascend/test_ascend_mla_fia_w8a8int8.py rename to test/registered/ascend/basic_function/runtime_opts/test_npu_mla_fia_w8a8int8.py index 4001df6f69b5..0b49028379dc 100644 --- a/test/srt/ascend/test_ascend_mla_fia_w8a8int8.py +++ b/test/registered/ascend/basic_function/runtime_opts/test_npu_mla_fia_w8a8int8.py @@ -4,6 +4,7 @@ from urllib.parse import urlparse from sglang.srt.utils import kill_process_tree +from sglang.test.ci.ci_register import register_npu_ci from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k from sglang.test.test_utils import ( DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, @@ -12,6 +13,9 @@ popen_launch_server, ) +register_npu_ci(est_time=400, suite="stage-b-test-2-npu-a2", nightly=False) +register_npu_ci(est_time=400, suite="nightly-2-npu-a3", nightly=True) + TEST_MODEL_MATRIX = { "/root/.cache/modelscope/hub/models/vllm-ascend/DeepSeek-V2-Lite-W8A8": { "accuracy": 0.34, diff --git a/test/srt/ascend/test_ascend_mla_w8a8int8.py b/test/registered/ascend/basic_function/runtime_opts/test_npu_mla_w8a8int8.py similarity index 91% rename from test/srt/ascend/test_ascend_mla_w8a8int8.py rename to test/registered/ascend/basic_function/runtime_opts/test_npu_mla_w8a8int8.py index 177af099a77a..c50bee071d38 100644 --- a/test/srt/ascend/test_ascend_mla_w8a8int8.py +++ b/test/registered/ascend/basic_function/runtime_opts/test_npu_mla_w8a8int8.py @@ -3,6 +3,7 @@ from urllib.parse import urlparse from sglang.srt.utils import kill_process_tree +from sglang.test.ci.ci_register import register_npu_ci from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k from sglang.test.test_utils import ( DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, @@ -11,6 +12,9 @@ popen_launch_server, ) +register_npu_ci(est_time=400, suite="stage-b-test-4-npu-a3", nightly=False) +register_npu_ci(est_time=400, suite="nightly-4-npu-a3", nightly=True) + TEST_MODEL_MATRIX = { "/root/.cache/modelscope/hub/models/vllm-ascend/DeepSeek-V2-Lite-W8A8": { "accuracy": 0.34, diff --git a/test/srt/ascend/test_ascend_tp1_bf16.py b/test/registered/ascend/basic_function/runtime_opts/test_npu_tp1_bf16.py similarity index 91% rename from test/srt/ascend/test_ascend_tp1_bf16.py rename to test/registered/ascend/basic_function/runtime_opts/test_npu_tp1_bf16.py index abc9609530d3..b01510dc7857 100644 --- a/test/srt/ascend/test_ascend_tp1_bf16.py +++ b/test/registered/ascend/basic_function/runtime_opts/test_npu_tp1_bf16.py @@ -3,6 +3,7 @@ from urllib.parse import urlparse from sglang.srt.utils import kill_process_tree +from sglang.test.ci.ci_register import register_npu_ci from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k from sglang.test.test_utils import ( DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, @@ -11,6 +12,9 @@ popen_launch_server, ) +register_npu_ci(est_time=400, suite="stage-b-test-1-npu-a2", nightly=False) +register_npu_ci(est_time=400, suite="nightly-1-npu-a3", nightly=True) + TEST_MODEL_MATRIX = { "/root/.cache/modelscope/hub/models/Qwen/Qwen2.5-7B-Instruct": { "accuracy": 0.84, diff --git a/test/srt/ascend/test_ascend_tp2_bf16.py b/test/registered/ascend/basic_function/runtime_opts/test_npu_tp2_bf16.py similarity index 91% rename from test/srt/ascend/test_ascend_tp2_bf16.py rename to test/registered/ascend/basic_function/runtime_opts/test_npu_tp2_bf16.py index e1f736e9bee2..8f85a16c019f 100644 --- a/test/srt/ascend/test_ascend_tp2_bf16.py +++ b/test/registered/ascend/basic_function/runtime_opts/test_npu_tp2_bf16.py @@ -3,6 +3,7 @@ from urllib.parse import urlparse from sglang.srt.utils import kill_process_tree +from sglang.test.ci.ci_register import register_npu_ci from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k from sglang.test.test_utils import ( DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, @@ -11,6 +12,9 @@ popen_launch_server, ) +register_npu_ci(est_time=400, suite="stage-b-test-2-npu-a2", nightly=False) +register_npu_ci(est_time=400, suite="nightly-2-npu-a3", nightly=True) + TEST_MODEL_MATRIX = { "/root/.cache/modelscope/hub/models/Qwen/Qwen2.5-7B-Instruct": { "accuracy": 0.85, diff --git a/test/srt/ascend/test_ascend_tp2_fia_bf16.py b/test/registered/ascend/basic_function/runtime_opts/test_npu_tp2_fia_bf16.py similarity index 92% rename from test/srt/ascend/test_ascend_tp2_fia_bf16.py rename to test/registered/ascend/basic_function/runtime_opts/test_npu_tp2_fia_bf16.py index 5f82bb47f4db..54f3db7d7e2a 100644 --- a/test/srt/ascend/test_ascend_tp2_fia_bf16.py +++ b/test/registered/ascend/basic_function/runtime_opts/test_npu_tp2_fia_bf16.py @@ -4,6 +4,7 @@ from urllib.parse import urlparse from sglang.srt.utils import kill_process_tree +from sglang.test.ci.ci_register import register_npu_ci from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k from sglang.test.test_utils import ( DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, @@ -12,6 +13,9 @@ popen_launch_server, ) +register_npu_ci(est_time=400, suite="stage-b-test-2-npu-a2", nightly=False) +register_npu_ci(est_time=400, suite="nightly-2-npu-a3", nightly=True) + TEST_MODEL_MATRIX = { "/root/.cache/modelscope/hub/models/Qwen/Qwen2.5-7B-Instruct": { "accuracy": 0.85, diff --git a/test/srt/ascend/test_ascend_tp4_bf16.py b/test/registered/ascend/basic_function/runtime_opts/test_npu_tp4_bf16.py similarity index 91% rename from test/srt/ascend/test_ascend_tp4_bf16.py rename to test/registered/ascend/basic_function/runtime_opts/test_npu_tp4_bf16.py index 79efa2445374..85873ad7b8fd 100644 --- a/test/srt/ascend/test_ascend_tp4_bf16.py +++ b/test/registered/ascend/basic_function/runtime_opts/test_npu_tp4_bf16.py @@ -3,6 +3,7 @@ from urllib.parse import urlparse from sglang.srt.utils import kill_process_tree +from sglang.test.ci.ci_register import register_npu_ci from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k from sglang.test.test_utils import ( DEFAULT_URL_FOR_TEST, @@ -10,6 +11,9 @@ popen_launch_server, ) +register_npu_ci(est_time=400, suite="stage-b-test-4-npu-a3", nightly=False) +register_npu_ci(est_time=400, suite="nightly-4-npu-a3", nightly=True) + TEST_MODEL_MATRIX = { "Qwen/Qwen3-30B-A3B-Instruct-2507": { "accuracy": 0.90, diff --git a/test/registered/ascend/embedding_models/test_npu_bge_large_en_v1_5.py b/test/registered/ascend/embedding_models/test_npu_bge_large_en_v1_5.py index 8e2869d70544..5da11e7a4552 100644 --- a/test/registered/ascend/embedding_models/test_npu_bge_large_en_v1_5.py +++ b/test/registered/ascend/embedding_models/test_npu_bge_large_en_v1_5.py @@ -11,7 +11,7 @@ register_npu_ci( est_time=400, - suite="nightly-1-npu-a3", + suite="full-1-npu-a3", nightly=True, disabled="embeddings are not all close", ) diff --git a/test/registered/ascend/llm_models/test_npu_afm_4_5b.py b/test/registered/ascend/llm_models/test_npu_afm_4_5b.py index ce905093feea..9f83350f9245 100644 --- a/test/registered/ascend/llm_models/test_npu_afm_4_5b.py +++ b/test/registered/ascend/llm_models/test_npu_afm_4_5b.py @@ -5,7 +5,7 @@ from sglang.test.ci.ci_register import register_npu_ci from sglang.test.test_utils import CustomTestCase -register_npu_ci(est_time=400, suite="nightly-1-npu-a3", nightly=True) +register_npu_ci(est_time=400, suite="full-1-npu-a3", nightly=True) class TestAFM(GSM8KAscendMixin, CustomTestCase): diff --git a/test/registered/ascend/llm_models/test_npu_c4ai_command_r_v01.py b/test/registered/ascend/llm_models/test_npu_c4ai_command_r_v01.py index c4ee782b46bd..150572c4d6c1 100644 --- a/test/registered/ascend/llm_models/test_npu_c4ai_command_r_v01.py +++ b/test/registered/ascend/llm_models/test_npu_c4ai_command_r_v01.py @@ -8,7 +8,7 @@ from sglang.test.ci.ci_register import register_npu_ci from sglang.test.test_utils import CustomTestCase -register_npu_ci(est_time=400, suite="nightly-2-npu-a3", nightly=False) +register_npu_ci(est_time=400, suite="full-2-npu-a3", nightly=False) class TestC4AI(GSM8KAscendMixin, CustomTestCase): diff --git a/test/registered/ascend/llm_models/test_npu_exaone_3.py b/test/registered/ascend/llm_models/test_npu_exaone_3.py index 23e72d4cd943..ed676dd5b5a8 100644 --- a/test/registered/ascend/llm_models/test_npu_exaone_3.py +++ b/test/registered/ascend/llm_models/test_npu_exaone_3.py @@ -5,7 +5,7 @@ from sglang.test.ci.ci_register import register_npu_ci from sglang.test.test_utils import CustomTestCase -register_npu_ci(est_time=400, suite="nightly-1-npu-a3", nightly=False) +register_npu_ci(est_time=400, suite="full-1-npu-a3", nightly=False) class TestEXAONE(GSM8KAscendMixin, CustomTestCase): diff --git a/test/registered/ascend/llm_models/test_npu_granite_3_0_3b_a800m.py b/test/registered/ascend/llm_models/test_npu_granite_3_0_3b_a800m.py index 00d3b2a6cc54..9552e3ad9bee 100644 --- a/test/registered/ascend/llm_models/test_npu_granite_3_0_3b_a800m.py +++ b/test/registered/ascend/llm_models/test_npu_granite_3_0_3b_a800m.py @@ -7,7 +7,7 @@ from sglang.test.ci.ci_register import register_npu_ci from sglang.test.test_utils import CustomTestCase -register_npu_ci(est_time=400, suite="nightly-1-npu-a3", nightly=True) +register_npu_ci(est_time=400, suite="full-1-npu-a3", nightly=True) class TestGranite(GSM8KAscendMixin, CustomTestCase): diff --git a/test/registered/ascend/llm_models/test_npu_granite_3_1_8b.py b/test/registered/ascend/llm_models/test_npu_granite_3_1_8b.py index ac665572a515..1ef751dd01cc 100644 --- a/test/registered/ascend/llm_models/test_npu_granite_3_1_8b.py +++ b/test/registered/ascend/llm_models/test_npu_granite_3_1_8b.py @@ -5,7 +5,7 @@ from sglang.test.ci.ci_register import register_npu_ci from sglang.test.test_utils import CustomTestCase -register_npu_ci(est_time=400, suite="nightly-1-npu-a3", nightly=True) +register_npu_ci(est_time=400, suite="full-1-npu-a3", nightly=True) class TestGranite(GSM8KAscendMixin, CustomTestCase): diff --git a/test/registered/ascend/llm_models/test_npu_grok_2.py b/test/registered/ascend/llm_models/test_npu_grok_2.py index 3eff4c19480d..0f75ecc6b375 100644 --- a/test/registered/ascend/llm_models/test_npu_grok_2.py +++ b/test/registered/ascend/llm_models/test_npu_grok_2.py @@ -6,7 +6,7 @@ register_npu_ci( est_time=400, - suite="nightly-16-npu-a3", + suite="full-16-npu-a3", nightly=False, disabled="https://github.com/Ascend/sglang/issues/25", ) diff --git a/test/registered/ascend/llm_models/test_npu_ling_lite.py b/test/registered/ascend/llm_models/test_npu_ling_lite.py index 0dc2a7809e32..2f24064e8fc9 100644 --- a/test/registered/ascend/llm_models/test_npu_ling_lite.py +++ b/test/registered/ascend/llm_models/test_npu_ling_lite.py @@ -5,7 +5,7 @@ from sglang.test.ci.ci_register import register_npu_ci from sglang.test.test_utils import CustomTestCase -register_npu_ci(est_time=400, suite="nightly-2-npu-a3", nightly=True) +register_npu_ci(est_time=400, suite="full-2-npu-a3", nightly=True) class TestLingLite(GSM8KAscendMixin, CustomTestCase): diff --git a/test/registered/ascend/llm_models/test_npu_mimo_7b_rl.py b/test/registered/ascend/llm_models/test_npu_mimo_7b_rl.py index 2fe9f802b891..e4d354b8bc1e 100644 --- a/test/registered/ascend/llm_models/test_npu_mimo_7b_rl.py +++ b/test/registered/ascend/llm_models/test_npu_mimo_7b_rl.py @@ -5,7 +5,7 @@ from sglang.test.ci.ci_register import register_npu_ci from sglang.test.test_utils import CustomTestCase -register_npu_ci(est_time=400, suite="nightly-1-npu-a3", nightly=True) +register_npu_ci(est_time=400, suite="full-1-npu-a3", nightly=True) class TestMiMo7BRL(GSM8KAscendMixin, CustomTestCase): diff --git a/test/registered/ascend/llm_models/test_npu_persimmon_8b_chat.py b/test/registered/ascend/llm_models/test_npu_persimmon_8b_chat.py index 9958edc25e4a..4a4a95782e73 100644 --- a/test/registered/ascend/llm_models/test_npu_persimmon_8b_chat.py +++ b/test/registered/ascend/llm_models/test_npu_persimmon_8b_chat.py @@ -8,7 +8,7 @@ register_npu_ci( est_time=400, - suite="nightly-1-npu-a3", + suite="full-1-npu-a3", nightly=False, ) diff --git a/test/registered/ascend/llm_models/test_npu_smollm_1_7b.py b/test/registered/ascend/llm_models/test_npu_smollm_1_7b.py index cfe3722f73ee..dbe1bcc1c185 100644 --- a/test/registered/ascend/llm_models/test_npu_smollm_1_7b.py +++ b/test/registered/ascend/llm_models/test_npu_smollm_1_7b.py @@ -5,7 +5,7 @@ from sglang.test.ci.ci_register import register_npu_ci from sglang.test.test_utils import CustomTestCase -register_npu_ci(est_time=400, suite="nightly-1-npu-a3", nightly=True) +register_npu_ci(est_time=400, suite="full-1-npu-a3", nightly=True) class TestSmolLM(GSM8KAscendMixin, CustomTestCase): diff --git a/test/registered/ascend/llm_models/test_npu_stablelm_2_1_6b.py b/test/registered/ascend/llm_models/test_npu_stablelm_2_1_6b.py index 07c71b07006b..f61f3ac88f48 100644 --- a/test/registered/ascend/llm_models/test_npu_stablelm_2_1_6b.py +++ b/test/registered/ascend/llm_models/test_npu_stablelm_2_1_6b.py @@ -5,7 +5,7 @@ from sglang.test.ci.ci_register import register_npu_ci from sglang.test.test_utils import CustomTestCase -register_npu_ci(est_time=400, suite="nightly-1-npu-a3", nightly=True) +register_npu_ci(est_time=400, suite="full-1-npu-a3", nightly=True) class TestStablelm(GSM8KAscendMixin, CustomTestCase): diff --git a/test/registered/ascend/reward_models/test_npu_gemma_2_27b_v0_2.py b/test/registered/ascend/reward_models/test_npu_gemma_2_27b_v0_2.py index 16772b0ff750..086a823422da 100644 --- a/test/registered/ascend/reward_models/test_npu_gemma_2_27b_v0_2.py +++ b/test/registered/ascend/reward_models/test_npu_gemma_2_27b_v0_2.py @@ -10,7 +10,7 @@ from sglang.test.test_utils import CustomTestCase logger = logging.getLogger(__name__) -register_npu_ci(est_time=400, suite="nightly-1-npu-a3", nightly=True) +register_npu_ci(est_time=400, suite="full-1-npu-a3", nightly=True) MODELS = [ ( diff --git a/test/registered/ascend/reward_models/test_npu_internlm2_7b_reward.py b/test/registered/ascend/reward_models/test_npu_internlm2_7b_reward.py index 7386e4df9471..a0877d30ff69 100644 --- a/test/registered/ascend/reward_models/test_npu_internlm2_7b_reward.py +++ b/test/registered/ascend/reward_models/test_npu_internlm2_7b_reward.py @@ -13,8 +13,8 @@ register_npu_ci( est_time=400, - suite="nightly-4-npu-a3", - nightly=False, + suite="full-4-npu-a3", + nightly=True, ) PROMPT = ( diff --git a/test/registered/ascend/reward_models/test_npu_llama_3_1_8b_v0_2.py b/test/registered/ascend/reward_models/test_npu_llama_3_1_8b_v0_2.py index 2f23aaf59102..c702e8d0b179 100644 --- a/test/registered/ascend/reward_models/test_npu_llama_3_1_8b_v0_2.py +++ b/test/registered/ascend/reward_models/test_npu_llama_3_1_8b_v0_2.py @@ -10,7 +10,7 @@ from sglang.test.runners import HFRunner, SRTRunner from sglang.test.test_utils import CustomTestCase -register_npu_ci(est_time=400, suite="nightly-1-npu-a3", nightly=False) +register_npu_ci(est_time=400, suite="full-1-npu-a3", nightly=True) MODELS = [ (SKYWORK_REWARD_LLAMA_3_1_8B_V0_2_WEIGHTS_PATH, 1, 4e-2), diff --git a/test/registered/ascend/vlm_models/test_ascend_glm_4_5v.py b/test/registered/ascend/vlm_models/test_ascend_glm_4_5v.py new file mode 100644 index 000000000000..9b7ba83ea56f --- /dev/null +++ b/test/registered/ascend/vlm_models/test_ascend_glm_4_5v.py @@ -0,0 +1,33 @@ +import unittest + +from sglang.test.ascend.vlm_utils import TestVLMModels +from sglang.test.ci.ci_register import register_npu_ci + +register_npu_ci(est_time=400, suite="nightly-8-npu-a3", nightly=True) + + +class TestGLM4Models(TestVLMModels): + model = "/root/.cache/modelscope/hub/models/ZhipuAI/GLM-4.5V" + mmmu_accuracy = 0.2 + other_args = [ + "--trust-remote-code", + "--cuda-graph-max-bs", + "32", + "--enable-multimodal", + "--mem-fraction-static", + 0.7, + "--log-level", + "info", + "--attention-backend", + "ascend", + "--disable-cuda-graph", + "--tp-size", + 8, + ] + + def test_vlm_mmmu_benchmark(self): + self._run_vlm_mmmu_test() + + +if __name__ == "__main__": + unittest.main() diff --git a/test/registered/ascend/vlm_models/test_npu_qwen3_vl_4b_instruct.py b/test/registered/ascend/vlm_models/test_npu_qwen3_vl_4b_instruct.py index 802b99d93cdf..33f5b447c5bc 100644 --- a/test/registered/ascend/vlm_models/test_npu_qwen3_vl_4b_instruct.py +++ b/test/registered/ascend/vlm_models/test_npu_qwen3_vl_4b_instruct.py @@ -4,7 +4,7 @@ from sglang.test.ascend.vlm_utils import TestVLMModels from sglang.test.ci.ci_register import register_npu_ci -register_npu_ci(est_time=400, suite="nightly-4-npu-a3", nightly=True) +register_npu_ci(est_time=400, suite="full-4-npu-a3", nightly=True) class TestQwen3VL4B(TestVLMModels): diff --git a/test/registered/attention/test_chunk_gated_delta_rule.py b/test/registered/attention/test_chunk_gated_delta_rule.py index e5d53c32a6b6..5daeb6a1bb0a 100644 --- a/test/registered/attention/test_chunk_gated_delta_rule.py +++ b/test/registered/attention/test_chunk_gated_delta_rule.py @@ -8,7 +8,7 @@ ) from sglang.test.ci.ci_register import register_cuda_ci -register_cuda_ci(est_time=60, suite="stage-b-test-1-gpu-large") +register_cuda_ci(est_time=8, suite="stage-b-test-1-gpu-large") @unittest.skipIf(not torch.cuda.is_available(), "Test requires CUDA") diff --git a/test/registered/attention/test_fa3.py b/test/registered/attention/test_fa3.py index 56969b9c806e..472a2099ae3f 100644 --- a/test/registered/attention/test_fa3.py +++ b/test/registered/attention/test_fa3.py @@ -6,7 +6,7 @@ from sglang.srt.environ import envs from sglang.srt.utils import get_device_sm, kill_process_tree from sglang.test.ci.ci_register import register_cuda_ci -from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k +from sglang.test.run_eval import run_eval from sglang.test.test_utils import ( DEFAULT_DRAFT_MODEL_EAGLE3, DEFAULT_MODEL_NAME_FOR_TEST, @@ -36,7 +36,6 @@ GSM_DATASET_PATH: "/shared/public/data/gsm8k/test.jsonl", } - if OFFLINE_MODE: DEFAULT_MODEL_NAME_FOR_TEST = OFFLINE_PATH_DICT[DEFAULT_MODEL_NAME_FOR_TEST] DEFAULT_DRAFT_MODEL_EAGLE3 = OFFLINE_PATH_DICT[DEFAULT_DRAFT_MODEL_EAGLE3] @@ -46,7 +45,6 @@ ] GSM_DATASET_PATH = OFFLINE_PATH_DICT[GSM_DATASET_PATH] - # Default server arguments shared across all tests DEFAULT_SERVER_ARGS = [ "--trust-remote-code", @@ -99,19 +97,21 @@ def test_gsm8k(self): requests.get(self.base_url + "/flush_cache") args = SimpleNamespace( + base_url=self.base_url, + model=self.model, + eval_name="gsm8k", + api="completion", + max_tokens=512, + num_examples=100, + num_threads=128, num_shots=4, - num_questions=100, - max_new_tokens=512, - parallel=128, - host="http://127.0.0.1", - port=int(self.base_url.split(":")[-1]), - data_path=GSM_DATASET_PATH, + gsm8k_data_path=GSM_DATASET_PATH, ) - metrics = run_eval_few_shot_gsm8k(args) + metrics = run_eval(args) print(f"{metrics=}") # Use the appropriate metric key based on the test class - metric_key = "accuracy" + metric_key = "score" self.assertGreater(metrics[metric_key], self.accuracy_threshold) if self.speculative_decode: diff --git a/test/registered/attention/test_flash_attention_4.py b/test/registered/attention/test_flash_attention_4.py index 656309820599..3c9c4242bbf1 100644 --- a/test/registered/attention/test_flash_attention_4.py +++ b/test/registered/attention/test_flash_attention_4.py @@ -4,7 +4,7 @@ from sglang.srt.utils import get_device_sm, kill_process_tree from sglang.test.ci.ci_register import register_cuda_ci -from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k +from sglang.test.run_eval import run_eval from sglang.test.test_utils import ( DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, DEFAULT_URL_FOR_TEST, @@ -42,18 +42,18 @@ def tearDownClass(cls): def test_gsm8k(self): parsed_url = urlparse(self.base_url) args = SimpleNamespace( - num_shots=5, - data_path=None, - num_questions=1319, - max_new_tokens=512, - parallel=200, - host=f"{parsed_url.scheme}://{parsed_url.hostname}", - port=parsed_url.port, + base_url=self.base_url, + model=self.model, + eval_name="gsm8k", + api="completion", + max_tokens=512, + num_examples=1319, + num_threads=200, ) - metrics = run_eval_few_shot_gsm8k(args) + metrics = run_eval(args) print(metrics) - self.assertGreater(metrics["accuracy"], 0.89) + self.assertGreater(metrics["score"], 0.89) if __name__ == "__main__": diff --git a/test/registered/attention/test_hybrid_attn_backend.py b/test/registered/attention/test_hybrid_attn_backend.py index 1c70ee03217d..cb41c0cbdb0a 100644 --- a/test/registered/attention/test_hybrid_attn_backend.py +++ b/test/registered/attention/test_hybrid_attn_backend.py @@ -6,7 +6,7 @@ from sglang.srt.environ import envs from sglang.srt.utils import get_device_sm, kill_process_tree from sglang.test.ci.ci_register import register_cuda_ci -from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k +from sglang.test.run_eval import run_eval from sglang.test.test_utils import ( DEFAULT_DRAFT_MODEL_EAGLE, DEFAULT_MODEL_NAME_FOR_TEST, @@ -20,7 +20,7 @@ # Hybrid attention backend tests (FA3 prefill + FlashInfer decode, requires SM 90+ / H100) # Multiple test classes: base, MLA, TorchCompile, SpecDecode variants -register_cuda_ci(est_time=200, suite="stage-b-test-1-gpu-large") +register_cuda_ci(est_time=350, suite="stage-b-test-1-gpu-large") GSM_DATASET_PATH = None @@ -76,24 +76,23 @@ def tearDownClass(cls): def test_gsm8k(self): requests.get(self.base_url + "/flush_cache") + model = DEFAULT_TARGET_MODEL_EAGLE if self.speculative_decode else self.model args = SimpleNamespace( - num_shots=4, - num_questions=100, - max_new_tokens=512, - parallel=128, - host="http://127.0.0.1", - port=int(self.base_url.split(":")[-1]), - data_path=GSM_DATASET_PATH, + base_url=self.base_url, + model=model, + eval_name="gsm8k", + api="completion", + max_tokens=512, + num_examples=100, + num_threads=128, ) - metrics = run_eval_few_shot_gsm8k(args) + metrics = run_eval(args) print(f"{metrics=}") - # Use the appropriate metric key based on the test class - metric_key = "accuracy" - self.assertGreater(metrics[metric_key], self.accuracy_threshold) + self.assertGreater(metrics["score"], self.accuracy_threshold) if self.speculative_decode: - server_info = requests.get(self.base_url + "/get_server_info") + server_info = requests.get(self.base_url + "/server_info") avg_spec_accept_length = server_info.json()["internal_states"][0][ "avg_spec_accept_length" ] diff --git a/test/registered/attention/test_local_attn.py b/test/registered/attention/test_local_attn.py index d1abe1b75a91..229260c19778 100644 --- a/test/registered/attention/test_local_attn.py +++ b/test/registered/attention/test_local_attn.py @@ -6,7 +6,7 @@ from sglang.srt.utils import get_device_sm, kill_process_tree from sglang.test.ci.ci_register import register_cuda_ci -from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k +from sglang.test.run_eval import run_eval from sglang.test.test_utils import ( DEFAULT_MODEL_NAME_FOR_TEST_LOCAL_ATTENTION, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, @@ -56,19 +56,20 @@ def test_gsm8k(self): requests.get(self.base_url + "/flush_cache") args = SimpleNamespace( + base_url=self.base_url, + model=self.model, + eval_name="gsm8k", + api="completion", + max_tokens=512, + num_examples=100, + num_threads=128, num_shots=4, - num_questions=100, - max_new_tokens=512, - parallel=128, - host="http://127.0.0.1", - port=int(self.base_url.split(":")[-1]), - data_path=None, ) - metrics = run_eval_few_shot_gsm8k(args) + metrics = run_eval(args) print(f"{metrics=}") # Use the appropriate metric key based on the test class - metric_key = "accuracy" + metric_key = "score" self.assertGreater(metrics[metric_key], self.accuracy_threshold) diff --git a/test/registered/backends/test_deepseek_r1_fp8_trtllm_backend.py b/test/registered/backends/test_deepseek_r1_fp8_trtllm_backend.py index b822bf3a48a1..74e89060ff4e 100644 --- a/test/registered/backends/test_deepseek_r1_fp8_trtllm_backend.py +++ b/test/registered/backends/test_deepseek_r1_fp8_trtllm_backend.py @@ -3,7 +3,7 @@ from sglang.srt.utils import kill_process_tree from sglang.test.ci.ci_register import register_cuda_ci -from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k +from sglang.test.run_eval import run_eval from sglang.test.test_utils import ( DEFAULT_URL_FOR_TEST, CustomTestCase, @@ -72,18 +72,18 @@ def tearDownClass(cls): def test_gsm8k(self): args = SimpleNamespace( - num_shots=5, - data_path=None, - num_questions=512, - parallel=512, - max_new_tokens=512, - host="http://127.0.0.1", - port=int(self.base_url.split(":")[-1]), + base_url=self.base_url, + model=self.model, + eval_name="gsm8k", + api="completion", + max_tokens=512, + num_examples=512, + num_threads=512, ) - metrics = run_eval_few_shot_gsm8k(args) + metrics = run_eval(args) print(f"Eval accuracy of GSM8K: {metrics=}") - self.assertGreater(metrics["accuracy"], 0.92) + self.assertGreater(metrics["score"], 0.92) if __name__ == "__main__": diff --git a/test/registered/backends/test_deepseek_v3_fp4_cutlass_moe.py b/test/registered/backends/test_deepseek_v3_fp4_cutlass_moe.py index c3a509efa68a..b547409c1814 100644 --- a/test/registered/backends/test_deepseek_v3_fp4_cutlass_moe.py +++ b/test/registered/backends/test_deepseek_v3_fp4_cutlass_moe.py @@ -3,7 +3,7 @@ from sglang.srt.utils import kill_process_tree from sglang.test.ci.ci_register import register_cuda_ci -from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k +from sglang.test.run_eval import run_eval from sglang.test.test_utils import ( DEFAULT_URL_FOR_TEST, CustomTestCase, @@ -52,23 +52,24 @@ def test_a_gsm8k( self, ): # Append an "a" to make this test run first (alphabetically) to warm up the server args = SimpleNamespace( + base_url=self.base_url, + model=self.model, + eval_name="gsm8k", + api="completion", + max_tokens=512, + num_examples=1319, + num_threads=1319, num_shots=8, - data_path=None, - num_questions=1319, - parallel=1319, - max_new_tokens=512, - host="http://127.0.0.1", - port=int(self.base_url.split(":")[-1]), ) - metrics = run_eval_few_shot_gsm8k(args) + metrics = run_eval(args) print(f"{metrics=}") if is_in_ci(): write_github_step_summary( f"### test_gsm8k (deepseek-v3-fp4-cutlass-moe)\n" - f'{metrics["accuracy"]=:.3f}\n' + f'{metrics["score"]=:.3f}\n' ) - self.assertGreater(metrics["accuracy"], 0.935) + self.assertGreater(metrics["score"], 0.935) if __name__ == "__main__": diff --git a/test/registered/backends/test_flashinfer_trtllm_gen_attn_backend.py b/test/registered/backends/test_flashinfer_trtllm_gen_attn_backend.py index 42164bc2bfca..11aed30fa79a 100644 --- a/test/registered/backends/test_flashinfer_trtllm_gen_attn_backend.py +++ b/test/registered/backends/test_flashinfer_trtllm_gen_attn_backend.py @@ -4,7 +4,7 @@ from sglang.srt.utils import kill_process_tree from sglang.test.ci.ci_register import register_cuda_ci -from sglang.test.few_shot_gsm8k import run_eval +from sglang.test.run_eval import run_eval from sglang.test.test_utils import ( DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, DEFAULT_URL_FOR_TEST, @@ -48,17 +48,17 @@ def tearDownClass(cls): def test_gsm8k(self): args = SimpleNamespace( - num_shots=5, - data_path=None, - num_questions=200, - max_new_tokens=512, - parallel=128, - host="http://127.0.0.1", - port=int(self.base_url.split(":")[-1]), + base_url=self.base_url, + model=self.model, + eval_name="gsm8k", + api="completion", + max_tokens=512, + num_examples=200, + num_threads=128, ) metrics = run_eval(args) print(f"{metrics=}") - self.assertGreater(metrics["accuracy"], 0.93) + self.assertGreater(metrics["score"], 0.93) if __name__ == "__main__": diff --git a/test/registered/backends/test_flashinfer_trtllm_gen_moe_backend.py b/test/registered/backends/test_flashinfer_trtllm_gen_moe_backend.py index b8e76570ca6b..b63447a60cd5 100644 --- a/test/registered/backends/test_flashinfer_trtllm_gen_moe_backend.py +++ b/test/registered/backends/test_flashinfer_trtllm_gen_moe_backend.py @@ -4,7 +4,7 @@ from sglang.srt.utils import kill_process_tree from sglang.test.ci.ci_register import register_cuda_ci -from sglang.test.few_shot_gsm8k import run_eval +from sglang.test.run_eval import run_eval from sglang.test.test_utils import ( DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, DEFAULT_URL_FOR_TEST, @@ -49,17 +49,17 @@ def tearDownClass(cls): def test_gsm8k(self): args = SimpleNamespace( - num_shots=5, - data_path=None, - num_questions=200, - max_new_tokens=512, - parallel=128, - host="http://127.0.0.1", - port=int(self.base_url.split(":")[-1]), + base_url=self.base_url, + model=self.model, + eval_name="gsm8k", + api="completion", + max_tokens=512, + num_examples=200, + num_threads=128, ) metrics = run_eval(args) print(f"{metrics=}") - self.assertGreater(metrics["accuracy"], 0.93) + self.assertGreater(metrics["score"], 0.93) class FlashinferTrtllmGenMoeBackendBF16Base: @@ -97,17 +97,17 @@ def tearDownClass(cls): def test_gsm8k(self): args = SimpleNamespace( - num_shots=5, - data_path=None, - num_questions=200, - max_new_tokens=512, - parallel=128, - host="http://127.0.0.1", - port=int(self.base_url.split(":")[-1]), + base_url=self.base_url, + model=self.model, + eval_name="gsm8k", + api="completion", + max_tokens=512, + num_examples=200, + num_threads=128, ) metrics = run_eval(args) print(f"{metrics=}") - self.assertGreater(metrics["accuracy"], 0.93) + self.assertGreater(metrics["score"], 0.93) class FlashinferTrtllmGenMoeBackendMXFP8Base: @@ -144,17 +144,17 @@ def tearDownClass(cls): def test_gsm8k(self): args = SimpleNamespace( - num_shots=5, - data_path=None, - num_questions=200, - max_new_tokens=512, - parallel=128, - host="http://127.0.0.1", - port=int(self.base_url.split(":")[-1]), + base_url=self.base_url, + model=self.model, + eval_name="gsm8k", + api="completion", + max_tokens=512, + num_examples=200, + num_threads=128, ) metrics = run_eval(args) print(f"{metrics=}") - self.assertGreater(metrics["accuracy"], 0.93) + self.assertGreater(metrics["score"], 0.93) class TestFlashinferTrtllmGenMoeBackendFP8( diff --git a/test/registered/backends/test_qwen3_fp4_trtllm_gen_moe.py b/test/registered/backends/test_qwen3_fp4_trtllm_gen_moe.py index f215af49b413..4011d3b2a796 100644 --- a/test/registered/backends/test_qwen3_fp4_trtllm_gen_moe.py +++ b/test/registered/backends/test_qwen3_fp4_trtllm_gen_moe.py @@ -3,7 +3,7 @@ from sglang.srt.utils import get_device_sm, kill_process_tree from sglang.test.ci.ci_register import register_cuda_ci -from sglang.test.few_shot_gsm8k import run_eval +from sglang.test.run_eval import run_eval from sglang.test.test_utils import ( DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, DEFAULT_URL_FOR_TEST, @@ -51,17 +51,18 @@ def tearDownClass(cls): def test_gsm8k(self): args = SimpleNamespace( + base_url=self.base_url, + model=self.model, + eval_name="gsm8k", + api="completion", + max_tokens=512, + num_examples=1319, + num_threads=1319, num_shots=8, - data_path=None, - num_questions=1319, - max_new_tokens=512, - parallel=1319, - host="http://127.0.0.1", - port=int(self.base_url.split(":")[-1]), ) metrics = run_eval(args) print(f"{metrics=}") - self.assertGreater(metrics["accuracy"], 0.88) + self.assertGreater(metrics["score"], 0.88) if __name__ == "__main__": diff --git a/test/registered/bench_fn/test_bench_serving_functionality.py b/test/registered/bench_fn/test_bench_serving_functionality.py index 564fdee29252..f573319c7e84 100644 --- a/test/registered/bench_fn/test_bench_serving_functionality.py +++ b/test/registered/bench_fn/test_bench_serving_functionality.py @@ -8,6 +8,7 @@ from sglang.bench_serving import run_benchmark from sglang.benchmark.utils import parse_custom_headers +from sglang.srt.constants import HEALTH_CHECK_RID_PREFIX from sglang.srt.utils import kill_process_tree from sglang.test.ci.ci_register import register_amd_ci, register_cuda_ci from sglang.test.test_utils import ( @@ -80,7 +81,7 @@ def _verify_multi_turn_logs(self, content: str): continue text = obj.get("obj", {}).get("text") rid = obj.get("rid", "") - if text and not rid.startswith("HEALTH_CHECK"): + if text and not rid.startswith(HEALTH_CHECK_RID_PREFIX): reqs.append(text) self.assertGreaterEqual(len(reqs), NUM_CONVERSATIONS * NUM_TURNS) diff --git a/test/registered/core/test_srt_endpoint.py b/test/registered/core/test_srt_endpoint.py index 46c5853a674e..22e3b468fff7 100644 --- a/test/registered/core/test_srt_endpoint.py +++ b/test/registered/core/test_srt_endpoint.py @@ -500,7 +500,7 @@ def send_and_check_cached_tokens(input_ids): self.assertEqual(send_and_check_cached_tokens(range(0, 11000)), 10000) def test_get_server_info(self): - response = requests.get(self.base_url + "/get_server_info") + response = requests.get(self.base_url + "/server_info") response_json = response.json() max_total_num_tokens = response_json["max_total_num_tokens"] @@ -630,7 +630,7 @@ def test_get_server_info_concurrent(self): tp = ThreadPoolExecutor(max_workers=30) def s(): - server_info = requests.get(self.base_url + "/get_server_info") + server_info = requests.get(self.base_url + "/server_info") server_info.json() futures = [] diff --git a/test/registered/cp/test_deepseek_v32_cp_single_node.py b/test/registered/cp/test_deepseek_v32_cp_single_node.py index 8b3255777df5..55595eff5809 100644 --- a/test/registered/cp/test_deepseek_v32_cp_single_node.py +++ b/test/registered/cp/test_deepseek_v32_cp_single_node.py @@ -4,7 +4,7 @@ from sglang.srt.environ import envs from sglang.srt.utils import kill_process_tree from sglang.test.ci.ci_register import register_cuda_ci -from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k +from sglang.test.run_eval import run_eval from sglang.test.test_utils import ( DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, DEFAULT_URL_FOR_TEST, @@ -14,7 +14,7 @@ write_github_step_summary, ) -register_cuda_ci(est_time=360, suite="stage-c-test-8-gpu-h200") +register_cuda_ci(est_time=640, suite="stage-c-test-deepep-8-gpu-h200") DEEPSEEK_V32_MODEL_PATH = "deepseek-ai/DeepSeek-V3.2" @@ -68,23 +68,24 @@ def test_a_gsm8k( self, ): # Append an "a" to make this test run first (alphabetically) to warm up the server args = SimpleNamespace( + base_url=self.base_url, + model=self.model, + eval_name="gsm8k", + api="completion", + max_tokens=512, + num_examples=500, + num_threads=32, num_shots=20, - data_path=None, - num_questions=500, - parallel=32, - max_new_tokens=512, - host="http://127.0.0.1", - port=int(self.base_url.split(":")[-1]), ) - metrics = run_eval_few_shot_gsm8k(args) + metrics = run_eval(args) print(f"{metrics=}") if is_in_ci(): write_github_step_summary( f"### test_a_gsm8k (deepseek-v32-cp-in-seq-split)\n" - f'{metrics["accuracy"]=:.3f}\n' + f'{metrics["score"]=:.3f}\n' ) - self.assertGreater(metrics["accuracy"], 0.935) + self.assertGreater(metrics["score"], 0.935) class TestDeepseekV32CPRoundRobinSplit(CustomTestCase): @@ -134,23 +135,24 @@ def test_a_gsm8k( self, ): # Append an "a" to make this test run first (alphabetically) to warm up the server args = SimpleNamespace( + base_url=self.base_url, + model=self.model, + eval_name="gsm8k", + api="completion", + max_tokens=512, + num_examples=500, + num_threads=32, num_shots=20, - data_path=None, - num_questions=500, - parallel=32, - max_new_tokens=512, - host="http://127.0.0.1", - port=int(self.base_url.split(":")[-1]), ) - metrics = run_eval_few_shot_gsm8k(args) + metrics = run_eval(args) print(f"{metrics=}") if is_in_ci(): write_github_step_summary( f"### test_a_gsm8k (deepseek-v32-cp-in-seq-split)\n" - f'{metrics["accuracy"]=:.3f}\n' + f'{metrics["score"]=:.3f}\n' ) - self.assertGreater(metrics["accuracy"], 0.935) + self.assertGreater(metrics["score"], 0.935) if __name__ == "__main__": diff --git a/test/registered/debug_utils/test_engine_dumper_comparator_e2e.py b/test/registered/debug_utils/test_engine_dumper_comparator_e2e.py index b16c139829bb..98ef9d7a09b9 100644 --- a/test/registered/debug_utils/test_engine_dumper_comparator_e2e.py +++ b/test/registered/debug_utils/test_engine_dumper_comparator_e2e.py @@ -120,7 +120,9 @@ # All sub-axes (attn_tp, moe_tp, attn_dp) are uniquely determined by tp_rank, # so only tp:replicated is needed — sub-axes are auto-resolved as implicitly replicated. # - # Attn tensors are NOT TP-sharded, mlp_output is already all-reduced. + # Attn tensors are NOT TP-sharded (attn_tp_size=1). + # mlp_output is still moe_tp:partial — the reduce-scatter happens in + # postprocess_layer(), after the dump point. # layer_input is dumped after prepare_attn which DP-distributes tokens, # so it needs dp:=attn_dp to filter to the non-empty DP rank. - target: sglang.srt.models.qwen3_moe.Qwen3MoeDecoderLayer.forward @@ -152,7 +154,7 @@ hidden_states = self.mlp( hidden_states, forward_batch, should_allreduce_fusion, use_reduce_scatter ) - append: "dumper.dump('mlp_output', hidden_states, dims='t h # tp:replicated')" + append: "dumper.dump('mlp_output', hidden_states, dims='t h[moe_tp:partial] # tp:replicated')" # --- attention internals --- - target: sglang.srt.models.qwen3_moe.Qwen3MoeAttention.forward_core @@ -184,14 +186,33 @@ def test_dp_attention(self, tmp_path: Path) -> None: """TP=2 baseline vs TP=2+DP=2+dp-attention target. In dp-attention mode (attn_tp_size=1, attn_dp_size=2), attention - tensors are NOT TP-sharded and mlp_output is already all-reduced. - A separate patch config with corrected dims is used for the target. + tensors are NOT TP-sharded and mlp_output is still moe_tp:partial + (the reduce-scatter happens in postprocess_layer, after the dump + point). A separate patch config with corrected dims is used for + the target. + + Comparison is limited to step 0 (prefill) because the decode + step has tokens on both DP ranks, which breaks the dp:=attn_dp + single-rank assumption and causes comparator errors. + + mlp_output is allowed to fail because the FusedMoE dispatcher + combine path may include an implicit all-reduce that makes the + dumped value differ from the raw partial expert output. All + other tensors (layer_input, attn_output, attn_pre_o_proj, + pre_mlp_residual, moe_router_logits, moe_expert_output) must + pass at step 0. """ _run_e2e_scenario( tmp_path=tmp_path, target_tp=BASELINE_TP, extra_target_server_args=["--dp", "2", "--enable-dp-attention"], target_patch_config_yaml=PATCH_CONFIG_DP_ATTENTION_YAML, + extra_comparator_args=[ + "--end-step", + "0", + "--allow-failed-pattern", + "mlp_output", + ], ) @@ -204,6 +225,7 @@ def _run_e2e_scenario( target_tp: int, extra_target_server_args: Optional[list[str]] = None, target_patch_config_yaml: Optional[str] = None, + extra_comparator_args: Optional[list[str]] = None, ) -> None: """Full e2e: write patch config -> baseline run -> target run -> compare.""" base_url: str = DEFAULT_URL_FOR_TEST @@ -249,6 +271,8 @@ def _run_e2e_scenario( "--allow-skipped-pattern", "input_ids|positions", ] + if extra_comparator_args: + cmd.extend(extra_comparator_args) result: subprocess.CompletedProcess[str] = subprocess.run( cmd, diff --git a/test/registered/disaggregation/test_disaggregation_basic.py b/test/registered/disaggregation/test_disaggregation_basic.py index 6b621e3e0dc3..6b03753eff21 100644 --- a/test/registered/disaggregation/test_disaggregation_basic.py +++ b/test/registered/disaggregation/test_disaggregation_basic.py @@ -8,14 +8,14 @@ from transformers import AutoTokenizer from sglang.test.ci.ci_register import register_cuda_ci -from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k +from sglang.test.run_eval import run_eval from sglang.test.server_fixtures.disaggregation_fixture import ( PDDisaggregationServerBase, ) from sglang.test.test_utils import ( - DEFAULT_DRAFT_MODEL_EAGLE, + DEFAULT_DRAFT_MODEL_EAGLE3, DEFAULT_MODEL_NAME_FOR_TEST, - DEFAULT_TARGET_MODEL_EAGLE, + DEFAULT_TARGET_MODEL_EAGLE3, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, popen_launch_pd_server, ) @@ -81,18 +81,17 @@ def start_decode(cls): def test_gsm8k(self): args = SimpleNamespace( - num_shots=5, - data_path=None, - num_questions=200, - max_new_tokens=512, - parallel=128, - host=f"http://{self.base_host}", - port=int(self.lb_port), + base_url=f"http://{self.base_host}:{self.lb_port}", + eval_name="gsm8k", + api="completion", + max_tokens=512, + num_examples=200, + num_threads=128, ) - metrics = run_eval_few_shot_gsm8k(args) + metrics = run_eval(args) print(f"Evaluation metrics: {metrics}") - self.assertGreater(metrics["accuracy"], 0.62) + self.assertGreater(metrics["score"], 0.62) def test_logprob(self): prompt = "The capital of france is " @@ -260,18 +259,17 @@ def start_decode(cls): def test_gsm8k(self): args = SimpleNamespace( - num_shots=5, - data_path=None, - num_questions=200, - max_new_tokens=512, - parallel=128, - host=f"http://{self.base_host}", - port=int(self.lb_port), + base_url=f"http://{self.base_host}:{self.lb_port}", + eval_name="gsm8k", + api="completion", + max_tokens=512, + num_examples=200, + num_threads=128, ) # Expect lots of failure but the server cannot crash try: - metrics = run_eval_few_shot_gsm8k(args) + metrics = run_eval(args) print(f"Evaluation metrics: {metrics}") except Exception as e: print(f"Test encountered expected errors: {e}") @@ -291,8 +289,8 @@ class TestDisaggregationMooncakeSpec(PDDisaggregationServerBase): @classmethod def setUpClass(cls): super().setUpClass() - cls.model = DEFAULT_TARGET_MODEL_EAGLE - cls.draft_model = DEFAULT_DRAFT_MODEL_EAGLE + cls.model = DEFAULT_TARGET_MODEL_EAGLE3 + cls.draft_model = DEFAULT_DRAFT_MODEL_EAGLE3 cls.spec_args = [ "--speculative-algorithm", "EAGLE", @@ -306,6 +304,7 @@ def setUpClass(cls): "16", "--cuda-graph-max-bs", "8", + "--dtype=float16", ] print(f"{cls.base_host=} {cls.lb_port=} {cls.prefill_port=} {cls.decode_port=}") @@ -361,18 +360,17 @@ def start_decode(cls): def test_gsm8k(self): args = SimpleNamespace( - num_shots=5, - data_path=None, - num_questions=200, - max_new_tokens=512, - parallel=2, - host=f"http://{self.base_host}", - port=int(self.lb_port), + base_url=f"http://{self.base_host}:{self.lb_port}", + eval_name="gsm8k", + api="completion", + max_tokens=512, + num_examples=200, + num_threads=128, ) - metrics = run_eval_few_shot_gsm8k(args) + metrics = run_eval(args) print(f"Evaluation metrics: {metrics}") - self.assertGreater(metrics["accuracy"], 0.20) + self.assertGreater(metrics["score"], 0.74) class TestDisaggregationSimulatedRetract(PDDisaggregationServerBase): @@ -439,18 +437,17 @@ def start_decode(cls): def test_gsm8k(self): args = SimpleNamespace( - num_shots=5, - data_path=None, - num_questions=200, - max_new_tokens=512, - parallel=128, - host=f"http://{self.base_host}", - port=int(self.lb_port), + base_url=f"http://{self.base_host}:{self.lb_port}", + eval_name="gsm8k", + api="completion", + max_tokens=512, + num_examples=200, + num_threads=128, ) - metrics = run_eval_few_shot_gsm8k(args) + metrics = run_eval(args) print(f"Evaluation metrics: {metrics}") - self.assertGreater(metrics["accuracy"], 0.62) + self.assertGreater(metrics["score"], 0.62) if __name__ == "__main__": diff --git a/test/registered/disaggregation/test_disaggregation_decode_offload.py b/test/registered/disaggregation/test_disaggregation_decode_offload.py index 9742c989686b..3bc2108031a2 100644 --- a/test/registered/disaggregation/test_disaggregation_decode_offload.py +++ b/test/registered/disaggregation/test_disaggregation_decode_offload.py @@ -12,16 +12,18 @@ from sglang.test.test_utils import ( DEFAULT_MODEL_NAME_FOR_TEST, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, - is_in_ci, popen_launch_pd_server, ) # Registering the test for CUDA CI with appropriate parameters # Increasing estimated time since we run evaluation twice -register_cuda_ci(est_time=600, suite="stage-b-test-2-gpu-large") +register_cuda_ci( + est_time=600, + suite="stage-b-test-2-gpu-large", + disabled="Temporarily disable the flaky test.", +) -@unittest.skipIf(is_in_ci(), "Temporarily disable the flaky test.") class TestDisaggregationDecodeOffload(PDDisaggregationServerBase): """ Test class for verifying KV cache offloading on the decode side in a diff --git a/test/registered/distributed/test_data_parallelism.py b/test/registered/distributed/test_data_parallelism.py index 70323f5634d5..bb3eb29c74e0 100644 --- a/test/registered/distributed/test_data_parallelism.py +++ b/test/registered/distributed/test_data_parallelism.py @@ -5,7 +5,7 @@ from sglang.srt.utils import kill_process_tree from sglang.test.ci.ci_register import register_amd_ci, register_cuda_ci -from sglang.test.kits.eval_accuracy_kit import MMLUMixin +from sglang.test.kits.eval_accuracy_kit import GSM8KMixin from sglang.test.test_utils import ( DEFAULT_MODEL_NAME_FOR_TEST, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, @@ -18,10 +18,8 @@ register_amd_ci(est_time=73, suite="stage-b-test-2-gpu-large-amd") -class TestDataParallelism(CustomTestCase, MMLUMixin): - mmlu_score_threshold = 0.65 - mmlu_num_examples = 64 - mmlu_num_threads = 32 +class TestDataParallelism(CustomTestCase, GSM8KMixin): + gsm8k_accuracy_thres = 0.7 @classmethod def setUpClass(cls): @@ -59,13 +57,13 @@ def test_update_weight(self): assert response.status_code == 200 def test_get_memory_pool_size(self): - # use `get_server_info` instead since `get_memory_pool_size` is merged into `get_server_info` - response = requests.get(self.base_url + "/get_server_info") + # use `server_info` instead since `get_memory_pool_size` is merged into `server_info` + response = requests.get(self.base_url + "/server_info") assert response.status_code == 200 time.sleep(1) - response = requests.get(self.base_url + "/get_server_info") + response = requests.get(self.base_url + "/server_info") assert response.status_code == 200 diff --git a/test/registered/distributed/test_disaggregation_aarch64.py b/test/registered/distributed/test_disaggregation_aarch64.py index be9ad8b5856d..c5a6f3e93c8a 100644 --- a/test/registered/distributed/test_disaggregation_aarch64.py +++ b/test/registered/distributed/test_disaggregation_aarch64.py @@ -3,7 +3,7 @@ from types import SimpleNamespace from sglang.test.ci.ci_register import register_cuda_ci -from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k +from sglang.test.run_eval import run_eval from sglang.test.server_fixtures.disaggregation_fixture import ( PDDisaggregationServerBase, ) @@ -82,18 +82,18 @@ def start_decode(cls): def test_gsm8k(self): args = SimpleNamespace( - num_shots=5, - data_path=None, - num_questions=200, - max_new_tokens=512, - parallel=128, - host=f"http://{self.base_host}", - port=int(self.lb_port), + base_url=self.base_url, + model=self.model, + eval_name="gsm8k", + api="completion", + max_tokens=512, + num_examples=200, + num_threads=128, ) - metrics = run_eval_few_shot_gsm8k(args) + metrics = run_eval(args) print(f"Evaluation metrics: {metrics}") - self.assertGreater(metrics["accuracy"], 0.62) + self.assertGreater(metrics["score"], 0.62) if __name__ == "__main__": diff --git a/test/registered/distributed/test_disaggregation_different_tp.py b/test/registered/distributed/test_disaggregation_different_tp.py index 4f5654f70b45..bbf9f75aacc4 100644 --- a/test/registered/distributed/test_disaggregation_different_tp.py +++ b/test/registered/distributed/test_disaggregation_different_tp.py @@ -3,7 +3,7 @@ from sglang.srt.environ import envs from sglang.test.ci.ci_register import register_cuda_ci -from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k +from sglang.test.run_eval import run_eval from sglang.test.server_fixtures.disaggregation_fixture import ( PDDisaggregationServerBase, ) @@ -79,18 +79,18 @@ def start_decode(cls): def test_gsm8k(self): args = SimpleNamespace( - num_shots=5, - data_path=None, - num_questions=200, - max_new_tokens=512, - parallel=128, - host=f"http://{self.base_host}", - port=int(self.lb_port), + base_url=self.base_url, + model=self.model, + eval_name="gsm8k", + api="completion", + max_tokens=512, + num_examples=200, + num_threads=128, ) - metrics = run_eval_few_shot_gsm8k(args) + metrics = run_eval(args) print(f"Evaluation metrics: {metrics}") - self.assertGreater(metrics["accuracy"], 0.60) + self.assertGreater(metrics["score"], 0.60) class TestDisaggregationMooncakeDecodeLargerTP(PDDisaggregationServerBase): @@ -154,18 +154,18 @@ def start_decode(cls): def test_gsm8k(self): args = SimpleNamespace( - num_shots=5, - data_path=None, - num_questions=200, - max_new_tokens=512, - parallel=128, - host=f"http://{self.base_host}", - port=int(self.lb_port), + base_url=self.base_url, + model=self.model, + eval_name="gsm8k", + api="completion", + max_tokens=512, + num_examples=200, + num_threads=128, ) - metrics = run_eval_few_shot_gsm8k(args) + metrics = run_eval(args) print(f"Evaluation metrics: {metrics}") - self.assertGreater(metrics["accuracy"], 0.60) + self.assertGreater(metrics["score"], 0.60) class TestDisaggregationMooncakeMHAPrefillLargerTP(PDDisaggregationServerBase): @@ -229,18 +229,18 @@ def start_decode(cls): def test_gsm8k(self): args = SimpleNamespace( - num_shots=5, - data_path=None, - num_questions=200, - max_new_tokens=512, - parallel=128, - host=f"http://{self.base_host}", - port=int(self.lb_port), + base_url=self.base_url, + model=self.model, + eval_name="gsm8k", + api="completion", + max_tokens=512, + num_examples=200, + num_threads=128, ) - metrics = run_eval_few_shot_gsm8k(args) + metrics = run_eval(args) print(f"Evaluation metrics: {metrics}") - self.assertGreater(metrics["accuracy"], 0.60) + self.assertGreater(metrics["score"], 0.60) class TestDisaggregationMooncakeMHADecodeLargerTP(PDDisaggregationServerBase): @@ -304,18 +304,18 @@ def start_decode(cls): def test_gsm8k(self): args = SimpleNamespace( - num_shots=5, - data_path=None, - num_questions=200, - max_new_tokens=512, - parallel=128, - host=f"http://{self.base_host}", - port=int(self.lb_port), + base_url=self.base_url, + model=self.model, + eval_name="gsm8k", + api="completion", + max_tokens=512, + num_examples=200, + num_threads=128, ) - metrics = run_eval_few_shot_gsm8k(args) + metrics = run_eval(args) print(f"Evaluation metrics: {metrics}") - self.assertGreater(metrics["accuracy"], 0.60) + self.assertGreater(metrics["score"], 0.60) if __name__ == "__main__": diff --git a/test/registered/distributed/test_disaggregation_dp_attention.py b/test/registered/distributed/test_disaggregation_dp_attention.py index b6d52fee61da..ba88a48d796b 100644 --- a/test/registered/distributed/test_disaggregation_dp_attention.py +++ b/test/registered/distributed/test_disaggregation_dp_attention.py @@ -4,7 +4,7 @@ from sglang.bench_serving import run_benchmark from sglang.srt.environ import envs from sglang.test.ci.ci_register import register_cuda_ci -from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k +from sglang.test.run_eval import run_eval from sglang.test.server_fixtures.disaggregation_fixture import ( PDDisaggregationServerBase, ) @@ -94,23 +94,22 @@ def start_decode(cls): def test_gsm8k(self): args = SimpleNamespace( - num_shots=5, - data_path=None, - num_questions=1400, - max_new_tokens=512, - parallel=128, - host=f"http://{self.base_host}", - port=int(self.lb_port), + base_url=self.base_url, + model=self.model, + eval_name="gsm8k", + api="completion", + max_tokens=512, + num_examples=1400, + num_threads=128, ) - metrics = run_eval_few_shot_gsm8k(args) + metrics = run_eval(args) print(f"Evaluation metrics: {metrics}") - self.assertGreater(metrics["accuracy"], 0.60) + self.assertGreater(metrics["score"], 0.60) class TestDisaggregationDPAttentionRoundRobin(TestDisaggregationDPAttention): LOAD_BALANCE_METHOD = "round_robin" - # TODO: add test for other load balance methods # TODO: add a balancedness metric def test_bench_serving(self): @@ -130,6 +129,48 @@ def test_bench_serving(self): self.assertEqual(result["completed"], 1000) +class TestDisaggregationDPAttentionTotalRequests(TestDisaggregationDPAttention): + LOAD_BALANCE_METHOD = "total_requests" + test_gsm8k = unittest.skip( + "Covered by base class; this class targets total_requests path." + )(TestDisaggregationDPAttention.test_gsm8k) + + def test_bench_serving(self): + args = get_benchmark_args( + base_url=f"http://{self.base_host}:{self.lb_port}", + dataset_name="random", + tokenizer=self.model, + num_prompts=256, + random_input_len=2048, + random_output_len=512, + request_rate=float("inf"), + max_concurrency=128, + ) + result = run_benchmark(args) + self.assertEqual(result["completed"], 256) + + +class TestDisaggregationDPAttentionTotalTokens(TestDisaggregationDPAttention): + LOAD_BALANCE_METHOD = "total_tokens" + test_gsm8k = unittest.skip( + "Covered by base class; this class targets total_tokens path." + )(TestDisaggregationDPAttention.test_gsm8k) + + def test_bench_serving(self): + args = get_benchmark_args( + base_url=f"http://{self.base_host}:{self.lb_port}", + dataset_name="random", + tokenizer=self.model, + num_prompts=256, + random_input_len=2048, + random_output_len=512, + request_rate=float("inf"), + max_concurrency=128, + ) + result = run_benchmark(args) + self.assertEqual(result["completed"], 256) + + @unittest.skip( "Skip this test until new testing logic in mini-lb has been updated in docker image." ) diff --git a/test/registered/distributed/test_disaggregation_hybrid_attention.py b/test/registered/distributed/test_disaggregation_hybrid_attention.py index febb67646829..87eb64f2f905 100644 --- a/test/registered/distributed/test_disaggregation_hybrid_attention.py +++ b/test/registered/distributed/test_disaggregation_hybrid_attention.py @@ -2,7 +2,7 @@ from types import SimpleNamespace from sglang.test.ci.ci_register import register_cuda_ci -from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k +from sglang.test.run_eval import run_eval from sglang.test.server_fixtures.disaggregation_fixture import ( PDDisaggregationServerBase, ) @@ -74,18 +74,18 @@ def start_decode(cls): def test_gsm8k(self): args = SimpleNamespace( - num_shots=5, - data_path=None, - num_questions=200, - max_new_tokens=512, - parallel=128, - host=f"http://{self.base_host}", - port=int(self.lb_port), + base_url=self.base_url, + model=self.model, + eval_name="gsm8k", + api="completion", + max_tokens=512, + num_examples=200, + num_threads=128, ) - metrics = run_eval_few_shot_gsm8k(args) + metrics = run_eval(args) print(f"Evaluation metrics: {metrics}") - self.assertGreater(metrics["accuracy"], 0.93) + self.assertGreater(metrics["score"], 0.93) class TestDisaggregationHybridAttentionMambaExtraBuffer(PDDisaggregationServerBase): @@ -150,18 +150,19 @@ def start_decode(cls): def test_gsm8k(self): args = SimpleNamespace( - num_shots=5, - data_path=None, - num_questions=200, - max_new_tokens=512, - parallel=128, - host=f"http://{self.base_host}", - port=int(self.lb_port), + base_url=self.base_url, + model=self.model, + eval_name="gsm8k", + api="completion", + max_tokens=512, + num_examples=200, + num_threads=128, ) - metrics = run_eval_few_shot_gsm8k(args) + metrics = run_eval(args) print(f"Evaluation metrics: {metrics}") - self.assertGreater(metrics["accuracy"], 0.93) + # TODO: Fix PD disaggregation accuracy issue (https://github.com/sgl-project/sglang/issues/21744) and increase the threshold back to 0.93. + self.assertGreater(metrics["score"], 0.90) class TestDisaggregationHybridAttentionMambaDPDecode(PDDisaggregationServerBase): @@ -228,18 +229,19 @@ def start_decode(cls): def test_gsm8k(self): args = SimpleNamespace( - num_shots=5, - data_path=None, - num_questions=200, - max_new_tokens=512, - parallel=128, - host=f"http://{self.base_host}", - port=int(self.lb_port), + base_url=self.base_url, + model=self.model, + eval_name="gsm8k", + api="completion", + max_tokens=512, + num_examples=200, + num_threads=128, ) - metrics = run_eval_few_shot_gsm8k(args) + metrics = run_eval(args) print(f"Evaluation metrics: {metrics}") - self.assertGreater(metrics["accuracy"], 0.93) + # TODO: Fix PD disaggregation accuracy issue (https://github.com/sgl-project/sglang/issues/21744) and increase the threshold back to 0.93. + self.assertGreater(metrics["score"], 0.90) if __name__ == "__main__": diff --git a/test/registered/distributed/test_disaggregation_pp.py b/test/registered/distributed/test_disaggregation_pp.py index 7c441b3c8946..683da93a68f4 100644 --- a/test/registered/distributed/test_disaggregation_pp.py +++ b/test/registered/distributed/test_disaggregation_pp.py @@ -3,7 +3,7 @@ from types import SimpleNamespace from sglang.test.ci.ci_register import register_cuda_ci -from sglang.test.few_shot_gsm8k import run_eval +from sglang.test.run_eval import run_eval from sglang.test.server_fixtures.disaggregation_fixture import ( PDDisaggregationServerBase, ) @@ -78,18 +78,18 @@ def start_decode(cls): def test_gsm8k(self): args = SimpleNamespace( - num_shots=5, - data_path=None, - num_questions=200, - max_new_tokens=512, - parallel=128, - host=f"http://{self.base_host}", - port=int(self.lb_port), + base_url=self.base_url, + model=self.model, + eval_name="gsm8k", + api="completion", + max_tokens=512, + num_examples=200, + num_threads=128, ) metrics = run_eval(args) print(f"{metrics=}") - self.assertGreater(metrics["accuracy"], 0.24) + self.assertGreater(metrics["score"], 0.24) # Wait a little bit so that the memory check happens. time.sleep(5) @@ -156,18 +156,18 @@ def start_decode(cls): def test_gsm8k(self): args = SimpleNamespace( - num_shots=5, - data_path=None, - num_questions=200, - max_new_tokens=512, - parallel=128, - host=f"http://{self.base_host}", - port=int(self.lb_port), + base_url=self.base_url, + model=self.model, + eval_name="gsm8k", + api="completion", + max_tokens=512, + num_examples=200, + num_threads=128, ) metrics = run_eval(args) print(f"{metrics=}") - self.assertGreater(metrics["accuracy"], 0.24) + self.assertGreater(metrics["score"], 0.24) # Wait a little bit so that the memory check happens. time.sleep(5) @@ -235,18 +235,18 @@ def start_decode(cls): def test_gsm8k(self): args = SimpleNamespace( - num_shots=5, - data_path=None, - num_questions=200, - max_new_tokens=512, - parallel=128, - host=f"http://{self.base_host}", - port=int(self.lb_port), + base_url=self.base_url, + model=self.model, + eval_name="gsm8k", + api="completion", + max_tokens=512, + num_examples=200, + num_threads=128, ) metrics = run_eval(args) print(f"{metrics=}") - self.assertGreater(metrics["accuracy"], 0.24) + self.assertGreater(metrics["score"], 0.24) # Wait a little bit so that the memory check happens. time.sleep(5) diff --git a/test/registered/distributed/test_dp_attention.py b/test/registered/distributed/test_dp_attention.py index a2a8832278fa..077a0327aa59 100644 --- a/test/registered/distributed/test_dp_attention.py +++ b/test/registered/distributed/test_dp_attention.py @@ -7,12 +7,12 @@ from sglang.srt.environ import envs from sglang.srt.utils import kill_process_tree from sglang.test.ci.ci_register import register_cuda_ci -from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k from sglang.test.kits.ebnf_constrained_kit import EBNFConstrainedMixin -from sglang.test.kits.eval_accuracy_kit import MGSMEnMixin +from sglang.test.kits.eval_accuracy_kit import GSM8KMixin from sglang.test.kits.json_constrained_kit import JSONConstrainedMixin from sglang.test.kits.radix_cache_server_kit import run_radix_attention_test from sglang.test.kits.regex_constrained_kit import RegexConstrainedMixin +from sglang.test.run_eval import run_eval from sglang.test.test_utils import ( DEFAULT_IMAGE_URL, DEFAULT_MLA_MODEL_NAME_FOR_TEST, @@ -30,16 +30,16 @@ class TestDPAttentionDP2TP2( CustomTestCase, - MGSMEnMixin, + GSM8KMixin, JSONConstrainedMixin, EBNFConstrainedMixin, RegexConstrainedMixin, ): - mgsm_en_score_threshold = 0.8 + gsm8k_accuracy_thres = 0.6 @classmethod def setUpClass(cls): - cls.model = DEFAULT_MLA_MODEL_NAME_FOR_TEST + cls.model = DEFAULT_MODEL_NAME_FOR_TEST_MLA cls.base_url = DEFAULT_URL_FOR_TEST cls._env_override = envs.SGLANG_DISABLE_CONSECUTIVE_PREFILL_OVERLAP.override( True @@ -154,26 +154,26 @@ def test_gsm8k(self): requests.get(self.base_url + "/flush_cache") args = SimpleNamespace( - num_shots=5, - data_path=None, - num_questions=200, - max_new_tokens=512, - parallel=128, - host="http://127.0.0.1", - port=int(self.base_url.split(":")[-1]), + base_url=self.base_url, + model=self.model, + eval_name="gsm8k", + api="completion", + max_tokens=512, + num_examples=200, + num_threads=128, ) - metrics = run_eval_few_shot_gsm8k(args) + metrics = run_eval(args) print(metrics) - self.assertGreater(metrics["accuracy"], 0.60) + self.assertGreater(metrics["score"], 0.60) - server_info = requests.get(self.base_url + "/get_server_info") + server_info = requests.get(self.base_url + "/server_info") avg_spec_accept_length = server_info.json()["internal_states"][0][ "avg_spec_accept_length" ] print( f"###test_gsm8k (deepseek-v3 mtp + dp):\n" - f"accuracy={metrics['accuracy']=:.3f}\n" + f"accuracy={metrics['score']=:.3f}\n" f"{avg_spec_accept_length=:.3f}\n" ) self.assertGreater(avg_spec_accept_length, 2.5) diff --git a/test/registered/distributed/test_dp_attention_large.py b/test/registered/distributed/test_dp_attention_large.py index 00e3f18e85b3..48cdee862f8a 100644 --- a/test/registered/distributed/test_dp_attention_large.py +++ b/test/registered/distributed/test_dp_attention_large.py @@ -6,7 +6,6 @@ from sglang.lang.chat_template import get_chat_template_by_model_path from sglang.srt.utils import kill_process_tree from sglang.test.ci.ci_register import register_amd_ci, register_cuda_ci -from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k from sglang.test.kits.ebnf_constrained_kit import EBNFConstrainedMixin from sglang.test.kits.json_constrained_kit import JSONConstrainedMixin from sglang.test.kits.regex_constrained_kit import RegexConstrainedMixin @@ -115,26 +114,26 @@ def test_gsm8k(self): requests.get(self.base_url + "/flush_cache") args = SimpleNamespace( - num_shots=5, - data_path=None, - num_questions=200, - max_new_tokens=512, - parallel=128, - host="http://127.0.0.1", - port=int(self.base_url.split(":")[-1]), + base_url=self.base_url, + model=self.model, + eval_name="gsm8k", + api="completion", + max_tokens=512, + num_examples=200, + num_threads=128, ) - metrics = run_eval_few_shot_gsm8k(args) + metrics = run_eval(args) print(metrics) - self.assertGreater(metrics["accuracy"], 0.60) + self.assertGreater(metrics["score"], 0.60) - server_info = requests.get(self.base_url + "/get_server_info") + server_info = requests.get(self.base_url + "/server_info") avg_spec_accept_length = server_info.json()["internal_states"][0][ "avg_spec_accept_length" ] print( f"###test_gsm8k (deepseek-v3 mtp + dp):\n" - f"accuracy={metrics['accuracy']=:.3f}\n" + f"accuracy={metrics['score']=:.3f}\n" f"{avg_spec_accept_length=:.3f}\n" ) self.assertGreater(avg_spec_accept_length, 2.5) diff --git a/test/registered/distributed/test_load_weights_from_remote_instance.py b/test/registered/distributed/test_load_weights_from_remote_instance.py index f1080caeb258..00dc8454d325 100644 --- a/test/registered/distributed/test_load_weights_from_remote_instance.py +++ b/test/registered/distributed/test_load_weights_from_remote_instance.py @@ -38,7 +38,7 @@ mp.set_start_method("spawn", force=True) -register_cuda_ci(est_time=72, suite="stage-b-test-2-gpu-large") +register_cuda_ci(est_time=130, suite="stage-b-test-2-gpu-large") register_amd_ci(est_time=72, suite="stage-b-test-2-gpu-large-amd") @@ -228,6 +228,8 @@ def init_process_dst( "--remote-instance-weight-loader-backend", remote_instance_loader_backend, "--remote-instance-weight-loader-start-seed-via-transfer-engine", + "--engine-info-bootstrap-port", + str(6789 + rank), ), ) torch.cuda.synchronize() diff --git a/test/registered/distributed/test_pp_single_node.py b/test/registered/distributed/test_pp_single_node.py index 7efcabebb672..76e1c068d7f1 100644 --- a/test/registered/distributed/test_pp_single_node.py +++ b/test/registered/distributed/test_pp_single_node.py @@ -16,7 +16,6 @@ from sglang.srt.server_args import ServerArgs from sglang.srt.utils import kill_process_tree from sglang.test.ci.ci_register import register_amd_ci, register_cuda_ci -from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k from sglang.test.run_eval import run_eval from sglang.test.test_utils import ( DEFAULT_MLA_MODEL_NAME_FOR_TEST, @@ -60,22 +59,22 @@ def tearDownClass(cls): def test_gsm8k(self): args = SimpleNamespace( - num_shots=5, - data_path=None, - num_questions=200, - max_new_tokens=512, - parallel=128, - host="http://127.0.0.1", - port=int(self.base_url.split(":")[-1]), + base_url=self.base_url, + model=DEFAULT_MODEL_NAME_FOR_TEST, + eval_name="gsm8k", + api="completion", + max_tokens=512, + num_examples=200, + num_threads=128, ) - metrics = run_eval_few_shot_gsm8k(args) + metrics = run_eval(args) print(f"{metrics=}") if is_in_amd_ci(): # AMD triton backend produces slightly lower accuracy than FA3 on NVIDIA - self.assertGreater(metrics["accuracy"], 0.70) + self.assertGreater(metrics["score"], 0.70) else: - self.assertGreater(metrics["accuracy"], 0.74) + self.assertGreater(metrics["score"], 0.74) # Wait a little bit so that the memory check happens. time.sleep(4) @@ -169,18 +168,18 @@ def setUpClass(cls): def test_gsm8k(self): args = SimpleNamespace( - num_shots=5, - data_path=None, - num_questions=200, - max_new_tokens=512, - parallel=128, - host="http://127.0.0.1", - port=int(self.base_url.split(":")[-1]), + base_url=self.base_url, + model=self.model, + eval_name="gsm8k", + api="completion", + max_tokens=512, + num_examples=200, + num_threads=128, ) - metrics = run_eval_few_shot_gsm8k(args) + metrics = run_eval(args) print(f"{metrics=}") - self.assertGreater(metrics["accuracy"], 0.65) + self.assertGreaterEqual(metrics["score"], 0.65) # Wait a little bit so that the memory check happens. time.sleep(4) @@ -223,15 +222,15 @@ def run_gsm8k_test(self, pp_size): try: args = SimpleNamespace( - num_shots=5, - data_path=None, - num_questions=512, - max_new_tokens=512, - parallel=128, - host="http://127.0.0.1", - port=int(self.base_url.split(":")[-1]), + base_url=self.base_url, + model=self.model_name, + eval_name="gsm8k", + api="completion", + max_tokens=512, + num_examples=512, + num_threads=128, ) - metrics = run_eval_few_shot_gsm8k(args) + metrics = run_eval(args) time.sleep(5) return metrics finally: @@ -244,13 +243,13 @@ def test_pp_consistency(self): print(f"[Qwen PP Comparison] Baseline: {baseline} | PP: {pp_metrics}") - self.assertGreaterEqual(baseline["accuracy"], 0.74) + self.assertGreaterEqual(baseline["score"], 0.74) self.assertGreaterEqual( - pp_metrics["accuracy"], - baseline["accuracy"] - 0.02, + pp_metrics["score"], + baseline["score"] - 0.02, msg=( f"PP accuracy dropped more than 2% compared to baseline. " - f"Baseline: {baseline['accuracy']:.2%}, PP: {pp_metrics['accuracy']:.2%}" + f"Baseline: {baseline['score']:.2%}, PP: {pp_metrics['score']:.2%}" ), ) @@ -279,15 +278,15 @@ def run_gsm8k_test(self, pp_size): try: args = SimpleNamespace( - num_shots=5, - data_path=None, - num_questions=512, - max_new_tokens=512, - parallel=128, - host="http://127.0.0.1", - port=int(self.base_url.split(":")[-1]), + base_url=self.base_url, + model=self.model_name, + eval_name="gsm8k", + api="completion", + max_tokens=512, + num_examples=512, + num_threads=128, ) - metrics = run_eval_few_shot_gsm8k(args) + metrics = run_eval(args) time.sleep(5) return metrics finally: @@ -299,13 +298,13 @@ def test_pp_consistency(self): print(f"[Qwen PP Comparison] Baseline: {baseline} | PP: {pp_metrics}") - self.assertGreaterEqual(baseline["accuracy"], 0.38) + self.assertGreaterEqual(baseline["score"], 0.38) self.assertGreaterEqual( - pp_metrics["accuracy"], - baseline["accuracy"] - 0.02, + pp_metrics["score"], + baseline["score"] - 0.02, msg=( f"PP accuracy dropped more than 2% compared to baseline. " - f"Baseline: {baseline['accuracy']:.2%}, PP: {pp_metrics['accuracy']:.2%}" + f"Baseline: {baseline['score']:.2%}, PP: {pp_metrics['score']:.2%}" ), ) @@ -331,15 +330,15 @@ def run_gsm8k_test(self, pp_size): try: args = SimpleNamespace( - num_shots=5, - data_path=None, - num_questions=512, - max_new_tokens=512, - parallel=128, - host="http://127.0.0.1", - port=int(self.base_url.split(":")[-1]), + base_url=self.base_url, + model=self.model_name, + eval_name="gsm8k", + api="completion", + max_tokens=512, + num_examples=512, + num_threads=128, ) - metrics = run_eval_few_shot_gsm8k(args) + metrics = run_eval(args) time.sleep(5) return metrics finally: @@ -351,13 +350,13 @@ def test_pp_consistency(self): print(f"[Qwen PP Comparison] Baseline: {baseline} | PP: {pp_metrics}") - self.assertGreaterEqual(baseline["accuracy"], 0.74) + self.assertGreaterEqual(baseline["score"], 0.74) self.assertGreaterEqual( - pp_metrics["accuracy"], - baseline["accuracy"] - 0.02, + pp_metrics["score"], + baseline["score"] - 0.02, msg=( f"PP accuracy dropped more than 2% compared to baseline. " - f"Baseline: {baseline['accuracy']:.2%}, PP: {pp_metrics['accuracy']:.2%}" + f"Baseline: {baseline['score']:.2%}, PP: {pp_metrics['score']:.2%}" ), ) @@ -390,15 +389,15 @@ def run_gsm8k_test(self, tp_size, pp_size): try: args = SimpleNamespace( - num_shots=5, - data_path=None, - num_questions=512, - max_new_tokens=512, - parallel=128, - host="http://127.0.0.1", - port=int(self.base_url.split(":")[-1]), + base_url=self.base_url, + model=self.model_name, + eval_name="gsm8k", + api="completion", + max_tokens=512, + num_examples=512, + num_threads=128, ) - metrics = run_eval_few_shot_gsm8k(args) + metrics = run_eval(args) time.sleep(5) return metrics finally: @@ -410,13 +409,13 @@ def test_pp_consistency(self): print(f"[Qwen35 PP Comparison] Baseline: {baseline} | PP: {pp_metrics}") - self.assertGreaterEqual(baseline["accuracy"], 0.83) + self.assertGreaterEqual(baseline["score"], 0.83) self.assertGreaterEqual( - pp_metrics["accuracy"], - baseline["accuracy"] - 0.05, + pp_metrics["score"], + baseline["score"] - 0.05, msg=( f"PP accuracy dropped more than 5% compared to baseline. " - f"Baseline: {baseline['accuracy']:.2%}, PP: {pp_metrics['accuracy']:.2%}" + f"Baseline: {baseline['score']:.2%}, PP: {pp_metrics['score']:.2%}" ), ) diff --git a/test/registered/dllm/test_llada2_mini.py b/test/registered/dllm/test_llada2_mini.py index 3bec211ea397..28588b0fb354 100644 --- a/test/registered/dllm/test_llada2_mini.py +++ b/test/registered/dllm/test_llada2_mini.py @@ -7,7 +7,7 @@ from types import SimpleNamespace from sglang.srt.utils import kill_process_tree -from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k +from sglang.test.run_eval import run_eval from sglang.test.send_one import BenchArgs, send_one_prompt from sglang.test.test_utils import ( DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, @@ -58,18 +58,18 @@ def tearDownClass(cls): def test_gsm8k(self): args = SimpleNamespace( - num_shots=5, - data_path=None, - num_questions=200, - max_new_tokens=512, - parallel=128, - host="http://127.0.0.1", - port=int(self.base_url.split(":")[-1]), + base_url=self.base_url, + model=self.model, + eval_name="gsm8k", + api="completion", + max_tokens=512, + num_examples=200, + num_threads=128, ) - metrics = run_eval_few_shot_gsm8k(args) + metrics = run_eval(args) print(f"{metrics=}") - self.assertGreater(metrics["accuracy"], 0.88) + self.assertGreater(metrics["score"], 0.88) if is_in_amd_ci(): self.assertGreater(metrics["output_throughput"], 80) else: diff --git a/test/registered/dllm/test_llada2_mini_amd.py b/test/registered/dllm/test_llada2_mini_amd.py index 68e0cfec985e..396ed1df07d4 100644 --- a/test/registered/dllm/test_llada2_mini_amd.py +++ b/test/registered/dllm/test_llada2_mini_amd.py @@ -9,7 +9,7 @@ from sglang.srt.utils import kill_process_tree from sglang.test.ci.ci_register import register_amd_ci -from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k +from sglang.test.run_eval import run_eval from sglang.test.send_one import BenchArgs, send_one_prompt from sglang.test.test_utils import ( DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, @@ -55,19 +55,17 @@ def tearDownClass(cls): def test_gsm8k(self): """Test GSM8K accuracy with DLLM on AMD.""" args = SimpleNamespace( - num_shots=5, - data_path=None, - num_questions=200, - max_new_tokens=512, - parallel=128, - host="http://127.0.0.1", - port=int(self.base_url.split(":")[-1]), + base_url=self.base_url, + model=self.model, + eval_name="gsm8k", + num_examples=200, + num_threads=128, ) - metrics = run_eval_few_shot_gsm8k(args) + metrics = run_eval(args) print(f"{metrics=}") # Relaxed thresholds for AMD - may need adjustment - self.assertGreater(metrics["accuracy"], 0.80) + self.assertGreater(metrics["score"], 0.80) self.assertGreater(metrics["output_throughput"], 50) def test_bs_1_speed(self): diff --git a/test/registered/embedding/test_embedding_models.py b/test/registered/embedding/test_embedding_models.py index 4a9c43a95ca4..3060ae3048f0 100644 --- a/test/registered/embedding/test_embedding_models.py +++ b/test/registered/embedding/test_embedding_models.py @@ -35,7 +35,7 @@ suite="stage-b-test-1-gpu-small-amd", disabled="see https://github.com/sgl-project/sglang/issues/11127", ) -register_cuda_ci(est_time=73, suite="stage-b-test-1-gpu-small") +register_cuda_ci(est_time=200, suite="stage-b-test-1-gpu-small") MODEL_TO_CONFIG = { "Alibaba-NLP/gte-Qwen2-1.5B-instruct": (1, 1e-5), diff --git a/test/registered/ep/test_deepep_large.py b/test/registered/ep/test_deepep_large.py index f107b4c40b3b..a8bb9b8a038c 100644 --- a/test/registered/ep/test_deepep_large.py +++ b/test/registered/ep/test_deepep_large.py @@ -5,7 +5,7 @@ from sglang.srt.utils import kill_process_tree from sglang.test.ci.ci_register import register_cuda_ci -from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k +from sglang.test.run_eval import run_eval from sglang.test.send_one import BenchArgs, send_one_prompt from sglang.test.test_utils import ( DEFAULT_DEEPEP_MODEL_NAME_FOR_TEST, @@ -67,18 +67,18 @@ def tearDownClass(cls): def test_gsm8k(self): args = SimpleNamespace( - num_shots=5, - data_path=None, - num_questions=1200, - parallel=1200, - max_new_tokens=512, - host="http://127.0.0.1", - port=int(self.base_url.split(":")[-1]), + base_url=self.base_url, + model=self.model, + eval_name="gsm8k", + api="completion", + max_tokens=512, + num_examples=1200, + num_threads=1200, ) - metrics = run_eval_few_shot_gsm8k(args) + metrics = run_eval(args) print(f"Eval accuracy of GSM8K: {metrics=}") - self.assertGreater(metrics["accuracy"], 0.92) + self.assertGreater(metrics["score"], 0.92) class TestDeepseekMTP(CustomTestCase): @@ -135,26 +135,26 @@ def tearDownClass(cls): def test_gsm8k(self): args = SimpleNamespace( - num_shots=5, - data_path=None, - num_questions=1200, - parallel=1200, - max_new_tokens=512, - host="http://127.0.0.1", - port=int(self.base_url.split(":")[-1]), + base_url=self.base_url, + model=self.model, + eval_name="gsm8k", + api="completion", + max_tokens=512, + num_examples=1200, + num_threads=1200, ) - metrics = run_eval_few_shot_gsm8k(args) + metrics = run_eval(args) print(f"Eval accuracy of GSM8K: {metrics=}") - self.assertGreater(metrics["accuracy"], 0.92) + self.assertGreater(metrics["score"], 0.92) - server_info = requests.get(self.base_url + "/get_server_info") + server_info = requests.get(self.base_url + "/server_info") avg_spec_accept_length = server_info.json()["internal_states"][0][ "avg_spec_accept_length" ] print( f"###test_gsm8k:\n" - f"accuracy={metrics['accuracy']=:.3f}\n" + f"accuracy={metrics['score']=:.3f}\n" f"{avg_spec_accept_length=:.3f}\n" ) self.assertGreater(avg_spec_accept_length, 1.85) @@ -195,17 +195,17 @@ def test_a_gsm8k( self, ): # Append an "a" to make this test run first (alphabetically) to warm up the server args = SimpleNamespace( - num_shots=5, - data_path=None, - num_questions=1200, - parallel=1200, - max_new_tokens=512, - host="http://127.0.0.1", - port=int(self.base_url.split(":")[-1]), + base_url=self.base_url, + model=self.model, + eval_name="gsm8k", + api="completion", + max_tokens=512, + num_examples=1200, + num_threads=1200, ) - metrics = run_eval_few_shot_gsm8k(args) + metrics = run_eval(args) print(f"{metrics=}") - self.assertGreater(metrics["accuracy"], 0.92) + self.assertGreater(metrics["score"], 0.92) def test_bs_1_speed(self): args = BenchArgs(port=int(self.base_url.split(":")[-1]), max_new_tokens=2048) diff --git a/test/registered/ep/test_deepep_small.py b/test/registered/ep/test_deepep_small.py index 911915aaff8a..d896705280bd 100644 --- a/test/registered/ep/test_deepep_small.py +++ b/test/registered/ep/test_deepep_small.py @@ -6,7 +6,7 @@ from sglang.srt.utils import kill_process_tree from sglang.test.ci.ci_register import register_cuda_ci -from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k +from sglang.test.run_eval import run_eval from sglang.test.test_utils import ( DEFAULT_MODEL_NAME_FOR_TEST_MLA, DEFAULT_MODEL_NAME_FOR_TEST_MLA_NEXTN, @@ -52,18 +52,18 @@ def tearDownClass(cls): def test_gsm8k(self): args = SimpleNamespace( - num_shots=5, - data_path=None, - num_questions=200, - max_new_tokens=512, - parallel=128, - host="http://127.0.0.1", - port=int(self.base_url.split(":")[-1]), + base_url=self.base_url, + model=self.model, + eval_name="gsm8k", + api="completion", + max_tokens=512, + num_examples=200, + num_threads=128, ) - metrics = run_eval_few_shot_gsm8k(args) + metrics = run_eval(args) print(metrics) - self.assertGreater(metrics["accuracy"], 0.60) + self.assertGreater(metrics["score"], 0.60) class TestHybridDPTP(CustomTestCase): @@ -97,18 +97,18 @@ def tearDownClass(cls): def test_gsm8k(self): args = SimpleNamespace( - num_shots=5, - data_path=None, - num_questions=200, - max_new_tokens=512, - parallel=128, - host="http://127.0.0.1", - port=int(self.base_url.split(":")[-1]), + base_url=self.base_url, + model=self.model, + eval_name="gsm8k", + api="completion", + max_tokens=512, + num_examples=200, + num_threads=128, ) - metrics = run_eval_few_shot_gsm8k(args) + metrics = run_eval(args) print(metrics) - self.assertGreater(metrics["accuracy"], 0.60) + self.assertGreater(metrics["score"], 0.60) class TestTP(CustomTestCase): @@ -139,18 +139,18 @@ def tearDownClass(cls): def test_gsm8k(self): args = SimpleNamespace( - num_shots=5, - data_path=None, - num_questions=200, - max_new_tokens=512, - parallel=128, - host="http://127.0.0.1", - port=int(self.base_url.split(":")[-1]), + base_url=self.base_url, + model=self.model, + eval_name="gsm8k", + api="completion", + max_tokens=512, + num_examples=200, + num_threads=128, ) - metrics = run_eval_few_shot_gsm8k(args) + metrics = run_eval(args) print(metrics) - self.assertGreater(metrics["accuracy"], 0.60) + self.assertGreater(metrics["score"], 0.60) @unittest.skip("covered in test_deepep_large.py") @@ -188,18 +188,18 @@ def tearDownClass(cls): def test_gsm8k(self): args = SimpleNamespace( - num_shots=5, - data_path=None, - num_questions=200, - max_new_tokens=512, - parallel=128, - host="http://127.0.0.1", - port=int(self.base_url.split(":")[-1]), + base_url=self.base_url, + model=self.model, + eval_name="gsm8k", + api="completion", + max_tokens=512, + num_examples=200, + num_threads=128, ) - metrics = run_eval_few_shot_gsm8k(args) + metrics = run_eval(args) print(metrics) - self.assertGreater(metrics["accuracy"], 0.60) + self.assertGreater(metrics["score"], 0.60) class TestTBO(CustomTestCase): @@ -240,18 +240,18 @@ def tearDownClass(cls): def test_gsm8k(self): args = SimpleNamespace( - num_shots=5, - data_path=None, - num_questions=200, - max_new_tokens=512, - parallel=128, - host="http://127.0.0.1", - port=int(self.base_url.split(":")[-1]), + base_url=self.base_url, + model=self.model, + eval_name="gsm8k", + api="completion", + max_tokens=512, + num_examples=200, + num_threads=128, ) - metrics = run_eval_few_shot_gsm8k(args) + metrics = run_eval(args) print(metrics) - self.assertGreater(metrics["accuracy"], 0.60) + self.assertGreater(metrics["score"], 0.60) class TestTBOWithTPAttn(CustomTestCase): @@ -289,18 +289,18 @@ def tearDownClass(cls): def test_gsm8k(self): args = SimpleNamespace( - num_shots=5, - data_path=None, - num_questions=200, - max_new_tokens=512, - parallel=128, - host="http://127.0.0.1", - port=int(self.base_url.split(":")[-1]), + base_url=self.base_url, + model=self.model, + eval_name="gsm8k", + api="completion", + max_tokens=512, + num_examples=200, + num_threads=128, ) - metrics = run_eval_few_shot_gsm8k(args) + metrics = run_eval(args) print(metrics) - self.assertGreater(metrics["accuracy"], 0.60) + self.assertGreater(metrics["score"], 0.60) # There exists bug when using MTP + TBO + attn_tp_size > 1, currently skip that case. @@ -342,18 +342,18 @@ def tearDownClass(cls): def test_gsm8k(self): args = SimpleNamespace( - num_shots=5, - data_path=None, - num_questions=200, - max_new_tokens=512, - parallel=128, - host="http://127.0.0.1", - port=int(self.base_url.split(":")[-1]), + base_url=self.base_url, + model=self.model, + eval_name="gsm8k", + api="completion", + max_tokens=512, + num_examples=200, + num_threads=128, ) - metrics = run_eval_few_shot_gsm8k(args) + metrics = run_eval(args) print(metrics) - self.assertGreater(metrics["accuracy"], 0.60) + self.assertGreater(metrics["score"], 0.60) @unittest.skip("covered in TestMTPWithTBO") @@ -399,26 +399,26 @@ def tearDownClass(cls): def test_gsm8k(self): args = SimpleNamespace( - num_shots=5, - data_path=None, - num_questions=200, - max_new_tokens=512, - parallel=128, - host="http://127.0.0.1", - port=int(self.base_url.split(":")[-1]), + base_url=self.base_url, + model=self.model, + eval_name="gsm8k", + api="completion", + max_tokens=512, + num_examples=200, + num_threads=128, ) - metrics = run_eval_few_shot_gsm8k(args) + metrics = run_eval(args) print(metrics) - self.assertGreater(metrics["accuracy"], 0.60) + self.assertGreater(metrics["score"], 0.60) - server_info = requests.get(self.base_url + "/get_server_info") + server_info = requests.get(self.base_url + "/server_info") avg_spec_accept_length = server_info.json()["internal_states"][0][ "avg_spec_accept_length" ] print( f"###test_gsm8k (deepseek-v3 mtp + dp + tbo):\n" - f"accuracy={metrics['accuracy']=:.3f}\n" + f"accuracy={metrics['score']=:.3f}\n" f"{avg_spec_accept_length=:.3f}\n" ) self.assertGreater(avg_spec_accept_length, 2.1) @@ -473,26 +473,26 @@ def tearDownClass(cls): def test_gsm8k(self): args = SimpleNamespace( - num_shots=5, - data_path=None, - num_questions=200, - max_new_tokens=512, - parallel=128, - host="http://127.0.0.1", - port=int(self.base_url.split(":")[-1]), + base_url=self.base_url, + model=self.model, + eval_name="gsm8k", + api="completion", + max_tokens=512, + num_examples=200, + num_threads=128, ) - metrics = run_eval_few_shot_gsm8k(args) + metrics = run_eval(args) print(metrics) - self.assertGreater(metrics["accuracy"], 0.60) + self.assertGreater(metrics["score"], 0.60) - server_info = requests.get(self.base_url + "/get_server_info") + server_info = requests.get(self.base_url + "/server_info") avg_spec_accept_length = server_info.json()["internal_states"][0][ "avg_spec_accept_length" ] print( f"###test_gsm8k (deepseek-v3 mtp + dp + tbo):\n" - f"accuracy={metrics['accuracy']=:.3f}\n" + f"accuracy={metrics['score']=:.3f}\n" f"{avg_spec_accept_length=:.3f}\n" ) self.assertGreater(avg_spec_accept_length, 2.1) @@ -549,26 +549,26 @@ def tearDownClass(cls): def test_gsm8k(self): args = SimpleNamespace( - num_shots=5, - data_path=None, - num_questions=200, - max_new_tokens=512, - parallel=128, - host="http://127.0.0.1", - port=int(self.base_url.split(":")[-1]), + base_url=self.base_url, + model=self.model, + eval_name="gsm8k", + api="completion", + max_tokens=512, + num_examples=200, + num_threads=128, ) - metrics = run_eval_few_shot_gsm8k(args) + metrics = run_eval(args) print(metrics) - self.assertGreater(metrics["accuracy"], 0.60) + self.assertGreater(metrics["score"], 0.60) - server_info = requests.get(self.base_url + "/get_server_info") + server_info = requests.get(self.base_url + "/server_info") avg_spec_accept_length = server_info.json()["internal_states"][0][ "avg_spec_accept_length" ] print( f"###test_gsm8k (deepseek-v3 mtp + dp + tbo):\n" - f"accuracy={metrics['accuracy']=:.3f}\n" + f"accuracy={metrics['score']=:.3f}\n" f"{avg_spec_accept_length=:.3f}\n" ) self.assertGreater(avg_spec_accept_length, 2.1) diff --git a/test/registered/ep/test_mooncake_ep_small.py b/test/registered/ep/test_mooncake_ep_small.py index ce87eed14981..5ef70914c698 100644 --- a/test/registered/ep/test_mooncake_ep_small.py +++ b/test/registered/ep/test_mooncake_ep_small.py @@ -4,7 +4,7 @@ from sglang.srt.utils import kill_process_tree from sglang.test.ci.ci_register import register_cuda_ci -from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k +from sglang.test.run_eval import run_eval from sglang.test.server_fixtures.disaggregation_fixture import get_rdma_devices_args from sglang.test.test_utils import ( DEFAULT_MODEL_NAME_FOR_TEST_MLA, @@ -15,7 +15,7 @@ popen_launch_server, ) -register_cuda_ci(est_time=660, suite="stage-c-test-deepep-4-gpu-h100") +register_cuda_ci(est_time=200, suite="stage-c-test-deepep-4-gpu-h100") ib_devices = get_rdma_devices_args() @@ -69,20 +69,21 @@ def tearDownClass(cls): def test_gsm8k(self): args = SimpleNamespace( - num_shots=5, - data_path=None, - num_questions=200, - max_new_tokens=512, - parallel=128, - host="http://127.0.0.1", - port=int(self.base_url.split(":")[-1]), + base_url=self.base_url, + model=self.model, + eval_name="gsm8k", + api="completion", + max_tokens=512, + num_examples=200, + num_threads=128, ) - metrics = run_eval_few_shot_gsm8k(args) + metrics = run_eval(args) print(metrics) - self.assertGreater(metrics["accuracy"], 0.60) + self.assertGreater(metrics["score"], 0.60) +@unittest.skipIf(is_in_ci(), "Skip since mooncake-ep fault-tolerant test is flaky.") class TestPureDP(TestTP): extra_args = [ "--enable-dp-attention", diff --git a/test/registered/eval/test_eval_accuracy_large.py b/test/registered/eval/test_eval_accuracy_large.py index dab7bb6be55f..97b441f41d76 100644 --- a/test/registered/eval/test_eval_accuracy_large.py +++ b/test/registered/eval/test_eval_accuracy_large.py @@ -16,7 +16,7 @@ popen_launch_server, ) -register_cuda_ci(est_time=300, suite="stage-b-test-1-gpu-small") +register_cuda_ci(est_time=580, suite="stage-b-test-1-gpu-small") register_amd_ci(est_time=420, suite="stage-b-test-1-gpu-small-amd") diff --git a/test/registered/eval/test_moe_eval_accuracy_large.py b/test/registered/eval/test_moe_eval_accuracy_large.py deleted file mode 100644 index edfb48fd2554..000000000000 --- a/test/registered/eval/test_moe_eval_accuracy_large.py +++ /dev/null @@ -1,62 +0,0 @@ -""" -Usage: -python -m unittest test_moe_eval_accuracy_large.TestMoEEvalAccuracyLarge.test_mmlu -""" - -import os -import unittest - -from sglang.srt.utils import kill_process_tree -from sglang.test.ci.ci_register import register_amd_ci, register_cuda_ci -from sglang.test.kits.eval_accuracy_kit import HumanEvalMixin, MGSMEnMixin, MMLUMixin -from sglang.test.test_utils import ( - DEFAULT_MOE_MODEL_NAME_FOR_TEST, - DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, - DEFAULT_URL_FOR_TEST, - CustomTestCase, - is_in_amd_ci, - popen_launch_server, -) - -register_cuda_ci(est_time=500, suite="stage-b-test-2-gpu-large") -register_amd_ci(est_time=500, suite="stage-b-test-2-gpu-large-amd") - - -class TestMoEEvalAccuracyLarge(CustomTestCase, MMLUMixin, HumanEvalMixin, MGSMEnMixin): - mmlu_score_threshold = 0.62 - humaneval_score_threshold = 0.40 - mgsm_en_score_threshold = 0.61 - - @classmethod - def setUpClass(cls): - cls.model = DEFAULT_MOE_MODEL_NAME_FOR_TEST - cls.base_url = DEFAULT_URL_FOR_TEST - - # Disable AITER for AMD CI to ensure consistent results - env = None - if is_in_amd_ci(): - env = os.environ.copy() - env["SGLANG_USE_AITER"] = "0" - env["SGLANG_USE_AITER_AR"] = "0" - env["HF_HUB_ENABLE_HF_TRANSFER"] = "0" - - cls.process = popen_launch_server( - cls.model, - cls.base_url, - timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, - other_args=[ - "--log-level-http", - "warning", - "--tp", - "2", - ], - env=env, - ) - - @classmethod - def tearDownClass(cls): - kill_process_tree(cls.process.pid) - - -if __name__ == "__main__": - unittest.main() diff --git a/test/registered/function_call/test_kimik2_detector.py b/test/registered/function_call/test_kimik2_detector.py index 158223958c59..68b7c98a980b 100644 --- a/test/registered/function_call/test_kimik2_detector.py +++ b/test/registered/function_call/test_kimik2_detector.py @@ -11,7 +11,7 @@ from sglang.srt.parser.reasoning_parser import KimiK2Detector as KimiK2ReasoningDetector from sglang.test.ci.ci_register import register_cpu_ci -register_cpu_ci(1.0, "stage-a-test-cpu") +register_cpu_ci(5, "stage-a-test-cpu") def _make_tool(name, parameters=None): diff --git a/test/registered/gb300/test_deepseek_v32.py b/test/registered/gb300/test_deepseek_v32.py new file mode 100644 index 000000000000..0f9ff25cdf7d --- /dev/null +++ b/test/registered/gb300/test_deepseek_v32.py @@ -0,0 +1,79 @@ +import unittest + +from sglang.test.accuracy_test_runner import AccuracyTestParams +from sglang.test.ci.ci_register import register_cuda_ci +from sglang.test.performance_test_runner import PerformanceTestParams +from sglang.test.run_combined_tests import run_combined_tests +from sglang.test.test_utils import ModelLaunchSettings + +register_cuda_ci(est_time=7200, suite="nightly-4-gpu-gb300", nightly=True) + +MODEL_PATH = "deepseek-ai/DeepSeek-V3.2" + +COMMON_ARGS = [ + "--trust-remote-code", + "--reasoning-parser=deepseek-v3", + "--tool-call-parser=deepseekv32", + "--mem-fraction-static=0.8", + "--enable-metrics", +] + +MTP_ARGS = [ + "--speculative-algorithm=EAGLE", + "--speculative-num-steps=3", + "--speculative-eagle-topk=1", + "--speculative-num-draft-tokens=4", +] + + +class TestDeepseekV32(unittest.TestCase): + """DeepSeek V3.2 on GB300 (4x B200 NVL4, tp=4).""" + + def test_deepseek_v32(self): + variants = [ + ModelLaunchSettings( + MODEL_PATH, + tp_size=4, + extra_args=COMMON_ARGS, + variant="TP4", + ), + ModelLaunchSettings( + MODEL_PATH, + tp_size=4, + extra_args=COMMON_ARGS + + [ + "--dp-size=4", + "--ep-size=4", + "--enable-dp-attention", + ], + variant="TP4+DP4+DPA", + ), + ModelLaunchSettings( + MODEL_PATH, + tp_size=4, + extra_args=COMMON_ARGS + + [ + "--dp-size=4", + "--ep-size=4", + "--enable-dp-attention", + ] + + MTP_ARGS, + variant="TP4+DP4+DPA+MTP", + env={"SGLANG_ENABLE_SPEC_V2": "1"}, + ), + ] + + run_combined_tests( + models=variants, + test_name="DeepSeek-V3.2", + accuracy_params=AccuracyTestParams( + dataset="gsm8k", baseline_accuracy=0.935 + ), + performance_params=PerformanceTestParams( + profile_dir="performance_profiles_gb300", + ), + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/registered/gb300/test_deepseek_v32_nvfp4.py b/test/registered/gb300/test_deepseek_v32_nvfp4.py new file mode 100644 index 000000000000..f6be6f94afae --- /dev/null +++ b/test/registered/gb300/test_deepseek_v32_nvfp4.py @@ -0,0 +1,82 @@ +import unittest + +from sglang.test.accuracy_test_runner import AccuracyTestParams +from sglang.test.ci.ci_register import register_cuda_ci +from sglang.test.performance_test_runner import PerformanceTestParams +from sglang.test.run_combined_tests import run_combined_tests +from sglang.test.test_utils import ModelLaunchSettings + +register_cuda_ci(est_time=7200, suite="nightly-4-gpu-gb300", nightly=True) + +MODEL_PATH = "nvidia/DeepSeek-V3.2-NVFP4" + +COMMON_ARGS = [ + "--trust-remote-code", + "--reasoning-parser=deepseek-v3", + "--tool-call-parser=deepseekv32", + "--quantization=modelopt_fp4", + "--moe-runner-backend=flashinfer_trtllm", + "--kv-cache-dtype=bfloat16", + "--mem-fraction-static=0.8", + "--enable-metrics", +] + +MTP_ARGS = [ + "--speculative-algorithm=EAGLE", + "--speculative-num-steps=3", + "--speculative-eagle-topk=1", + "--speculative-num-draft-tokens=4", +] + + +class TestDeepseekV32Nvfp4(unittest.TestCase): + """DeepSeek V3.2 NVFP4 on GB300 (4x B200 NVL4, tp=4).""" + + def test_deepseek_v32_nvfp4(self): + variants = [ + ModelLaunchSettings( + MODEL_PATH, + tp_size=4, + extra_args=COMMON_ARGS, + variant="TP4", + ), + ModelLaunchSettings( + MODEL_PATH, + tp_size=4, + extra_args=COMMON_ARGS + + [ + "--dp-size=4", + "--ep-size=4", + "--enable-dp-attention", + ], + variant="TP4+DP4+DPA", + ), + ModelLaunchSettings( + MODEL_PATH, + tp_size=4, + extra_args=COMMON_ARGS + + [ + "--dp-size=4", + "--ep-size=4", + "--enable-dp-attention", + ] + + MTP_ARGS, + variant="TP4+DP4+DPA+MTP", + env={"SGLANG_ENABLE_SPEC_V2": "1"}, + ), + ] + + run_combined_tests( + models=variants, + test_name="DeepSeek-V3.2-NVFP4", + accuracy_params=AccuracyTestParams( + dataset="gsm8k", baseline_accuracy=0.935 + ), + performance_params=PerformanceTestParams( + profile_dir="performance_profiles_gb300", + ), + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/registered/gb300/test_glm5_fp8.py b/test/registered/gb300/test_glm5_fp8.py new file mode 100644 index 000000000000..e429e58731dc --- /dev/null +++ b/test/registered/gb300/test_glm5_fp8.py @@ -0,0 +1,68 @@ +import unittest + +from sglang.test.accuracy_test_runner import AccuracyTestParams +from sglang.test.ci.ci_register import register_cuda_ci +from sglang.test.performance_test_runner import PerformanceTestParams +from sglang.test.run_combined_tests import run_combined_tests +from sglang.test.test_utils import ModelLaunchSettings + +register_cuda_ci(est_time=7200, suite="nightly-4-gpu-gb300", nightly=True) + +MODEL_PATH = "zai-org/GLM-5-FP8" + +COMMON_ARGS = [ + "--trust-remote-code", + "--reasoning-parser=glm45", + "--tool-call-parser=glm47", + "--mem-fraction-static=0.9", + "--enable-metrics", +] + +MTP_ARGS = [ + "--speculative-algorithm=EAGLE", + "--speculative-num-steps=3", + "--speculative-eagle-topk=1", + "--speculative-num-draft-tokens=4", +] + + +class TestGlm5Fp8(unittest.TestCase): + """GLM-5 FP8 on GB300 (4x B200 NVL4, tp=4).""" + + def test_glm5_fp8(self): + variants = [ + ModelLaunchSettings( + MODEL_PATH, + tp_size=4, + extra_args=COMMON_ARGS, + variant="TP4", + ), + ModelLaunchSettings( + MODEL_PATH, + tp_size=4, + extra_args=COMMON_ARGS + ["--dp-size=4", "--enable-dp-attention"], + variant="TP4+DP4+DPA", + ), + ModelLaunchSettings( + MODEL_PATH, + tp_size=4, + extra_args=COMMON_ARGS + + ["--dp-size=4", "--enable-dp-attention"] + + MTP_ARGS, + variant="TP4+DP4+DPA+MTP", + env={"SGLANG_ENABLE_SPEC_V2": "1"}, + ), + ] + + run_combined_tests( + models=variants, + test_name="GLM-5-FP8", + accuracy_params=AccuracyTestParams(dataset="gsm8k", baseline_accuracy=0.92), + performance_params=PerformanceTestParams( + profile_dir="performance_profiles_gb300", + ), + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/registered/gb300/test_glm5_nvfp4.py b/test/registered/gb300/test_glm5_nvfp4.py new file mode 100644 index 000000000000..595276c689fb --- /dev/null +++ b/test/registered/gb300/test_glm5_nvfp4.py @@ -0,0 +1,71 @@ +import unittest + +from sglang.test.accuracy_test_runner import AccuracyTestParams +from sglang.test.ci.ci_register import register_cuda_ci +from sglang.test.performance_test_runner import PerformanceTestParams +from sglang.test.run_combined_tests import run_combined_tests +from sglang.test.test_utils import ModelLaunchSettings + +register_cuda_ci(est_time=7200, suite="nightly-4-gpu-gb300", nightly=True) + +MODEL_PATH = "nvidia/GLM-5-NVFP4" + +COMMON_ARGS = [ + "--trust-remote-code", + "--reasoning-parser=glm45", + "--tool-call-parser=glm47", + "--quantization=modelopt_fp4", + "--moe-runner-backend=flashinfer_trtllm", + "--kv-cache-dtype=bfloat16", + "--mem-fraction-static=0.9", + "--enable-metrics", +] + +MTP_ARGS = [ + "--speculative-algorithm=EAGLE", + "--speculative-num-steps=3", + "--speculative-eagle-topk=1", + "--speculative-num-draft-tokens=4", +] + + +class TestGlm5Nvfp4(unittest.TestCase): + """GLM-5 NVFP4 on GB300 (4x B200 NVL4, tp=4).""" + + def test_glm5_nvfp4(self): + variants = [ + ModelLaunchSettings( + MODEL_PATH, + tp_size=4, + extra_args=COMMON_ARGS, + variant="TP4", + ), + ModelLaunchSettings( + MODEL_PATH, + tp_size=4, + extra_args=COMMON_ARGS + ["--dp-size=4", "--enable-dp-attention"], + variant="TP4+DP4+DPA", + ), + ModelLaunchSettings( + MODEL_PATH, + tp_size=4, + extra_args=COMMON_ARGS + + ["--dp-size=4", "--enable-dp-attention"] + + MTP_ARGS, + variant="TP4+DP4+DPA+MTP", + env={"SGLANG_ENABLE_SPEC_V2": "1"}, + ), + ] + + run_combined_tests( + models=variants, + test_name="GLM-5-NVFP4", + accuracy_params=AccuracyTestParams(dataset="gsm8k", baseline_accuracy=0.92), + performance_params=PerformanceTestParams( + profile_dir="performance_profiles_gb300", + ), + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/registered/gb300/test_kimi_k25.py b/test/registered/gb300/test_kimi_k25.py new file mode 100644 index 000000000000..47beb0b19997 --- /dev/null +++ b/test/registered/gb300/test_kimi_k25.py @@ -0,0 +1,58 @@ +import unittest + +from sglang.test.accuracy_test_runner import AccuracyTestParams +from sglang.test.ci.ci_register import register_cuda_ci +from sglang.test.performance_test_runner import PerformanceTestParams +from sglang.test.run_combined_tests import run_combined_tests +from sglang.test.test_utils import ModelLaunchSettings + +register_cuda_ci(est_time=7200, suite="nightly-4-gpu-gb300", nightly=True) + +MODEL_PATH = "moonshotai/Kimi-K2.5" + +COMMON_ARGS = [ + "--trust-remote-code", + "--reasoning-parser=kimi_k2", + "--tool-call-parser=kimi_k2", + "--mem-fraction-static=0.8", + "--enable-multimodal", + "--enable-metrics", +] + + +class TestKimiK25(unittest.TestCase): + """Kimi-K2.5 (native INT4) on GB300 (4x B200 NVL4, tp=4). + + No EAGLE/MTP support for Kimi-K2.5 — only TP and TP+DP+DPA variants. + """ + + def test_kimi_k25(self): + variants = [ + ModelLaunchSettings( + MODEL_PATH, + tp_size=4, + extra_args=COMMON_ARGS, + variant="TP4", + ), + ModelLaunchSettings( + MODEL_PATH, + tp_size=4, + extra_args=COMMON_ARGS + ["--dp-size=4", "--enable-dp-attention"], + variant="TP4+DP4+DPA", + ), + ] + + run_combined_tests( + models=variants, + test_name="Kimi-K2.5", + accuracy_params=AccuracyTestParams( + dataset="mmmu-pro", baseline_accuracy=0.69, repeat=1, max_tokens=32768 + ), + performance_params=PerformanceTestParams( + profile_dir="performance_profiles_gb300", + ), + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/registered/gb300/test_kimi_k25_nvfp4.py b/test/registered/gb300/test_kimi_k25_nvfp4.py new file mode 100644 index 000000000000..7faf6c92baba --- /dev/null +++ b/test/registered/gb300/test_kimi_k25_nvfp4.py @@ -0,0 +1,61 @@ +import unittest + +from sglang.test.accuracy_test_runner import AccuracyTestParams +from sglang.test.ci.ci_register import register_cuda_ci +from sglang.test.performance_test_runner import PerformanceTestParams +from sglang.test.run_combined_tests import run_combined_tests +from sglang.test.test_utils import ModelLaunchSettings + +register_cuda_ci(est_time=7200, suite="nightly-4-gpu-gb300", nightly=True) + +MODEL_PATH = "nvidia/Kimi-K2.5-NVFP4" + +COMMON_ARGS = [ + "--trust-remote-code", + "--reasoning-parser=kimi_k2", + "--tool-call-parser=kimi_k2", + "--quantization=modelopt_fp4", + "--attention-backend=trtllm_mla", + "--moe-runner-backend=flashinfer_trtllm", + "--mem-fraction-static=0.8", + "--enable-multimodal", + "--enable-metrics", +] + + +class TestKimiK25Nvfp4(unittest.TestCase): + """Kimi-K2.5 NVFP4 on GB300 (4x B200 NVL4, tp=4). + + No EAGLE/MTP support for Kimi-K2.5 — only TP and TP+DP+DPA variants. + """ + + def test_kimi_k25_nvfp4(self): + variants = [ + ModelLaunchSettings( + MODEL_PATH, + tp_size=4, + extra_args=COMMON_ARGS, + variant="TP4", + ), + ModelLaunchSettings( + MODEL_PATH, + tp_size=4, + extra_args=COMMON_ARGS + ["--dp-size=4", "--enable-dp-attention"], + variant="TP4+DP4+DPA", + ), + ] + + run_combined_tests( + models=variants, + test_name="Kimi-K2.5-NVFP4", + accuracy_params=AccuracyTestParams( + dataset="mmmu-pro", baseline_accuracy=0.69, repeat=1, max_tokens=32768 + ), + performance_params=PerformanceTestParams( + profile_dir="performance_profiles_gb300", + ), + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/registered/gb300/test_qwen35_fp8.py b/test/registered/gb300/test_qwen35_fp8.py new file mode 100644 index 000000000000..1121b1a81cf0 --- /dev/null +++ b/test/registered/gb300/test_qwen35_fp8.py @@ -0,0 +1,75 @@ +import unittest + +from sglang.test.accuracy_test_runner import AccuracyTestParams +from sglang.test.ci.ci_register import register_cuda_ci +from sglang.test.performance_test_runner import PerformanceTestParams +from sglang.test.run_combined_tests import run_combined_tests +from sglang.test.test_utils import ModelLaunchSettings + +register_cuda_ci(est_time=7200, suite="nightly-4-gpu-gb300", nightly=True) + +MODEL_PATH = "Qwen/Qwen3.5-397B-A17B-FP8" + +COMMON_ARGS = [ + "--trust-remote-code", + "--reasoning-parser=qwen3", + "--tool-call-parser=qwen3_coder", + "--enable-flashinfer-allreduce-fusion", + "--attention-backend=trtllm_mha", + "--mem-fraction-static=0.8", + "--enable-multimodal", + "--enable-metrics", +] + +MTP_ARGS = [ + "--speculative-algorithm=EAGLE", + "--speculative-num-steps=3", + "--speculative-eagle-topk=1", + "--speculative-num-draft-tokens=4", + "--mamba-scheduler-strategy=extra_buffer", + "--page-size=64", +] + + +class TestQwen35Fp8(unittest.TestCase): + """Qwen3.5-397B FP8 on GB300 (4x B200 NVL4, tp=4).""" + + def test_qwen35_fp8(self): + variants = [ + ModelLaunchSettings( + MODEL_PATH, + tp_size=4, + extra_args=COMMON_ARGS, + variant="TP4", + ), + ModelLaunchSettings( + MODEL_PATH, + tp_size=4, + extra_args=COMMON_ARGS + ["--dp-size=4", "--enable-dp-attention"], + variant="TP4+DP4+DPA", + ), + ModelLaunchSettings( + MODEL_PATH, + tp_size=4, + extra_args=COMMON_ARGS + + ["--dp-size=4", "--enable-dp-attention"] + + MTP_ARGS, + variant="TP4+DP4+DPA+MTP", + env={"SGLANG_ENABLE_SPEC_V2": "1"}, + ), + ] + + run_combined_tests( + models=variants, + test_name="Qwen3.5-397B-FP8", + accuracy_params=AccuracyTestParams( + dataset="mmmu-pro", baseline_accuracy=0.78, repeat=1, max_tokens=32768 + ), + performance_params=PerformanceTestParams( + profile_dir="performance_profiles_gb300", + ), + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/registered/gb300/test_qwen35_nvfp4.py b/test/registered/gb300/test_qwen35_nvfp4.py new file mode 100644 index 000000000000..f48ad701c25a --- /dev/null +++ b/test/registered/gb300/test_qwen35_nvfp4.py @@ -0,0 +1,79 @@ +import unittest + +from sglang.test.accuracy_test_runner import AccuracyTestParams +from sglang.test.ci.ci_register import register_cuda_ci +from sglang.test.performance_test_runner import PerformanceTestParams +from sglang.test.run_combined_tests import run_combined_tests +from sglang.test.test_utils import ModelLaunchSettings + +register_cuda_ci(est_time=7200, suite="nightly-4-gpu-gb300", nightly=True) + +MODEL_PATH = "nvidia/Qwen3.5-397B-A17B-NVFP4" + +COMMON_ARGS = [ + "--trust-remote-code", + "--reasoning-parser=qwen3", + "--tool-call-parser=qwen3_coder", + "--quantization=modelopt_fp4", + "--fp4-gemm-backend=flashinfer_cutlass", + "--moe-runner-backend=flashinfer_trtllm", + "--kv-cache-dtype=fp8_e4m3", + "--enable-flashinfer-allreduce-fusion", + "--attention-backend=trtllm_mha", + "--mem-fraction-static=0.8", + "--enable-multimodal", + "--enable-metrics", +] + +MTP_ARGS = [ + "--speculative-algorithm=EAGLE", + "--speculative-num-steps=3", + "--speculative-eagle-topk=1", + "--speculative-num-draft-tokens=4", + "--mamba-scheduler-strategy=extra_buffer", + "--page-size=64", +] + + +class TestQwen35Nvfp4(unittest.TestCase): + """Qwen3.5-397B NVFP4 on GB300 (4x B200 NVL4, tp=4).""" + + def test_qwen35_nvfp4(self): + variants = [ + ModelLaunchSettings( + MODEL_PATH, + tp_size=4, + extra_args=COMMON_ARGS, + variant="TP4", + ), + ModelLaunchSettings( + MODEL_PATH, + tp_size=4, + extra_args=COMMON_ARGS + ["--dp-size=4", "--enable-dp-attention"], + variant="TP4+DP4+DPA", + ), + ModelLaunchSettings( + MODEL_PATH, + tp_size=4, + extra_args=COMMON_ARGS + + ["--dp-size=4", "--enable-dp-attention"] + + MTP_ARGS, + variant="TP4+DP4+DPA+MTP", + env={"SGLANG_ENABLE_SPEC_V2": "1"}, + ), + ] + + run_combined_tests( + models=variants, + test_name="Qwen3.5-397B-NVFP4", + accuracy_params=AccuracyTestParams( + dataset="mmmu-pro", baseline_accuracy=0.78, repeat=1, max_tokens=32768 + ), + performance_params=PerformanceTestParams( + profile_dir="performance_profiles_gb300", + ), + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/registered/hicache/test_hicache_storage_file_backend.py b/test/registered/hicache/test_hicache_storage_file_backend.py index 554dd5b6f56e..12f779412c9c 100644 --- a/test/registered/hicache/test_hicache_storage_file_backend.py +++ b/test/registered/hicache/test_hicache_storage_file_backend.py @@ -19,7 +19,7 @@ from sglang.benchmark.utils import get_tokenizer from sglang.srt.utils import kill_process_tree from sglang.test.ci.ci_register import register_amd_ci, register_cuda_ci -from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k +from sglang.test.run_eval import run_eval from sglang.test.test_utils import ( DEFAULT_MLA_MODEL_NAME_FOR_TEST, DEFAULT_MODEL_NAME_FOR_TEST, @@ -295,15 +295,14 @@ def run_eval_accuracy_test(test_instance, accuracy_threshold: float = 0.03): # First evaluation - populate cache print("Phase 1: Running initial GSM8K evaluation to populate cache...") args_initial = SimpleNamespace( - num_shots=5, - data_path=None, - num_questions=50, - max_new_tokens=512, - parallel=10, - host=f"http://{test_instance.base_host}", - port=int(test_instance.base_port), + base_url=f"http://{test_instance.base_host}:{test_instance.base_port}", + eval_name="gsm8k", + api="completion", + max_tokens=512, + num_examples=200, + num_threads=64, ) - metrics_initial = run_eval_few_shot_gsm8k(args_initial) + metrics_initial = run_eval(args_initial) # Flush cache to force remote storage access print("Phase 2: Flushing device cache...") @@ -311,18 +310,18 @@ def run_eval_accuracy_test(test_instance, accuracy_threshold: float = 0.03): # Second evaluation - should use remote cache print("Phase 3: Running second GSM8K evaluation using remote cache...") - metrics_cached = run_eval_few_shot_gsm8k(args_initial) + metrics_cached = run_eval(args_initial) # Verify accuracy consistency - accuracy_diff = abs(metrics_initial["accuracy"] - metrics_cached["accuracy"]) + accuracy_diff = abs(metrics_initial["score"] - metrics_cached["score"]) print(f"Accuracy difference: {accuracy_diff:.4f}") # Assertions test_instance.assertGreater( - metrics_initial["accuracy"], 0.6, "Initial accuracy should be reasonable" + metrics_initial["score"], 0.6, "Initial accuracy should be reasonable" ) test_instance.assertGreater( - metrics_cached["accuracy"], 0.6, "Cached accuracy should be reasonable" + metrics_cached["score"], 0.6, "Cached accuracy should be reasonable" ) test_instance.assertLess( accuracy_diff, diff --git a/test/registered/kernels/test_nsa_indexer.py b/test/registered/kernels/test_nsa_indexer.py index 789baed00414..77c007dad470 100644 --- a/test/registered/kernels/test_nsa_indexer.py +++ b/test/registered/kernels/test_nsa_indexer.py @@ -24,7 +24,7 @@ from sglang.srt.server_args import ServerArgs, set_global_server_args_for_scheduler from sglang.test.test_utils import CustomTestCase -register_cuda_ci(est_time=2, suite="stage-b-test-1-gpu-large") +register_cuda_ci(est_time=15, suite="stage-b-test-1-gpu-large") # Global configuration for all indexer tests DEFAULT_CONFIG = { diff --git a/test/registered/language/test_srt_backend.py b/test/registered/language/test_srt_backend.py index 99620e2d2591..a459cad89534 100644 --- a/test/registered/language/test_srt_backend.py +++ b/test/registered/language/test_srt_backend.py @@ -32,6 +32,7 @@ def setUpClass(cls): model_path=DEFAULT_MODEL_NAME_FOR_TEST, cuda_graph_max_bs=4, mem_fraction_static=0.7, + log_level="info", ) sgl.set_default_backend(cls.backend) diff --git a/test/registered/lora/test_fused_moe_lora_kernel.py b/test/registered/lora/test_fused_moe_lora_kernel.py index 8c0bde3f3e6d..44179a5fa7e2 100644 --- a/test/registered/lora/test_fused_moe_lora_kernel.py +++ b/test/registered/lora/test_fused_moe_lora_kernel.py @@ -15,7 +15,7 @@ # ============================================================================== -register_cuda_ci(est_time=120, suite="stage-b-test-1-gpu-large") +register_cuda_ci(est_time=25, suite="stage-b-test-1-gpu-large") def round_up(x, base): diff --git a/test/registered/lora/test_lora_gpt_oss_20b_logprob_diff.py b/test/registered/lora/test_lora_gpt_oss_20b_logprob_diff.py new file mode 100644 index 000000000000..e3a5e9dd6c4b --- /dev/null +++ b/test/registered/lora/test_lora_gpt_oss_20b_logprob_diff.py @@ -0,0 +1,151 @@ +# Copyright 2023-2025 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +""" +Regression test for gpt-oss-20b LoRA logprob accuracy. + +Compares SGLang LoRA logprobs against reference training logprobs from a +pre-computed dataset. The LoRA adapter and reference data are downloaded from: +https://huggingface.co/datasets/yushengsu/lora-diff-gpt-oss-20b + +Usage: + python -m unittest test_lora_gpt_oss_20b_logprob_diff +""" + +import multiprocessing as mp +import os +import unittest + +import torch +from huggingface_hub import snapshot_download + +import sglang as sgl +from sglang.test.ci.ci_register import register_cuda_ci +from sglang.test.test_utils import CustomTestCase + +register_cuda_ci( + est_time=300, + suite="stage-c-test-4-gpu-b200", +) + +BASE_MODEL = "lmsys/gpt-oss-20b-bf16" +LORA_HF_REPO = "yushengsu/lora-diff-gpt-oss-20b" +LORA_BACKEND = "triton" +MAX_LORA_RANK = 32 +TP_SIZE = 4 +DISABLE_CUDA_GRAPH = True +MOE_RUNNER_BACKEND = "triton" +EXPERTS_SHARED_OUTER_LORAS = True +PREFILL_ATTENTION_BACKEND = "fa4" +DECODE_ATTENTION_BACKEND = "fa4" + +KL_THRESHOLD = 5e-3 + + +def kl_v2(a, b): + a = torch.tensor(a) if not torch.is_tensor(a) else a + b = torch.tensor(b) if not torch.is_tensor(b) else b + return (((a - b) ** 2) * 0.5).mean().item() + + +def get_prompt_logprobs(engine, input_ids, lora_path): + out = engine.generate( + input_ids=input_ids, + sampling_params={"max_new_tokens": 0, "temperature": 0.0}, + return_logprob=True, + logprob_start_len=0, + lora_path=lora_path, + ) + return [logprob for logprob, _, _ in out["meta_info"]["input_token_logprobs"]][1:] + + +class TestLoRAGptOss20BLogprobDiff(CustomTestCase): + + def test_lora_gpt_oss_20b_logprob_accuracy(self): + adapter_path = snapshot_download( + LORA_HF_REPO, + repo_type="dataset", + ) + + engine = sgl.Engine( + model_path=BASE_MODEL, + tp_size=TP_SIZE, + enable_lora=True, + max_lora_rank=MAX_LORA_RANK, + lora_paths={"my_lora": adapter_path}, + lora_backend=LORA_BACKEND, + attention_backend="flashinfer", + disable_cuda_graph=DISABLE_CUDA_GRAPH, + moe_runner_backend=MOE_RUNNER_BACKEND, + experts_shared_outer_loras=EXPERTS_SHARED_OUTER_LORAS, + prefill_attention_backend=PREFILL_ATTENTION_BACKEND, + decode_attention_backend=DECODE_ATTENTION_BACKEND, + ) + + try: + cdata = torch.load( + os.path.join(adapter_path, "compare_sample_train_data.pt"), + weights_only=False, + ) + + base_logprobs = get_prompt_logprobs(engine, cdata["tokens"], lora_path=None) + logprobs = get_prompt_logprobs(engine, cdata["tokens"], lora_path="my_lora") + + base_t = torch.tensor(base_logprobs) + lora_t = torch.tensor(logprobs) + diff = (base_t - lora_t).abs() + print( + f"[VERIFY] base vs lora: mean_diff={diff.mean().item():.6f}, " + f"max_diff={diff.max().item():.6f}, " + f"identical={torch.equal(base_t, lora_t)}" + ) + + self.assertFalse( + torch.equal(base_t, lora_t), + "LoRA logprobs should differ from base model logprobs", + ) + + kl_sglang_trainer = kl_v2(cdata["training_logprobs"], logprobs) + kl_orig_trainer = kl_v2( + cdata["training_logprobs"], cdata["sampling_logprobs"] + ) + kl_sglang_orig = kl_v2(logprobs, cdata["sampling_logprobs"]) + + print(f"KL(orig_sampler, trainer) = {kl_orig_trainer:.6e}") + print(f"KL(sglang, trainer) = {kl_sglang_trainer:.6e}") + print(f"KL(sglang, orig_sampler) = {kl_sglang_orig:.6e}") + + self.assertLessEqual( + kl_sglang_trainer, + KL_THRESHOLD, + f"KL(sglang, trainer) = {kl_sglang_trainer:.6e} exceeds " + f"threshold {KL_THRESHOLD}", + ) + + finally: + engine.shutdown() + + +if __name__ == "__main__": + try: + mp.set_start_method("spawn") + except RuntimeError: + pass + + try: + unittest.main(warnings="ignore", verbosity=2) + finally: + if torch.cuda.is_available(): + torch.cuda.empty_cache() + torch.cuda.synchronize() diff --git a/test/registered/lora/test_lora_moe_vllm_sgl_logprob_diff.py b/test/registered/lora/test_lora_moe_vllm_sgl_logprob_diff.py index 6926f1d89a58..7a7b47bbc76f 100644 --- a/test/registered/lora/test_lora_moe_vllm_sgl_logprob_diff.py +++ b/test/registered/lora/test_lora_moe_vllm_sgl_logprob_diff.py @@ -25,7 +25,7 @@ from sglang.test.runners import SRTRunner register_cuda_ci( - est_time=25, + est_time=50, suite="stage-b-test-1-gpu-large", ) @@ -332,7 +332,7 @@ def test_sglang_moe_parity_strict(self): ref = REFERENCE_STATS[i] # Epsilon to allow room for different, but correct, implementations - eps = 1e-4 + eps = 2e-4 # Assertions self.assertEqual(v_text, s_text, f"String mismatch on prompt {i}") diff --git a/test/registered/lora/test_lora_qwen3_30b_a3b_instruct_2507_logprob_diff.py b/test/registered/lora/test_lora_qwen3_30b_a3b_instruct_2507_logprob_diff.py new file mode 100644 index 000000000000..a729407f6986 --- /dev/null +++ b/test/registered/lora/test_lora_qwen3_30b_a3b_instruct_2507_logprob_diff.py @@ -0,0 +1,151 @@ +# Copyright 2023-2025 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +""" +Regression test for Qwen3-30B-A3B-Instruct-2507 LoRA logprob accuracy. + +Compares SGLang LoRA logprobs against reference training logprobs from a +pre-computed dataset. The LoRA adapter and reference data are downloaded from: +https://huggingface.co/datasets/yushengsu/lora-diff-Qwen3-30B-A3B-Instruct-2507 + +Usage: + python -m unittest test_lora_qwen3_30b_a3b_instruct_2507_logprob_diff +""" + +import multiprocessing as mp +import os +import unittest + +import torch +from huggingface_hub import snapshot_download + +import sglang as sgl +from sglang.test.ci.ci_register import register_cuda_ci +from sglang.test.test_utils import CustomTestCase + +register_cuda_ci( + est_time=160, + suite="stage-c-test-4-gpu-b200", +) + +BASE_MODEL = "Qwen/Qwen3-30B-A3B-Instruct-2507" +LORA_HF_REPO = "yushengsu/lora-diff-Qwen3-30B-A3B-Instruct-2507" +LORA_BACKEND = "triton" +MAX_LORA_RANK = 32 +TP_SIZE = 4 +DISABLE_CUDA_GRAPH = True +MOE_RUNNER_BACKEND = "triton" +EXPERTS_SHARED_OUTER_LORAS = True +PREFILL_ATTENTION_BACKEND = "fa4" +DECODE_ATTENTION_BACKEND = "fa4" + +KL_THRESHOLD = 5e-3 + + +def kl_v2(a, b): + a = torch.tensor(a) if not torch.is_tensor(a) else a + b = torch.tensor(b) if not torch.is_tensor(b) else b + return (((a - b) ** 2) * 0.5).mean().item() + + +def get_prompt_logprobs(engine, input_ids, lora_path): + out = engine.generate( + input_ids=input_ids, + sampling_params={"max_new_tokens": 0, "temperature": 0.0}, + return_logprob=True, + logprob_start_len=0, + lora_path=lora_path, + ) + return [logprob for logprob, _, _ in out["meta_info"]["input_token_logprobs"]][1:] + + +class TestLoRAQwen3_30B_A3B_Instruct_2507_LogprobDiff(CustomTestCase): + + def test_lora_qwen3_30b_a3b_instruct_2507_logprob_accuracy(self): + adapter_path = snapshot_download( + LORA_HF_REPO, + repo_type="dataset", + ) + + engine = sgl.Engine( + model_path=BASE_MODEL, + tp_size=TP_SIZE, + enable_lora=True, + max_lora_rank=MAX_LORA_RANK, + lora_paths={"my_lora": adapter_path}, + lora_backend=LORA_BACKEND, + attention_backend="flashinfer", + disable_cuda_graph=DISABLE_CUDA_GRAPH, + moe_runner_backend=MOE_RUNNER_BACKEND, + experts_shared_outer_loras=EXPERTS_SHARED_OUTER_LORAS, + prefill_attention_backend=PREFILL_ATTENTION_BACKEND, + decode_attention_backend=DECODE_ATTENTION_BACKEND, + ) + + try: + cdata = torch.load( + os.path.join(adapter_path, "compare_sample_train_data.pt"), + weights_only=False, + ) + + base_logprobs = get_prompt_logprobs(engine, cdata["tokens"], lora_path=None) + logprobs = get_prompt_logprobs(engine, cdata["tokens"], lora_path="my_lora") + + base_t = torch.tensor(base_logprobs) + lora_t = torch.tensor(logprobs) + diff = (base_t - lora_t).abs() + print( + f"[VERIFY] base vs lora: mean_diff={diff.mean().item():.6f}, " + f"max_diff={diff.max().item():.6f}, " + f"identical={torch.equal(base_t, lora_t)}" + ) + + self.assertFalse( + torch.equal(base_t, lora_t), + "LoRA logprobs should differ from base model logprobs", + ) + + kl_sglang_trainer = kl_v2(cdata["training_logprobs"], logprobs) + kl_orig_trainer = kl_v2( + cdata["training_logprobs"], cdata["sampling_logprobs"] + ) + kl_sglang_orig = kl_v2(logprobs, cdata["sampling_logprobs"]) + + print(f"KL(orig_sampler, trainer) = {kl_orig_trainer:.6e}") + print(f"KL(sglang, trainer) = {kl_sglang_trainer:.6e}") + print(f"KL(sglang, orig_sampler) = {kl_sglang_orig:.6e}") + + self.assertLessEqual( + kl_sglang_trainer, + KL_THRESHOLD, + f"KL(sglang, trainer) = {kl_sglang_trainer:.6e} exceeds " + f"threshold {KL_THRESHOLD}", + ) + + finally: + engine.shutdown() + + +if __name__ == "__main__": + try: + mp.set_start_method("spawn") + except RuntimeError: + pass + + try: + unittest.main(warnings="ignore", verbosity=2) + finally: + if torch.cuda.is_available(): + torch.cuda.empty_cache() + torch.cuda.synchronize() diff --git a/test/registered/lora/test_lora_qwen3_8b_logprob_diff.py b/test/registered/lora/test_lora_qwen3_8b_logprob_diff.py new file mode 100644 index 000000000000..c2b9039a2ccf --- /dev/null +++ b/test/registered/lora/test_lora_qwen3_8b_logprob_diff.py @@ -0,0 +1,202 @@ +# Copyright 2023-2025 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +""" +Regression test for Qwen3-8B LoRA logprob accuracy. + +Compares SGLang LoRA logprobs against reference training logprobs from a +pre-computed dataset. The LoRA adapter and reference data are downloaded from: +https://huggingface.co/datasets/yushengsu/lora-diff-Qwen3-8B + +Usage: + python -m unittest test_lora_qwen3_8b_logprob_diff +""" + +import multiprocessing as mp +import os +import unittest +from unittest.mock import patch + +import torch +import torch.nn as nn +from huggingface_hub import snapshot_download + +import sglang as sgl +from sglang.srt.lora.utils import auto_detect_lora_target_modules +from sglang.test.ci.ci_register import register_cuda_ci +from sglang.test.test_utils import CustomTestCase + +register_cuda_ci( + est_time=40, + suite="stage-b-test-1-gpu-large", +) + +BASE_MODEL = "Qwen/Qwen3-8B" +LORA_HF_REPO = "yushengsu/lora-diff-Qwen3-8B" +LORA_BACKEND = "triton" +MAX_LORA_RANK = 32 +TP_SIZE = 1 +DISABLE_CUDA_GRAPH = True +PREFILL_ATTENTION_BACKEND = "fa4" +DECODE_ATTENTION_BACKEND = "fa4" + +KL_THRESHOLD = 5e-3 + + +def kl_v2(a, b): + a = torch.tensor(a) if not torch.is_tensor(a) else a + b = torch.tensor(b) if not torch.is_tensor(b) else b + return (((a - b) ** 2) * 0.5).mean().item() + + +def get_prompt_logprobs(engine, input_ids, lora_path): + out = engine.generate( + input_ids=input_ids, + sampling_params={"max_new_tokens": 0, "temperature": 0.0}, + return_logprob=True, + logprob_start_len=0, + lora_path=lora_path, + ) + return [logprob for logprob, _, _ in out["meta_info"]["input_token_logprobs"]][1:] + + +class _MockLinearBase(nn.Module): + pass + + +class _MockFusedMoE(nn.Module): + pass + + +class _MockParallelLMHead(nn.Module): + pass + + +def _build_qwen3_mock(): + """Build a lightweight nn.Module tree that mirrors Qwen3-8B's named modules.""" + model = nn.Module() + inner = nn.Module() + layer = nn.Module() + + attn = nn.Module() + attn.qkv_proj = _MockLinearBase() + attn.o_proj = _MockLinearBase() + layer.self_attn = attn + + mlp = nn.Module() + mlp.gate_up_proj = _MockLinearBase() + mlp.down_proj = _MockLinearBase() + layer.mlp = mlp + + inner.layers = nn.ModuleList([layer]) + inner.embed_tokens = nn.Embedding(10, 8) # not a LinearBase — should be excluded + model.model = inner + model.lm_head = _MockParallelLMHead() + return model + + +class TestLoRAQwen3_8BLogprobDiff(CustomTestCase): + + def test_auto_detect_lora_target_modules(self): + """Verify auto_detect_lora_target_modules returns the expected module + set for a Qwen3-8B-like (dense) architecture. Catches silent renames + of internal param names that would break LoRA auto-detection.""" + model = _build_qwen3_mock() + + with patch("sglang.srt.layers.linear.LinearBase", _MockLinearBase), patch( + "sglang.srt.layers.moe.fused_moe_triton.layer.FusedMoE", _MockFusedMoE + ), patch( + "sglang.srt.layers.vocab_parallel_embedding.ParallelLMHead", + _MockParallelLMHead, + ): + detected = auto_detect_lora_target_modules(model) + + expected = {"qkv_proj", "o_proj", "gate_up_proj", "down_proj", "lm_head"} + self.assertEqual(detected, expected) + + def test_lora_qwen3_8b_logprob_accuracy(self): + adapter_path = snapshot_download( + LORA_HF_REPO, + repo_type="dataset", + ) + + engine = sgl.Engine( + model_path=BASE_MODEL, + tp_size=TP_SIZE, + enable_lora=True, + max_lora_rank=MAX_LORA_RANK, + lora_paths={"my_lora": adapter_path}, + lora_backend=LORA_BACKEND, + attention_backend="flashinfer", + disable_cuda_graph=DISABLE_CUDA_GRAPH, + prefill_attention_backend=PREFILL_ATTENTION_BACKEND, + decode_attention_backend=DECODE_ATTENTION_BACKEND, + ) + + try: + cdata = torch.load( + os.path.join(adapter_path, "compare_sample_train_data.pt"), + weights_only=False, + ) + + base_logprobs = get_prompt_logprobs(engine, cdata["tokens"], lora_path=None) + logprobs = get_prompt_logprobs(engine, cdata["tokens"], lora_path="my_lora") + + base_t = torch.tensor(base_logprobs) + lora_t = torch.tensor(logprobs) + diff = (base_t - lora_t).abs() + print( + f"[VERIFY] base vs lora: mean_diff={diff.mean().item():.6f}, " + f"max_diff={diff.max().item():.6f}, " + f"identical={torch.equal(base_t, lora_t)}" + ) + + self.assertFalse( + torch.equal(base_t, lora_t), + "LoRA logprobs should differ from base model logprobs", + ) + + kl_sglang_trainer = kl_v2(cdata["training_logprobs"], logprobs) + kl_orig_trainer = kl_v2( + cdata["training_logprobs"], cdata["sampling_logprobs"] + ) + kl_sglang_orig = kl_v2(logprobs, cdata["sampling_logprobs"]) + + print(f"KL(orig_sampler, trainer) = {kl_orig_trainer:.6e}") + print(f"KL(sglang, trainer) = {kl_sglang_trainer:.6e}") + print(f"KL(sglang, orig_sampler) = {kl_sglang_orig:.6e}") + + self.assertLessEqual( + kl_sglang_trainer, + KL_THRESHOLD, + f"KL(sglang, trainer) = {kl_sglang_trainer:.6e} exceeds " + f"threshold {KL_THRESHOLD}", + ) + + finally: + engine.shutdown() + + +if __name__ == "__main__": + try: + mp.set_start_method("spawn") + except RuntimeError: + pass + + try: + unittest.main(warnings="ignore", verbosity=2) + finally: + if torch.cuda.is_available(): + torch.cuda.empty_cache() + torch.cuda.synchronize() diff --git a/test/registered/lora/test_lora_qwen3_vl_30b_a3b_instruct_logprob_diff.py b/test/registered/lora/test_lora_qwen3_vl_30b_a3b_instruct_logprob_diff.py new file mode 100644 index 000000000000..176d16919cd6 --- /dev/null +++ b/test/registered/lora/test_lora_qwen3_vl_30b_a3b_instruct_logprob_diff.py @@ -0,0 +1,151 @@ +# Copyright 2023-2025 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +""" +Regression test for Qwen3-VL-30B-A3B-Instruct LoRA logprob accuracy. + +Compares SGLang LoRA logprobs against reference training logprobs from a +pre-computed dataset. The LoRA adapter and reference data are downloaded from: +https://huggingface.co/datasets/yushengsu/lora-diff-Qwen3-VL-30B-A3B-Instruct + +Usage: + python -m unittest test_lora_qwen3_vl_30b_a3b_instruct_logprob_diff +""" + +import multiprocessing as mp +import os +import unittest + +import torch +from huggingface_hub import snapshot_download + +import sglang as sgl +from sglang.test.ci.ci_register import register_cuda_ci +from sglang.test.test_utils import CustomTestCase + +register_cuda_ci( + est_time=160, + suite="stage-c-test-4-gpu-b200", +) + +BASE_MODEL = "Qwen/Qwen3-VL-30B-A3B-Instruct" +LORA_HF_REPO = "yushengsu/lora-diff-Qwen3-VL-30B-A3B-Instruct" +LORA_BACKEND = "triton" +MAX_LORA_RANK = 32 +TP_SIZE = 4 +DISABLE_CUDA_GRAPH = True +MOE_RUNNER_BACKEND = "triton" +EXPERTS_SHARED_OUTER_LORAS = True +PREFILL_ATTENTION_BACKEND = "fa4" +DECODE_ATTENTION_BACKEND = "fa4" + +KL_THRESHOLD = 5e-3 + + +def kl_v2(a, b): + a = torch.tensor(a) if not torch.is_tensor(a) else a + b = torch.tensor(b) if not torch.is_tensor(b) else b + return (((a - b) ** 2) * 0.5).mean().item() + + +def get_prompt_logprobs(engine, input_ids, lora_path): + out = engine.generate( + input_ids=input_ids, + sampling_params={"max_new_tokens": 0, "temperature": 0.0}, + return_logprob=True, + logprob_start_len=0, + lora_path=lora_path, + ) + return [logprob for logprob, _, _ in out["meta_info"]["input_token_logprobs"]][1:] + + +class TestLoRAQwen3VL_30B_A3B_Instruct_LogprobDiff(CustomTestCase): + + def test_lora_qwen3_vl_30b_a3b_instruct_logprob_accuracy(self): + adapter_path = snapshot_download( + LORA_HF_REPO, + repo_type="dataset", + ) + + engine = sgl.Engine( + model_path=BASE_MODEL, + tp_size=TP_SIZE, + enable_lora=True, + max_lora_rank=MAX_LORA_RANK, + lora_paths={"my_lora": adapter_path}, + lora_backend=LORA_BACKEND, + attention_backend="flashinfer", + disable_cuda_graph=DISABLE_CUDA_GRAPH, + moe_runner_backend=MOE_RUNNER_BACKEND, + experts_shared_outer_loras=EXPERTS_SHARED_OUTER_LORAS, + prefill_attention_backend=PREFILL_ATTENTION_BACKEND, + decode_attention_backend=DECODE_ATTENTION_BACKEND, + ) + + try: + cdata = torch.load( + os.path.join(adapter_path, "compare_sample_train_data.pt"), + weights_only=False, + ) + + base_logprobs = get_prompt_logprobs(engine, cdata["tokens"], lora_path=None) + logprobs = get_prompt_logprobs(engine, cdata["tokens"], lora_path="my_lora") + + base_t = torch.tensor(base_logprobs) + lora_t = torch.tensor(logprobs) + diff = (base_t - lora_t).abs() + print( + f"[VERIFY] base vs lora: mean_diff={diff.mean().item():.6f}, " + f"max_diff={diff.max().item():.6f}, " + f"identical={torch.equal(base_t, lora_t)}" + ) + + self.assertFalse( + torch.equal(base_t, lora_t), + "LoRA logprobs should differ from base model logprobs", + ) + + kl_sglang_trainer = kl_v2(cdata["training_logprobs"], logprobs) + kl_orig_trainer = kl_v2( + cdata["training_logprobs"], cdata["sampling_logprobs"] + ) + kl_sglang_orig = kl_v2(logprobs, cdata["sampling_logprobs"]) + + print(f"KL(orig_sampler, trainer) = {kl_orig_trainer:.6e}") + print(f"KL(sglang, trainer) = {kl_sglang_trainer:.6e}") + print(f"KL(sglang, orig_sampler) = {kl_sglang_orig:.6e}") + + self.assertLessEqual( + kl_sglang_trainer, + KL_THRESHOLD, + f"KL(sglang, trainer) = {kl_sglang_trainer:.6e} exceeds " + f"threshold {KL_THRESHOLD}", + ) + + finally: + engine.shutdown() + + +if __name__ == "__main__": + try: + mp.set_start_method("spawn") + except RuntimeError: + pass + + try: + unittest.main(warnings="ignore", verbosity=2) + finally: + if torch.cuda.is_available(): + torch.cuda.empty_cache() + torch.cuda.synchronize() diff --git a/test/registered/lora/test_lora_tp.py b/test/registered/lora/test_lora_tp.py index 017b4da53d5c..32c4352889da 100644 --- a/test/registered/lora/test_lora_tp.py +++ b/test/registered/lora/test_lora_tp.py @@ -30,8 +30,8 @@ from sglang.test.test_utils import CustomTestCase, is_in_ci register_cuda_ci( - est_time=116, - suite="stage-b-test-2-gpu-large", + est_time=190, + suite="stage-c-test-8-gpu-h200", ) register_amd_ci( est_time=116, @@ -65,6 +65,7 @@ def _run_tp_on_model_cases( max_new_tokens=32, enable_lora_overlap_loading=enable_lora_overlap_loading, test_tag=f"tp={tp_size}, enable_lora_overlap_loading={enable_lora_overlap_loading}", + attention_backend="fa3", ) def test_ci_lora_models(self): diff --git a/test/registered/mla/test_flashmla.py b/test/registered/mla/test_flashmla.py index b270358c83d8..c5f084e42540 100644 --- a/test/registered/mla/test_flashmla.py +++ b/test/registered/mla/test_flashmla.py @@ -11,7 +11,7 @@ from sglang.srt.utils import kill_process_tree from sglang.test.ci.ci_register import register_cuda_ci -from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k +from sglang.test.run_eval import run_eval from sglang.test.test_utils import ( DEFAULT_MODEL_NAME_FOR_TEST_MLA, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, @@ -53,18 +53,18 @@ def tearDownClass(cls): def test_gsm8k(self): args = SimpleNamespace( - num_shots=5, - data_path=None, - num_questions=200, - max_new_tokens=512, - parallel=128, - host="http://127.0.0.1", - port=int(self.base_url.split(":")[-1]), + base_url=self.base_url, + model=self.model, + eval_name="gsm8k", + api="completion", + max_tokens=512, + num_examples=200, + num_threads=128, ) - metrics = run_eval_few_shot_gsm8k(args) + metrics = run_eval(args) print(metrics) - self.assertGreater(metrics["accuracy"], 0.60) + self.assertGreater(metrics["score"], 0.60) class TestFlashMLAMTP(CustomTestCase): @@ -112,18 +112,18 @@ def test_gsm8k(self): requests.get(self.base_url + "/flush_cache") args = SimpleNamespace( - num_shots=5, - data_path=None, - num_questions=200, - max_new_tokens=512, - parallel=128, - host="http://127.0.0.1", - port=int(self.base_url.split(":")[-1]), + base_url=self.base_url, + model=self.model, + eval_name="gsm8k", + api="completion", + max_tokens=512, + num_examples=200, + num_threads=128, ) - metrics = run_eval_few_shot_gsm8k(args) + metrics = run_eval(args) print(metrics) - self.assertGreater(metrics["accuracy"], 0.60) + self.assertGreater(metrics["score"], 0.60) server_info = requests.get(self.base_url + "/server_info").json() avg_spec_accept_length = server_info["internal_states"][0][ diff --git a/test/registered/mla/test_mla_deepseek_v3.py b/test/registered/mla/test_mla_deepseek_v3.py index 392154e3d6ce..e1cf2ba83b6f 100644 --- a/test/registered/mla/test_mla_deepseek_v3.py +++ b/test/registered/mla/test_mla_deepseek_v3.py @@ -6,7 +6,7 @@ from sglang.srt.utils import is_cuda, is_hip, kill_process_tree from sglang.test.ci.ci_register import register_amd_ci, register_cuda_ci -from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k +from sglang.test.run_eval import run_eval from sglang.test.test_utils import ( DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, DEFAULT_URL_FOR_TEST, @@ -45,18 +45,18 @@ def tearDownClass(cls): def test_gsm8k(self): args = SimpleNamespace( - num_shots=5, - data_path=None, - num_questions=200, - max_new_tokens=512, - parallel=128, - host="http://127.0.0.1", - port=int(self.base_url.split(":")[-1]), + base_url=self.base_url, + model=self.model, + eval_name="gsm8k", + api="completion", + max_tokens=512, + num_examples=200, + num_threads=128, ) - metrics = run_eval_few_shot_gsm8k(args) + metrics = run_eval(args) print(metrics) - self.assertGreater(metrics["accuracy"], 0.62) + self.assertGreater(metrics["score"], 0.62) @unittest.skipIf(is_in_ci(), "To reduce the CI execution time.") @@ -82,18 +82,18 @@ def tearDownClass(cls): def test_gsm8k(self): args = SimpleNamespace( - num_shots=5, - data_path=None, - num_questions=200, - max_new_tokens=512, - parallel=128, - host="http://127.0.0.1", - port=int(self.base_url.split(":")[-1]), + base_url=self.base_url, + model=self.model, + eval_name="gsm8k", + api="completion", + max_tokens=512, + num_examples=200, + num_threads=128, ) - metrics = run_eval_few_shot_gsm8k(args) + metrics = run_eval(args) print(metrics) - self.assertGreater(metrics["accuracy"], 0.62) + self.assertGreater(metrics["score"], 0.62) @unittest.skipIf(is_hip(), "FA is not available.") @@ -133,18 +133,18 @@ def tearDownClass(cls): def test_gsm8k(self): args = SimpleNamespace( - num_shots=5, - data_path=None, - num_questions=200, - max_new_tokens=512, - parallel=128, - host="http://127.0.0.1", - port=int(self.base_url.split(":")[-1]), + base_url=self.base_url, + model=self.model, + eval_name="gsm8k", + api="completion", + max_tokens=512, + num_examples=200, + num_threads=128, ) - metrics = run_eval_few_shot_gsm8k(args) + metrics = run_eval(args) print(metrics) - self.assertGreater(metrics["accuracy"], 0.60) + self.assertGreater(metrics["score"], 0.60) class TestDeepseekV3MTP(CustomTestCase): @@ -186,20 +186,20 @@ def test_gsm8k(self): requests.get(self.base_url + "/flush_cache") args = SimpleNamespace( - num_shots=5, - data_path=None, - num_questions=200, - max_new_tokens=512, - parallel=128, - host="http://127.0.0.1", - port=int(self.base_url.split(":")[-1]), + base_url=self.base_url, + model=self.model, + eval_name="gsm8k", + api="completion", + max_tokens=512, + num_examples=200, + num_threads=128, ) - metrics = run_eval_few_shot_gsm8k(args) + metrics = run_eval(args) print(metrics) - self.assertGreater(metrics["accuracy"], 0.60) + self.assertGreater(metrics["score"], 0.60) - server_info = requests.get(self.base_url + "/get_server_info") + server_info = requests.get(self.base_url + "/server_info") avg_spec_accept_length = server_info.json()["internal_states"][0][ "avg_spec_accept_length" ] diff --git a/test/registered/mla/test_mla_flashinfer.py b/test/registered/mla/test_mla_flashinfer.py index 555a54e5e768..5bd6658057f8 100644 --- a/test/registered/mla/test_mla_flashinfer.py +++ b/test/registered/mla/test_mla_flashinfer.py @@ -6,7 +6,7 @@ from sglang.srt.utils import kill_process_tree from sglang.test.ci.ci_register import register_cuda_ci -from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k +from sglang.test.run_eval import run_eval from sglang.test.test_utils import ( DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, DEFAULT_URL_FOR_TEST, @@ -47,18 +47,18 @@ def tearDownClass(cls): def test_gsm8k(self): args = SimpleNamespace( - num_shots=5, - data_path=None, - num_questions=200, - max_new_tokens=512, - parallel=128, - host="http://127.0.0.1", - port=int(self.base_url.split(":")[-1]), + base_url=self.base_url, + model=self.model, + eval_name="gsm8k", + api="completion", + max_tokens=512, + num_examples=200, + num_threads=128, ) - metrics = run_eval_few_shot_gsm8k(args) + metrics = run_eval(args) print(metrics) - self.assertGreater(metrics["accuracy"], 0.615) + self.assertGreater(metrics["score"], 0.615) class TestFlashinferMLAMTP(CustomTestCase): @@ -102,20 +102,20 @@ def test_gsm8k(self): requests.get(self.base_url + "/flush_cache") args = SimpleNamespace( - num_shots=5, - data_path=None, - num_questions=200, - max_new_tokens=512, - parallel=128, - host="http://127.0.0.1", - port=int(self.base_url.split(":")[-1]), + base_url=self.base_url, + model=self.model, + eval_name="gsm8k", + api="completion", + max_tokens=512, + num_examples=200, + num_threads=128, ) - metrics = run_eval_few_shot_gsm8k(args) + metrics = run_eval(args) print(metrics) - self.assertGreater(metrics["accuracy"], 0.60) + self.assertGreater(metrics["score"], 0.60) - server_info = requests.get(self.base_url + "/get_server_info").json() + server_info = requests.get(self.base_url + "/server_info").json() avg_spec_accept_length = server_info["internal_states"][0][ "avg_spec_accept_length" ] diff --git a/test/registered/mla/test_mla_int8_deepseek_v3.py b/test/registered/mla/test_mla_int8_deepseek_v3.py index d7acb80403bb..1faf4846e118 100644 --- a/test/registered/mla/test_mla_int8_deepseek_v3.py +++ b/test/registered/mla/test_mla_int8_deepseek_v3.py @@ -6,7 +6,7 @@ from sglang.srt.utils import kill_process_tree from sglang.test.ci.ci_register import register_cuda_ci -from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k +from sglang.test.run_eval import run_eval from sglang.test.test_utils import ( DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, DEFAULT_URL_FOR_TEST, @@ -22,7 +22,7 @@ class TestMLADeepseekV3ChannelInt8(CustomTestCase): @classmethod def setUpClass(cls): - cls.model = "sgl-project/sglang-ci-dsv3-channel-int8-test" + cls.model = "lmsys/sglang-ci-dsv3-channel-int8-test" cls.base_url = DEFAULT_URL_FOR_TEST other_args = ["--trust-remote-code"] if torch.cuda.is_available() and torch.version.cuda: @@ -48,25 +48,25 @@ def tearDownClass(cls): def test_gsm8k(self): args = SimpleNamespace( - num_shots=5, - data_path=None, - num_questions=200, - max_new_tokens=512, - parallel=128, - host="http://127.0.0.1", - port=int(self.base_url.split(":")[-1]), + base_url=self.base_url, + model=self.model, + eval_name="gsm8k", + api="completion", + max_tokens=512, + num_examples=200, + num_threads=128, ) - metrics = run_eval_few_shot_gsm8k(args) + metrics = run_eval(args) print(metrics) - self.assertGreaterEqual(metrics["accuracy"], 0.61) + self.assertGreaterEqual(metrics["score"], 0.61) @unittest.skipIf(is_in_ci(), "To reduce the CI execution time.") class TestDeepseekV3MTPChannelInt8(CustomTestCase): @classmethod def setUpClass(cls): - cls.model = "sgl-project/sglang-ci-dsv3-channel-int8-test" + cls.model = "lmsys/sglang-ci-dsv3-channel-int8-test" cls.base_url = DEFAULT_URL_FOR_TEST other_args = ["--trust-remote-code"] if torch.cuda.is_available() and torch.version.cuda: @@ -80,7 +80,7 @@ def setUpClass(cls): "--speculative-algorithm", "EAGLE", "--speculative-draft-model-path", - "sgl-project/sglang-ci-dsv3-channel-int8-test-NextN", + "lmsys/sglang-ci-dsv3-channel-int8-test-NextN", "--speculative-num-steps", "2", "--speculative-eagle-topk", @@ -104,20 +104,20 @@ def test_gsm8k(self): requests.get(self.base_url + "/flush_cache") args = SimpleNamespace( - num_shots=5, - data_path=None, - num_questions=200, - max_new_tokens=512, - parallel=128, - host="http://127.0.0.1", - port=int(self.base_url.split(":")[-1]), + base_url=self.base_url, + model=self.model, + eval_name="gsm8k", + api="completion", + max_tokens=512, + num_examples=200, + num_threads=128, ) - metrics = run_eval_few_shot_gsm8k(args) + metrics = run_eval(args) print(metrics) - self.assertGreater(metrics["accuracy"], 0.60) + self.assertGreater(metrics["score"], 0.60) - server_info = requests.get(self.base_url + "/get_server_info") + server_info = requests.get(self.base_url + "/server_info") avg_spec_accept_length = server_info.json()["internal_states"][0][ "avg_spec_accept_length" ] @@ -129,7 +129,7 @@ def test_gsm8k(self): class TestMLADeepseekV3BlockInt8(CustomTestCase): @classmethod def setUpClass(cls): - cls.model = "sgl-project/sglang-ci-dsv3-block-int8-test" + cls.model = "lmsys/sglang-ci-dsv3-block-int8-test" cls.base_url = DEFAULT_URL_FOR_TEST other_args = ["--trust-remote-code"] if torch.cuda.is_available() and torch.version.cuda: @@ -155,24 +155,24 @@ def tearDownClass(cls): def test_gsm8k(self): args = SimpleNamespace( - num_shots=5, - data_path=None, - num_questions=200, - max_new_tokens=512, - parallel=128, - host="http://127.0.0.1", - port=int(self.base_url.split(":")[-1]), + base_url=self.base_url, + model=self.model, + eval_name="gsm8k", + api="completion", + max_tokens=512, + num_examples=200, + num_threads=128, ) - metrics = run_eval_few_shot_gsm8k(args) + metrics = run_eval(args) print(metrics) - self.assertGreater(metrics["accuracy"], 0.62) + self.assertGreater(metrics["score"], 0.62) class TestDeepseekV3MTPBlockInt8(CustomTestCase): @classmethod def setUpClass(cls): - cls.model = "sgl-project/sglang-ci-dsv3-block-int8-test" + cls.model = "lmsys/sglang-ci-dsv3-block-int8-test" cls.base_url = DEFAULT_URL_FOR_TEST other_args = ["--trust-remote-code"] if torch.cuda.is_available() and torch.version.cuda: @@ -208,20 +208,20 @@ def test_gsm8k(self): requests.get(self.base_url + "/flush_cache") args = SimpleNamespace( - num_shots=5, - data_path=None, - num_questions=200, - max_new_tokens=512, - parallel=128, - host="http://127.0.0.1", - port=int(self.base_url.split(":")[-1]), + base_url=self.base_url, + model=self.model, + eval_name="gsm8k", + api="completion", + max_tokens=512, + num_examples=200, + num_threads=128, ) - metrics = run_eval_few_shot_gsm8k(args) + metrics = run_eval(args) print(metrics) - self.assertGreater(metrics["accuracy"], 0.60) + self.assertGreater(metrics["score"], 0.60) - server_info = requests.get(self.base_url + "/get_server_info") + server_info = requests.get(self.base_url + "/server_info") avg_spec_accept_length = server_info.json()["internal_states"][0][ "avg_spec_accept_length" ] diff --git a/test/registered/model_loading/test_runai_model_loader.py b/test/registered/model_loading/test_runai_model_loader.py new file mode 100644 index 000000000000..8db3a8bba36b --- /dev/null +++ b/test/registered/model_loading/test_runai_model_loader.py @@ -0,0 +1,50 @@ +import unittest + +import sglang as sgl +from sglang.srt.environ import temp_set_env +from sglang.test.ci.ci_register import register_cuda_ci +from sglang.test.test_utils import CustomTestCase + +register_cuda_ci(est_time=120, suite="stage-b-test-1-gpu-small") + +TEST_GCS_MODEL = "gs://vertex-model-garden-public-us/codegemma/codegemma-2b/" + +PROMPTS = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", +] + + +class TestRunaiModelLoader(CustomTestCase): + @classmethod + def setUpClass(cls): + with temp_set_env( + GOOGLE_CLOUD_PROJECT="fake-project", + RUNAI_STREAMER_GCS_USE_ANONYMOUS_CREDENTIALS="true", + CLOUD_STORAGE_EMULATOR_ENDPOINT="https://storage.googleapis.com", + ): + cls.engine = sgl.Engine( + model_path=TEST_GCS_MODEL, + load_format="runai_streamer", + cuda_graph_max_bs=1, + max_total_tokens=64, + ) + + @classmethod + def tearDownClass(cls): + if hasattr(cls, "engine") and cls.engine: + cls.engine.shutdown() + + def test_generate_produces_output(self): + outputs = self.engine.generate(PROMPTS) + self.assertEqual(len(outputs), len(PROMPTS)) + for i, output in enumerate(outputs): + text = output["text"] + self.assertIsInstance(text, str) + self.assertGreater(len(text), 0, f"Prompt {i} produced empty output") + + +if __name__ == "__main__": + unittest.main() diff --git a/test/registered/models/test_compressed_tensors_models.py b/test/registered/models/test_compressed_tensors_models.py index 77e751669755..3ffc328b25a6 100644 --- a/test/registered/models/test_compressed_tensors_models.py +++ b/test/registered/models/test_compressed_tensors_models.py @@ -5,7 +5,7 @@ from sglang.srt.utils import is_hip, kill_process_tree from sglang.test.ci.ci_register import register_amd_ci, register_cuda_ci -from sglang.test.few_shot_gsm8k import run_eval +from sglang.test.run_eval import run_eval from sglang.test.test_utils import ( DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, DEFAULT_URL_FOR_TEST, @@ -35,21 +35,21 @@ def tearDownClass(cls): def test_gsm8k(self): args = SimpleNamespace( - num_shots=5, - data_path=None, - num_questions=200, - max_new_tokens=512, - parallel=128, - host="http://127.0.0.1", - port=int(self.base_url.split(":")[-1]), + base_url=self.base_url, + model=self.model, + eval_name="gsm8k", + api="completion", + max_tokens=512, + num_examples=200, + num_threads=128, ) metrics = run_eval(args) print(f"{metrics=}") if is_hip(): # Lower threshold for AMD because FP8 dtype differs (fp8_fnuz) - self.assertGreaterEqual(metrics["accuracy"], 0.40) + self.assertGreaterEqual(metrics["score"], 0.40) else: - self.assertGreaterEqual(metrics["accuracy"], 0.45) + self.assertGreaterEqual(metrics["score"], 0.45) if __name__ == "__main__": diff --git a/test/registered/models/test_gpt_oss_models_pcg.py b/test/registered/models/test_gpt_oss_models_pcg.py deleted file mode 100644 index 438f127c2a4f..000000000000 --- a/test/registered/models/test_gpt_oss_models_pcg.py +++ /dev/null @@ -1,72 +0,0 @@ -""" -GPT-OSS piecewise CUDA graph tests. -""" - -import unittest -from types import SimpleNamespace - -from sglang.srt.utils import kill_process_tree -from sglang.test.ci.ci_register import register_cuda_ci -from sglang.test.few_shot_gsm8k import run_eval -from sglang.test.test_utils import ( - DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, - DEFAULT_URL_FOR_TEST, - CustomTestCase, - popen_launch_server, -) - -register_cuda_ci( - est_time=400, - suite="stage-b-test-2-gpu-large", -) - -GPT_OSS_MODEL = "openai/gpt-oss-120b" - -ACC_THRESHOLDS = { - GPT_OSS_MODEL: {"gsm8k": 0.81}, -} - - -class TestGptOssPiecewiseCudaGraph(CustomTestCase): - - @classmethod - def setUpClass(cls): - cls.model = GPT_OSS_MODEL - cls.base_url = DEFAULT_URL_FOR_TEST - cls.process = popen_launch_server( - cls.model, - cls.base_url, - timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, - other_args=[ - "--tp", - "2", - "--trust-remote-code", - "--reasoning-parser", - "gpt-oss", - "--enable-piecewise-cuda-graph", - ], - ) - - @classmethod - def tearDownClass(cls): - kill_process_tree(cls.process.pid) - - def test_gsm8k(self): - args = SimpleNamespace( - num_shots=5, - data_path=None, - num_questions=200, - max_new_tokens=512, - parallel=128, - host="http://127.0.0.1", - port=int(self.base_url.split(":")[-1]), - ) - metrics = run_eval(args) - print(f"{metrics=}") - self.assertGreaterEqual( - metrics["accuracy"], ACC_THRESHOLDS[self.model]["gsm8k"] - ) - - -if __name__ == "__main__": - unittest.main() diff --git a/test/registered/models/test_kimi_linear_models.py b/test/registered/models/test_kimi_linear_models.py index c63f6ac32120..a97a3af8c08c 100644 --- a/test/registered/models/test_kimi_linear_models.py +++ b/test/registered/models/test_kimi_linear_models.py @@ -3,7 +3,7 @@ from sglang.srt.utils import kill_process_tree from sglang.test.ci.ci_register import register_cuda_ci -from sglang.test.few_shot_gsm8k import run_eval +from sglang.test.run_eval import run_eval from sglang.test.test_utils import ( DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, DEFAULT_URL_FOR_TEST, @@ -11,7 +11,7 @@ popen_launch_server, ) -register_cuda_ci(est_time=90, suite="stage-b-test-2-gpu-large") +register_cuda_ci(est_time=180, suite="stage-b-test-2-gpu-large") class TestKimiLinear(CustomTestCase): @@ -32,17 +32,17 @@ def tearDownClass(cls): def test_gsm8k(self): args = SimpleNamespace( - num_shots=5, - data_path=None, - num_questions=200, - max_new_tokens=512, - parallel=128, - host="http://127.0.0.1", - port=int(self.base_url.split(":")[-1]), + base_url=self.base_url, + model=self.model, + eval_name="gsm8k", + api="completion", + max_tokens=512, + num_examples=200, + num_threads=128, ) metrics = run_eval(args) print(f"{metrics=}") - self.assertGreater(metrics["accuracy"], 0.88) + self.assertGreater(metrics["score"], 0.88) if __name__ == "__main__": diff --git a/test/registered/models/test_kimi_linear_models_pcg.py b/test/registered/models/test_kimi_linear_models_pcg.py deleted file mode 100644 index 32314f2b80f7..000000000000 --- a/test/registered/models/test_kimi_linear_models_pcg.py +++ /dev/null @@ -1,69 +0,0 @@ -""" -Kimi-Linear piecewise CUDA graph tests. -""" - -import unittest -from types import SimpleNamespace - -from sglang.srt.utils import kill_process_tree -from sglang.test.ci.ci_register import register_cuda_ci -from sglang.test.few_shot_gsm8k import run_eval -from sglang.test.test_utils import ( - DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, - DEFAULT_URL_FOR_TEST, - CustomTestCase, - popen_launch_server, -) - -register_cuda_ci( - est_time=100, - suite="stage-b-test-2-gpu-large", -) - -KIMI_LINEAR_MODEL = "moonshotai/Kimi-Linear-48B-A3B-Instruct" - -ACC_THRESHOLDS = { - KIMI_LINEAR_MODEL: {"gsm8k": 0.88}, -} - - -class TestKimiLinearPiecewiseCudaGraph(CustomTestCase): - - @classmethod - def setUpClass(cls): - cls.model = KIMI_LINEAR_MODEL - cls.base_url = DEFAULT_URL_FOR_TEST - cls.process = popen_launch_server( - cls.model, - cls.base_url, - timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, - other_args=[ - "--tp", - "2", - "--trust-remote-code", - ], - ) - - @classmethod - def tearDownClass(cls): - kill_process_tree(cls.process.pid) - - def test_gsm8k(self): - args = SimpleNamespace( - num_shots=5, - data_path=None, - num_questions=200, - max_new_tokens=512, - parallel=128, - host="http://127.0.0.1", - port=int(self.base_url.split(":")[-1]), - ) - metrics = run_eval(args) - print(f"{metrics=}") - self.assertGreaterEqual( - metrics["accuracy"], ACC_THRESHOLDS[self.model]["gsm8k"] - ) - - -if __name__ == "__main__": - unittest.main() diff --git a/test/registered/models/test_ministral4_models.py b/test/registered/models/test_ministral4_models.py new file mode 100644 index 000000000000..875e0a75e511 --- /dev/null +++ b/test/registered/models/test_ministral4_models.py @@ -0,0 +1,32 @@ +import unittest + +from sglang.test.ci.ci_register import register_cuda_ci +from sglang.test.kits.eval_accuracy_kit import GSM8KMixin +from sglang.test.kits.mmmu_vlm_kit import MMMUMixin +from sglang.test.server_fixtures.default_fixture import DefaultServerBase +from sglang.test.server_fixtures.mmmu_fixture import MMMUServerBase + +register_cuda_ci( + est_time=200, + suite="stage-b-test-2-gpu-large", +) + +MODEL = "mistralai/Mistral-Small-4-119B-2603" + + +class TestMistralSmall4TextOnly(GSM8KMixin, DefaultServerBase): + gsm8k_accuracy_thres = 0.9 + model = MODEL + other_args = ["--tp-size", "2", "--trust-remote-code"] + + +class TestMistralSmall4MMMU(MMMUMixin, MMMUServerBase): + accuracy = 0.45 + model = MODEL + other_args = ["--tp-size", "2", "--trust-remote-code"] + mmmu_args = ["--limit=0.1"] + """`--limit=0.1`: 10 percent of each task - this is fine for testing since the nominal result isn't interesting - this run is just to prevent relative regressions.""" + + +if __name__ == "__main__": + unittest.main() diff --git a/test/registered/models/test_nvidia_nemotron_3_nano.py b/test/registered/models/test_nvidia_nemotron_3_nano.py index 0580979c49f3..1d26abd21c1a 100644 --- a/test/registered/models/test_nvidia_nemotron_3_nano.py +++ b/test/registered/models/test_nvidia_nemotron_3_nano.py @@ -4,7 +4,7 @@ from sglang.test.kits.lm_eval_kit import LMEvalMixin from sglang.test.server_fixtures.default_fixture import DefaultServerBase -register_cuda_ci(est_time=180, suite="stage-b-test-2-gpu-large") +register_cuda_ci(est_time=660, suite="stage-b-test-2-gpu-large") NEMOTRON_3_NANO_THINKING_ARGS = [ "--trust-remote-code", diff --git a/test/registered/models/test_nvidia_nemotron_nano_v2.py b/test/registered/models/test_nvidia_nemotron_nano_v2.py index e5b0150acfb7..b12395a6a3f0 100644 --- a/test/registered/models/test_nvidia_nemotron_nano_v2.py +++ b/test/registered/models/test_nvidia_nemotron_nano_v2.py @@ -5,7 +5,7 @@ from sglang.test.kits.eval_accuracy_kit import GSM8KMixin from sglang.test.server_fixtures.default_fixture import DefaultServerBase -register_cuda_ci(est_time=132, suite="stage-b-test-2-gpu-large") +register_cuda_ci(est_time=240, suite="stage-b-test-2-gpu-large") class TestNvidiaNemotronNanoV2BF16(GSM8KMixin, DefaultServerBase): diff --git a/test/registered/models/test_qwen3_next_models_pcg.py b/test/registered/models/test_qwen3_next_models_pcg.py deleted file mode 100644 index 593b6d70079e..000000000000 --- a/test/registered/models/test_qwen3_next_models_pcg.py +++ /dev/null @@ -1,29 +0,0 @@ -""" -Qwen3 Next piecewise CUDA graph tests. -""" - -import unittest - -from sglang.test.ci.ci_register import register_cuda_ci -from sglang.test.kits.eval_accuracy_kit import GSM8KMixin -from sglang.test.server_fixtures.default_fixture import DefaultServerBase - -register_cuda_ci( - est_time=400, - suite="stage-c-test-4-gpu-h100", -) - -QWEN3_NEXT_MODEL = "Qwen/Qwen3-Next-80B-A3B-Instruct" - - -class TestQwen3NextPiecewiseCudaGraph(GSM8KMixin, DefaultServerBase): - model = QWEN3_NEXT_MODEL - gsm8k_accuracy_thres = 0.93 - other_args = [ - "--tp", - "4", - ] - - -if __name__ == "__main__": - unittest.main() diff --git a/test/registered/models/test_qwen_models.py b/test/registered/models/test_qwen_models.py index 45b2aeb540a3..24817c816c4c 100644 --- a/test/registered/models/test_qwen_models.py +++ b/test/registered/models/test_qwen_models.py @@ -5,7 +5,7 @@ from sglang.srt.utils import kill_process_tree from sglang.test.ci.ci_register import register_amd_ci, register_cuda_ci -from sglang.test.few_shot_gsm8k import run_eval +from sglang.test.run_eval import run_eval from sglang.test.test_utils import ( DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, DEFAULT_URL_FOR_TEST, @@ -35,17 +35,17 @@ def tearDownClass(cls): def test_gsm8k(self): args = SimpleNamespace( - num_shots=5, - data_path=None, - num_questions=200, - max_new_tokens=512, - parallel=128, - host="http://127.0.0.1", - port=int(self.base_url.split(":")[-1]), + base_url=self.base_url, + model=self.model, + eval_name="gsm8k", + api="completion", + max_tokens=512, + num_examples=200, + num_threads=128, ) metrics = run_eval(args) print(f"{metrics=}") - self.assertGreater(metrics["accuracy"], 0.78) + self.assertGreater(metrics["score"], 0.78) class TestQwen2FP8(CustomTestCase): @@ -66,17 +66,17 @@ def tearDownClass(cls): def test_gsm8k(self): args = SimpleNamespace( - num_shots=5, - data_path=None, - num_questions=200, - max_new_tokens=512, - parallel=128, - host="http://127.0.0.1", - port=int(self.base_url.split(":")[-1]), + base_url=self.base_url, + model=self.model, + eval_name="gsm8k", + api="completion", + max_tokens=512, + num_examples=200, + num_threads=128, ) metrics = run_eval(args) print(f"{metrics=}") - self.assertGreater(metrics["accuracy"], 0.78) + self.assertGreater(metrics["score"], 0.78) if __name__ == "__main__": diff --git a/test/registered/models/test_transformers_backend_eval.py b/test/registered/models/test_transformers_backend_eval.py new file mode 100644 index 000000000000..665696cff0fe --- /dev/null +++ b/test/registered/models/test_transformers_backend_eval.py @@ -0,0 +1,43 @@ +"""A small end-to-end eval coverage for the transformers modeling backend.""" + +import unittest +from types import SimpleNamespace + +from sglang.test.ci.ci_register import register_cuda_ci +from sglang.test.few_shot_gsm8k import run_eval +from sglang.test.server_fixtures.default_fixture import DefaultServerBase + +register_cuda_ci(est_time=180, suite="stage-b-test-small-1-gpu") + + +class TestTransformersBackendEval(DefaultServerBase): + model = "HuggingFaceTB/SmolLM3-3B" + gsm8k_num_questions = 30 + gsm8k_accuracy_thres = 0.5 + gsm8k_parallel = 30 + other_args = [ + "--model-impl", + "transformers", + "--enable-torch-compile", + "--torch-compile-max-bs", + "4", + "--disable-cuda-graph", + ] + + def test_gsm8k(self): + args = SimpleNamespace( + num_shots=5, + data_path=None, + num_questions=self.gsm8k_num_questions, + max_new_tokens=512, + parallel=self.gsm8k_parallel, + host="127.0.0.1", + port=int(self.base_url.split(":")[-1]), + ) + metrics = run_eval(args) + print(f"{metrics=}") + self.assertGreaterEqual(metrics["accuracy"], self.gsm8k_accuracy_thres) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/registered/models/test_transformers_models.py b/test/registered/models/test_transformers_models.py index 325db6e2967a..42416ad5481f 100644 --- a/test/registered/models/test_transformers_models.py +++ b/test/registered/models/test_transformers_models.py @@ -10,6 +10,7 @@ from sglang.srt.utils import is_hip, kill_process_tree from sglang.test.ci.ci_register import register_amd_ci, register_cuda_ci +from sglang.test.run_eval import run_eval from sglang.test.runners import DEFAULT_PROMPTS, SRTRunner, check_close_model_outputs from sglang.test.test_utils import ( DEFAULT_MODEL_NAME_FOR_TEST, @@ -20,7 +21,7 @@ popen_launch_server, ) -register_cuda_ci(est_time=245, suite="stage-b-test-1-gpu-small") +register_cuda_ci(est_time=450, suite="stage-b-test-1-gpu-small") register_amd_ci(est_time=320, suite="stage-b-test-1-gpu-small-amd") @@ -50,26 +51,22 @@ def test_mmlu(self): num_examples=64, num_threads=32, ) - from sglang.test.run_eval import run_eval - metrics = run_eval(args) self.assertGreaterEqual(metrics["score"], self.mmlu_lower_bound) def test_gsm8k(self): args = SimpleNamespace( - num_shots=5, - data_path=None, - num_questions=200, - max_new_tokens=512, - parallel=128, - host="http://127.0.0.1", - port=int(self.base_url.split(":")[-1]), + base_url=self.base_url, + model=self.model, + eval_name="gsm8k", + api="completion", + max_tokens=512, + num_examples=200, + num_threads=128, ) - from sglang.test.few_shot_gsm8k import run_eval - metrics = run_eval(args) print(f"{metrics=}") - self.assertGreater(metrics["accuracy"], self.gsm8k_lower_bound) + self.assertGreater(metrics["score"], self.gsm8k_lower_bound) @unittest.skipIf(is_hip(), "TorchAO int4wo quantization is not supported on AMD GPUs") diff --git a/test/registered/moe/test_cutedsl_moe.py b/test/registered/moe/test_cutedsl_moe.py index 3846334d0f02..76b3b4a8723f 100644 --- a/test/registered/moe/test_cutedsl_moe.py +++ b/test/registered/moe/test_cutedsl_moe.py @@ -12,7 +12,7 @@ from sglang.srt.layers.moe.topk import TopKConfig, select_experts from sglang.test.ci.ci_register import register_cuda_ci -register_cuda_ci(est_time=300, suite="stage-c-test-4-gpu-b200") +register_cuda_ci(est_time=20, suite="stage-c-test-4-gpu-b200") SKIP_TEST = torch.cuda.get_device_capability() < (10, 0) SKIP_REASON = "Nvfp4 Requires compute capability of 10 or above." diff --git a/test/registered/moe/test_glm4_moe_models.py b/test/registered/moe/test_glm4_moe_models.py index 59003f5ba40d..010556b8368c 100644 --- a/test/registered/moe/test_glm4_moe_models.py +++ b/test/registered/moe/test_glm4_moe_models.py @@ -3,7 +3,7 @@ from sglang.srt.utils import kill_process_tree from sglang.test.ci.ci_register import register_cuda_ci -from sglang.test.few_shot_gsm8k import run_eval +from sglang.test.run_eval import run_eval from sglang.test.test_utils import ( DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, DEFAULT_URL_FOR_TEST, @@ -35,17 +35,17 @@ def tearDownClass(cls): def test_gsm8k(self): args = SimpleNamespace( - num_shots=5, - data_path=None, - num_questions=100, - max_new_tokens=512, - parallel=128, - host="http://127.0.0.1", - port=int(self.base_url.split(":")[-1]), + base_url=self.base_url, + model=self.model, + eval_name="gsm8k", + api="completion", + max_tokens=512, + num_examples=100, + num_threads=128, ) metrics = run_eval(args) print(f"{metrics=}") - self.assertGreater(metrics["accuracy"], 0.8) + self.assertGreater(metrics["score"], 0.8) if __name__ == "__main__": diff --git a/test/registered/moe/test_moe_ep.py b/test/registered/moe/test_moe_ep.py index 155ce767ca3e..1d27d3c209c3 100644 --- a/test/registered/moe/test_moe_ep.py +++ b/test/registered/moe/test_moe_ep.py @@ -5,20 +5,20 @@ from sglang.test.ci.ci_register import register_cuda_ci from sglang.test.run_eval import run_eval from sglang.test.test_utils import ( - DEFAULT_MLA_MODEL_NAME_FOR_TEST, + DEFAULT_MODEL_NAME_FOR_TEST_MLA, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, DEFAULT_URL_FOR_TEST, CustomTestCase, popen_launch_server, ) -register_cuda_ci(est_time=140, suite="stage-b-test-2-gpu-large") +register_cuda_ci(est_time=250, suite="stage-b-test-2-gpu-large") class TestEp(CustomTestCase): @classmethod def setUpClass(cls): - cls.model = DEFAULT_MLA_MODEL_NAME_FOR_TEST + cls.model = DEFAULT_MODEL_NAME_FOR_TEST_MLA cls.base_url = DEFAULT_URL_FOR_TEST cls.process = popen_launch_server( cls.model, @@ -37,23 +37,25 @@ def setUpClass(cls): def tearDownClass(cls): kill_process_tree(cls.process.pid) - def test_mgsm_en(self): + def test_gsm8k(self): args = SimpleNamespace( base_url=self.base_url, - model=self.model, - eval_name="mgsm_en", - num_examples=None, - num_threads=1024, + eval_name="gsm8k", + api="completion", + max_tokens=512, + num_examples=200, + num_threads=128, ) - metrics = run_eval(args) - self.assertGreaterEqual(metrics["score"], 0.8) + print(metrics) + + self.assertGreater(metrics["score"], 0.60) class TestEpDeepGEMM(CustomTestCase): @classmethod def setUpClass(cls): - cls.model = DEFAULT_MLA_MODEL_NAME_FOR_TEST + cls.model = DEFAULT_MODEL_NAME_FOR_TEST_MLA cls.base_url = DEFAULT_URL_FOR_TEST cls.process = popen_launch_server( cls.model, @@ -76,17 +78,19 @@ def setUpClass(cls): def tearDownClass(cls): kill_process_tree(cls.process.pid) - def test_mgsm_en(self): + def test_gsm8k(self): args = SimpleNamespace( base_url=self.base_url, - model=self.model, - eval_name="mgsm_en", - num_examples=None, - num_threads=1024, + eval_name="gsm8k", + api="completion", + max_tokens=512, + num_examples=200, + num_threads=128, ) - metrics = run_eval(args) - self.assertGreaterEqual(metrics["score"], 0.8) + print(metrics) + + self.assertGreater(metrics["score"], 0.60) if __name__ == "__main__": diff --git a/test/registered/moe/test_triton_fused_moe.py b/test/registered/moe/test_triton_fused_moe.py index e20ed609fb19..2255d64bfc46 100644 --- a/test/registered/moe/test_triton_fused_moe.py +++ b/test/registered/moe/test_triton_fused_moe.py @@ -12,7 +12,7 @@ from sglang.test.ci.ci_register import register_cuda_ci from sglang.test.test_utils import CustomTestCase -register_cuda_ci(est_time=89, suite="stage-b-test-1-gpu-large") +register_cuda_ci(est_time=30, suite="stage-b-test-1-gpu-large") class TestFusedMOE(CustomTestCase): diff --git a/test/registered/metrics/test_metrics.py b/test/registered/observability/test_metrics.py similarity index 76% rename from test/registered/metrics/test_metrics.py rename to test/registered/observability/test_metrics.py index 4aafbabae38b..d1f0f381bee4 100644 --- a/test/registered/metrics/test_metrics.py +++ b/test/registered/observability/test_metrics.py @@ -20,7 +20,7 @@ popen_launch_server, ) -register_cuda_ci(est_time=32, suite="stage-b-test-1-gpu-small") +register_cuda_ci(est_time=95, suite="stage-b-test-1-gpu-small") register_amd_ci(est_time=32, suite="stage-b-test-1-gpu-small-amd") _MODEL_NAME = "Qwen/Qwen3-0.6B" @@ -32,6 +32,17 @@ def test_metrics_1gpu(self): self._execute_core( other_args=[], verify_metrics_extra=None, + expect_mfu_metrics=True, + enable_mfu_metrics=True, + ) + + def test_mfu_metrics_gate_disabled(self): + """MFU metrics should not be emitted when the gate is disabled.""" + self._execute_core( + other_args=[], + verify_metrics_extra=None, + expect_mfu_metrics=False, + enable_mfu_metrics=False, ) def test_metrics_2gpu(self): @@ -71,19 +82,30 @@ def _verify_metrics_extra(metrics): self._execute_core( other_args=["--tp", "2", "--dp", "2", "--enable-dp-attention"], verify_metrics_extra=_verify_metrics_extra, + expect_mfu_metrics=True, + enable_mfu_metrics=True, ) - def _execute_core(self, other_args, verify_metrics_extra): + def _execute_core( + self, + other_args, + verify_metrics_extra, + expect_mfu_metrics: bool, + enable_mfu_metrics: bool, + ): with ( envs.SGLANG_ENABLE_METRICS_DP_ATTENTION.override(True), envs.SGLANG_ENABLE_METRICS_DEVICE_TIMER.override(True), envs.SGLANG_TEST_RETRACT.override(True), ): + launch_args = ["--enable-metrics", "--cuda-graph-max-bs", 2, *other_args] + if enable_mfu_metrics: + launch_args.insert(1, "--enable-mfu-metrics") process = popen_launch_server( _MODEL_NAME, DEFAULT_URL_FOR_TEST, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, - other_args=["--enable-metrics", "--cuda-graph-max-bs", 2, *other_args], + other_args=launch_args, ) try: @@ -125,13 +147,13 @@ def _execute_core(self, other_args, verify_metrics_extra): print(f"metrics_text=\n{metrics_text}") metrics = _parse_prometheus_metrics(metrics_text) - self._verify_metrics_common(metrics_text, metrics) + self._verify_metrics_common(metrics_text, metrics, expect_mfu_metrics) if verify_metrics_extra is not None: verify_metrics_extra(metrics) finally: kill_process_tree(process.pid) - def _verify_metrics_common(self, metrics_text, metrics): + def _verify_metrics_common(self, metrics_text, metrics, expect_mfu_metrics: bool): essential_metrics = [ "sglang:num_running_reqs", "sglang:num_used_tokens", @@ -154,6 +176,13 @@ def _verify_metrics_common(self, metrics_text, metrics): "sglang:routing_key_running_req_count", "sglang:routing_key_all_req_count", ] + mfu_metrics = [ + "sglang:estimated_flops_per_gpu_total", + "sglang:estimated_read_bytes_per_gpu_total", + "sglang:estimated_write_bytes_per_gpu_total", + ] + if expect_mfu_metrics: + essential_metrics.extend(mfu_metrics) for metric in essential_metrics: self.assertIn(metric, metrics_text, f"Missing metric: {metric}") @@ -186,6 +215,39 @@ def _verify_metrics_common(self, metrics_text, metrics): ] _check_metrics_positive(self, metrics, metrics_to_check) + if expect_mfu_metrics: + # Estimated perf metrics may have multiple series (e.g., by rank). Ensure + # that at least one series for this model has a positive accumulated value. + for metric_name in mfu_metrics: + values = [ + sample.value + for sample in metrics.get(metric_name, []) + if sample.labels.get("model_name") == _MODEL_NAME + ] + self.assertTrue( + values, f"{metric_name}: no samples for model {_MODEL_NAME}" + ) + self.assertGreater( + sum(values), + 0, + f"{metric_name}: expected positive total for model {_MODEL_NAME}", + ) + else: + # With only --enable-metrics (without --enable-mfu-metrics), MFU + # counters should not emit positive values. + for metric_name in mfu_metrics: + values = [ + sample.value + for sample in metrics.get(metric_name, []) + if sample.labels.get("model_name") == _MODEL_NAME + ] + if values: + self.assertEqual( + sum(values), + 0, + f"{metric_name}: expected no positive samples with MFU metrics gate disabled", + ) + def _parse_prometheus_metrics(metrics_text: str) -> Dict[str, List[Sample]]: result = {} diff --git a/test/registered/metrics/test_priority_metrics.py b/test/registered/observability/test_priority_metrics.py similarity index 100% rename from test/registered/metrics/test_priority_metrics.py rename to test/registered/observability/test_priority_metrics.py diff --git a/test/registered/observability/test_tracing.py b/test/registered/observability/test_tracing.py new file mode 100644 index 000000000000..995c3196a7b4 --- /dev/null +++ b/test/registered/observability/test_tracing.py @@ -0,0 +1,794 @@ +"""Integration tests for tracing with a lightweight in-process OTLP collector. + +This module implements a minimal OTLP collector that receives traces via gRPC +and stores them in memory for test assertions, eliminating the need for +Docker-based opentelemetry-collector and file I/O. +""" + +import os + +# Configure OTLP exporter for faster test execution +# Must be set before importing sglang trace module +os.environ.setdefault("SGLANG_OTLP_EXPORTER_SCHEDULE_DELAY_MILLIS", "50") +os.environ.setdefault("SGLANG_OTLP_EXPORTER_MAX_EXPORT_BATCH_SIZE", "4") + +import json +import logging +import multiprocessing as mp +import threading +import time +import unittest +from concurrent import futures +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional, Set, Union + +import requests +import zmq + +from sglang import Engine +from sglang.srt.observability.req_time_stats import RequestStage +from sglang.srt.observability.trace import ( + TraceReqContext, + TraceSliceContext, + get_cur_time_ns, + process_tracing_init, + set_global_trace_level, + trace_set_thread_info, +) +from sglang.srt.utils import kill_process_tree +from sglang.srt.utils.network import get_zmq_socket +from sglang.test.ci.ci_register import register_cuda_ci +from sglang.test.test_utils import ( + DEFAULT_SMALL_MODEL_NAME_FOR_TEST, + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, + CustomTestCase, + popen_launch_server, +) + +logger = logging.getLogger(__name__) + +# CI registration +register_cuda_ci(est_time=120, suite="stage-b-test-1-gpu-small") + + +# ============================================================================ +# Lightweight OTLP Collector (replaces Docker-based otel-collector) +# ============================================================================ + + +@dataclass +class Span: + """Represents a single span extracted from OTLP trace data.""" + + name: str + trace_id: str = "" + span_id: str = "" + parent_span_id: str = "" + start_time_ns: int = 0 + end_time_ns: int = 0 + attributes: Dict[str, Any] = field(default_factory=dict) + events: List[Dict[str, Any]] = field(default_factory=list) + + +class LightweightOtlpCollector: + """A minimal OTLP collector that stores traces in memory for test assertions. + + This replaces the Docker-based opentelemetry-collector for testing purposes. + It listens on a gRPC port for OTLP trace data and stores spans in memory, + allowing tests to verify specific spans based on trace level. + """ + + def __init__(self, port: int = 4317): + self.port = port + self._server = None + self._thread = None + self._running = False + self._lock = threading.Lock() + # In-memory storage for collected spans + self._spans: List[Span] = [] + self._raw_traces: List[Dict[str, Any]] = [] + + def _try_grpc_server(self): + """Try to start gRPC server with full OTLP protocol.""" + try: + from grpc import server as grpc_server + from opentelemetry.proto.collector.trace.v1.trace_service_pb2 import ( + ExportTraceServiceResponse, + ) + from opentelemetry.proto.collector.trace.v1.trace_service_pb2_grpc import ( + TraceServiceServicer, + add_TraceServiceServicer_to_server, + ) + + class TraceServicer(TraceServiceServicer): + def __init__(self, collector): + self.collector = collector + + def Export(self, request, context): + self.collector._handle_trace_request(request) + return ExportTraceServiceResponse() + + self._server = grpc_server(futures.ThreadPoolExecutor(max_workers=4)) + add_TraceServiceServicer_to_server(TraceServicer(self), self._server) + self._server.add_insecure_port(f"127.0.0.1:{self.port}") + return True + except ImportError: + logger.warning("Full gRPC OTLP not available, using HTTP fallback") + return False + + def _handle_trace_request(self, request): + """Handle incoming trace request and extract spans to memory.""" + with self._lock: + try: + trace_data = self._protobuf_to_dict(request) + self._raw_traces.append(trace_data) + # Extract spans from the trace data + self._extract_spans(trace_data) + except Exception as e: + logger.error(f"Failed to process trace: {e}") + + def _protobuf_to_dict(self, proto_obj) -> Dict[str, Any]: + """Convert protobuf message to dict.""" + result = {} + for field, value in proto_obj.ListFields(): + if field.message_type: + type_name = type(value).__name__ + if "Repeated" in type_name: + result[field.name] = [self._protobuf_to_dict(v) for v in value] + else: + result[field.name] = self._protobuf_to_dict(value) + else: + result[field.name] = value + return result + + def _extract_spans(self, trace_data: Dict[str, Any]): + """Extract Span objects from OTLP trace data structure.""" + resource_spans = trace_data.get("resource_spans", []) + for rs in resource_spans: + scope_spans = rs.get("scope_spans", []) + for ss in scope_spans: + spans = ss.get("spans", []) + for span_data in spans: + span = Span( + name=span_data.get("name", ""), + trace_id=span_data.get("trace_id", ""), + span_id=span_data.get("span_id", ""), + parent_span_id=span_data.get("parent_span_id", ""), + start_time_ns=span_data.get("start_time_unix_nano", 0), + end_time_ns=span_data.get("end_time_unix_nano", 0), + attributes=span_data.get("attributes", {}), + events=span_data.get("events", []), + ) + self._spans.append(span) + + def _http_server_loop(self): + """Fallback HTTP server for OTLP HTTP protocol.""" + from http.server import BaseHTTPRequestHandler, HTTPServer + + class OTLPHandler(BaseHTTPRequestHandler): + def __init__(self, request, client_address, server): + self.collector = server.collector + super().__init__(request, client_address, server) + + def do_POST(self): + if self.path in ["/v1/traces", "/v1/traces/"]: + content_length = int(self.headers.get("Content-Length", 0)) + body = self.rfile.read(content_length) + try: + data = json.loads(body) + with self.collector._lock: + self.collector._raw_traces.append(data) + self.collector._extract_spans_http(data) + self.send_response(200) + self.end_headers() + except Exception as e: + logger.error(f"HTTP trace handling error: {e}") + self.send_response(500) + self.end_headers() + else: + self.send_response(404) + self.end_headers() + + def log_message(self, format, *args): + pass # Suppress HTTP server logs + + class CollectorHTTPServer(HTTPServer): + def __init__(self, server_address, collector): + self.collector = collector + super().__init__( + server_address, + lambda r, a, s: OTLPHandler(r, a, s), + ) + + server = CollectorHTTPServer(("127.0.0.1", 4318), self) + server.timeout = 0.5 + while self._running: + server.handle_request() + + def _extract_spans_http(self, data: Dict[str, Any]): + """Extract Span objects from OTLP HTTP JSON format.""" + resource_spans = data.get("resourceSpans", []) + for rs in resource_spans: + scope_spans = rs.get("scopeSpans", []) + for ss in scope_spans: + spans = ss.get("spans", []) + for span_data in spans: + span = Span( + name=span_data.get("name", ""), + trace_id=span_data.get("traceId", ""), + span_id=span_data.get("spanId", ""), + parent_span_id=span_data.get("parentSpanId", ""), + start_time_ns=span_data.get("startTimeUnixNano", 0), + end_time_ns=span_data.get("endTimeUnixNano", 0), + attributes=span_data.get("attributes", {}), + events=span_data.get("events", []), + ) + self._spans.append(span) + + def start(self): + """Start the collector server.""" + self._running = True + self._spans.clear() + self._raw_traces.clear() + if self._try_grpc_server(): + self._server.start() + logger.info(f"OTLP gRPC collector started on port {self.port}") + else: + # Fallback to HTTP server in a thread + self._thread = threading.Thread(target=self._http_server_loop, daemon=True) + self._thread.start() + logger.info("OTLP HTTP collector started on port 4318") + + def stop(self): + """Stop the collector server.""" + self._running = False + if self._server: + self._server.stop(1) + self._server = None + logger.info("OTLP collector stopped") + + # ======================================================================== + # Public API for test assertions + # ======================================================================== + + def get_spans(self) -> List[Span]: + """Get all collected spans.""" + with self._lock: + return list(self._spans) + + def get_span_names(self) -> Set[str]: + """Get all unique span names.""" + with self._lock: + return {s.name for s in self._spans} + + def has_span(self, name: str) -> bool: + """Check if a span with the given name exists.""" + return name in self.get_span_names() + + def has_any_span(self, names: List[str]) -> bool: + """Check if any of the given span names exist.""" + span_names = self.get_span_names() + return any(name in span_names for name in names) + + def has_all_spans(self, names: List[str]) -> bool: + """Check if all of the given span names exist.""" + span_names = self.get_span_names() + return all(name in span_names for name in names) + + def get_spans_by_name(self, name: str) -> List[Span]: + """Get all spans with the given name.""" + with self._lock: + return [s for s in self._spans if s.name == name] + + def count_spans(self) -> int: + """Get total count of collected spans.""" + with self._lock: + return len(self._spans) + + def clear(self): + """Clear all collected spans.""" + with self._lock: + self._spans.clear() + self._raw_traces.clear() + + +# ============================================================================ +# Test Helper Functions +# ============================================================================ + + +def _get_span_names_by_level(level: int) -> List[str]: + """Get expected span names for a given trace level. + + Based on RequestStage definitions in req_time_stats.py: + - Each RequestStage has a level attribute indicating minimum trace level required + - Spans with level <= current trace level will be exported + """ + span_names = [] + # RequestStage is a class with class attributes that are RequestStageConfig instances + for attr_name in dir(RequestStage): + if attr_name.startswith("_"): + continue + attr = getattr(RequestStage, attr_name) + # Check if it's a RequestStageConfig (has stage_name and level attributes) + if hasattr(attr, "stage_name") and hasattr(attr, "level"): + if attr.level <= level and attr.stage_name: + span_names.append(attr.stage_name) + return span_names + + +# Pre-computed span names by level for efficiency +SPAN_NAMES_LEVEL_1 = _get_span_names_by_level(1) +SPAN_NAMES_LEVEL_2 = _get_span_names_by_level(2) +SPAN_NAMES_LEVEL_3 = _get_span_names_by_level(3) + +# Common span names expected in typical inference requests +# Level 1: Basic request lifecycle +EXPECTED_SPANS_LEVEL_1 = [ + RequestStage.PREFILL_FORWARD.stage_name, + RequestStage.DECODE_FORWARD.stage_name, +] + +# Level 2: More detailed including dispatch +EXPECTED_SPANS_LEVEL_2 = EXPECTED_SPANS_LEVEL_1 + [ + RequestStage.REQUEST_PROCESS.stage_name, +] + +# Level 3: Most detailed including internal operations +EXPECTED_SPANS_LEVEL_3 = EXPECTED_SPANS_LEVEL_2 + [ + RequestStage.DECODE_LOOP.stage_name, +] + + +@dataclass +class Req: + rid: int + req_context: Optional[Union[TraceReqContext]] = None + + +def _subprocess_worker(): + """Worker function for subprocess trace context propagation test. + Must be at module level for pickle compatibility with spawn. + """ + process_tracing_init("127.0.0.1:4317", "test") + trace_set_thread_info("Sub Process") + + context = zmq.Context(2) + recv_from_main = get_zmq_socket(context, zmq.PULL, "ipc:///tmp/zmq_test.ipc", True) + + try: + req = recv_from_main.recv_pyobj() + req.req_context.rebuild_thread_context() + req.req_context.trace_slice_start("work", level=1) + time.sleep(0.2) + req.req_context.trace_slice_end("work", level=1, thread_finish_flag=True) + finally: + recv_from_main.close() + context.term() + + +# ============================================================================ +# Test Cases +# ============================================================================ + + +class TestTracePackage(CustomTestCase): + """Unit tests for tracing package API without server/engine.""" + + def setUp(self): + self.collector = None + + def tearDown(self): + if self.collector: + self.collector.stop() + self.collector = None + + def _start_collector(self): + """Start the lightweight OTLP collector.""" + self.collector = LightweightOtlpCollector() + self.collector.start() + time.sleep(0.2) + + def test_slice_simple(self): + """Unit test: simple slice trace API.""" + self._start_collector() + + try: + process_tracing_init("127.0.0.1:4317", "test") + trace_set_thread_info("Test") + set_global_trace_level(3) + req_context = TraceReqContext(0) + req_context.trace_req_start() + req_context.trace_slice_start("test slice", level=1) + time.sleep(0.1) + req_context.trace_slice_end("test slice", level=1) + req_context.trace_req_finish() + + time.sleep(0.3) + + self.assertTrue( + self.collector.has_span("test slice"), + f"Expected span 'test slice', got {self.collector.get_span_names()}", + ) + finally: + pass + + def test_slice_complex(self): + """Unit test: complex slice trace with events.""" + self._start_collector() + + try: + process_tracing_init("127.0.0.1:4317", "test") + trace_set_thread_info("Test") + set_global_trace_level(3) + req_context = TraceReqContext(0) + req_context.trace_req_start() + + t1 = get_cur_time_ns() + time.sleep(0.1) + req_context.trace_event("event test", 1) + t2 = get_cur_time_ns() + time.sleep(0.1) + t3 = get_cur_time_ns() + + slice1 = TraceSliceContext("slice A", t1, t2) + slice2 = TraceSliceContext("slice B", t2, t3) + req_context.trace_slice(slice1) + req_context.trace_slice(slice2, thread_finish_flag=True) + req_context.trace_req_finish() + + time.sleep(0.3) + + self.assertTrue( + self.collector.has_all_spans(["slice A", "slice B"]), + f"Expected spans 'slice A' and 'slice B', got {self.collector.get_span_names()}", + ) + finally: + pass + + def test_context_propagate(self): + """Unit test: trace context propagation across processes via ZMQ.""" + self._start_collector() + + ctx = mp.get_context("spawn") + + context = zmq.Context(2) + send_to_subproc = get_zmq_socket( + context, zmq.PUSH, "ipc:///tmp/zmq_test.ipc", False + ) + + try: + process_tracing_init("127.0.0.1:4317", "test") + trace_set_thread_info("Main Process") + + subproc = ctx.Process(target=_subprocess_worker) + subproc.start() + + time.sleep(0.3) + + req = Req(rid=0) + req.req_context = TraceReqContext(0) + req.req_context.trace_req_start() + req.req_context.trace_slice_start("dispatch", level=1) + time.sleep(0.2) + send_to_subproc.send_pyobj(req) + req.req_context.trace_slice_end("dispatch", level=1) + + subproc.join() + req.req_context.trace_req_finish() + + time.sleep(0.5) + + self.assertTrue( + self.collector.has_all_spans(["dispatch", "work"]), + f"Expected spans 'dispatch' and 'work', got {self.collector.get_span_names()}", + ) + finally: + send_to_subproc.close() + context.term() + + +class TestTraceServer(CustomTestCase): + """Integration tests for tracing with server - starts server once for all tests.""" + + @classmethod + def setUpClass(cls): + """Start collector and server once for all tests.""" + cls.collector = LightweightOtlpCollector() + cls.collector.start() + time.sleep(0.2) + + cls.process = popen_launch_server( + DEFAULT_SMALL_MODEL_NAME_FOR_TEST, + DEFAULT_URL_FOR_TEST, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=[ + "--enable-trace", + "--otlp-traces-endpoint", + "127.0.0.1:4317", + ], + ) + + response = requests.get(f"{DEFAULT_URL_FOR_TEST}/health_generate") + assert response.status_code == 200 + + # Wait for warmup spans to be exported + cls.collector.clear() + + @classmethod + def tearDownClass(cls): + if cls.process: + kill_process_tree(cls.process.pid) + if cls.collector: + cls.collector.stop() + + def setUp(self): + """Wait for spans to be drained before each test.""" + max_wait_seconds = 10 + check_interval = 0.2 + elapsed = 0 + consecutive_zero_count = 0 + required_consecutive_zeros = 3 + + while elapsed < max_wait_seconds: + span_count = self.collector.count_spans() + if span_count == 0: + consecutive_zero_count += 1 + if consecutive_zero_count >= required_consecutive_zeros: + break + else: + consecutive_zero_count = 0 + self.collector.clear() + time.sleep(check_interval) + elapsed += check_interval + else: + raise RuntimeError( + f"Timeout waiting for spans to drain after {max_wait_seconds}s. " + f"Remaining spans: {self.collector.count_spans()}" + ) + + def _send_request_and_wait( + self, text, max_new_tokens=32, stream=True, trace_level=None + ): + """Helper to send a request and wait for spans.""" + if trace_level is not None: + response = requests.get( + f"{DEFAULT_URL_FOR_TEST}/set_trace_level?level={trace_level}" + ) + self.assertEqual(response.status_code, 200) + self.collector.clear() + + response = requests.post( + f"{DEFAULT_URL_FOR_TEST}/generate", + json={ + "text": text, + "sampling_params": { + "temperature": 0, + "max_new_tokens": max_new_tokens, + }, + "stream": stream, + }, + stream=stream, + ) + if stream: + for _ in response.iter_lines(decode_unicode=False): + pass + else: + self.assertEqual(response.status_code, 200) + + time.sleep(1) + + def test_trace_level_0(self): + """Test trace level 0 does not export any spans.""" + self._send_request_and_wait("Hello world", max_new_tokens=5, trace_level=0) + self.assertEqual( + self.collector.count_spans(), + 0, + f"Spans collected but expected none: {sorted(self.collector.get_span_names())}", + ) + + def test_trace_level_1(self): + """Test trace level 1 exports basic request lifecycle spans.""" + self._send_request_and_wait("The capital of France is", trace_level=1) + + self.assertGreater( + self.collector.count_spans(), + 0, + "No spans collected but expected some", + ) + + span_names = self.collector.get_span_names() + matched = [name for name in EXPECTED_SPANS_LEVEL_1 if name in span_names] + self.assertGreater( + len(matched), + 0, + f"No expected spans found. Expected any of {EXPECTED_SPANS_LEVEL_1}, " + f"got {sorted(span_names)}", + ) + + def test_trace_level_2(self): + """Test trace level 2 exports more detailed spans.""" + self._send_request_and_wait("What is AI?", trace_level=2) + + span_names = self.collector.get_span_names() + matched = [name for name in EXPECTED_SPANS_LEVEL_2 if name in span_names] + self.assertGreater( + len(matched), + 0, + f"No expected spans found. Expected any of {EXPECTED_SPANS_LEVEL_2}, " + f"got {sorted(span_names)}", + ) + + def test_trace_level_3(self): + """Test trace level 3 exports most detailed spans.""" + self._send_request_and_wait("Explain quantum computing", trace_level=3) + + span_names = self.collector.get_span_names() + matched = [name for name in EXPECTED_SPANS_LEVEL_3 if name in span_names] + self.assertGreater( + len(matched), + 0, + f"No expected spans found. Expected any of {EXPECTED_SPANS_LEVEL_3}, " + f"got {sorted(span_names)}", + ) + + def test_batch_request(self): + """Test tracing with batch requests (multiple prompts in one request).""" + response = requests.get(f"{DEFAULT_URL_FOR_TEST}/set_trace_level?level=1") + self.assertEqual(response.status_code, 200) + self.collector.clear() + + batch_size = 4 + prompts = ["The capital of France is"] * batch_size + response = requests.post( + f"{DEFAULT_URL_FOR_TEST}/generate", + json={ + "text": prompts, + "sampling_params": { + "temperature": 0, + "max_new_tokens": 10, + }, + "stream": False, + }, + ) + self.assertEqual(response.status_code, 200) + + time.sleep(0.5) + + self.assertGreater( + self.collector.count_spans(), + 0, + "No spans collected from batch request", + ) + + all_spans = self.collector.get_spans() + request_spans = [ + s for s in all_spans if s.name == RequestStage.PREFILL_FORWARD.stage_name + ] + self.assertEqual( + len(request_spans), + batch_size, + f"Expected {batch_size} prefill_forward spans, got {len(request_spans)}", + ) + + def test_parallel_sample(self): + """Test tracing with parallel sampling (n > 1 in sampling_params).""" + response = requests.get(f"{DEFAULT_URL_FOR_TEST}/set_trace_level?level=1") + self.assertEqual(response.status_code, 200) + self.collector.clear() + + # parallel_sample_num is controlled by 'n' in sampling_params + parallel_num = 4 + response = requests.post( + f"{DEFAULT_URL_FOR_TEST}/generate", + json={ + "text": "The capital of France is", + "sampling_params": { + "temperature": 0.5, # Need non-zero temp for parallel sampling + "max_new_tokens": 10, + "n": parallel_num, + }, + "stream": False, + }, + ) + self.assertEqual(response.status_code, 200) + + time.sleep(0.5) + + self.assertGreater( + self.collector.count_spans(), + 0, + "No spans collected from parallel sample request", + ) + + # With parallel sampling, we expect prefill spans for each parallel sample + all_spans = self.collector.get_spans() + request_spans = [ + s for s in all_spans if s.name == RequestStage.PREFILL_FORWARD.stage_name + ] + self.assertGreaterEqual( + len(request_spans), + 1, + f"Expected at least 1 prefill_forward span, got {len(request_spans)}", + ) + + +class TestTraceEngine(CustomTestCase): + """Integration tests for tracing with Engine API - each test creates its own engine.""" + + def setUp(self): + self.collector = None + + def tearDown(self): + if self.collector: + self.collector.stop() + self.collector = None + + def _start_collector(self): + """Start the lightweight OTLP collector.""" + self.collector = LightweightOtlpCollector() + self.collector.start() + time.sleep(0.2) + + def test_trace_engine_enable(self): + """Test tracing with Engine API.""" + self._start_collector() + + prompt = "Today is a sunny day and I like" + model_path = DEFAULT_SMALL_MODEL_NAME_FOR_TEST + sampling_params = {"temperature": 0, "max_new_tokens": 8} + + engine = Engine( + model_path=model_path, + random_seed=42, + enable_trace=True, + otlp_traces_endpoint="localhost:4317", + ) + + try: + engine.generate(prompt, sampling_params) + time.sleep(0.5) + + self.assertGreater( + self.collector.count_spans(), + 0, + "No spans collected from Engine.generate", + ) + self.assertTrue( + self.collector.has_any_span([RequestStage.PREFILL_FORWARD.stage_name]), + f"Expected prefill_forward span, got {self.collector.get_span_names()}", + ) + finally: + engine.shutdown() + + def test_trace_engine_encode(self): + """Test tracing with Engine encode API.""" + self._start_collector() + + prompt = "Today is a sunny day and I like" + model_path = DEFAULT_SMALL_MODEL_NAME_FOR_TEST + + engine = Engine( + model_path=model_path, + random_seed=42, + enable_trace=True, + otlp_traces_endpoint="localhost:4317", + is_embedding=True, + ) + + try: + engine.encode(prompt) + time.sleep(0.5) + + self.assertGreater( + self.collector.count_spans(), + 0, + "No spans collected from Engine.encode", + ) + finally: + engine.shutdown() + + +if __name__ == "__main__": + unittest.main() diff --git a/test/registered/observability/test_tracing_disaggregation.py b/test/registered/observability/test_tracing_disaggregation.py new file mode 100644 index 000000000000..df15ae10180c --- /dev/null +++ b/test/registered/observability/test_tracing_disaggregation.py @@ -0,0 +1,237 @@ +"""Test tracing in PD disaggregation mode.""" + +import os + +# Configure OTLP exporter for faster test execution +# Must be set before importing sglang trace module +os.environ.setdefault("SGLANG_OTLP_EXPORTER_SCHEDULE_DELAY_MILLIS", "50") +os.environ.setdefault("SGLANG_OTLP_EXPORTER_MAX_EXPORT_BATCH_SIZE", "4") + +import logging +import shlex +import time +import unittest +from urllib.parse import urlparse + +import requests + +# Import the lightweight collector from the main tracing test module +from test_tracing import LightweightOtlpCollector + +from sglang.srt.observability.req_time_stats import RequestStage +from sglang.srt.utils import kill_process_tree +from sglang.test.ci.ci_register import register_cuda_ci +from sglang.test.server_fixtures.disaggregation_fixture import get_rdma_devices_args +from sglang.test.test_utils import ( + DEFAULT_MODEL_NAME_FOR_TEST, + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, + CustomTestCase, + popen_launch_pd_server, + popen_with_error_check, +) +from sglang.utils import wait_for_http_ready + +logger = logging.getLogger(__name__) + +# CI registration - PD disaggregation requires 2 GPUs +register_cuda_ci(est_time=45, suite="stage-b-test-2-gpu-large") + + +class TestTraceDisaggregation(CustomTestCase): + """Test tracing in PD disaggregation mode.""" + + @classmethod + def setUpClass(cls): + # Initialize collector first + cls.collector = LightweightOtlpCollector() + cls.collector.start() + time.sleep(0.2) + + # Setup PD disaggregation server addresses + parsed_url = urlparse(DEFAULT_URL_FOR_TEST) + cls.base_host = parsed_url.hostname + base_port = str(parsed_url.port) + cls.lb_port = base_port + cls.prefill_port = f"{int(base_port) + 100}" + cls.decode_port = f"{int(base_port) + 200}" + cls.bootstrap_port = f"{int(base_port) + 500}" + cls.prefill_url = f"http://{cls.base_host}:{cls.prefill_port}" + cls.decode_url = f"http://{cls.base_host}:{cls.decode_port}" + cls.lb_url = f"http://{cls.base_host}:{cls.lb_port}" + cls.process_lb = None + cls.process_decode = None + cls.process_prefill = None + cls.model = DEFAULT_MODEL_NAME_FOR_TEST + + # Config transfer backend + cls.transfer_backend = ["--disaggregation-transfer-backend", "mooncake"] + cls.rdma_devices = ["--disaggregation-ib-device", get_rdma_devices_args()] + + # Start prefill server with trace enabled + prefill_args = [ + "--trust-remote-code", + "--disaggregation-mode", + "prefill", + "--disaggregation-bootstrap-port", + cls.bootstrap_port, + "--tp", + "1", + "--enable-trace", + "--otlp-traces-endpoint", + "localhost:4317", + ] + prefill_args += cls.transfer_backend + cls.rdma_devices + cls.process_prefill = popen_launch_pd_server( + cls.model, + cls.prefill_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=prefill_args, + ) + + # Start decode server with trace enabled + decode_args = [ + "--trust-remote-code", + "--disaggregation-mode", + "decode", + "--disaggregation-bootstrap-port", + cls.bootstrap_port, + "--tp", + "1", + "--base-gpu-id", + "1", + "--enable-trace", + "--otlp-traces-endpoint", + "localhost:4317", + ] + decode_args += cls.transfer_backend + cls.rdma_devices + cls.process_decode = popen_launch_pd_server( + cls.model, + cls.decode_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=decode_args, + ) + + # Wait for servers to be ready + wait_for_http_ready( + url=cls.prefill_url + "/health", + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + process=cls.process_prefill, + ) + wait_for_http_ready( + url=cls.decode_url + "/health", + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + process=cls.process_decode, + ) + + # Start load balancer + lb_command = [ + "python3", + "-m", + "sglang_router.launch_router", + "--pd-disaggregation", + "--mini-lb", + "--prefill", + cls.prefill_url, + "--decode", + cls.decode_url, + "--host", + cls.base_host, + "--port", + cls.lb_port, + ] + print("Starting load balancer:", shlex.join(lb_command)) + cls.process_lb = popen_with_error_check(lb_command) + wait_for_http_ready(url=cls.lb_url + "/health", process=cls.process_lb) + + # Wait for warmup spans and clear + time.sleep(1) + cls.collector.clear() + + @classmethod + def tearDownClass(cls): + for process in [cls.process_lb, cls.process_decode, cls.process_prefill]: + if process: + try: + kill_process_tree(process.pid) + except Exception as e: + print(f"Error killing process {process.pid}: {e}") + if cls.collector: + cls.collector.stop() + time.sleep(5) + + def setUp(self): + """Wait for spans to be drained before each test.""" + max_wait_seconds = 10 + check_interval = 0.2 + elapsed = 0 + consecutive_zero_count = 0 + required_consecutive_zeros = 3 + + while elapsed < max_wait_seconds: + span_count = self.collector.count_spans() + if span_count == 0: + consecutive_zero_count += 1 + if consecutive_zero_count >= required_consecutive_zeros: + break + else: + consecutive_zero_count = 0 + self.collector.clear() + time.sleep(check_interval) + elapsed += check_interval + else: + raise RuntimeError( + f"Timeout waiting for spans to drain after {max_wait_seconds}s. " + f"Remaining spans: {self.collector.count_spans()}" + ) + + def test_disaggregation_transfer_spans(self): + """Test that disaggregation produces PREFILL_TRANSFER_KV_CACHE and DECODE_TRANSFERRED spans.""" + # Set trace level + response = requests.get(f"{self.prefill_url}/set_trace_level?level=1") + self.assertEqual(response.status_code, 200) + response = requests.get(f"{self.decode_url}/set_trace_level?level=1") + self.assertEqual(response.status_code, 200) + self.collector.clear() + + # Send a request through load balancer + response = requests.post( + f"{self.lb_url}/generate", + json={ + "text": "The capital of France is", + "sampling_params": { + "temperature": 0, + "max_new_tokens": 10, + }, + "stream": False, + }, + ) + self.assertEqual(response.status_code, 200) + + # Wait for async export + time.sleep(1) + + # Verify spans were collected + self.assertGreater( + self.collector.count_spans(), + 0, + "No spans collected from disaggregation request", + ) + + # Verify disaggregation-specific spans exist + span_names = self.collector.get_span_names() + + # Check for transfer-related spans + self.assertTrue( + self.collector.has_any_span( + [ + RequestStage.PREFILL_TRANSFER_KV_CACHE.stage_name, + RequestStage.DECODE_TRANSFERRED.stage_name, + ] + ), + f"Expected disaggregation transfer spans, got {sorted(span_names)}", + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/registered/openai_server/features/test_openai_server_ebnf.py b/test/registered/openai_server/features/test_openai_server_ebnf.py index 05cf9f1021e8..2ec7d2dcf9c9 100644 --- a/test/registered/openai_server/features/test_openai_server_ebnf.py +++ b/test/registered/openai_server/features/test_openai_server_ebnf.py @@ -13,7 +13,7 @@ popen_launch_server, ) -register_cuda_ci(est_time=7, suite="stage-b-test-1-gpu-small") +register_cuda_ci(est_time=55, suite="stage-b-test-1-gpu-small") register_amd_ci(est_time=20, suite="stage-b-test-1-gpu-small-amd") diff --git a/test/registered/openai_server/function_call/test_tool_choice.py b/test/registered/openai_server/function_call/test_tool_choice.py index 9b0fad76dca0..80d4cd4a10e1 100644 --- a/test/registered/openai_server/function_call/test_tool_choice.py +++ b/test/registered/openai_server/function_call/test_tool_choice.py @@ -22,7 +22,7 @@ popen_launch_server, ) -register_cuda_ci(est_time=120, suite="stage-b-test-1-gpu-small") +register_cuda_ci(est_time=250, suite="stage-b-test-1-gpu-small") register_amd_ci(est_time=258, suite="stage-b-test-1-gpu-small-amd") @@ -915,14 +915,6 @@ def setUpClass(cls): cls.base_url += "/v1" cls.tokenizer = get_tokenizer(cls.model) - @unittest.skip("maxItems:1 bug causes whitespace stall") - def test_tool_choice_required_non_streaming(self): - pass - - @unittest.skip("maxItems:1 bug causes whitespace stall") - def test_tool_choice_specific_function_non_streaming(self): - pass - if __name__ == "__main__": unittest.main() diff --git a/test/registered/openai_server/validation/test_openai_server_ignore_eos.py b/test/registered/openai_server/validation/test_openai_server_ignore_eos.py index 9014466c089b..2f27699ee14b 100644 --- a/test/registered/openai_server/validation/test_openai_server_ignore_eos.py +++ b/test/registered/openai_server/validation/test_openai_server_ignore_eos.py @@ -11,7 +11,7 @@ popen_launch_server, ) -register_cuda_ci(est_time=6, suite="stage-b-test-1-gpu-small") +register_cuda_ci(est_time=55, suite="stage-b-test-1-gpu-small") register_amd_ci(est_time=47, suite="stage-b-test-1-gpu-small-amd") diff --git a/test/registered/openai_server/validation/test_request_length_validation.py b/test/registered/openai_server/validation/test_request_length_validation.py index 0e9fb0a3e645..0699d3917177 100644 --- a/test/registered/openai_server/validation/test_request_length_validation.py +++ b/test/registered/openai_server/validation/test_request_length_validation.py @@ -67,6 +67,23 @@ def test_input_length_longer_than_maximum_allowed_length(self): self.assertIn("is longer than the model's context length", str(cm.exception)) + def test_input_length_longer_than_context_length_streaming(self): + client = openai.Client(api_key=self.api_key, base_url=f"{self.base_url}/v1") + + long_text = "hello " * 1200 + + with self.assertRaises(openai.BadRequestError) as cm: + client.chat.completions.create( + model=DEFAULT_SMALL_MODEL_NAME_FOR_TEST, + messages=[ + {"role": "user", "content": long_text}, + ], + temperature=0, + stream=True, + ) + + self.assertIn("is longer than the model's context length", str(cm.exception)) + def test_max_tokens_validation(self): client = openai.Client(api_key=self.api_key, base_url=f"{self.base_url}/v1") diff --git a/test/registered/piecewise_cuda_graph/test_piecewise_cuda_graph_support_1_gpu.py b/test/registered/piecewise_cuda_graph/test_piecewise_cuda_graph_support_1_gpu.py index aa1afba03b88..e38b59f5b86b 100644 --- a/test/registered/piecewise_cuda_graph/test_piecewise_cuda_graph_support_1_gpu.py +++ b/test/registered/piecewise_cuda_graph/test_piecewise_cuda_graph_support_1_gpu.py @@ -126,8 +126,25 @@ def test_embedding(self): engine.shutdown() self.assertGreater(len(out_without_pcg), 0) + t_out = torch.tensor(out) + t_out_without_pcg = torch.tensor(out_without_pcg) + max_abs_diff = (t_out - t_out_without_pcg).abs().max().item() + max_rel_diff = ( + ((t_out - t_out_without_pcg).abs() / (t_out_without_pcg.abs() + 1e-8)) + .max() + .item() + ) + print( + f"PCG embedding diff: max_abs={max_abs_diff:.6f}, max_rel={max_rel_diff:.6f}" + ) self.assertTrue( - torch.allclose(torch.tensor(out), torch.tensor(out_without_pcg)) + torch.allclose( + t_out, + t_out_without_pcg, + atol=1e-2, + rtol=1e-2, + ), + f"Piecewise CUDA graph embedding mismatch: max_abs_diff={max_abs_diff}, max_rel_diff={max_rel_diff}", ) diff --git a/test/registered/quant/test_deepseek_v32_fp4_4gpu.py b/test/registered/quant/test_deepseek_v32_fp4_4gpu.py index 4bdcd1367092..19b0a1dbac02 100644 --- a/test/registered/quant/test_deepseek_v32_fp4_4gpu.py +++ b/test/registered/quant/test_deepseek_v32_fp4_4gpu.py @@ -3,7 +3,7 @@ from sglang.srt.utils import kill_process_tree from sglang.test.ci.ci_register import register_cuda_ci -from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k +from sglang.test.run_eval import run_eval from sglang.test.send_one import BenchArgs, send_one_prompt from sglang.test.test_utils import ( DEFAULT_URL_FOR_TEST, @@ -56,23 +56,24 @@ def test_a_gsm8k( self, ): # Append an "a" to make this test run first (alphabetically) to warm up the server args = SimpleNamespace( + base_url=self.base_url, + model=self.model, + eval_name="gsm8k", + api="completion", + max_tokens=512, + num_examples=500, + num_threads=500, num_shots=20, - data_path=None, - num_questions=1319, - parallel=1319, - max_new_tokens=512, - host="http://127.0.0.1", - port=int(self.base_url.split(":")[-1]), ) - metrics = run_eval_few_shot_gsm8k(args) + metrics = run_eval(args) print(f"{metrics=}") if is_in_ci(): write_github_step_summary( - f"### test_gsm8k (deepseek-v3-fp4)\n" f'{metrics["accuracy"]=:.3f}\n' + f"### test_gsm8k (deepseek-v3-fp4)\n" f'{metrics["score"]=:.3f}\n' ) - self.assertGreater(metrics["accuracy"], 0.935) + self.assertGreater(metrics["score"], 0.93) def test_bs_1_speed(self): args = BenchArgs(port=int(self.base_url.split(":")[-1]), max_new_tokens=2048) @@ -123,23 +124,24 @@ def test_a_gsm8k( self, ): # Append an "a" to make this test run first (alphabetically) to warm up the server args = SimpleNamespace( + base_url=self.base_url, + model=self.model, + eval_name="gsm8k", + api="completion", + max_tokens=512, + num_examples=500, + num_threads=500, num_shots=20, - data_path=None, - num_questions=1319, - parallel=1319, - max_new_tokens=512, - host="http://127.0.0.1", - port=int(self.base_url.split(":")[-1]), ) - metrics = run_eval_few_shot_gsm8k(args) + metrics = run_eval(args) print(f"{metrics=}") if is_in_ci(): write_github_step_summary( - f"### test_gsm8k (deepseek-v3-fp4)\n" f'{metrics["accuracy"]=:.3f}\n' + f"### test_gsm8k (deepseek-v3-fp4)\n" f'{metrics["score"]=:.3f}\n' ) - self.assertGreater(metrics["accuracy"], 0.935) + self.assertGreater(metrics["score"], 0.93) def test_bs_1_speed(self): args = BenchArgs(port=int(self.base_url.split(":")[-1]), max_new_tokens=2048) diff --git a/test/registered/quant/test_deepseek_v32_fp4_mtp_4gpu.py b/test/registered/quant/test_deepseek_v32_fp4_mtp_4gpu.py index 3a99f6ad9f12..32eeb9e9bbc1 100644 --- a/test/registered/quant/test_deepseek_v32_fp4_mtp_4gpu.py +++ b/test/registered/quant/test_deepseek_v32_fp4_mtp_4gpu.py @@ -6,7 +6,7 @@ from sglang.srt.environ import envs from sglang.srt.utils import kill_process_tree from sglang.test.ci.ci_register import register_cuda_ci -from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k +from sglang.test.run_eval import run_eval from sglang.test.send_one import BenchArgs, send_one_prompt from sglang.test.test_utils import ( DEFAULT_URL_FOR_TEST, @@ -72,18 +72,19 @@ def test_a_gsm8k( requests.get(self.base_url + "/flush_cache") args = SimpleNamespace( + base_url=self.base_url, + model=self.model, + eval_name="gsm8k", + api="completion", + max_tokens=512, + num_examples=500, + num_threads=500, num_shots=20, - data_path=None, - num_questions=500, - parallel=500, - max_new_tokens=512, - host="http://127.0.0.1", - port=int(self.base_url.split(":")[-1]), ) - metrics = run_eval_few_shot_gsm8k(args) + metrics = run_eval(args) print(f"{metrics=}") - server_info = requests.get(self.base_url + "/get_server_info") + server_info = requests.get(self.base_url + "/server_info") avg_spec_accept_length = server_info.json()["internal_states"][0][ "avg_spec_accept_length" ] @@ -92,10 +93,10 @@ def test_a_gsm8k( if is_in_ci(): write_github_step_summary( f"### test_gsm8k (deepseek-v32 mtp)\n" - f'{metrics["accuracy"]=:.3f}\n' + f'{metrics["score"]=:.3f}\n' f"{avg_spec_accept_length=:.2f}\n" ) - self.assertGreater(metrics["accuracy"], 0.94) + self.assertGreater(metrics["score"], 0.93) self.assertGreater(avg_spec_accept_length, 2.7) def test_bs_1_speed(self): @@ -162,18 +163,19 @@ def test_a_gsm8k( requests.get(self.base_url + "/flush_cache") args = SimpleNamespace( + base_url=self.base_url, + model=self.model, + eval_name="gsm8k", + api="completion", + max_tokens=512, + num_examples=500, + num_threads=500, num_shots=20, - data_path=None, - num_questions=500, - parallel=500, - max_new_tokens=512, - host="http://127.0.0.1", - port=int(self.base_url.split(":")[-1]), ) - metrics = run_eval_few_shot_gsm8k(args) + metrics = run_eval(args) print(f"{metrics=}") - server_info = requests.get(self.base_url + "/get_server_info") + server_info = requests.get(self.base_url + "/server_info") avg_spec_accept_length = server_info.json()["internal_states"][0][ "avg_spec_accept_length" ] @@ -182,10 +184,10 @@ def test_a_gsm8k( if is_in_ci(): write_github_step_summary( f"### test_gsm8k (deepseek-v32 mtp)\n" - f'{metrics["accuracy"]=:.3f}\n' + f'{metrics["score"]=:.3f}\n' f"{avg_spec_accept_length=:.2f}\n" ) - self.assertGreater(metrics["accuracy"], 0.94) + self.assertGreater(metrics["score"], 0.93) self.assertGreater(avg_spec_accept_length, 2.7) def test_bs_1_speed(self): diff --git a/test/registered/quant/test_deepseek_v3_fp4_4gpu.py b/test/registered/quant/test_deepseek_v3_fp4_4gpu.py index 3658eec44441..f54149952d58 100644 --- a/test/registered/quant/test_deepseek_v3_fp4_4gpu.py +++ b/test/registered/quant/test_deepseek_v3_fp4_4gpu.py @@ -4,7 +4,7 @@ from sglang.srt.utils import kill_process_tree from sglang.test.ci.ci_register import register_cuda_ci -from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k +from sglang.test.run_eval import run_eval from sglang.test.send_one import BenchArgs, send_one_prompt from sglang.test.test_utils import ( DEFAULT_URL_FOR_TEST, @@ -54,23 +54,24 @@ def test_a_gsm8k( self, ): # Append an "a" to make this test run first (alphabetically) to warm up the server args = SimpleNamespace( + base_url=self.base_url, + model=self.model, + eval_name="gsm8k", + api="completion", + max_tokens=512, + num_examples=1319, + num_threads=1319, num_shots=8, - data_path=None, - num_questions=1319, - parallel=1319, - max_new_tokens=512, - host="http://127.0.0.1", - port=int(self.base_url.split(":")[-1]), ) - metrics = run_eval_few_shot_gsm8k(args) + metrics = run_eval(args) print(f"{metrics=}") if is_in_ci(): write_github_step_summary( - f"### test_gsm8k (deepseek-v3-fp4)\n" f'{metrics["accuracy"]=:.3f}\n' + f"### test_gsm8k (deepseek-v3-fp4)\n" f'{metrics["score"]=:.3f}\n' ) - self.assertGreater(metrics["accuracy"], 0.93) + self.assertGreater(metrics["score"], 0.93) def test_bs_1_speed(self): args = BenchArgs(port=int(self.base_url.split(":")[-1]), max_new_tokens=2048) @@ -124,23 +125,24 @@ def test_a_gsm8k( self, ): # Append an "a" to make this test run first (alphabetically) to warm up the server args = SimpleNamespace( + base_url=self.base_url, + model=self.model, + eval_name="gsm8k", + api="completion", + max_tokens=512, + num_examples=1319, + num_threads=1319, num_shots=8, - data_path=None, - num_questions=1319, - parallel=1319, - max_new_tokens=512, - host="http://127.0.0.1", - port=int(self.base_url.split(":")[-1]), ) - metrics = run_eval_few_shot_gsm8k(args) + metrics = run_eval(args) print(f"{metrics=}") if is_in_ci(): write_github_step_summary( f"### test_gsm8k (deepseek-v3-fp4-cutlass-moe)\n" - f'{metrics["accuracy"]=:.3f}\n' + f'{metrics["score"]=:.3f}\n' ) - self.assertGreater(metrics["accuracy"], 0.93) + self.assertGreater(metrics["score"], 0.93) class TestDeepseekV3FP4SymmetricMemory(CustomTestCase): @@ -178,23 +180,24 @@ def test_a_gsm8k( self, ): # Append an "a" to make this test run first (alphabetically) to warm up the server args = SimpleNamespace( + base_url=self.base_url, + model=self.model, + eval_name="gsm8k", + api="completion", + max_tokens=512, + num_examples=1319, + num_threads=1319, num_shots=8, - data_path=None, - num_questions=1319, - parallel=1319, - max_new_tokens=512, - host="http://127.0.0.1", - port=int(self.base_url.split(":")[-1]), ) - metrics = run_eval_few_shot_gsm8k(args) + metrics = run_eval(args) print(f"{metrics=}") if is_in_ci(): write_github_step_summary( - f"### test_gsm8k (deepseek-v3-fp4)\n" f'{metrics["accuracy"]=:.3f}\n' + f"### test_gsm8k (deepseek-v3-fp4)\n" f'{metrics["score"]=:.3f}\n' ) - self.assertGreater(metrics["accuracy"], 0.93) + self.assertGreater(metrics["score"], 0.93) if __name__ == "__main__": diff --git a/test/registered/quant/test_fp8_blockwise_gemm.py b/test/registered/quant/test_fp8_blockwise_gemm.py index af7600cb5380..30a04a1fc2c4 100644 --- a/test/registered/quant/test_fp8_blockwise_gemm.py +++ b/test/registered/quant/test_fp8_blockwise_gemm.py @@ -4,7 +4,7 @@ from sglang.srt.utils import get_device_sm, kill_process_tree from sglang.test.ci.ci_register import register_cuda_ci -from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k +from sglang.test.run_eval import run_eval from sglang.test.test_utils import ( DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, DEFAULT_URL_FOR_TEST, @@ -15,7 +15,7 @@ register_cuda_ci(est_time=420, suite="stage-c-test-4-gpu-b200") MODEL_PATH = "Qwen/Qwen3-4B-Instruct-2507-FP8" -BF16_MODEL_PATH = "Qwen/Qwen3-4B-Instruct-2507" +MXFP8_MODEL_PATH = "zianglih/Qwen3-4B-Instruct-2507-MXFP8" class FP8BlockwiseGemmBase: @@ -46,18 +46,19 @@ def tearDownClass(cls): def test_gsm8k(self): parsed_url = urlparse(self.base_url) args = SimpleNamespace( + base_url=self.base_url, + model=self.model, + eval_name="gsm8k", + api="completion", + max_tokens=512, + num_examples=1319, + num_threads=200, num_shots=8, - data_path=None, - num_questions=1319, - max_new_tokens=512, - parallel=200, - host=f"{parsed_url.scheme}://{parsed_url.hostname}", - port=parsed_url.port, ) - metrics = run_eval_few_shot_gsm8k(args) + metrics = run_eval(args) print(metrics) - self.assertGreaterEqual(metrics["accuracy"], 0.8) + self.assertGreaterEqual(metrics["score"], 0.8) class MXFP8GemmBase: @@ -67,12 +68,10 @@ class MXFP8GemmBase: def setUpClass(cls): if cls.backend is None: raise NotImplementedError("Subclass must set 'backend' attribute") - cls.model = try_cached_model(BF16_MODEL_PATH) + cls.model = try_cached_model(MXFP8_MODEL_PATH) cls.base_url = DEFAULT_URL_FOR_TEST other_args = [ "--trust-remote-code", - "--quantization", - "mxfp8", "--fp8-gemm-backend", cls.backend, ] @@ -90,18 +89,19 @@ def tearDownClass(cls): def test_gsm8k(self): parsed_url = urlparse(self.base_url) args = SimpleNamespace( + base_url=self.base_url, + model=self.model, + eval_name="gsm8k", + api="completion", + max_tokens=512, + num_examples=1319, + num_threads=200, num_shots=8, - data_path=None, - num_questions=1319, - max_new_tokens=512, - parallel=200, - host=f"{parsed_url.scheme}://{parsed_url.hostname}", - port=parsed_url.port, ) - metrics = run_eval_few_shot_gsm8k(args) + metrics = run_eval(args) print(metrics) - self.assertGreaterEqual(metrics["accuracy"], 0.8) + self.assertGreaterEqual(metrics["score"], 0.8) class TestFP8BlockwiseGemmTriton(FP8BlockwiseGemmBase, unittest.TestCase): @@ -122,6 +122,7 @@ class TestFP8BlockwiseGemmFlashinferDeepGemm(FP8BlockwiseGemmBase, unittest.Test backend = "flashinfer_deepgemm" +@unittest.skip("Currently PCG capture takes too long to complete, disable until fixed") @unittest.skipIf(get_device_sm() < 100, "Test requires CUDA SM 100 or higher") class TestMXFP8GemmTriton(MXFP8GemmBase, unittest.TestCase): backend = "triton" @@ -132,5 +133,10 @@ class TestMXFP8GemmFlashinferTrtllm(MXFP8GemmBase, unittest.TestCase): backend = "flashinfer_trtllm" +@unittest.skipIf(get_device_sm() < 100, "Test requires CUDA SM 100 or higher") +class TestMXFP8GemmFlashinferCutlass(MXFP8GemmBase, unittest.TestCase): + backend = "flashinfer_cutlass" + + if __name__ == "__main__": unittest.main() diff --git a/test/registered/quant/test_fp8_gemm_sm120.py b/test/registered/quant/test_fp8_gemm_sm120.py new file mode 100644 index 000000000000..f695851de663 --- /dev/null +++ b/test/registered/quant/test_fp8_gemm_sm120.py @@ -0,0 +1,86 @@ +import unittest +from types import SimpleNamespace +from urllib.parse import urlparse + +from sglang.srt.utils import get_device_sm, kill_process_tree +from sglang.test.ci.ci_register import register_cuda_ci +from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k +from sglang.test.test_utils import ( + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, + popen_launch_server, + try_cached_model, +) + +register_cuda_ci(est_time=120, suite="stage-b-test-small-1-gpu") + +PERTENSOR_MODEL_PATH = "nvidia/Llama-3.1-8B-Instruct-FP8" +BLOCKWISE_MODEL_PATH = "Qwen/Qwen3-4B-Instruct-2507-FP8" + + +class FP8GemmSM120Base: + model_path = None + backend = None + quantization = None + + @classmethod + def setUpClass(cls): + if cls.backend is None: + raise NotImplementedError("Subclass must set 'backend' attribute") + cls.model = try_cached_model(cls.model_path) + cls.base_url = DEFAULT_URL_FOR_TEST + other_args = [ + "--trust-remote-code", + "--fp8-gemm-backend", + cls.backend, + "--disable-piecewise-cuda-graph", + ] + if cls.quantization: + other_args += ["--quantization", cls.quantization] + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=other_args, + ) + + @classmethod + def tearDownClass(cls): + if hasattr(cls, "process"): + kill_process_tree(cls.process.pid) + + def test_gsm8k(self): + parsed_url = urlparse(self.base_url) + args = SimpleNamespace( + num_shots=self.num_shots, + data_path=None, + num_questions=1319, + max_new_tokens=512, + parallel=200, + host=parsed_url.hostname, + port=parsed_url.port, + ) + metrics = run_eval_few_shot_gsm8k(args) + print(f"{metrics=}") + self.assertGreaterEqual(metrics["accuracy"], self.accuracy_threshold) + + +@unittest.skipIf(get_device_sm() < 100, "Test requires CUDA SM 100 or higher") +class TestFP8PerTensorGemmSM120Auto(FP8GemmSM120Base, unittest.TestCase): + model_path = PERTENSOR_MODEL_PATH + backend = "auto" + quantization = "modelopt_fp8" + num_shots = 5 + accuracy_threshold = 0.73 + + +@unittest.skipIf(get_device_sm() < 100, "Test requires CUDA SM 100 or higher") +class TestFP8BlockwiseGemmSM120Auto(FP8GemmSM120Base, unittest.TestCase): + model_path = BLOCKWISE_MODEL_PATH + backend = "auto" + num_shots = 8 + accuracy_threshold = 0.87 + + +if __name__ == "__main__": + unittest.main() diff --git a/test/registered/quant/test_fp8_kernel.py b/test/registered/quant/test_fp8_kernel.py index 57a6362358f2..ea5c60689105 100644 --- a/test/registered/quant/test_fp8_kernel.py +++ b/test/registered/quant/test_fp8_kernel.py @@ -9,7 +9,7 @@ from sglang.test.ci.ci_register import register_cuda_ci from sglang.test.test_utils import CustomTestCase -register_cuda_ci(est_time=132, suite="stage-b-test-1-gpu-large") +register_cuda_ci(est_time=15, suite="stage-b-test-1-gpu-large") class TestFP8Base(CustomTestCase): diff --git a/test/registered/quant/test_fp8kv_triton.py b/test/registered/quant/test_fp8kv_triton.py index 3158e8e24513..df9573e22891 100644 --- a/test/registered/quant/test_fp8kv_triton.py +++ b/test/registered/quant/test_fp8kv_triton.py @@ -4,7 +4,7 @@ from sglang.srt.utils import kill_process_tree from sglang.test.ci.ci_register import register_cuda_ci -from sglang.test.few_shot_gsm8k import run_eval +from sglang.test.run_eval import run_eval from sglang.test.test_utils import ( DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, DEFAULT_URL_FOR_TEST, @@ -12,7 +12,7 @@ popen_launch_server, ) -register_cuda_ci(est_time=520, suite="stage-b-test-1-gpu-large") +register_cuda_ci(est_time=80, suite="stage-b-test-1-gpu-large") class TestFP8KVCacheTritonBackend(CustomTestCase): @@ -41,17 +41,17 @@ def tearDownClass(cls): def test_gsm8k(self): parsed_url = urlparse(self.base_url) args = SimpleNamespace( - num_shots=5, - data_path=None, - num_questions=200, - max_new_tokens=512, - parallel=200, - host=f"{parsed_url.scheme}://{parsed_url.hostname}", - port=parsed_url.port, + base_url=self.base_url, + model=self.model, + eval_name="gsm8k", + api="completion", + max_tokens=512, + num_examples=200, + num_threads=200, ) metrics = run_eval(args) print(f"{metrics=}") - self.assertGreater(metrics["accuracy"], 0.70) + self.assertGreater(metrics["score"], 0.70) if __name__ == "__main__": diff --git a/test/registered/quant/test_int4fp8_moe.py b/test/registered/quant/test_int4fp8_moe.py index f5a8f6dca26d..c46c50447d4e 100644 --- a/test/registered/quant/test_int4fp8_moe.py +++ b/test/registered/quant/test_int4fp8_moe.py @@ -2,7 +2,7 @@ from sglang.srt.utils import kill_process_tree from sglang.test.ci.ci_register import register_amd_ci -from sglang.test.few_shot_gsm8k import run_eval +from sglang.test.run_eval import run_eval from sglang.test.test_utils import ( DEFAULT_URL_FOR_TEST, CustomTestCase, @@ -45,14 +45,15 @@ def tearDownClass(cls): def test_gsm8k(self): args = SimpleNamespace( + base_url=self.base_url, + model=self.model, + eval_name="gsm8k", + api="completion", + max_tokens=512, + num_examples=1400, + num_threads=128, num_shots=8, - data_path=None, - num_questions=1400, - max_new_tokens=512, - parallel=128, - host="http://127.0.0.1", - port=int(self.base_url.split(":")[-1]), ) metrics = run_eval(args) print(f"{metrics=}") - self.assertGreater(metrics["accuracy"], 0.56) + self.assertGreater(metrics["score"], 0.56) diff --git a/test/registered/quant/test_int8_kernel.py b/test/registered/quant/test_int8_kernel.py index c15de1d4a2f8..0a5d3001a826 100644 --- a/test/registered/quant/test_int8_kernel.py +++ b/test/registered/quant/test_int8_kernel.py @@ -11,7 +11,7 @@ from sglang.test.ci.ci_register import register_cuda_ci from sglang.test.test_utils import CustomTestCase -register_cuda_ci(est_time=8, suite="stage-b-test-1-gpu-small") +register_cuda_ci(est_time=16, suite="stage-b-test-1-gpu-small") def native_w8a8_per_token_matmul(A, B, As, Bs, output_dtype=torch.float16): diff --git a/test/registered/quant/test_modelopt_fp8.py b/test/registered/quant/test_modelopt_fp8.py index b13adbe264ed..a65e2e20c909 100644 --- a/test/registered/quant/test_modelopt_fp8.py +++ b/test/registered/quant/test_modelopt_fp8.py @@ -4,7 +4,7 @@ from sglang.srt.utils import kill_process_tree from sglang.test.ci.ci_register import register_cuda_ci -from sglang.test.few_shot_gsm8k import run_eval +from sglang.test.run_eval import run_eval from sglang.test.test_utils import ( DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, DEFAULT_URL_FOR_TEST, @@ -35,17 +35,17 @@ def tearDownClass(cls): def test_gsm8k(self): parsed_url = urlparse(self.base_url) args = SimpleNamespace( - num_shots=5, - data_path=None, - num_questions=200, - max_new_tokens=512, - parallel=200, - host=f"{parsed_url.scheme}://{parsed_url.hostname}", - port=parsed_url.port, + base_url=self.base_url, + model=self.model, + eval_name="gsm8k", + api="completion", + max_tokens=512, + num_examples=200, + num_threads=200, ) metrics = run_eval(args) print(f"{metrics=}") - self.assertGreater(metrics["accuracy"], 0.70) + self.assertGreater(metrics["score"], 0.70) if __name__ == "__main__": diff --git a/test/registered/quant/test_nvfp4_gemm.py b/test/registered/quant/test_nvfp4_gemm.py index 1a94b6b48e7e..f784973c9203 100644 --- a/test/registered/quant/test_nvfp4_gemm.py +++ b/test/registered/quant/test_nvfp4_gemm.py @@ -4,7 +4,7 @@ from sglang.srt.utils import get_device_sm, kill_process_tree from sglang.test.ci.ci_register import register_cuda_ci -from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k +from sglang.test.run_eval import run_eval from sglang.test.test_utils import ( DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, DEFAULT_URL_FOR_TEST, @@ -47,18 +47,18 @@ def tearDownClass(cls): def test_gsm8k(self): parsed_url = urlparse(self.base_url) args = SimpleNamespace( - num_shots=5, - data_path=None, - num_questions=1319, - max_new_tokens=512, - parallel=200, - host=f"{parsed_url.scheme}://{parsed_url.hostname}", - port=parsed_url.port, + base_url=self.base_url, + model=self.model, + eval_name="gsm8k", + api="completion", + max_tokens=512, + num_examples=1319, + num_threads=200, ) - metrics = run_eval_few_shot_gsm8k(args) + metrics = run_eval(args) print(metrics) - self.assertGreater(metrics["accuracy"], 0.64) + self.assertGreater(metrics["score"], 0.64) @unittest.skipIf(get_device_sm() < 100, "Test requires CUDA SM 100 or higher") diff --git a/test/registered/quant/test_nvfp4_gemm_sm120.py b/test/registered/quant/test_nvfp4_gemm_sm120.py new file mode 100644 index 000000000000..95f32942e453 --- /dev/null +++ b/test/registered/quant/test_nvfp4_gemm_sm120.py @@ -0,0 +1,71 @@ +import unittest +from types import SimpleNamespace +from urllib.parse import urlparse + +from sglang.srt.utils import get_device_sm, kill_process_tree +from sglang.test.ci.ci_register import register_cuda_ci +from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k +from sglang.test.test_utils import ( + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, + popen_launch_server, + try_cached_model, +) + +register_cuda_ci(est_time=90, suite="stage-b-test-small-1-gpu") + +MODEL_PATH = "nvidia/Llama-3.1-8B-Instruct-NVFP4" + + +class FP4GemmSM120Base: + backend = None + + @classmethod + def setUpClass(cls): + if cls.backend is None: + raise NotImplementedError("Subclass must set 'backend' attribute") + cls.model = try_cached_model(MODEL_PATH) + cls.base_url = DEFAULT_URL_FOR_TEST + other_args = [ + "--trust-remote-code", + "--quantization", + "modelopt_fp4", + "--fp4-gemm-backend", + cls.backend, + "--disable-piecewise-cuda-graph", + ] + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=other_args, + ) + + @classmethod + def tearDownClass(cls): + if hasattr(cls, "process"): + kill_process_tree(cls.process.pid) + + def test_gsm8k(self): + parsed_url = urlparse(self.base_url) + args = SimpleNamespace( + num_shots=5, + data_path=None, + num_questions=1319, + max_new_tokens=512, + parallel=200, + host=parsed_url.hostname, + port=parsed_url.port, + ) + metrics = run_eval_few_shot_gsm8k(args) + print(f"{metrics=}") + self.assertGreater(metrics["accuracy"], 0.64) + + +@unittest.skipIf(get_device_sm() < 100, "Test requires CUDA SM 100 or higher") +class TestFP4GemmSM120Auto(FP4GemmSM120Base, unittest.TestCase): + backend = "auto" + + +if __name__ == "__main__": + unittest.main() diff --git a/test/registered/quant/test_nvfp4_marlin_fallback.py b/test/registered/quant/test_nvfp4_marlin_fallback.py new file mode 100644 index 000000000000..348294d5e565 --- /dev/null +++ b/test/registered/quant/test_nvfp4_marlin_fallback.py @@ -0,0 +1,788 @@ +"""Tests for NVFP4 Marlin fallback on non-Blackwell GPUs (SM75+).""" + +import unittest + +import torch + +from sglang.test.ci.ci_register import register_cuda_ci +from sglang.test.test_utils import CustomTestCase + +register_cuda_ci(est_time=480, suite="stage-b-test-1-gpu-large") + +_FP4_MARLIN_GROUP_SIZE = 16 + +_FP4_E2M1_LUT_VALUES = [ + 0.0, + 0.5, + 1.0, + 1.5, + 2.0, + 3.0, + 4.0, + 6.0, + 0.0, + -0.5, + -1.0, + -1.5, + -2.0, + -3.0, + -4.0, + -6.0, +] + + +def _check_requirements(): + from sglang.srt.utils import is_cuda + + if not is_cuda(): + return False + from sglang.srt.layers.quantization.marlin_utils_fp4 import is_fp4_marlin_supported + + if not is_fp4_marlin_supported(): + return False + return True + + +def _dequant_fp4_weights( + raw_weight: torch.Tensor, device: torch.device +) -> torch.Tensor: + """Dequantize uint8-packed FP4 E2M1 weights to float32 via lookup table.""" + lut = torch.tensor(_FP4_E2M1_LUT_VALUES, dtype=torch.float32, device=device) + lo = (raw_weight.int() & 0x0F).long() + hi = ((raw_weight.int() >> 4) & 0x0F).long() + return torch.stack([lut[lo], lut[hi]], dim=-1).reshape( + raw_weight.shape[0], raw_weight.shape[1] * 2 + ) + + +class _FakeLayer(torch.nn.Module): + """Minimal stand-in for a quantized layer in unit tests.""" + + pass + + +# --------------------------------------------------------------------------- +# Linear (non-MoE) tests +# --------------------------------------------------------------------------- +class TestNvfp4MarlinLinear(CustomTestCase): + """Test the FP4 Marlin linear layer fallback (non-MoE).""" + + def setUp(self): + if not _check_requirements(): + self.skipTest("Requirements not met (CUDA unavailable or SM < 75)") + self.device = torch.device("cuda") + self.dtype = torch.bfloat16 + + # -- helpers ------------------------------------------------------------- + + def _make_fake_fp4_layer(self, N, K): + layer = _FakeLayer() + layer.params_dtype = self.dtype + layer.input_size_per_partition = K + layer.output_size_per_partition = N + + layer.weight = torch.nn.Parameter( + torch.randint(0, 256, (N, K // 2), dtype=torch.uint8, device=self.device), + requires_grad=False, + ) + layer.weight_scale = torch.nn.Parameter( + torch.ones( + N, + K // _FP4_MARLIN_GROUP_SIZE, + dtype=torch.float8_e4m3fn, + device=self.device, + ), + requires_grad=False, + ) + layer.weight_scale_2_marlin = torch.nn.Parameter( + torch.tensor(1.0, dtype=torch.float32, device=self.device), + requires_grad=False, + ) + return layer + + def _run_fp4_marlin_vs_reference(self, M, N, K): + """Prepare a layer, run the Marlin kernel, return (kernel_out, ref_out).""" + from sglang.srt.layers.quantization.marlin_utils_fp4 import ( + apply_fp4_marlin_linear, + prepare_fp4_layer_for_marlin, + ) + + raw_weight = torch.randint( + 0, 256, (N, K // 2), dtype=torch.uint8, device=self.device + ) + dq_weight = _dequant_fp4_weights(raw_weight, self.device) + + raw_scale = torch.full( + (N, K // _FP4_MARLIN_GROUP_SIZE), + 1.0, + dtype=torch.float8_e4m3fn, + device=self.device, + ) + global_scale_val = torch.tensor(1.0, dtype=torch.float32, device=self.device) + + x = torch.randn(M, K, dtype=self.dtype, device=self.device) + ref_output = (x.float() @ dq_weight.T).to(self.dtype) + + layer = self._make_fake_fp4_layer(N, K) + layer.weight = torch.nn.Parameter(raw_weight, requires_grad=False) + layer.weight_scale = torch.nn.Parameter(raw_scale, requires_grad=False) + layer.weight_scale_2_marlin = torch.nn.Parameter( + global_scale_val.to(self.dtype), requires_grad=False + ) + + prepare_fp4_layer_for_marlin( + layer, + weight_attr="weight", + weight_scale_attr="weight_scale", + weight_global_scale_attr="weight_scale_2_marlin", + ) + + marlin_output = apply_fp4_marlin_linear( + input=x, + weight=layer.weight, + weight_scale=layer.weight_scale, + weight_global_scale=layer.weight_scale_2_marlin, + workspace=layer.marlin_workspace, + size_n=N, + size_k=K, + ) + return marlin_output, ref_output + + # -- tests --------------------------------------------------------------- + + def test_prepare_and_apply_fp4_marlin_linear(self): + """Smoke test: shape and dtype are correct after prepare + apply.""" + from sglang.srt.layers.quantization.marlin_utils_fp4 import ( + apply_fp4_marlin_linear, + prepare_fp4_layer_for_marlin, + ) + + N, K, M = 256, 128, 16 + layer = self._make_fake_fp4_layer(N, K) + prepare_fp4_layer_for_marlin( + layer, + weight_attr="weight", + weight_scale_attr="weight_scale", + weight_global_scale_attr="weight_scale_2_marlin", + ) + + self.assertTrue(hasattr(layer, "marlin_workspace")) + + x = torch.randn(M, K, dtype=self.dtype, device=self.device) + output = apply_fp4_marlin_linear( + input=x, + weight=layer.weight, + weight_scale=layer.weight_scale, + weight_global_scale=layer.weight_scale_2_marlin, + workspace=layer.marlin_workspace, + size_n=N, + size_k=K, + ) + self.assertEqual(output.shape, (M, N)) + self.assertEqual(output.dtype, self.dtype) + + def test_fp4_marlin_numerical_correctness(self): + """Kernel output vs BF16 dequant reference (cosine sim, MAE, assert_close).""" + N, K, M = 256, 256, 32 + marlin_output, ref_output = self._run_fp4_marlin_vs_reference(M, N, K) + + self.assertEqual(marlin_output.shape, ref_output.shape) + self.assertEqual(marlin_output.dtype, ref_output.dtype) + + cos_sim = torch.nn.functional.cosine_similarity( + marlin_output.float().flatten(), ref_output.float().flatten(), dim=0 + ) + self.assertGreater( + cos_sim.item(), + 0.99, + f"Cosine similarity {cos_sim.item():.6f} too low", + ) + + rel_mae = torch.mean( + torch.abs(marlin_output.float() - ref_output.float()) + ) / torch.mean(torch.abs(ref_output.float())) + self.assertLess( + rel_mae.item(), + 0.04, + f"Relative MAE {rel_mae.item():.6f} >= 0.04", + ) + + torch.testing.assert_close(marlin_output, ref_output, atol=1e-1, rtol=1e-1) + + def test_fp4_marlin_multiple_shapes(self): + """Numerical correctness across various (M, N, K) dimensions.""" + shapes = [ + (1, 256, 256), + (16, 512, 128), + (64, 128, 512), + (32, 256, 256), + ] + for M, N, K in shapes: + with self.subTest(M=M, N=N, K=K): + marlin_out, ref_out = self._run_fp4_marlin_vs_reference(M, N, K) + rel_mae = torch.mean( + torch.abs(marlin_out.float() - ref_out.float()) + ) / torch.mean(torch.abs(ref_out.float())) + self.assertLess( + rel_mae.item(), + 0.04, + f"Shape ({M},{N},{K}): relative MAE {rel_mae.item():.6f} >= 0.04", + ) + + def test_fp4_marlin_linear_with_bias(self): + """Verify output_with_bias == output_no_bias + bias.""" + from sglang.srt.layers.quantization.marlin_utils_fp4 import ( + apply_fp4_marlin_linear, + prepare_fp4_layer_for_marlin, + ) + + N, K, M = 256, 128, 16 + layer = self._make_fake_fp4_layer(N, K) + prepare_fp4_layer_for_marlin( + layer, + weight_attr="weight", + weight_scale_attr="weight_scale", + weight_global_scale_attr="weight_scale_2_marlin", + ) + + x = torch.randn(M, K, dtype=self.dtype, device=self.device) + bias = torch.randn(N, dtype=self.dtype, device=self.device) + + common = dict( + weight=layer.weight, + weight_scale=layer.weight_scale, + weight_global_scale=layer.weight_scale_2_marlin, + workspace=layer.marlin_workspace, + size_n=N, + size_k=K, + ) + output_no_bias = apply_fp4_marlin_linear(input=x, **common) + output_with_bias = apply_fp4_marlin_linear(input=x, bias=bias, **common) + + torch.testing.assert_close( + output_with_bias, output_no_bias + bias, atol=1e-5, rtol=1e-5 + ) + + def test_fp4_marlin_registered_op_numerical(self): + """torch.ops.sglang.apply_fp4_marlin_linear matches the direct Python call.""" + import sglang.srt.layers.quantization.marlin_utils_fp4 # noqa: F401 + from sglang.srt.layers.quantization.marlin_utils_fp4 import ( + apply_fp4_marlin_linear, + prepare_fp4_layer_for_marlin, + ) + + N, K, M = 256, 128, 16 + layer = self._make_fake_fp4_layer(N, K) + prepare_fp4_layer_for_marlin( + layer, + weight_attr="weight", + weight_scale_attr="weight_scale", + weight_global_scale_attr="weight_scale_2_marlin", + ) + + x = torch.randn(M, K, dtype=self.dtype, device=self.device) + + common = dict( + input=x, + weight=layer.weight, + weight_scale=layer.weight_scale, + weight_global_scale=layer.weight_scale_2_marlin, + workspace=layer.marlin_workspace, + size_n=N, + size_k=K, + ) + + direct_out = apply_fp4_marlin_linear(**common) + op_out = torch.ops.sglang.apply_fp4_marlin_linear(**common) + + self.assertEqual(op_out.shape, direct_out.shape) + self.assertEqual(op_out.dtype, direct_out.dtype) + torch.testing.assert_close(op_out, direct_out, atol=0, rtol=0) + + def test_fp4_marlin_3d_input(self): + """Verify correct reshape for 3-D input (batch, seq_len, K).""" + from sglang.srt.layers.quantization.marlin_utils_fp4 import ( + apply_fp4_marlin_linear, + prepare_fp4_layer_for_marlin, + ) + + N, K = 256, 128 + batch, seq_len = 2, 8 + layer = self._make_fake_fp4_layer(N, K) + prepare_fp4_layer_for_marlin( + layer, + weight_attr="weight", + weight_scale_attr="weight_scale", + weight_global_scale_attr="weight_scale_2_marlin", + ) + + x_3d = torch.randn(batch, seq_len, K, dtype=self.dtype, device=self.device) + x_2d = x_3d.reshape(-1, K) + + common = dict( + weight=layer.weight, + weight_scale=layer.weight_scale, + weight_global_scale=layer.weight_scale_2_marlin, + workspace=layer.marlin_workspace, + size_n=N, + size_k=K, + ) + + out_3d = apply_fp4_marlin_linear(input=x_3d, **common) + out_2d = apply_fp4_marlin_linear(input=x_2d, **common) + + self.assertEqual(out_3d.shape, (batch, seq_len, N)) + self.assertEqual(out_3d.dtype, self.dtype) + torch.testing.assert_close(out_3d.reshape(-1, N), out_2d, atol=0, rtol=0) + + def test_fake_apply_fp4_marlin_linear(self): + """Fake impl for PCG tracing must return the correct shape and dtype.""" + from sglang.srt.layers.quantization.marlin_utils_fp4 import ( + fake_apply_fp4_marlin_linear, + ) + + N, K = 256, 128 + + for input_shape in [(16, K), (2, 8, K)]: + with self.subTest(input_shape=input_shape): + x = torch.randn(*input_shape, dtype=self.dtype, device=self.device) + out = fake_apply_fp4_marlin_linear( + input=x, + weight=torch.empty(0, device=self.device), + weight_scale=torch.empty(0, device=self.device), + weight_global_scale=torch.empty(0, device=self.device), + workspace=torch.empty(0, device=self.device), + size_n=N, + size_k=K, + ) + expected_shape = input_shape[:-1] + (N,) + self.assertEqual(out.shape, expected_shape) + self.assertEqual(out.dtype, self.dtype) + + def test_prepare_rejects_bad_weight_shape(self): + """prepare_fp4_layer_for_marlin must raise on mismatched weight shape.""" + from sglang.srt.layers.quantization.marlin_utils_fp4 import ( + prepare_fp4_layer_for_marlin, + ) + + N, K = 256, 128 + layer = _FakeLayer() + layer.params_dtype = self.dtype + layer.input_size_per_partition = K + layer.output_size_per_partition = N + + layer.weight = torch.nn.Parameter( + torch.randint( + 0, 256, (N + 1, K // 2), dtype=torch.uint8, device=self.device + ), + requires_grad=False, + ) + layer.weight_scale = torch.nn.Parameter( + torch.ones( + N, + K // _FP4_MARLIN_GROUP_SIZE, + dtype=torch.float8_e4m3fn, + device=self.device, + ), + requires_grad=False, + ) + layer.weight_scale_2_marlin = torch.nn.Parameter( + torch.tensor(1.0, dtype=torch.float32, device=self.device), + requires_grad=False, + ) + + with self.assertRaises(AssertionError): + prepare_fp4_layer_for_marlin( + layer, + weight_attr="weight", + weight_scale_attr="weight_scale", + weight_global_scale_attr="weight_scale_2_marlin", + ) + + def test_prepare_fp4_layer_permutes_bias(self): + """prepare_fp4_layer_for_marlin must permute layer.bias when present.""" + from sglang.srt.layers.quantization.marlin_utils_fp4 import ( + prepare_fp4_layer_for_marlin, + ) + + N, K = 256, 128 + layer = self._make_fake_fp4_layer(N, K) + original_bias = torch.randn(N, dtype=self.dtype, device=self.device) + layer.bias = torch.nn.Parameter(original_bias.clone(), requires_grad=False) + + prepare_fp4_layer_for_marlin( + layer, + weight_attr="weight", + weight_scale_attr="weight_scale", + weight_global_scale_attr="weight_scale_2_marlin", + ) + + self.assertEqual(layer.bias.shape, (N,)) + self.assertEqual(layer.bias.dtype, self.dtype) + self.assertFalse( + torch.equal(layer.bias.data, original_bias), + "Bias should be permuted by prepare_fp4_layer_for_marlin", + ) + + def test_fp4_marlin_custom_op_registration(self): + """apply_fp4_marlin_linear must be registered as torch.ops.sglang for PCG.""" + import sglang.srt.layers.quantization.marlin_utils_fp4 # noqa: F401 + + self.assertTrue( + hasattr(torch.ops.sglang, "apply_fp4_marlin_linear"), + "apply_fp4_marlin_linear not registered as a custom op", + ) + + def test_nvfp4_marlin_scale_values_correctness(self): + """Verify scale conversion produces analytically correct values.""" + from sglang.srt.layers.quantization.marlin_utils_fp4 import ( + nvfp4_marlin_process_global_scale, + nvfp4_marlin_process_scales, + ) + + # -- global scale: BF16 -- + # fp4_exp=2, target_exp=8 => bias = 2^7 - 2^1 = 126 + # result = 1.0 * 2^(126-7) = 2^119 + gs_bf16 = torch.tensor(1.0, dtype=torch.bfloat16, device=self.device) + result_bf16 = nvfp4_marlin_process_global_scale(gs_bf16) + expected_bf16 = torch.tensor(2.0**119, dtype=torch.bfloat16, device=self.device) + self.assertEqual( + result_bf16.item(), + expected_bf16.item(), + f"BF16 global_scale(1.0): expected 2^119, got {result_bf16.item()}", + ) + self.assertEqual(result_bf16.dtype, torch.bfloat16) + + # -- global scale: FP16 -- + # fp4_exp=2, target_exp=5 => bias = 2^4 - 2^1 = 14 + # result = 1.0 * 2^(14-7) = 128 + gs_fp16 = torch.tensor(1.0, dtype=torch.float16, device=self.device) + result_fp16 = nvfp4_marlin_process_global_scale(gs_fp16) + self.assertEqual( + result_fp16.item(), + 128.0, + f"FP16 global_scale(1.0): expected 128.0, got {result_fp16.item()}", + ) + self.assertEqual(result_fp16.dtype, torch.float16) + + # -- global scale: linearity -- + gs_2 = torch.tensor(2.0, dtype=torch.bfloat16, device=self.device) + result_2 = nvfp4_marlin_process_global_scale(gs_2) + self.assertAlmostEqual( + result_2.item(), + 2.0 * result_bf16.item(), + places=0, + msg="Global scale processing should be linear", + ) + + # -- per-group scales: structural properties -- + N, K_div_group = 64, 16 + raw_scale = torch.ones( + N, K_div_group, dtype=torch.float8_e4m3fn, device=self.device + ).to(self.dtype) + processed = nvfp4_marlin_process_scales(raw_scale) + + self.assertEqual(processed.dtype, torch.float8_e4m3fn) + self.assertEqual(processed.shape, (N, K_div_group)) + self.assertFalse(torch.isnan(processed.to(self.dtype)).any()) + + # Deterministic + self.assertTrue(torch.equal(processed, nvfp4_marlin_process_scales(raw_scale))) + + # Large scales (448 = FP8 E4M3 max) must not produce NaN + large_scale = torch.full( + (N, K_div_group), 448.0, dtype=self.dtype, device=self.device + ) + proc_large = nvfp4_marlin_process_scales(large_scale) + self.assertFalse(torch.isnan(proc_large.to(self.dtype)).any()) + self.assertEqual(proc_large.shape, (N, K_div_group)) + + +# --------------------------------------------------------------------------- +# MoE tests +# --------------------------------------------------------------------------- +class TestNvfp4MarlinMoe(CustomTestCase): + """Test the FP4 Marlin MoE fallback.""" + + def setUp(self): + if not _check_requirements(): + self.skipTest("Requirements not met (CUDA unavailable or SM < 75)") + self.device = torch.device("cuda") + self.dtype = torch.bfloat16 + try: + from sglang.jit_kernel.gptq_marlin_repack import gptq_marlin_repack + + self._gptq_marlin_repack = gptq_marlin_repack + except ImportError: + self.skipTest("gptq_marlin_repack JIT compilation not available") + self._perm = torch.empty(0, dtype=torch.int, device=self.device) + + # -- helpers ------------------------------------------------------------- + + def _repack_fp4_weight(self, raw_fp4, size_k, size_n): + """Repack raw uint8 FP4 weights into Marlin tile layout.""" + qw = raw_fp4.view(torch.int32).T.contiguous() + return self._gptq_marlin_repack(qw, self._perm, size_k, size_n, num_bits=4) + + def _make_marlin_scale(self, size_k, size_n): + from sglang.srt.layers.quantization.marlin_utils import marlin_permute_scales + from sglang.srt.layers.quantization.marlin_utils_fp4 import ( + nvfp4_marlin_process_scales, + ) + + raw = torch.ones( + size_k // _FP4_MARLIN_GROUP_SIZE, + size_n, + dtype=self.dtype, + device=self.device, + ) + permuted = marlin_permute_scales(raw, size_k, size_n, _FP4_MARLIN_GROUP_SIZE) + return nvfp4_marlin_process_scales(permuted) + + def _make_processed_global_scale(self): + from sglang.srt.layers.quantization.marlin_utils_fp4 import ( + nvfp4_marlin_process_global_scale, + ) + + return nvfp4_marlin_process_global_scale( + torch.tensor(1.0, dtype=self.dtype, device=self.device) + ) + + # -- tests --------------------------------------------------------------- + + def test_fused_marlin_moe_fp4(self): + """Smoke test: shape, dtype, no NaN for multi-expert MoE.""" + from sglang.srt.layers.moe.fused_moe_triton.fused_marlin_moe import ( + fused_marlin_moe, + ) + + E, K, N, topk, M = 4, 128, 64, 2, 8 + + def _rand_weight(size_k, size_n): + raw = torch.randint( + 0, 256, (size_n, size_k // 2), dtype=torch.uint8, device=self.device + ) + return self._repack_fp4_weight(raw, size_k, size_n) + + w1 = torch.stack([_rand_weight(K, 2 * N) for _ in range(E)]) + w2 = torch.stack([_rand_weight(N, K) for _ in range(E)]) + w1_scale = torch.stack([self._make_marlin_scale(K, 2 * N) for _ in range(E)]) + w2_scale = torch.stack([self._make_marlin_scale(N, K) for _ in range(E)]) + + gs = self._make_processed_global_scale() + w1_gs = gs.expand(E) + w2_gs = gs.expand(E) + + hidden = torch.randn(M, K, dtype=self.dtype, device=self.device) + gating = torch.randn(M, E, dtype=self.dtype, device=self.device) + topk_weights, topk_ids = torch.topk(torch.softmax(gating, dim=-1), topk, dim=-1) + + output = fused_marlin_moe( + hidden_states=hidden, + w1=w1, + w2=w2, + w1_scale=w1_scale, + w2_scale=w2_scale, + gating_output=gating, + topk_weights=topk_weights, + topk_ids=topk_ids, + num_bits=4, + w1_global_scale=w1_gs, + w2_global_scale=w2_gs, + ) + + self.assertEqual(output.shape, (M, K)) + self.assertEqual(output.dtype, self.dtype) + self.assertFalse(torch.isnan(output).any(), "Output contains NaN!") + + def test_fused_marlin_moe_fp4_numerical(self): + """E=1, topk=1 MoE output vs dequant reference (SiLU-gated).""" + from sglang.srt.layers.moe.fused_moe_triton.fused_marlin_moe import ( + fused_marlin_moe, + ) + + E, K, N, topk, M = 1, 128, 64, 1, 8 + + raw_w1 = torch.randint( + 0, 256, (2 * N, K // 2), dtype=torch.uint8, device=self.device + ) + raw_w2 = torch.randint( + 0, 256, (K, N // 2), dtype=torch.uint8, device=self.device + ) + dq_w1 = _dequant_fp4_weights(raw_w1, self.device) + dq_w2 = _dequant_fp4_weights(raw_w2, self.device) + + w1 = self._repack_fp4_weight(raw_w1, K, 2 * N).unsqueeze(0) + w2 = self._repack_fp4_weight(raw_w2, N, K).unsqueeze(0) + w1_scale = self._make_marlin_scale(K, 2 * N).unsqueeze(0) + w2_scale = self._make_marlin_scale(N, K).unsqueeze(0) + + gs = self._make_processed_global_scale() + w1_gs = gs.unsqueeze(0) + w2_gs = gs.unsqueeze(0) + + x = torch.randn(M, K, dtype=self.dtype, device=self.device) * 0.1 + gating = torch.ones(M, E, dtype=self.dtype, device=self.device) + topk_weights = torch.ones(M, topk, dtype=self.dtype, device=self.device) + topk_ids = torch.zeros(M, topk, dtype=torch.int64, device=self.device) + + output = fused_marlin_moe( + hidden_states=x, + w1=w1, + w2=w2, + w1_scale=w1_scale, + w2_scale=w2_scale, + gating_output=gating, + topk_weights=topk_weights, + topk_ids=topk_ids, + num_bits=4, + w1_global_scale=w1_gs, + w2_global_scale=w2_gs, + ) + + gate_up = x.float() @ dq_w1.T + gate, up = gate_up[:, :N], gate_up[:, N:] + ref_output = ((torch.nn.functional.silu(gate) * up) @ dq_w2.T).to(self.dtype) + + self.assertEqual(output.shape, ref_output.shape) + self.assertFalse(torch.isinf(output).any(), "MoE output contains Inf") + self.assertFalse(torch.isnan(output).any(), "MoE output contains NaN") + + finite = torch.isfinite(ref_output) & torch.isfinite(output) + if finite.any(): + cos_sim = torch.nn.functional.cosine_similarity( + output[finite].float().flatten(), + ref_output[finite].float().flatten(), + dim=0, + ) + self.assertGreater( + cos_sim.item(), + 0.90, + f"MoE cosine similarity {cos_sim.item():.4f} too low", + ) + + def test_prepare_moe_fp4_layer_for_marlin(self): + """Weight repacking produces correct shapes for all expert tensors.""" + from sglang.srt.layers.quantization.marlin_utils_fp4 import ( + prepare_moe_fp4_layer_for_marlin, + ) + + E, K, N = 4, 128, 64 + + class _FakeMoeRunnerConfig: + is_gated = True + + layer = _FakeLayer() + layer.num_local_experts = E + layer.intermediate_size_per_partition = N + layer.params_dtype = self.dtype + layer.moe_runner_config = _FakeMoeRunnerConfig() + + layer.w13_weight = torch.nn.Parameter( + torch.randint( + 0, 256, (E, 2 * N, K // 2), dtype=torch.uint8, device=self.device + ), + requires_grad=False, + ) + layer.w2_weight = torch.nn.Parameter( + torch.randint( + 0, 256, (E, K, N // 2), dtype=torch.uint8, device=self.device + ), + requires_grad=False, + ) + layer.w13_weight_scale = torch.nn.Parameter( + torch.ones( + E, + 2 * N, + K // _FP4_MARLIN_GROUP_SIZE, + dtype=torch.float8_e4m3fn, + device=self.device, + ), + requires_grad=False, + ) + layer.w2_weight_scale = torch.nn.Parameter( + torch.ones( + E, + K, + N // _FP4_MARLIN_GROUP_SIZE, + dtype=torch.float8_e4m3fn, + device=self.device, + ), + requires_grad=False, + ) + layer.w13_weight_scale_2 = torch.nn.Parameter( + torch.ones(E, 2, dtype=torch.float32, device=self.device), + requires_grad=False, + ) + layer.w2_weight_scale_2 = torch.nn.Parameter( + torch.ones(E, dtype=torch.float32, device=self.device), + requires_grad=False, + ) + + prepare_moe_fp4_layer_for_marlin(layer) + + self.assertEqual(layer.w13_weight.shape[0], E) + self.assertEqual(layer.w2_weight.shape[0], E) + self.assertEqual(layer.w13_weight_scale_2.shape, (E,)) + self.assertEqual(layer.w2_weight_scale_2.shape, (E,)) + + +# --------------------------------------------------------------------------- +# Support / capability tests +# --------------------------------------------------------------------------- +class TestFp4MarlinSupport(CustomTestCase): + """Test the capability detection functions.""" + + def test_is_fp4_marlin_supported(self): + from sglang.srt.layers.quantization.marlin_utils_fp4 import ( + is_fp4_marlin_supported, + ) + + result = is_fp4_marlin_supported() + if torch.cuda.is_available() and torch.version.hip is None: + cap = torch.cuda.get_device_capability() + sm = cap[0] * 10 + cap[1] + expected = sm >= 75 + self.assertEqual(result, expected) + elif torch.version.hip is not None: + self.assertFalse(result, "FP4 Marlin should not be supported on ROCm/HIP") + + def test_min_capability_changed(self): + """get_min_capability() must return 75 (not 100).""" + from sglang.srt.layers.quantization.modelopt_quant import ModelOptFp4Config + + cap = ModelOptFp4Config.get_min_capability() + self.assertEqual(cap, 75, f"Expected 75, got {cap}") + + def test_should_use_fp4_marlin_fallback(self): + """should_use_fp4_marlin_fallback returns True on non-Blackwell SM>=75.""" + from sglang.srt.layers.quantization.marlin_utils_fp4 import ( + should_use_fp4_marlin_fallback, + ) + + result = should_use_fp4_marlin_fallback() + self.assertIsInstance(result, bool) + + if torch.cuda.is_available() and torch.version.hip is None: + cap = torch.cuda.get_device_capability() + sm = cap[0] * 10 + cap[1] + is_blackwell = sm >= 100 + if is_blackwell: + self.assertFalse( + result, + "Blackwell GPUs should NOT use Marlin fallback (native FP4)", + ) + elif sm >= 75: + self.assertTrue( + result, + f"SM{sm} should use Marlin fallback, but got False", + ) + else: + self.assertFalse( + result, + f"SM{sm} should not support FP4 Marlin at all", + ) + + +if __name__ == "__main__": + unittest.main(verbosity=3) diff --git a/test/registered/quant/test_quant_config_parsing.py b/test/registered/quant/test_quant_config_parsing.py index f2aeba7122d4..5b9d12817bd0 100644 --- a/test/registered/quant/test_quant_config_parsing.py +++ b/test/registered/quant/test_quant_config_parsing.py @@ -5,7 +5,7 @@ from sglang.test.ci.ci_register import register_cpu_ci from sglang.test.test_utils import CustomTestCase -register_cpu_ci(est_time=5, suite="stage-a-test-cpu") +register_cpu_ci(est_time=20, suite="stage-a-test-cpu") class TestQuantLogString(CustomTestCase): diff --git a/test/registered/quant/test_quantization.py b/test/registered/quant/test_quantization.py index 4ab48314482b..cdb1f0970ef4 100644 --- a/test/registered/quant/test_quantization.py +++ b/test/registered/quant/test_quantization.py @@ -16,7 +16,7 @@ write_results_to_json, ) -register_cuda_ci(est_time=185, suite="stage-b-test-1-gpu-large") +register_cuda_ci(est_time=370, suite="stage-b-test-1-gpu-large") MODEL_SCORE_THRESHOLDS = { "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4": 0.825, diff --git a/test/registered/quant/test_torchao.py b/test/registered/quant/test_torchao.py index c5f6ad5991bd..a53b929a6dd6 100644 --- a/test/registered/quant/test_torchao.py +++ b/test/registered/quant/test_torchao.py @@ -5,7 +5,7 @@ from sglang import Engine from sglang.test.ci.ci_register import register_amd_ci, register_cuda_ci -register_cuda_ci(est_time=103, suite="stage-b-test-1-gpu-small") +register_cuda_ci(est_time=200, suite="stage-b-test-1-gpu-small") register_amd_ci(est_time=230, suite="stage-b-test-1-gpu-small-amd") from sglang.lang.chat_template import get_chat_template_by_model_path from sglang.srt.utils import kill_process_tree diff --git a/test/registered/quant/test_w4a8_deepseek_v3.py b/test/registered/quant/test_w4a8_deepseek_v3.py index a6c33bea33de..f30e16d2ab52 100644 --- a/test/registered/quant/test_w4a8_deepseek_v3.py +++ b/test/registered/quant/test_w4a8_deepseek_v3.py @@ -6,7 +6,7 @@ from sglang.srt.utils import kill_process_tree from sglang.test.ci.ci_register import register_cuda_ci -from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k +from sglang.test.run_eval import run_eval from sglang.test.test_utils import ( DEFAULT_DEEPSEEK_W4AFP8_MODEL_FOR_TEST, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, @@ -41,18 +41,18 @@ def tearDownClass(cls): def test_gsm8k(self): args = SimpleNamespace( - num_shots=5, - data_path=None, - num_questions=1200, - parallel=1200, - max_new_tokens=512, - host="http://127.0.0.1", - port=int(self.base_url.split(":")[-1]), + base_url=self.base_url, + model=self.model, + eval_name="gsm8k", + api="completion", + max_tokens=512, + num_examples=1200, + num_threads=1200, ) - metrics = run_eval_few_shot_gsm8k(args) + metrics = run_eval(args) print(f"Eval accuracy of GSM8K: {metrics=}") - self.assertGreater(metrics["accuracy"], 0.92) + self.assertGreater(metrics["score"], 0.92) class TestDeepseekV3W4Afp8Mtp(CustomTestCase): @@ -95,18 +95,18 @@ def test_gsm8k( self, ): args = SimpleNamespace( - num_shots=5, - data_path=None, - num_questions=200, - max_new_tokens=512, - parallel=128, - host="http://127.0.0.1", - port=int(self.base_url.split(":")[-1]), + base_url=self.base_url, + model=self.model, + eval_name="gsm8k", + api="completion", + max_tokens=512, + num_examples=200, + num_threads=128, ) - metrics = run_eval_few_shot_gsm8k(args) + metrics = run_eval(args) print(f"{metrics=}") - server_info = requests.get(self.base_url + "/get_server_info") + server_info = requests.get(self.base_url + "/server_info") avg_spec_accept_length = server_info.json()["internal_states"][0][ "avg_spec_accept_length" ] @@ -115,10 +115,10 @@ def test_gsm8k( if is_in_ci(): write_github_step_summary( f"### test_gsm8k (deepseek-v3 mtp)\n" - f'{metrics["accuracy"]=:.3f}\n' + f'{metrics["score"]=:.3f}\n' f"{avg_spec_accept_length=:.2f}\n" ) - self.assertGreater(metrics["accuracy"], 0.935) + self.assertGreater(metrics["score"], 0.935) self.assertGreater(avg_spec_accept_length, 2.9) @@ -163,18 +163,18 @@ def test_gsm8k( self, ): args = SimpleNamespace( - num_shots=5, - data_path=None, - num_questions=200, - max_new_tokens=512, - parallel=128, - host="http://127.0.0.1", - port=int(self.base_url.split(":")[-1]), + base_url=self.base_url, + model=self.model, + eval_name="gsm8k", + api="completion", + max_tokens=512, + num_examples=200, + num_threads=128, ) - metrics = run_eval_few_shot_gsm8k(args) + metrics = run_eval(args) print(f"Eval accuracy of GSM8K: {metrics=}") - self.assertGreater(metrics["accuracy"], 0.92) + self.assertGreater(metrics["score"], 0.92) class TestDeepseekV3W4Afp8DeepepAutoMtp(CustomTestCase): @@ -231,18 +231,18 @@ def test_gsm8k( self, ): args = SimpleNamespace( - num_shots=5, - data_path=None, - num_questions=200, - max_new_tokens=512, - parallel=128, - host="http://127.0.0.1", - port=int(self.base_url.split(":")[-1]), + base_url=self.base_url, + model=self.model, + eval_name="gsm8k", + api="completion", + max_tokens=512, + num_examples=200, + num_threads=128, ) - metrics = run_eval_few_shot_gsm8k(args) + metrics = run_eval(args) print(f"Eval accuracy of GSM8K: {metrics=}") - self.assertGreater(metrics["accuracy"], 0.92) + self.assertGreater(metrics["score"], 0.92) if __name__ == "__main__": diff --git a/test/registered/quant/test_w8a8_quantization.py b/test/registered/quant/test_w8a8_quantization.py index 88e344831547..a2a2a1cb4920 100644 --- a/test/registered/quant/test_w8a8_quantization.py +++ b/test/registered/quant/test_w8a8_quantization.py @@ -6,7 +6,7 @@ from sglang.srt.utils import kill_process_tree from sglang.test.ci.ci_register import register_cuda_ci -from sglang.test.few_shot_gsm8k import run_eval +from sglang.test.run_eval import run_eval from sglang.test.test_utils import ( DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, DEFAULT_URL_FOR_TEST, @@ -51,17 +51,17 @@ def test_gsm8k(self): self.skipTest("gsm8k_accuracy_threshold not set for this test") args = SimpleNamespace( - num_shots=5, - data_path=None, - num_questions=200, - max_new_tokens=512, - parallel=128, - host="http://127.0.0.1", - port=int(self.base_url.split(":")[-1]), + base_url=self.base_url, + model=self.model, + eval_name="gsm8k", + api="completion", + max_tokens=512, + num_examples=200, + num_threads=128, ) metrics = run_eval(args) print(metrics) - self.assertGreater(metrics["accuracy"], self.gsm8k_accuracy_threshold) + self.assertGreater(metrics["score"], self.gsm8k_accuracy_threshold) def run_decode(self, max_new_tokens): response = requests.post( diff --git a/test/registered/rl/test_pause_generation_tensor_consistency.py b/test/registered/rl/test_pause_generation_tensor_consistency.py new file mode 100644 index 000000000000..761a963e27b3 --- /dev/null +++ b/test/registered/rl/test_pause_generation_tensor_consistency.py @@ -0,0 +1,212 @@ +""" +Unit test for the pause_generation. +""" + +import unittest + +import torch + +from sglang.test.ci.ci_register import register_cpu_ci +from sglang.test.test_utils import CustomTestCase + +register_cpu_ci(est_time=5, suite="stage-a-test-cpu") + + +# --------------------------------------------------------------------------- +# Minimal stand-alone simulation of the relevant ScheduleBatch logic. +# We do NOT import ScheduleBatch directly because that pulls in heavy +# GPU-extension dependencies (deep_gemm, etc.). Instead we replicate the +# exact behaviour of filter_batch / merge_batch / is_empty that matters for +# this bug. +# --------------------------------------------------------------------------- + + +class _FakeReq: + def __init__(self, finished: bool = False): + self._finished = finished + + def finished(self) -> bool: + return self._finished + + +class _FakeBatch: + """Minimal simulation of the scheduler-side fields touched by this bug.""" + + def __init__(self, n: int, all_finished: bool = False): + self.reqs = [_FakeReq(finished=all_finished) for _ in range(n)] + self.seq_lens = torch.ones(n, dtype=torch.int32) + self.seq_lens_cpu = torch.ones(n, dtype=torch.int32) + self.orig_seq_lens = torch.ones(n, dtype=torch.int32) + self.req_pool_indices = torch.zeros(n, dtype=torch.int64) + self.output_ids = torch.zeros(n, dtype=torch.int64) + self.seq_lens_sum = n + + def is_empty(self) -> bool: + return len(self.reqs) == 0 + + def filter_batch(self): + """Simplified filter_batch: identical early-return logic to ScheduleBatch.""" + keep_indices = [i for i in range(len(self.reqs)) if not self.reqs[i].finished()] + + # Early-return paths — tensors are NOT updated. + if len(keep_indices) == 0: + self.reqs = [] + return + if len(keep_indices) == len(self.reqs): + return + + # Full filter path (not needed for this test but included for completeness). + self.reqs = [self.reqs[i] for i in keep_indices] + idx = torch.tensor(keep_indices, dtype=torch.int64) + self.seq_lens = self.seq_lens[idx] + self.seq_lens_cpu = self.seq_lens_cpu[idx] + self.orig_seq_lens = self.orig_seq_lens[idx] + self.req_pool_indices = self.req_pool_indices[idx] + if self.output_ids is not None: + self.output_ids = self.output_ids[idx] + self.seq_lens_sum = int(self.seq_lens.sum().item()) + + def merge_batch(self, other: "_FakeBatch"): + """Simplified merge_batch: replicates the tensor-cat logic.""" + self.seq_lens = torch.cat([self.seq_lens, other.seq_lens]) + self.seq_lens_cpu = torch.cat([self.seq_lens_cpu, other.seq_lens_cpu]) + self.orig_seq_lens = torch.cat([self.orig_seq_lens, other.orig_seq_lens]) + self.req_pool_indices = torch.cat( + [self.req_pool_indices, other.req_pool_indices] + ) + if self.output_ids is not None and other.output_ids is not None: + self.output_ids = torch.cat([self.output_ids, other.output_ids]) + self.seq_lens_sum += other.seq_lens_sum + self.reqs.extend(other.reqs) + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +class TestPauseGenerationTensorConsistency(CustomTestCase): + """Verify pause_generation does not corrupt the running_batch tensors.""" + + # ------------------------------------------------------------------ + # Bug reproduction + # ------------------------------------------------------------------ + + def test_buggy_merge_violates_invariant(self): + """Without the fix, merging an all-finished extend batch breaks the + invariant ``len(reqs) == seq_lens.shape[0]``.""" + N = 651 + running_batch = _FakeBatch(N) + last_batch = _FakeBatch(1, all_finished=True) + + # Pre-fix pause_generation path: + # filter_batch -> reqs=[], tensors unchanged (early return) + last_batch.filter_batch() + self.assertTrue(last_batch.is_empty()) + # Tensors still have M=1 element each despite reqs being empty. + self.assertEqual(last_batch.seq_lens.shape[0], 1) + + # BUG: unconditional merge + running_batch.merge_batch(last_batch) + + # Invariant is now violated. + self.assertEqual(len(running_batch.reqs), N) + self.assertEqual(running_batch.seq_lens.shape[0], N + 1) + self.assertNotEqual( + len(running_batch.reqs), + running_batch.seq_lens.shape[0], + "len(reqs) != seq_lens.shape[0] — invariant broken", + ) + + # ------------------------------------------------------------------ + # Fix verification + # ------------------------------------------------------------------ + + def test_fix_preserves_invariant_when_all_reqs_finished(self): + """With the is_empty() guard the merge is skipped and invariant holds.""" + N = 651 + running_batch = _FakeBatch(N) + last_batch = _FakeBatch(1, all_finished=True) + + last_batch.filter_batch() # reqs=[], tensors untouched + + # FIX: mirror get_next_batch_to_run's is_empty() guard + if not last_batch.is_empty(): + if running_batch.is_empty(): + running_batch = last_batch + else: + running_batch.merge_batch(last_batch) + + self.assertEqual( + len(running_batch.reqs), + running_batch.seq_lens.shape[0], + "Invariant preserved: len(reqs) == seq_lens.shape[0]", + ) + self.assertEqual(len(running_batch.reqs), N) + self.assertEqual(running_batch.seq_lens.shape[0], N) + + def test_fix_still_merges_partial_extend_batch(self): + """The fix must not skip a merge when some extend requests survive.""" + N = 651 + running_batch = _FakeBatch(N) + + # 3-req extend batch: 1 finished, 2 still running + last_batch = _FakeBatch(3, all_finished=False) + last_batch.reqs[0] = _FakeReq(finished=True) + + last_batch.filter_batch() # keeps 2 running reqs + + self.assertEqual(len(last_batch.reqs), 2) + self.assertFalse(last_batch.is_empty()) + + if not last_batch.is_empty(): + if running_batch.is_empty(): + running_batch = last_batch + else: + running_batch.merge_batch(last_batch) + + self.assertEqual(len(running_batch.reqs), N + 2) + self.assertEqual(running_batch.seq_lens.shape[0], N + 2) + + def test_fix_handles_empty_running_batch(self): + """When running_batch is empty and last_batch has live reqs, the fix + replaces running_batch (matches get_next_batch_to_run semantics).""" + running_batch = _FakeBatch(0) + last_batch = _FakeBatch(3, all_finished=False) + + last_batch.filter_batch() # all 3 alive -> no-op + + if not last_batch.is_empty(): + if running_batch.is_empty(): + running_batch = last_batch + else: + running_batch.merge_batch(last_batch) + + self.assertEqual(len(running_batch.reqs), 3) + self.assertEqual(running_batch.seq_lens.shape[0], 3) + + def test_next_filter_batch_early_return_preserves_inconsistency(self): + """After the buggy merge, the next filter_batch call returns early + (because keep_indices covers all N reqs), leaving N+1 tensors behind.""" + N = 651 + running_batch = _FakeBatch(N) + last_batch = _FakeBatch(1, all_finished=True) + + last_batch.filter_batch() + running_batch.merge_batch(last_batch) # BUG path + + # Simulate update_running_batch -> filter_batch: all N reqs still alive + running_batch.filter_batch() + + # Early return: tensors NOT trimmed + self.assertEqual(len(running_batch.reqs), N) + self.assertEqual( + running_batch.seq_lens.shape[0], + N + 1, + "seq_lens is still N+1 after the second filter_batch early-return", + ) + self.assertNotEqual(len(running_batch.reqs), running_batch.seq_lens.shape[0]) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/registered/rl/test_return_routed_experts.py b/test/registered/rl/test_return_routed_experts.py index fec50b965c0c..7e3e3b517883 100644 --- a/test/registered/rl/test_return_routed_experts.py +++ b/test/registered/rl/test_return_routed_experts.py @@ -1,13 +1,14 @@ import asyncio +import json import logging import unittest from typing import List import aiohttp -import requests import torch from torch.nn.utils.rnn import pad_sequence +from sglang.benchmark.utils import download_and_cache_hf_file from sglang.srt.layers.moe.routed_experts_capturer import ( extract_routed_experts_from_meta_info, ) @@ -21,23 +22,18 @@ popen_launch_server, ) -register_cuda_ci(est_time=360, suite="stage-c-test-4-gpu-h100") +register_cuda_ci(est_time=200, suite="stage-b-test-2-gpu-large") register_amd_ci( - est_time=360, - suite="stage-c-test-4-gpu-amd", - disabled="TP=4 DP=4 routed expert mismatch >15% on AMD; needs TP/DP tuning + concurrency reduction", + est_time=200, + suite="stage-b-test-2-gpu-large-amd", + disabled="TP=2 DP=2 routed expert mismatch >15% on AMD; needs TP/DP tuning + concurrency reduction", ) -SHAREGPT_URL = ( - "https://huggingface.co/datasets/anon8231489123/" - "ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json" -) +SHAREGPT_REPO_ID = "anon8231489123/ShareGPT_Vicuna_unfiltered" +SHAREGPT_FILENAME = "ShareGPT_V3_unfiltered_cleaned_split.json" logger = logging.getLogger(__name__) -@unittest.skip( - "Flaky in CI, need to be fixed and re-enabled. See https://github.com/sgl-project/sglang/issues/21266" -) class TestReturnRoutedExperts(CustomTestCase): # modified from test_hicache.py @classmethod @@ -50,31 +46,28 @@ def setUpClass(cls): "--disable-cuda-graph", "--disable-radix-cache", "--tp", - 4, + 2, "--dp", - 4, + 2, "--enable-dp-attention", ] cls.reference_args = [ "--enable-return-routed-experts", "--enable-deterministic-inference", "--tp", - 4, + 2, "--dp", - 4, + 2, "--enable-dp-attention", ] cls.sampling_args = { "temperature": 0, } # prepare ShareGPT dataset - try: - response = requests.get(SHAREGPT_URL, timeout=60) - response.raise_for_status() - data = response.json() - print(f"Dataset size: {len(data)}") - except requests.exceptions.RequestException as e: - raise Exception(f"Failed to download ShareGPT dataset: {e}") from e + dataset_path = download_and_cache_hf_file(SHAREGPT_REPO_ID, SHAREGPT_FILENAME) + with open(dataset_path) as f: + data = json.load(f) + print(f"Dataset size: {len(data)}") cls.texts = [] for s in data: if "conversations" in s and len(s["conversations"]) > 0: @@ -145,7 +138,7 @@ def _run_endpoint_test(cls, endpoint): f"Total mismatches report: {num_mismatches} out of {num_baseline_topks} ({num_mismatches/num_baseline_topks:.4%})" ) assert ( - num_mismatches / num_baseline_topks < 0.05 + num_mismatches / num_baseline_topks < 0.10 ), f"Too many mismatches: {num_mismatches} out of {num_baseline_topks} ({num_mismatches/num_baseline_topks:.4%})" @classmethod diff --git a/test/registered/sampling/test_fused_temperature_softmax.py b/test/registered/sampling/test_fused_temperature_softmax.py new file mode 100644 index 000000000000..956d67c2cefe --- /dev/null +++ b/test/registered/sampling/test_fused_temperature_softmax.py @@ -0,0 +1,268 @@ +"""Correctness tests for fused_temperature_softmax Triton kernel.""" + +import unittest + +import torch +from flashinfer.sampling import softmax as flashinfer_softmax + +from sglang.srt.layers.fused_sampling import ( + fused_temperature_softmax, + fused_temperature_softmax_inplace, +) +from sglang.srt.utils import get_device +from sglang.test.ci.ci_register import register_cuda_ci +from sglang.test.test_utils import CustomTestCase + +register_cuda_ci(est_time=15, suite="stage-b-test-small-1-gpu") + + +def reference_temperature_softmax(logits, temperatures): + """Reference implementation: div + softmax (separate kernels).""" + logits = logits.clone() + logits.div_(temperatures) + return torch.softmax(logits, dim=-1).float() + + +class TestFusedTemperatureSoftmax(CustomTestCase): + @classmethod + def setUpClass(cls): + torch.set_default_device(get_device()) + torch.manual_seed(42) + + def _check_close(self, fused, ref, atol=1e-5, rtol=1e-5): + """Assert outputs are close and both are valid probability distributions.""" + self.assertEqual(fused.shape, ref.shape) + # Valid probabilities: non-negative, sum to ~1 + self.assertTrue((fused >= 0).all(), f"Negative probabilities in fused output") + row_sums = fused.sum(dim=-1) + torch.testing.assert_close( + row_sums, + torch.ones_like(row_sums), + atol=1e-4, + rtol=1e-4, + ) + torch.testing.assert_close(fused, ref, atol=atol, rtol=rtol) + + # --- out-of-place kernel --- + + def test_basic(self): + logits = torch.randn(4, 1024, dtype=torch.bfloat16) + temps = torch.tensor([0.7, 1.0, 1.5, 2.0], dtype=torch.float32).view(-1, 1) + ref = reference_temperature_softmax(logits, temps) + fused = fused_temperature_softmax(logits, temps) + self._check_close(fused, ref, atol=1e-4, rtol=1e-3) + + def test_large_vocab(self): + logits = torch.randn(8, 128256, dtype=torch.bfloat16) + temps = torch.full((8, 1), 0.6, dtype=torch.float32) + ref = reference_temperature_softmax(logits, temps) + fused = fused_temperature_softmax(logits, temps) + self._check_close(fused, ref, atol=1e-4, rtol=1e-3) + + def test_batch_sizes(self): + for bs in [1, 2, 16, 64, 128, 512]: + logits = torch.randn(bs, 32000, dtype=torch.bfloat16) + temps = torch.rand(bs, 1, dtype=torch.float32) * 1.5 + 0.1 + ref = reference_temperature_softmax(logits, temps) + fused = fused_temperature_softmax(logits, temps) + self._check_close(fused, ref, atol=1e-4, rtol=1e-3) + + def test_temperature_one(self): + """Temperature=1.0 should be equivalent to plain softmax.""" + logits = torch.randn(16, 32000, dtype=torch.bfloat16) + temps = torch.ones(16, 1, dtype=torch.float32) + ref = torch.softmax(logits.float(), dim=-1) + fused = fused_temperature_softmax(logits, temps) + self._check_close(fused, ref, atol=1e-4, rtol=1e-3) + + def test_very_low_temperature(self): + """Very low temperature should produce near-one-hot distribution.""" + logits = torch.randn(4, 1024, dtype=torch.bfloat16) + temps = torch.full((4, 1), 0.01, dtype=torch.float32) + fused = fused_temperature_softmax(logits, temps) + # Max probability should be very close to 1.0 + max_probs = fused.max(dim=-1).values + self.assertTrue((max_probs > 0.99).all()) + + def test_very_high_temperature(self): + """Very high temperature should produce near-uniform distribution.""" + logits = torch.randn(4, 1024, dtype=torch.bfloat16) + temps = torch.full((4, 1), 100.0, dtype=torch.float32) + fused = fused_temperature_softmax(logits, temps) + uniform = 1.0 / 1024 + self.assertTrue( + (fused - uniform).abs().max() < 0.01, + "High temperature should produce near-uniform distribution", + ) + + def test_fp16_input(self): + logits = torch.randn(8, 32000, dtype=torch.float16) + temps = torch.rand(8, 1, dtype=torch.float32) * 1.5 + 0.1 + ref = reference_temperature_softmax(logits, temps) + fused = fused_temperature_softmax(logits, temps) + self._check_close(fused, ref, atol=1e-3, rtol=1e-2) + + def test_fp32_input(self): + logits = torch.randn(8, 32000, dtype=torch.float32) + temps = torch.rand(8, 1, dtype=torch.float32) + 0.5 + ref = reference_temperature_softmax(logits, temps) + fused = fused_temperature_softmax(logits, temps) + self._check_close(fused, ref, atol=1e-5, rtol=1e-5) + + def test_mixed_temperatures(self): + """Each row has a different temperature.""" + logits = torch.randn(8, 32000, dtype=torch.bfloat16) + temps = torch.tensor( + [0.1, 0.5, 0.7, 1.0, 1.2, 1.5, 2.0, 5.0], dtype=torch.float32 + ).view(-1, 1) + ref = reference_temperature_softmax(logits, temps) + fused = fused_temperature_softmax(logits, temps) + self._check_close(fused, ref, atol=1e-4, rtol=1e-3) + + def test_empty_batch(self): + logits = torch.randn(0, 32000, dtype=torch.bfloat16) + temps = torch.ones(0, 1, dtype=torch.float32) + fused = fused_temperature_softmax(logits, temps) + self.assertEqual(fused.shape, (0, 32000)) + + # --- in-place kernel --- + + def test_inplace_basic(self): + logits = torch.randn(8, 32000, dtype=torch.float32) + temps = torch.rand(8, 1, dtype=torch.float32) * 1.5 + 0.1 + ref = reference_temperature_softmax(logits, temps) + fused_temperature_softmax_inplace(logits, temps) + # In-place writes back to logits in the original dtype + self._check_close(logits.float(), ref, atol=1e-5, rtol=1e-5) + + def test_inplace_bf16(self): + logits = torch.randn(8, 32000, dtype=torch.bfloat16) + temps = torch.rand(8, 1, dtype=torch.float32) + 0.5 + ref = reference_temperature_softmax(logits, temps) + fused_temperature_softmax_inplace(logits, temps) + self._check_close(logits.float(), ref, atol=2e-3, rtol=2e-3) + + def test_inplace_large_vocab(self): + logits = torch.randn(4, 128256, dtype=torch.bfloat16) + temps = torch.full((4, 1), 0.8, dtype=torch.float32) + ref = reference_temperature_softmax(logits, temps) + fused_temperature_softmax_inplace(logits, temps) + self._check_close(logits.float(), ref, atol=2e-3, rtol=2e-3) + + # --- exact known-value correctness --- + + def test_known_uniform_logits(self): + """Identical logits must produce uniform distribution regardless of temperature.""" + logits = torch.zeros(2, 5, dtype=torch.float32) + temps = torch.tensor([0.5, 2.0], dtype=torch.float32).view(-1, 1) + fused = fused_temperature_softmax(logits, temps) + expected = torch.full((2, 5), 0.2, dtype=torch.float32, device=fused.device) + torch.testing.assert_close(fused, expected, atol=1e-6, rtol=1e-6) + + def test_known_softmax_values(self): + """Verify against hand-computed softmax(logits / T).""" + logits = torch.tensor([[1.0, 2.0, 3.0]], dtype=torch.float32) + temps = torch.tensor([[1.0]], dtype=torch.float32) + fused = fused_temperature_softmax(logits, temps) + # softmax([1,2,3]) = exp([1,2,3]) / sum(exp([1,2,3])) + e = torch.exp(logits) + expected = (e / e.sum(dim=-1, keepdim=True)).to(fused.device) + torch.testing.assert_close(fused, expected, atol=1e-6, rtol=1e-6) + + def test_known_softmax_with_temperature(self): + """Verify softmax([1,2,3] / 0.5) against hand computation.""" + logits = torch.tensor([[1.0, 2.0, 3.0]], dtype=torch.float32) + temps = torch.tensor([[0.5]], dtype=torch.float32) + fused = fused_temperature_softmax(logits, temps) + scaled = logits / 0.5 + e = torch.exp(scaled) + expected = (e / e.sum(dim=-1, keepdim=True)).to(fused.device) + torch.testing.assert_close(fused, expected, atol=1e-6, rtol=1e-6) + + # --- argmax preservation --- + + def test_argmax_preserved(self): + """argmax must be invariant to temperature for finite T > 0.""" + logits = torch.randn(64, 32000, dtype=torch.bfloat16) + original_argmax = logits.float().argmax(dim=-1) + for t_val in [0.1, 0.5, 1.0, 2.0, 10.0]: + temps = torch.full((64, 1), t_val, dtype=torch.float32) + fused = fused_temperature_softmax(logits, temps) + fused_argmax = fused.argmax(dim=-1) + self.assertTrue( + (original_argmax == fused_argmax).all(), + f"argmax changed at temperature={t_val}", + ) + + # --- numerical stability --- + + def test_large_logits_no_nan(self): + """Extreme logit magnitudes must not produce NaN or Inf.""" + logits = torch.tensor( + [[1e6, -1e6, 0.0], [1e4, 1e4 + 1, 1e4 - 1]], dtype=torch.float32 + ) + temps = torch.tensor([[1.0], [0.01]], dtype=torch.float32) + fused = fused_temperature_softmax(logits, temps) + self.assertFalse(torch.isnan(fused).any(), "NaN in output") + self.assertFalse(torch.isinf(fused).any(), "Inf in output") + row_sums = fused.sum(dim=-1) + torch.testing.assert_close( + row_sums, + torch.ones_like(row_sums), + atol=1e-4, + rtol=1e-4, + ) + + def test_large_logits_inplace_no_nan(self): + """In-place variant: extreme logits must not produce NaN or Inf.""" + logits = torch.tensor( + [[1e6, -1e6, 0.0], [1e4, 1e4 + 1, 1e4 - 1]], dtype=torch.float32 + ) + temps = torch.tensor([[1.0], [0.01]], dtype=torch.float32) + fused_temperature_softmax_inplace(logits, temps) + self.assertFalse(torch.isnan(logits).any(), "NaN in output") + self.assertFalse(torch.isinf(logits).any(), "Inf in output") + + # --- comparison with flashinfer.sampling.softmax --- + + def test_vs_flashinfer_basic(self): + logits = torch.randn(4, 1024, dtype=torch.bfloat16) + temps = torch.tensor([0.7, 1.0, 1.5, 2.0], dtype=torch.float32).view(-1, 1) + fused = fused_temperature_softmax(logits, temps) + fi = flashinfer_softmax(logits, temperature=temps.view(-1)) + self._check_close(fused, fi, atol=1e-4, rtol=1e-3) + + def test_vs_flashinfer_large_vocab(self): + logits = torch.randn(8, 128256, dtype=torch.bfloat16) + temps = torch.full((8, 1), 0.6, dtype=torch.float32) + fused = fused_temperature_softmax(logits, temps) + fi = flashinfer_softmax(logits, temperature=temps.view(-1)) + self._check_close(fused, fi, atol=1e-4, rtol=1e-3) + + def test_vs_flashinfer_batch_sizes(self): + for bs in [1, 16, 64, 128, 512]: + logits = torch.randn(bs, 32000, dtype=torch.bfloat16) + temps = torch.rand(bs, 1, dtype=torch.float32) * 1.5 + 0.1 + fused = fused_temperature_softmax(logits, temps) + fi = flashinfer_softmax(logits, temperature=temps.view(-1)) + self._check_close(fused, fi, atol=1e-4, rtol=1e-3) + + def test_vs_flashinfer_scalar_temperature(self): + logits = torch.randn(16, 32000, dtype=torch.bfloat16) + temps_2d = torch.full((16, 1), 0.8, dtype=torch.float32) + fused = fused_temperature_softmax(logits, temps_2d) + fi = flashinfer_softmax(logits, temperature=0.8) + self._check_close(fused, fi, atol=1e-4, rtol=1e-3) + + def test_vs_flashinfer_mixed_temperatures(self): + logits = torch.randn(8, 32000, dtype=torch.bfloat16) + temps = torch.tensor( + [0.1, 0.5, 0.7, 1.0, 1.2, 1.5, 2.0, 5.0], dtype=torch.float32 + ).view(-1, 1) + fused = fused_temperature_softmax(logits, temps) + fi = flashinfer_softmax(logits, temperature=temps.view(-1)) + self._check_close(fused, fi, atol=1e-4, rtol=1e-3) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/registered/sampling/test_pytorch_sampling_backend.py b/test/registered/sampling/test_pytorch_sampling_backend.py index 501d706787ef..6abea3c3e1ea 100644 --- a/test/registered/sampling/test_pytorch_sampling_backend.py +++ b/test/registered/sampling/test_pytorch_sampling_backend.py @@ -14,7 +14,7 @@ popen_launch_server, ) -register_cuda_ci(est_time=66, suite="stage-b-test-1-gpu-small") +register_cuda_ci(est_time=150, suite="stage-b-test-1-gpu-small") register_amd_ci(est_time=66, suite="stage-b-test-1-gpu-small-amd") diff --git a/test/registered/scheduler/test_abort.py b/test/registered/scheduler/test_abort.py index ce53425d9ae9..47399bc4e07c 100644 --- a/test/registered/scheduler/test_abort.py +++ b/test/registered/scheduler/test_abort.py @@ -19,7 +19,7 @@ run_and_check_memory_leak, ) -register_cuda_ci(est_time=131, suite="stage-b-test-1-gpu-small") +register_cuda_ci(est_time=350, suite="stage-b-test-1-gpu-small") register_amd_ci(est_time=300, suite="stage-b-test-1-gpu-small-amd") @@ -317,6 +317,8 @@ def setUpClass(cls): timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, other_args=[ "--max-running-requests=1", + # Disable PCG to avoid padding in flashinfer backend. Ref: https://github.com/sgl-project/sglang/pull/21452 + "--disable-piecewise-cuda-graph", ], ) diff --git a/test/registered/scheduler/test_chunked_prefill.py b/test/registered/scheduler/test_chunked_prefill.py index 9b72ce21f0ca..bc22976c344c 100644 --- a/test/registered/scheduler/test_chunked_prefill.py +++ b/test/registered/scheduler/test_chunked_prefill.py @@ -7,7 +7,7 @@ from sglang.test.ci.ci_register import register_amd_ci, register_cuda_ci from sglang.test.test_utils import CustomTestCase, run_mmlu_test, run_mulit_request_test -register_cuda_ci(est_time=312, suite="stage-b-test-1-gpu-small") +register_cuda_ci(est_time=550, suite="stage-b-test-1-gpu-small") register_amd_ci(est_time=312, suite="stage-b-test-1-gpu-small-amd") diff --git a/test/registered/scheduler/test_priority_scheduling.py b/test/registered/scheduler/test_priority_scheduling.py index 339db350ad5d..72321b694a66 100644 --- a/test/registered/scheduler/test_priority_scheduling.py +++ b/test/registered/scheduler/test_priority_scheduling.py @@ -41,6 +41,8 @@ def setUpClass(cls): "--max-queued-requests", # Enforce max queued request number is 3 "3", "--enable-priority-scheduling", # Enable priority scheduling + # Disable PCG to avoid padding in flashinfer backend. Ref: https://github.com/sgl-project/sglang/pull/21452 + "--disable-piecewise-cuda-graph", ), return_stdout_stderr=(cls.stdout, cls.stderr), ) @@ -247,6 +249,7 @@ def setUpClass(cls): "--max-queued-requests", # Enforce max queued request number is 3 "3", "--enable-priority-scheduling", # Enable priority scheduling + "--disable-piecewise-cuda-graph", ), return_stdout_stderr=(cls.stdout, cls.stderr), ) diff --git a/test/registered/scheduler/test_retract_decode.py b/test/registered/scheduler/test_retract_decode.py index 15f333104460..0628c97e425b 100644 --- a/test/registered/scheduler/test_retract_decode.py +++ b/test/registered/scheduler/test_retract_decode.py @@ -17,7 +17,7 @@ ) from sglang.utils import is_in_ci -register_cuda_ci(est_time=311, suite="stage-b-test-1-gpu-small") +register_cuda_ci(est_time=550, suite="stage-b-test-1-gpu-small") register_amd_ci(est_time=600, suite="stage-b-test-1-gpu-small-amd") diff --git a/test/registered/sessions/test_session_control.py b/test/registered/sessions/test_session_control.py index 48dd5f434656..0e6b2a7f22eb 100644 --- a/test/registered/sessions/test_session_control.py +++ b/test/registered/sessions/test_session_control.py @@ -31,7 +31,7 @@ def remove_prefix(text: str, prefix: str) -> str: return text[len(prefix) :] if text.startswith(prefix) else text -class TestSessionControl(unittest.TestCase): +class TestSessionControl(CustomTestCase): @classmethod def setUpClass(cls): cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST @@ -43,6 +43,8 @@ def setUpClass(cls): other_args=[ "--attention-backend", "triton", + "--disable-cuda-graph", + "--disable-piecewise-cuda-graph", ], ) diff --git a/test/registered/spec/eagle/test_deepseek_v3_fp4_mtp_small.py b/test/registered/spec/eagle/test_deepseek_v3_fp4_mtp_small.py index 58cba7abb393..d1bce0c96663 100644 --- a/test/registered/spec/eagle/test_deepseek_v3_fp4_mtp_small.py +++ b/test/registered/spec/eagle/test_deepseek_v3_fp4_mtp_small.py @@ -6,7 +6,7 @@ from sglang.srt.environ import envs from sglang.srt.utils import kill_process_tree from sglang.test.ci.ci_register import register_cuda_ci -from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k +from sglang.test.run_eval import run_eval from sglang.test.send_one import BenchArgs, send_one_prompt from sglang.test.test_utils import ( DEFAULT_URL_FOR_TEST, @@ -18,7 +18,6 @@ register_cuda_ci(est_time=900, suite="stage-b-test-4-gpu-b200") - FULL_DEEPSEEK_V3_FP4_MODEL_PATH = "nvidia/DeepSeek-V3-0324-FP4" SERVER_LAUNCH_TIMEOUT = 1200 @@ -74,15 +73,15 @@ def test_a_gsm8k( requests.get(self.base_url + "/flush_cache") args = SimpleNamespace( - num_shots=5, - data_path=None, - num_questions=200, - max_new_tokens=512, - parallel=128, - host="http://127.0.0.1", - port=int(self.base_url.split(":")[-1]), + base_url=self.base_url, + model=self.model, + eval_name="gsm8k", + api="completion", + max_tokens=512, + num_examples=200, + num_threads=128, ) - metrics = run_eval_few_shot_gsm8k(args) + metrics = run_eval(args) print(f"{metrics=}") server_info = requests.get(self.base_url + "/server_info").json() @@ -94,11 +93,11 @@ def test_a_gsm8k( if is_in_ci(): write_github_step_summary( f"### test_gsm8k (deepseek-v3-fp4 mtp)\n" - f'{metrics["accuracy"]=:.3f}\n' + f'{metrics["score"]=:.3f}\n' f"{avg_spec_accept_length=:.2f}\n" ) - self.assertGreater(metrics["accuracy"], 0.94) + self.assertGreater(metrics["score"], 0.94) self.assertGreater(avg_spec_accept_length, 2.7) def test_bs_1_speed(self): diff --git a/test/registered/spec/eagle/test_eagle3_basic.py b/test/registered/spec/eagle/test_eagle3_basic.py index 442eefcc6f82..4cf48afc1468 100644 --- a/test/registered/spec/eagle/test_eagle3_basic.py +++ b/test/registered/spec/eagle/test_eagle3_basic.py @@ -12,7 +12,7 @@ DEFAULT_TARGET_MODEL_EAGLE3, ) -register_cuda_ci(est_time=50, suite="stage-b-test-1-gpu-small") +register_cuda_ci(est_time=200, suite="stage-b-test-1-gpu-small") register_amd_ci(est_time=50, suite="stage-b-test-1-gpu-small") _is_hip = is_hip() diff --git a/test/registered/spec/eagle/test_eagle_dp_attention.py b/test/registered/spec/eagle/test_eagle_dp_attention.py index a25edf588f7f..dc20b39819a3 100644 --- a/test/registered/spec/eagle/test_eagle_dp_attention.py +++ b/test/registered/spec/eagle/test_eagle_dp_attention.py @@ -5,7 +5,7 @@ from sglang.srt.environ import envs from sglang.test.ci.ci_register import register_amd_ci, register_cuda_ci -from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k +from sglang.test.run_eval import run_eval from sglang.test.send_one import BenchArgs, send_one_prompt from sglang.test.test_utils import ( DEFAULT_DRAFT_MODEL_EAGLE_DP_ATTN, @@ -76,18 +76,18 @@ def test_a_gsm8k(self): requests.get(self.base_url + "/flush_cache") args = SimpleNamespace( - num_shots=5, - data_path=None, - num_questions=200, - max_new_tokens=512, - parallel=128, - host="http://127.0.0.1", - port=int(self.base_url.split(":")[-1]), + base_url=self.base_url, + model=self.model, + eval_name="gsm8k", + api="completion", + max_tokens=512, + num_examples=200, + num_threads=128, ) - metrics = run_eval_few_shot_gsm8k(args) + metrics = run_eval(args) print(f"{metrics=}") - server_info = requests.get(self.base_url + "/get_server_info") + server_info = requests.get(self.base_url + "/server_info") server_data = server_info.json() # Try to get avg_spec_accept_length @@ -104,14 +104,14 @@ def test_a_gsm8k(self): if is_in_ci(): write_github_step_summary( f"### test_gsm8k (EAGLE3 DP Attention)\n" - f'{metrics["accuracy"]=:.3f}\n' + f'{metrics["score"]=:.3f}\n' f"{avg_spec_accept_length=:.2f}\n" ) if is_in_amd_ci(): # AMD triton backend produces slightly lower accuracy than FA3 on NVIDIA - self.assertGreater(metrics["accuracy"], 0.88) + self.assertGreater(metrics["score"], 0.88) else: - self.assertGreater(metrics["accuracy"], 0.91) + self.assertGreater(metrics["score"], 0.91) if avg_spec_accept_length is not None: if is_in_amd_ci(): # AMD triton backend produces slightly lower accept length than FA3 on NVIDIA diff --git a/test/registered/spec/eagle/test_eagle_infer_a.py b/test/registered/spec/eagle/test_eagle_infer_a.py index 1845fc1b29ac..cecc2eb531a7 100644 --- a/test/registered/spec/eagle/test_eagle_infer_a.py +++ b/test/registered/spec/eagle/test_eagle_infer_a.py @@ -1,32 +1,18 @@ -import os import random import unittest -import requests -import torch - import sglang as sgl -from sglang.srt.utils import kill_process_tree from sglang.srt.utils.hf_transformers_utils import get_tokenizer from sglang.test.ci.ci_register import register_cuda_ci from sglang.test.test_utils import ( DEFAULT_DRAFT_MODEL_EAGLE, DEFAULT_DRAFT_MODEL_EAGLE3, - DEFAULT_MODEL_NAME_FOR_TEST_MLA, DEFAULT_TARGET_MODEL_EAGLE, DEFAULT_TARGET_MODEL_EAGLE3, - DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, - DEFAULT_URL_FOR_TEST, CustomTestCase, - is_in_ci, - popen_launch_server, ) -register_cuda_ci(est_time=561, suite="stage-b-test-1-gpu-large") - -torch_dtype = torch.float16 -prefill_tolerance = 5e-2 -decode_tolerance: float = 5e-2 +register_cuda_ci(est_time=450, suite="stage-b-test-1-gpu-large") class TestEAGLEEngine(CustomTestCase): @@ -204,255 +190,5 @@ class TestEAGLE3Engine(TestEAGLEEngine): } -class TestEAGLERadixCache(CustomTestCase): - BASE_CONFIG = { - "model_path": DEFAULT_TARGET_MODEL_EAGLE3, - "speculative_draft_model_path": DEFAULT_DRAFT_MODEL_EAGLE3, - "speculative_algorithm": "EAGLE3", - "speculative_num_steps": 2, - "speculative_eagle_topk": 2, - "speculative_num_draft_tokens": 5, - "mem_fraction_static": 0.7, - "dtype": "float16", - "trust_remote_code": True, - "attention_backend": "fa3", - "skip_server_warmup": True, - "cuda_graph_max_bs": 5, - } - - def test_correctness(self): - os.environ["SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN"] = "1" - configs = [ - # Basic config - self.BASE_CONFIG, - # Chunked prefill & Page Size > 1 - {**self.BASE_CONFIG, "chunked_prefill_size": 64, "page_size": 4}, - {**self.BASE_CONFIG, "page_size": 4}, - # Large page size tend to expose IMA bugs. - {**self.BASE_CONFIG, "page_size": 256}, - {**self.BASE_CONFIG, "cuda_graph_bs": [5], "page_size": 4}, - # Disable CUDA Graph - { - **self.BASE_CONFIG, - "disable_cuda_graph": True, - "page_size": 4, - }, - ] - - for i, config in enumerate(configs): - with self.subTest(i=i): - print(f"{config=}") - engine = sgl.Engine(**config, log_level="info", decode_log_interval=10) - try: - self._test_acc_length(engine) - self._test_batch_generation(engine) - finally: - engine.shutdown() - print("=" * 100) - del os.environ["SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN"] - - def _test_acc_length(self, engine): - warmup_prompt = [ - "Human: Give me a fully functional FastAPI server. Show the python code.\n\nAssistant:", - ] - sampling_params = {"temperature": 0, "max_new_tokens": 512} - output = engine.generate(warmup_prompt, sampling_params) - test_prompt = [ - "<|start_header_id|>system<|end_header_id|>\n\nYou are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.\n\nIf a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nGive me a fully functional FastAPI server. Show the python code.<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" - ] - output = engine.generate(test_prompt, sampling_params) - output = output[0] - - if "spec_verify_ct" in output["meta_info"]: - acc_length = ( - output["meta_info"]["completion_tokens"] - / output["meta_info"]["spec_verify_ct"] - ) - else: - acc_length = 1.0 - - speed = ( - output["meta_info"]["completion_tokens"] - / output["meta_info"]["e2e_latency"] - ) - print(f"{acc_length=:.4f}, {speed=}") - - self.assertGreater(acc_length, 2.5) - - def _test_batch_generation(self, engine): - prompts = [ - "Hello, my name is", - "The president of the United States is", - "The capital of France is", - "The future of AI is", - ] - params = {"temperature": 0, "max_new_tokens": 50} - - outputs = engine.generate(prompts, params) - for prompt, output in zip(prompts, outputs): - print(f"Prompt: {prompt}") - print(f"Generated: {output['text']}") - print("-" * 40) - - print(f"{engine.get_server_info()=}") - - avg_spec_accept_length = engine.get_server_info()["internal_states"][0][ - "avg_spec_accept_length" - ] - print(f"{avg_spec_accept_length=}") - self.assertGreater(avg_spec_accept_length, 2.0) - - -@unittest.skipIf(is_in_ci(), "To reduce the CI execution time.") -class TestEAGLEDraftExtend(CustomTestCase): - @classmethod - def setUpClass(cls): - cls.base_url = DEFAULT_URL_FOR_TEST - cls.process = popen_launch_server( - DEFAULT_TARGET_MODEL_EAGLE, - cls.base_url, - timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, - other_args=[ - "--speculative-algorithm", - "EAGLE", - "--speculative-draft-model-path", - DEFAULT_DRAFT_MODEL_EAGLE, - "--speculative-num-steps", - 1, - "--speculative-eagle-topk", - 1, - "--speculative-num-draft-tokens", - 2, - "--max-running-requests", - 4, - "--attention-backend", - "fa3", - ], - ) - cls.accept_len_threshold = 1.50 - - @classmethod - def tearDownClass(cls): - kill_process_tree(cls.process.pid) - - def test_one_batch_accept_length(self): - resp = requests.get(self.base_url + "/flush_cache") - self.assertEqual(resp.status_code, 200) - - prompts = [ - "Hello, my name is", - "The president of the United States is", - "The capital of France is", - "The future of AI is", - ] - url = self.base_url + "/generate" - data = { - "text": prompts, - "sampling_params": { - "temperature": 0, - "max_new_tokens": 512, - }, - } - response = requests.post(url, json=data) - self.assertEqual(response.status_code, 200) - outputs = response.json() - for i in range(len(prompts)): - output = outputs[i] - if "spec_verify_ct" in output["meta_info"]: - acc_length = ( - output["meta_info"]["completion_tokens"] - / output["meta_info"]["spec_verify_ct"] - ) - else: - acc_length = 1.0 - - print(f"{acc_length=}") - self.assertGreater(acc_length, self.accept_len_threshold) - - -class TestEAGLEDraftExtendFlashinfer(TestEAGLEDraftExtend): - @classmethod - def setUpClass(cls): - cls.base_url = DEFAULT_URL_FOR_TEST - cls.process = popen_launch_server( - DEFAULT_TARGET_MODEL_EAGLE, - cls.base_url, - timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, - other_args=[ - "--speculative-algorithm", - "EAGLE", - "--speculative-draft-model-path", - DEFAULT_DRAFT_MODEL_EAGLE, - "--speculative-num-steps", - 1, - "--speculative-eagle-topk", - 1, - "--speculative-num-draft-tokens", - 2, - "--max-running-requests", - 4, - "--attention-backend", - "flashinfer", - ], - ) - cls.accept_len_threshold = 1.50 - - -@unittest.skipIf(is_in_ci(), "To reduce the CI execution time.") -class TestEAGLEDraftExtendTriton(TestEAGLEDraftExtend): - @classmethod - def setUpClass(cls): - cls.base_url = DEFAULT_URL_FOR_TEST - cls.process = popen_launch_server( - DEFAULT_TARGET_MODEL_EAGLE, - cls.base_url, - timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, - other_args=[ - "--speculative-algorithm", - "EAGLE", - "--speculative-draft-model-path", - DEFAULT_DRAFT_MODEL_EAGLE, - "--speculative-num-steps", - 1, - "--speculative-eagle-topk", - 1, - "--speculative-num-draft-tokens", - 2, - "--max-running-requests", - 4, - "--attention-backend", - "triton", - ], - ) - cls.accept_len_threshold = 1.50 - - -@unittest.skipIf(is_in_ci(), "To reduce the CI execution time.") -class TestEAGLEDraftExtendFlashinferMLA(TestEAGLEDraftExtend): - @classmethod - def setUpClass(cls): - cls.base_url = DEFAULT_URL_FOR_TEST - cls.process = popen_launch_server( - DEFAULT_MODEL_NAME_FOR_TEST_MLA, - cls.base_url, - timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, - other_args=[ - "--speculative-algorithm", - "EAGLE", - "--speculative-num-steps", - 1, - "--speculative-eagle-topk", - 1, - "--speculative-num-draft-tokens", - 2, - "--max-running-requests", - 4, - "--attention-backend", - "flashinfer", - ], - ) - cls.accept_len_threshold = 1.85 - - if __name__ == "__main__": unittest.main() diff --git a/test/registered/spec/eagle/test_eagle_infer_b.py b/test/registered/spec/eagle/test_eagle_infer_b.py index 014cc5e27b89..7c726acfb780 100644 --- a/test/registered/spec/eagle/test_eagle_infer_b.py +++ b/test/registered/spec/eagle/test_eagle_infer_b.py @@ -12,20 +12,22 @@ from sglang.srt.environ import envs from sglang.test.ci.ci_register import register_cuda_ci -from sglang.test.few_shot_gsm8k import run_eval as run_gsm8k_eval from sglang.test.kits.abort_timeout_kit import ( AbortAllMixin, RunningTimeoutTwoWaveMixin, WaitingTimeoutMixin, ) from sglang.test.kits.radix_cache_server_kit import run_radix_attention_test +from sglang.test.run_eval import run_eval from sglang.test.server_fixtures.eagle_fixture import EagleServerBase from sglang.test.test_utils import DEFAULT_TARGET_MODEL_EAGLE, run_logprob_check -register_cuda_ci(est_time=1100, suite="stage-b-test-1-gpu-large") +register_cuda_ci(est_time=600, suite="stage-b-test-1-gpu-large") class TestEAGLEServerBasic(EagleServerBase): + """Core tests that run on every server config variant.""" + extra_args = ["--chunked-prefill-size", 128, "--max-running-requests", 8] # FIXME(lsyin): move the test methods to kits @@ -42,43 +44,22 @@ def test_request_abort(self): for p in threads: p.join() - def test_radix_attention(self): - run_radix_attention_test(self.base_url) - self.assertIsNone(self.process.poll()) - - def test_max_token_one(self): - requests.get(self.base_url + "/flush_cache") - - args = SimpleNamespace( - num_shots=5, - data_path=None, - num_questions=200, - max_new_tokens=1, - parallel=128, - host="http://127.0.0.1", - port=int(self.base_url.split(":")[-1]), - ) - - # Just run and check it does not hang - metrics = run_gsm8k_eval(args) - self.assertGreater(metrics["output_throughput"], 50) - def test_gsm8k(self): requests.get(self.base_url + "/flush_cache") args = SimpleNamespace( - num_shots=5, - data_path=None, - num_questions=200, - max_new_tokens=512, - parallel=128, - host="http://127.0.0.1", - port=int(self.base_url.split(":")[-1]), + base_url=self.base_url, + model=self.target_model, + eval_name="gsm8k", + api="completion", + max_tokens=512, + num_examples=200, + num_threads=128, ) - metrics = run_gsm8k_eval(args) + metrics = run_eval(args) print(f"{metrics=}") - self.assertGreater(metrics["accuracy"], 0.20) + self.assertGreater(metrics["score"], 0.20) server_info = requests.get(self.base_url + "/server_info").json() avg_spec_accept_length = server_info["internal_states"][0][ @@ -91,11 +72,49 @@ def test_gsm8k(self): if speculative_eagle_topk == 1: self.assertGreater(avg_spec_accept_length, 2.5) else: - self.assertGreater(avg_spec_accept_length, 3.49) + self.assertGreater(avg_spec_accept_length, 3.47) # Wait a little bit so that the memory check happens. time.sleep(4) + +class TestEAGLEServerAdditional(TestEAGLEServerBasic): + spec_topk = 5 + spec_steps = 8 + spec_tokens = 64 + extra_args = [ + "--max-running-requests", + 8, + "--cuda-graph-max-bs", + 5, + "--attention-backend", + "fa3", + "--page-size", + 256, + "--dtype", + "float16", + ] + + def test_radix_attention(self): + run_radix_attention_test(self.base_url) + self.assertIsNone(self.process.poll()) + + def test_max_token_one(self): + requests.get(self.base_url + "/flush_cache") + + args = SimpleNamespace( + base_url=self.base_url, + model=self.target_model, + eval_name="gsm8k", + api="completion", + max_tokens=1, + num_examples=200, + num_threads=128, + ) + + metrics = run_eval(args) + self.assertGreater(metrics["output_throughput"], 50) + def test_logprob_start_len(self): logprob_start_len = 4 new_tokens = 4 @@ -337,21 +356,6 @@ class TestEAGLEServerPageSizeTopk(TestEAGLEServerBasic): ] -class TestEAGLEServerPageSizeTopkFA3(TestEAGLEServerBasic): - # default topk=8 and tokens=64 - spec_topk = 5 - spec_steps = 8 - spec_tokens = 64 - - extra_args = [ - "--page-size=256", - "--attention-backend=fa3", - "--cuda-graph-max-bs=5", - "--dtype=float16", - "--max-running-requests=8", - ] - - class TestEAGLEAbortAll(AbortAllMixin, EagleServerBase): abort_all_max_new_tokens = 4000 extra_args = ["--max-running-requests=8"] diff --git a/test/registered/spec/eagle/test_eagle_infer_beta.py b/test/registered/spec/eagle/test_eagle_infer_beta.py index 252062611f2f..faee2ae48505 100644 --- a/test/registered/spec/eagle/test_eagle_infer_beta.py +++ b/test/registered/spec/eagle/test_eagle_infer_beta.py @@ -7,9 +7,9 @@ from sglang.srt.environ import envs from sglang.srt.utils import kill_process_tree from sglang.test.ci.ci_register import register_cuda_ci -from sglang.test.few_shot_gsm8k import run_eval from sglang.test.kits.matched_stop_kit import MatchedStopMixin from sglang.test.kits.radix_cache_server_kit import run_radix_attention_test +from sglang.test.run_eval import run_eval from sglang.test.test_utils import ( DEFAULT_DRAFT_MODEL_EAGLE, DEFAULT_TARGET_MODEL_EAGLE, @@ -86,19 +86,19 @@ def test_radix_attention(self): def test_gsm8k(self): args = SimpleNamespace( - num_shots=5, - data_path=None, - num_questions=1000, - max_new_tokens=512, - parallel=128, - host="http://127.0.0.1", - port=int(self.base_url.split(":")[-1]), + base_url=self.base_url, + model=self.model, + eval_name="gsm8k", + api="completion", + max_tokens=512, + num_examples=1000, + num_threads=128, ) metrics = run_eval(args) print(f"TestEagleLargeBS -- {metrics=}") self.assertGreater( - metrics["accuracy"], 0.23 - ) # 0.3333 for 60 questions; 0.234 for 1319 questions + metrics["score"], 0.22 + ) # ~0.227 for 1000 questions via /v1/completions assert self.process.poll() is None def test_logprob_spec_v2_match(self): diff --git a/test/registered/spec/eagle/test_eagle_infer_beta_dp_attention.py b/test/registered/spec/eagle/test_eagle_infer_beta_dp_attention.py index 8a6e9779fb51..407004dc01e6 100644 --- a/test/registered/spec/eagle/test_eagle_infer_beta_dp_attention.py +++ b/test/registered/spec/eagle/test_eagle_infer_beta_dp_attention.py @@ -6,7 +6,7 @@ from sglang.srt.environ import envs from sglang.srt.utils import kill_process_tree from sglang.test.ci.ci_register import register_cuda_ci -from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k +from sglang.test.run_eval import run_eval from sglang.test.test_utils import ( DEFAULT_MODEL_NAME_FOR_TEST_MLA, DEFAULT_MODEL_NAME_FOR_TEST_MLA_NEXTN, @@ -17,23 +17,23 @@ ) # EAGLE with DP attention on B200 (tp=2, dp=2, requires 4 B200 GPUs) -register_cuda_ci(est_time=300, suite="stage-c-test-4-gpu-b200") +register_cuda_ci(est_time=100, suite="stage-c-test-4-gpu-b200") -def test_gsm8k(base_url: str): +def test_gsm8k(base_url: str, model: str): requests.get(base_url + "/flush_cache") args = SimpleNamespace( - num_shots=5, - data_path=None, - num_questions=200, - max_new_tokens=512, - parallel=128, - host="http://127.0.0.1", - port=int(base_url.split(":")[-1]), + base_url=base_url, + model=model, + eval_name="gsm8k", + api="completion", + max_tokens=512, + num_examples=200, + num_threads=128, ) - metrics = run_eval_few_shot_gsm8k(args) - server_info = requests.get(base_url + "/get_server_info") + metrics = run_eval(args) + server_info = requests.get(base_url + "/server_info") avg_spec_accept_length = server_info.json()["internal_states"][0][ "avg_spec_accept_length" ] @@ -84,8 +84,8 @@ def tearDownClass(cls): kill_process_tree(cls.process.pid) def test_a_gsm8k(self): - metrics, avg_spec_accept_length = test_gsm8k(self.base_url) - self.assertGreater(metrics["accuracy"], 0.62) + metrics, avg_spec_accept_length = test_gsm8k(self.base_url, self.model) + self.assertGreater(metrics["score"], 0.62) self.assertGreater(avg_spec_accept_length, 2.7) diff --git a/test/registered/spec/eagle/test_eagle_infer_beta_dp_attention_large.py b/test/registered/spec/eagle/test_eagle_infer_beta_dp_attention_large.py index 8a7fcd00ddc8..c875e995c167 100644 --- a/test/registered/spec/eagle/test_eagle_infer_beta_dp_attention_large.py +++ b/test/registered/spec/eagle/test_eagle_infer_beta_dp_attention_large.py @@ -6,7 +6,7 @@ from sglang.srt.environ import envs from sglang.srt.utils import kill_process_tree from sglang.test.ci.ci_register import register_cuda_ci -from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k +from sglang.test.run_eval import run_eval from sglang.test.test_utils import ( DEFAULT_DEEPSEEK_NVFP4_MODEL_FOR_TEST, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, @@ -21,19 +21,19 @@ register_cuda_ci(est_time=600, suite="nightly-8-gpu-b200", nightly=True) -def test_gsm8k(base_url: str): +def test_gsm8k(base_url: str, model: str): requests.get(base_url + "/flush_cache") args = SimpleNamespace( - num_shots=5, - data_path=None, - num_questions=200, - max_new_tokens=512, - parallel=128, - host="http://127.0.0.1", - port=int(base_url.split(":")[-1]), + base_url=base_url, + model=model, + eval_name="gsm8k", + api="completion", + max_tokens=512, + num_examples=200, + num_threads=128, ) - metrics = run_eval_few_shot_gsm8k(args) + metrics = run_eval(args) server_info = requests.get(base_url + "/server_info").json() avg_spec_accept_length = server_info["internal_states"][0]["avg_spec_accept_length"] @@ -92,14 +92,14 @@ def tearDownClass(cls): kill_process_tree(cls.process.pid) def test_a_gsm8k(self): - metrics, avg_spec_accept_length = test_gsm8k(self.base_url) + metrics, avg_spec_accept_length = test_gsm8k(self.base_url, self.model) - self.assertGreater(metrics["accuracy"], 0.94) + self.assertGreater(metrics["score"], 0.94) self.assertGreater(avg_spec_accept_length, 2.7) if is_in_ci(): write_github_step_summary( f"### test_gsm8k (deepseek-v3-fp4 mtp)\n" - f'{metrics["accuracy"]=:.3f}\n' + f'{metrics["score"]=:.3f}\n' f"{avg_spec_accept_length=:.2f}\n" ) diff --git a/test/registered/spec/test_standalone_speculative_decoding.py b/test/registered/spec/test_standalone_speculative_decoding.py index 1a3cc0647e6a..f74bc31e4843 100644 --- a/test/registered/spec/test_standalone_speculative_decoding.py +++ b/test/registered/spec/test_standalone_speculative_decoding.py @@ -7,7 +7,7 @@ from sglang.srt.environ import envs from sglang.srt.utils import kill_process_tree from sglang.test.ci.ci_register import register_cuda_ci -from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k +from sglang.test.run_eval import run_eval from sglang.test.test_utils import ( DEFAULT_DRAFT_MODEL_STANDALONE, DEFAULT_TARGET_MODEL_STANDALONE, @@ -22,7 +22,6 @@ GSM_DATASET_PATH = None - # Default server arguments shared across all tests DEFAULT_SERVER_ARGS = [ "--trust-remote-code", @@ -97,22 +96,24 @@ def test_gsm8k(self): requests.get(self.base_url + "/flush_cache") args = SimpleNamespace( + base_url=self.base_url, + model=self.model, + eval_name="gsm8k", + api="completion", + max_tokens=512, + num_examples=100, + num_threads=128, num_shots=4, - num_questions=100, - max_new_tokens=512, - parallel=128, - host="http://127.0.0.1", - port=int(self.base_url.split(":")[-1]), - data_path=GSM_DATASET_PATH, + gsm8k_data_path=GSM_DATASET_PATH, ) - metrics = run_eval_few_shot_gsm8k(args) + metrics = run_eval(args) print(f"{metrics=}") # Use the appropriate metric key based on the test class - metric_key = "accuracy" + metric_key = "score" self.assertGreater(metrics[metric_key], self.accuracy_threshold) - server_info = requests.get(self.base_url + "/get_server_info") + server_info = requests.get(self.base_url + "/server_info") avg_spec_accept_length = server_info.json()["internal_states"][0][ "avg_spec_accept_length" ] @@ -158,22 +159,24 @@ def test_gsm8k(self): requests.get(self.base_url + "/flush_cache") args = SimpleNamespace( + base_url=self.base_url, + model=self.model, + eval_name="gsm8k", + api="completion", + max_tokens=512, + num_examples=100, + num_threads=128, num_shots=4, - num_questions=100, - max_new_tokens=512, - parallel=128, - host="http://127.0.0.1", - port=int(self.base_url.split(":")[-1]), - data_path=GSM_DATASET_PATH, + gsm8k_data_path=GSM_DATASET_PATH, ) - metrics = run_eval_few_shot_gsm8k(args) + metrics = run_eval(args) print(f"{metrics=}") # Use the appropriate metric key based on the test class - metric_key = "accuracy" + metric_key = "score" self.assertGreater(metrics[metric_key], self.accuracy_threshold) - server_info = requests.get(self.base_url + "/get_server_info") + server_info = requests.get(self.base_url + "/server_info") avg_spec_accept_length = server_info.json()["internal_states"][0][ "avg_spec_accept_length" ] diff --git a/test/registered/spec/utils/test_build_eagle_tree.py b/test/registered/spec/utils/test_build_eagle_tree.py index f13ede2bf006..103349fcca0d 100644 --- a/test/registered/spec/utils/test_build_eagle_tree.py +++ b/test/registered/spec/utils/test_build_eagle_tree.py @@ -9,7 +9,7 @@ from sglang.srt.utils import get_device from sglang.test.ci.ci_register import register_amd_ci, register_cuda_ci -register_cuda_ci(est_time=3, suite="stage-b-test-1-gpu-small") +register_cuda_ci(est_time=6, suite="stage-b-test-1-gpu-small") register_amd_ci(est_time=3, suite="stage-b-test-1-gpu-small-amd") diff --git a/test/registered/spec/utils/test_ngram_corpus.py b/test/registered/spec/utils/test_ngram_corpus.py index 6f2427a40966..e8d9fc026beb 100644 --- a/test/registered/spec/utils/test_ngram_corpus.py +++ b/test/registered/spec/utils/test_ngram_corpus.py @@ -6,14 +6,12 @@ from sglang.test.ci.ci_register import register_cuda_ci from sglang.test.test_utils import CustomTestCase -register_cuda_ci(est_time=30, suite="stage-b-test-1-gpu-small") +register_cuda_ci(est_time=8, suite="stage-b-test-1-gpu-small") def _make_corpus(match_type="BFS", **kwargs): defaults = dict( max_trie_depth=12, - min_match_window_size=1, - max_match_window_size=10, min_bfs_breadth=1, max_bfs_breadth=8, draft_token_num=8, @@ -239,9 +237,7 @@ def test_small_capacity_does_not_crash(self): self.assertEqual(len(ids), 8, "Should still produce draft_token_num outputs") def test_eviction_preserves_recent(self): - corpus = _make_corpus( - "BFS", capacity=500, max_trie_depth=6, max_match_window_size=5 - ) + corpus = _make_corpus("BFS", capacity=500, max_trie_depth=6) old_seq = list(range(1000, 1050)) corpus.batch_put([old_seq]) @@ -357,7 +353,6 @@ def test_repeated_insert_promotes_token(self): draft_token_num=2, max_bfs_breadth=1, min_bfs_breadth=1, - max_match_window_size=3, max_trie_depth=5, ) corpus.batch_put([[1, 2, 3, 10, 11]]) @@ -386,7 +381,6 @@ def test_most_recent_insert_selected(self): draft_token_num=2, max_bfs_breadth=1, min_bfs_breadth=1, - max_match_window_size=3, max_trie_depth=5, ) corpus.batch_put([[1, 2, 3, 10, 11]]) @@ -422,7 +416,7 @@ class TestSingleTokenContext(CustomTestCase): """Verify behavior with minimum-length context.""" def test_single_token_query(self): - corpus = _make_corpus("BFS", min_match_window_size=1) + corpus = _make_corpus("BFS") corpus.batch_put([[5, 10, 20, 30]]) corpus.synchronize() @@ -436,7 +430,7 @@ class TestLongContext(CustomTestCase): """Verify behavior when query context exceeds max_trie_depth.""" def test_context_longer_than_max_trie_depth(self): - corpus = _make_corpus("BFS", max_trie_depth=6, max_match_window_size=5) + corpus = _make_corpus("BFS", max_trie_depth=6) seq = list(range(1, 20)) corpus.batch_put([seq]) corpus.synchronize() @@ -447,6 +441,23 @@ def test_context_longer_than_max_trie_depth(self): self.assertEqual(ids_list[0], 15, "First token should be last context token") self.assertIn(16, ids_list, "Should match via suffix despite long context") + def test_matches_longest_stored_suffix(self): + corpus = _make_corpus("BFS", max_trie_depth=6, draft_token_num=4) + corpus.batch_put([[1, 2, 3, 4, 5, 6, 7]]) + corpus.batch_put([[99, 3, 4, 5, 6, 8]]) + corpus.synchronize() + + ids, _ = corpus.batch_get([[2, 3, 4, 5, 6]]) + ids_list = ids.tolist() + self.assertIn( + 7, ids_list, "Longest stored suffix should contribute a continuation" + ) + self.assertIn( + 8, + ids_list, + "Shorter matching suffixes should still contribute continuations", + ) + class TestDraftBudgetSaturation(CustomTestCase): """Verify the draft tree uses exactly draft_token_num slots.""" @@ -469,41 +480,39 @@ def test_full_budget_used(self): class TestTruncate(CustomTestCase): - """Verify the Result.truncate method via the Python binding.""" + """Verify truncation logic on batch_get output.""" def test_truncate_reduces_output(self): corpus = _make_corpus("BFS", draft_token_num=8) corpus.batch_put(SEED_SEQUENCES) corpus.synchronize() - result = corpus._ngram.batchMatch([[1, 2, 3]]) - original_len = len(result.token) - self.assertEqual(original_len, 8) + ids, masks = corpus.batch_get([[1, 2, 3]]) + ids = ids.reshape(8) + self.assertEqual(len(ids), 8) - result.truncate(4) - self.assertEqual(len(result.token), 4) - self.assertEqual(len(result.mask), 4 * 4) + # Simulate truncate to 4 + trunc_n = 4 + trunc_ids = ids[:trunc_n] + self.assertEqual(len(trunc_ids), trunc_n) def test_truncate_preserves_mask_structure(self): corpus = _make_corpus("BFS", draft_token_num=8) corpus.batch_put(SEED_SEQUENCES) corpus.synchronize() - result = corpus._ngram.batchMatch([[1, 2, 3]]) - full_ids = list(result.token) - full_mask = list(result.mask) - n = len(full_ids) + ids, masks = corpus.batch_get([[1, 2, 3]]) + n = 8 + full_mask = masks.reshape(n, n) - result_copy = corpus._ngram.batchMatch([[1, 2, 3]]) trunc_n = 4 - result_copy.truncate(trunc_n) - trunc_mask = list(result_copy.mask) + trunc_mask = full_mask[:trunc_n, :trunc_n] for i in range(trunc_n): for j in range(trunc_n): self.assertEqual( - trunc_mask[i * trunc_n + j], - full_mask[i * n + j], + trunc_mask[i, j], + full_mask[i, j], f"Mask mismatch at ({i},{j})", ) @@ -538,9 +547,7 @@ class TestSqueezeEvictsOld(CustomTestCase): """Verify that squeeze actually evicts old data, not just preserves recent.""" def test_old_data_evicted(self): - corpus = _make_corpus( - "BFS", capacity=150, max_trie_depth=6, max_match_window_size=5 - ) + corpus = _make_corpus("BFS", capacity=150, max_trie_depth=6) old_seq = list(range(5000, 5030)) corpus.batch_put([old_seq]) diff --git a/test/registered/unit/README.md b/test/registered/unit/README.md index 7d5b36fd04d0..f5023b43194a 100644 --- a/test/registered/unit/README.md +++ b/test/registered/unit/README.md @@ -32,7 +32,9 @@ Tests can use CPU or GPU — the key criterion is **no server process**. diff-cover coverage.xml --compare-branch=origin/main --fail-under=60 ``` -## Example +## Examples + +### Basic unit test ```python """Unit tests for — no server, no model loading.""" @@ -57,6 +59,30 @@ if __name__ == "__main__": unittest.main() ``` +### Stubbing GPU-only imports for CPU tests + +Some modules (e.g. `scheduler.py`, `io_struct.py`) transitively import packages like +`sgl_kernel` that require a GPU to initialize. To run pure-mock tests against these +modules on CPU-only CI, stub the problematic package **before** importing it. + +`maybe_stub_sgl_kernel()` in `test_utils.py` does this for `sgl_kernel`: it's a no-op +on GPU machines, and on CPU it installs a `sys.meta_path` finder that auto-creates empty +stub modules for all `sgl_kernel.*` submodules. + +```python +from sglang.test.ci.ci_register import register_cpu_ci +from sglang.test.test_utils import maybe_stub_sgl_kernel + +maybe_stub_sgl_kernel() # must precede any import that pulls in sgl_kernel + +from sglang.srt.managers.io_struct import FlushCacheReqInput +from sglang.srt.managers.scheduler import Scheduler + +register_cpu_ci(est_time=2, suite="stage-a-test-cpu") +``` + +The same pattern can be applied to other GPU-only packages: try importing the real package, and if it fails, register a `sys.meta_path` finder that stubs it. See `maybe_stub_sgl_kernel()` in `python/sglang/test/test_utils.py` for the implementation. + ## Rules - **No** `popen_launch_server()` or `Engine(...)`. diff --git a/test/registered/unit/function_call/test_function_call_parser.py b/test/registered/unit/function_call/test_function_call_parser.py index c7b3fb173156..c418b0866d0e 100644 --- a/test/registered/unit/function_call/test_function_call_parser.py +++ b/test/registered/unit/function_call/test_function_call_parser.py @@ -18,7 +18,7 @@ from sglang.srt.function_call.qwen3_coder_detector import Qwen3CoderDetector from sglang.test.ci.ci_register import register_cpu_ci -register_cpu_ci(1.0, "stage-a-test-cpu") +register_cpu_ci(15, "stage-a-test-cpu") class TestPythonicDetector(unittest.TestCase): @@ -3853,5 +3853,160 @@ def test_streaming_function_call_marker_json_split_at_quotes(self): self.assertEqual(params["city"], "Rome") +class TestQwen25Detector(unittest.TestCase): + """Test Qwen25Detector streaming and non-streaming multi-tool-call parsing.""" + + def setUp(self): + from sglang.srt.function_call.qwen25_detector import Qwen25Detector + + self.detector = Qwen25Detector() + self.tools = [ + Tool( + type="function", + function=Function( + name="get_current_weather", + description="Get the current weather in a given location", + parameters={ + "type": "object", + "properties": { + "city": { + "type": "string", + "description": "The city name", + }, + "state": { + "type": "string", + "description": "Two-letter state abbreviation", + }, + "unit": { + "type": "string", + "enum": ["celsius", "fahrenheit"], + }, + }, + "required": ["city", "state", "unit"], + }, + ), + ), + ] + + # -- Non-streaming tests -- + + def test_detect_and_parse_single_tool_call(self): + text = '\n{"name": "get_current_weather", "arguments": {"city": "NYC", "state": "NY", "unit": "fahrenheit"}}\n' + result = self.detector.detect_and_parse(text, self.tools) + self.assertEqual(len(result.calls), 1) + self.assertEqual(result.calls[0].name, "get_current_weather") + params = json.loads(result.calls[0].parameters) + self.assertEqual(params["city"], "NYC") + + def test_detect_and_parse_multiple_tool_calls(self): + text = ( + '\n{"name": "get_current_weather", "arguments": {"city": "NYC", "state": "NY", "unit": "fahrenheit"}}\n\n' + '\n{"name": "get_current_weather", "arguments": {"city": "Baltimore", "state": "MD", "unit": "fahrenheit"}}\n\n' + '\n{"name": "get_current_weather", "arguments": {"city": "Minneapolis", "state": "MN", "unit": "fahrenheit"}}\n\n' + '\n{"name": "get_current_weather", "arguments": {"city": "Los Angeles", "state": "CA", "unit": "fahrenheit"}}\n' + ) + result = self.detector.detect_and_parse(text, self.tools) + self.assertEqual(len(result.calls), 4) + cities = [json.loads(c.parameters)["city"] for c in result.calls] + self.assertEqual(cities, ["NYC", "Baltimore", "Minneapolis", "Los Angeles"]) + + def test_detect_and_parse_with_normal_text_prefix(self): + text = ( + "Sure, let me check the weather.\n" + '\n{"name": "get_current_weather", "arguments": {"city": "NYC", "state": "NY", "unit": "celsius"}}\n' + ) + result = self.detector.detect_and_parse(text, self.tools) + self.assertEqual(len(result.calls), 1) + self.assertIn("let me check", result.normal_text) + + # -- Streaming tests -- + + def _collect_streaming_tool_calls(self, chunks): + """Helper: feed chunks through streaming parser and collect tool calls by index.""" + tool_calls_by_index = {} + for chunk in chunks: + result = self.detector.parse_streaming_increment(chunk, self.tools) + for call in result.calls: + if call.tool_index is not None: + if call.tool_index not in tool_calls_by_index: + tool_calls_by_index[call.tool_index] = { + "name": "", + "parameters": "", + } + if call.name: + tool_calls_by_index[call.tool_index]["name"] = call.name + if call.parameters: + tool_calls_by_index[call.tool_index][ + "parameters" + ] += call.parameters + return tool_calls_by_index + + def test_streaming_single_tool_call(self): + chunks = [ + "\n", + '{"name": "get_current_weather",', + ' "arguments": {"city": "NYC",', + ' "state": "NY",', + ' "unit": "fahrenheit"}}', + "\n", + ] + result = self._collect_streaming_tool_calls(chunks) + self.assertEqual(len(result), 1) + self.assertEqual(result[0]["name"], "get_current_weather") + params = json.loads(result[0]["parameters"]) + self.assertEqual(params["city"], "NYC") + + def test_streaming_multiple_tool_calls(self): + """Core regression test: multiple tool calls must all be parsed in streaming mode.""" + chunks = [ + "\n", + '{"name": "get_current_weather",', + ' "arguments": {"city": "NYC", "state": "NY", "unit": "fahrenheit"}}', + "\n\n", + "\n", + '{"name": "get_current_weather",', + ' "arguments": {"city": "Baltimore", "state": "MD", "unit": "fahrenheit"}}', + "\n\n", + "\n", + '{"name": "get_current_weather",', + ' "arguments": {"city": "LA", "state": "CA", "unit": "fahrenheit"}}', + "\n", + ] + result = self._collect_streaming_tool_calls(chunks) + self.assertEqual(len(result), 3, f"Expected 3 tool calls, got {len(result)}") + cities = [json.loads(result[i]["parameters"])["city"] for i in sorted(result)] + self.assertEqual(cities, ["NYC", "Baltimore", "LA"]) + + def test_streaming_multiple_tool_calls_fused_chunks(self): + """Test when separator and next bot_token arrive in a single chunk.""" + chunks = [ + '\n{"name": "get_current_weather", "arguments": {"city": "NYC", "state": "NY", "unit": "fahrenheit"}}', + '\n\n\n{"name": "get_current_weather",', + ' "arguments": {"city": "LA", "state": "CA", "unit": "fahrenheit"}}', + "\n", + ] + result = self._collect_streaming_tool_calls(chunks) + self.assertEqual(len(result), 2, f"Expected 2 tool calls, got {len(result)}") + cities = [json.loads(result[i]["parameters"])["city"] for i in sorted(result)] + self.assertEqual(cities, ["NYC", "LA"]) + + def test_streaming_multiple_tool_calls_char_by_char_separator(self): + """Test when the separator between tool calls arrives character by character.""" + call1 = '{"name": "get_current_weather", "arguments": {"city": "NYC", "state": "NY", "unit": "fahrenheit"}}' + call2 = '{"name": "get_current_weather", "arguments": {"city": "LA", "state": "CA", "unit": "celsius"}}' + separator = "\n\n\n" + + chunks = ["\n", call1] + for ch in separator: + chunks.append(ch) + chunks.append(call2) + chunks.append("\n") + + result = self._collect_streaming_tool_calls(chunks) + self.assertEqual(len(result), 2, f"Expected 2 tool calls, got {len(result)}") + cities = [json.loads(result[i]["parameters"])["city"] for i in sorted(result)] + self.assertEqual(cities, ["NYC", "LA"]) + + if __name__ == "__main__": unittest.main() diff --git a/test/registered/unit/function_call/test_glm47_moe_detector.py b/test/registered/unit/function_call/test_glm47_moe_detector.py index 357514e19e85..e0c1921192b0 100644 --- a/test/registered/unit/function_call/test_glm47_moe_detector.py +++ b/test/registered/unit/function_call/test_glm47_moe_detector.py @@ -10,7 +10,7 @@ ) from sglang.test.ci.ci_register import register_cpu_ci -register_cpu_ci(1.0, "stage-a-test-cpu") +register_cpu_ci(5, "stage-a-test-cpu") class TestGlm47MoeDetector(unittest.TestCase): diff --git a/test/registered/unit/function_call/test_json_schema_constraint.py b/test/registered/unit/function_call/test_json_schema_constraint.py index bc6a9fa13913..f66943718298 100644 --- a/test/registered/unit/function_call/test_json_schema_constraint.py +++ b/test/registered/unit/function_call/test_json_schema_constraint.py @@ -18,7 +18,7 @@ ) from sglang.test.ci.ci_register import register_cpu_ci -register_cpu_ci(1.0, "stage-a-test-cpu") +register_cpu_ci(5, "stage-a-test-cpu") class TestJsonSchemaConstraint(unittest.TestCase): @@ -102,7 +102,7 @@ def test_specific_tool_choice_schema(self): self.assertEqual(schema["type"], "array") self.assertEqual(schema["minItems"], 1) - self.assertEqual(schema["maxItems"], 1) + self.assertNotIn("maxItems", schema) # Should only have schema for the specific tool item_schema = schema["items"] @@ -121,13 +121,72 @@ def test_specific_tool_choice_dict_schema(self): self.assertEqual(schema["type"], "array") self.assertEqual(schema["minItems"], 1) - self.assertEqual(schema["maxItems"], 1) + self.assertNotIn("maxItems", schema) # Should only have schema for the specific tool item_schema = schema["items"] self.assertEqual(item_schema["properties"]["name"]["enum"], ["search"]) self.assertIn("parameters", item_schema["properties"]) + def test_specific_tool_choice_allows_multiple_calls(self): + """Test that specific tool choice schema allows multiple calls. + + Regression test for https://github.com/sgl-project/sglang/issues/17998: + maxItems: 1 caused the model to stall on whitespace when the prompt + implied multiple calls to the same function. + """ + tool_choice = ToolChoice( + type="function", function=ToolChoiceFuncName(name="get_weather") + ) + schema = get_json_schema_constraint(self.tools, tool_choice) + + single_call = [ + {"name": "get_weather", "parameters": {"location": "NYC"}}, + ] + multi_call = [ + {"name": "get_weather", "parameters": {"location": "NYC"}}, + {"name": "get_weather", "parameters": {"location": "LA"}}, + {"name": "get_weather", "parameters": {"location": "Chicago"}}, + ] + + validator = jsonschema.Draft202012Validator(schema) + validator.validate(single_call) + validator.validate(multi_call) + + def test_specific_tool_choice_no_parallel(self): + """Test that parallel_tool_calls=False sets maxItems=1""" + tool_choice = ToolChoice( + type="function", function=ToolChoiceFuncName(name="get_weather") + ) + schema = get_json_schema_constraint( + self.tools, tool_choice, parallel_tool_calls=False + ) + + self.assertIsNotNone(schema) + self.assertEqual(schema["maxItems"], 1) + + single_call = [ + {"name": "get_weather", "parameters": {"location": "NYC"}}, + ] + multi_call = [ + {"name": "get_weather", "parameters": {"location": "NYC"}}, + {"name": "get_weather", "parameters": {"location": "LA"}}, + ] + + validator = jsonschema.Draft202012Validator(schema) + validator.validate(single_call) + with self.assertRaises(jsonschema.ValidationError): + validator.validate(multi_call) + + def test_required_tool_choice_no_parallel(self): + """Test that required + parallel_tool_calls=False sets maxItems=1""" + schema = get_json_schema_constraint( + self.tools, "required", parallel_tool_calls=False + ) + + self.assertIsNotNone(schema) + self.assertEqual(schema["maxItems"], 1) + def test_nonexistent_tool_choice(self): """Test schema generation for nonexistent tool""" tool_choice = ToolChoice( diff --git a/test/registered/unit/function_call/test_parallel_tool_calls.py b/test/registered/unit/function_call/test_parallel_tool_calls.py index bf1e18a7baa8..2f5af4d9ae00 100644 --- a/test/registered/unit/function_call/test_parallel_tool_calls.py +++ b/test/registered/unit/function_call/test_parallel_tool_calls.py @@ -23,7 +23,7 @@ from sglang.srt.function_call.json_array_parser import JsonArrayParser from sglang.test.ci.ci_register import register_cpu_ci -register_cpu_ci(1.0, "stage-a-test-cpu") +register_cpu_ci(5, "stage-a-test-cpu") class TestParallelToolCalls(unittest.TestCase): diff --git a/test/registered/unit/function_call/test_unknown_tool_name.py b/test/registered/unit/function_call/test_unknown_tool_name.py index e7a8394ff2cb..ce96c698a0cd 100644 --- a/test/registered/unit/function_call/test_unknown_tool_name.py +++ b/test/registered/unit/function_call/test_unknown_tool_name.py @@ -9,7 +9,7 @@ from sglang.srt.function_call.core_types import StreamingParseResult from sglang.test.ci.ci_register import register_cpu_ci -register_cpu_ci(1.0, "stage-a-test-cpu") +register_cpu_ci(5, "stage-a-test-cpu") class DummyDetector(BaseFormatDetector): diff --git a/test/registered/unit/managers/test_prefill_adder.py b/test/registered/unit/managers/test_prefill_adder.py index e153d9d9c635..2a111171263e 100644 --- a/test/registered/unit/managers/test_prefill_adder.py +++ b/test/registered/unit/managers/test_prefill_adder.py @@ -12,7 +12,7 @@ from sglang.test.ci.ci_register import register_amd_ci, register_cuda_ci from sglang.test.test_utils import CustomTestCase -register_cuda_ci(est_time=1, suite="stage-b-test-1-gpu-small") +register_cuda_ci(est_time=8, suite="stage-b-test-1-gpu-small") register_amd_ci(est_time=2, suite="stage-b-test-1-gpu-small-amd") diff --git a/test/registered/unit/managers/test_scheduler_flush_cache.py b/test/registered/unit/managers/test_scheduler_flush_cache.py index 76c610ea4fba..828854ed487e 100644 --- a/test/registered/unit/managers/test_scheduler_flush_cache.py +++ b/test/registered/unit/managers/test_scheduler_flush_cache.py @@ -1,11 +1,15 @@ import unittest from unittest.mock import MagicMock, patch +from sglang.test.ci.ci_register import register_cpu_ci +from sglang.test.test_utils import maybe_stub_sgl_kernel + +maybe_stub_sgl_kernel() + from sglang.srt.managers.io_struct import FlushCacheReqInput from sglang.srt.managers.scheduler import Scheduler -from sglang.test.ci.ci_register import register_cpu_ci -register_cpu_ci(est_time=2, suite="stage-a-cpu-only") +register_cpu_ci(est_time=2, suite="stage-a-test-cpu") class TestSchedulerFlushCache(unittest.TestCase): diff --git a/test/registered/unit/managers/test_scheduler_pause_generation.py b/test/registered/unit/managers/test_scheduler_pause_generation.py new file mode 100644 index 000000000000..210ba0aa6627 --- /dev/null +++ b/test/registered/unit/managers/test_scheduler_pause_generation.py @@ -0,0 +1,134 @@ +import unittest +from collections import deque +from unittest.mock import MagicMock + +from sglang.test.ci.ci_register import register_cpu_ci +from sglang.test.test_utils import maybe_stub_sgl_kernel + +maybe_stub_sgl_kernel() + +from sglang.srt.managers.io_struct import PauseGenerationReqInput +from sglang.srt.managers.scheduler import Scheduler + +register_cpu_ci(est_time=2, suite="stage-a-test-cpu") + + +class TestSchedulerPauseGeneration(unittest.TestCase): + def _new_scheduler(self) -> Scheduler: + scheduler = Scheduler.__new__(Scheduler) + scheduler._engine_paused = False + scheduler.enable_overlap = False + scheduler.last_batch = None + scheduler.cur_batch = None + scheduler.chunked_req = None + scheduler.running_batch = MagicMock() + scheduler.running_batch.reqs = [] + scheduler.running_batch.is_empty.return_value = True + scheduler.running_batch.batch_is_full = False + scheduler.tree_cache = MagicMock() + scheduler.tree_cache.protected_size.return_value = 0 + scheduler.req_to_token_pool = MagicMock() + scheduler.result_queue = deque() + # Support _kv_snap diagnostic logging in patched schedulers + scheduler.token_to_kv_pool_allocator = MagicMock() + scheduler.token_to_kv_pool_allocator.available_size.return_value = 1000 + scheduler.max_total_num_tokens = 1000 + scheduler._get_token_info = MagicMock(return_value=(0, 0, 1000, 0)) + return scheduler + + def test_inplace_only_sets_flag(self): + """in_place pause should only set _engine_paused and return.""" + scheduler = self._new_scheduler() + scheduler.last_batch = MagicMock() + scheduler.cur_batch = MagicMock() + scheduler.chunked_req = MagicMock() + + original_last_batch = scheduler.last_batch + original_cur_batch = scheduler.cur_batch + original_chunked_req = scheduler.chunked_req + + scheduler.pause_generation(PauseGenerationReqInput(mode="in_place")) + + self.assertTrue(scheduler._engine_paused) + # All state must be preserved — no mutation + self.assertIs(scheduler.last_batch, original_last_batch) + self.assertIs(scheduler.cur_batch, original_cur_batch) + self.assertIs(scheduler.chunked_req, original_chunked_req) + + def test_inplace_does_not_drain_overlap_queue(self): + """in_place should not process the overlap result_queue.""" + scheduler = self._new_scheduler() + scheduler.enable_overlap = True + scheduler.last_batch = MagicMock() + scheduler.result_queue = deque([(MagicMock(), MagicMock())]) + + scheduler.pause_generation(PauseGenerationReqInput(mode="in_place")) + + self.assertTrue(scheduler._engine_paused) + self.assertEqual(len(scheduler.result_queue), 1) + + def test_inplace_does_not_merge_batch(self): + """in_place should not filter or merge last_batch into running_batch.""" + scheduler = self._new_scheduler() + last_batch = MagicMock() + last_batch.forward_mode.is_extend.return_value = True + scheduler.last_batch = last_batch + + scheduler.pause_generation(PauseGenerationReqInput(mode="in_place")) + + last_batch.filter_batch.assert_not_called() + scheduler.running_batch.merge_batch.assert_not_called() + + def test_abort_clears_state(self): + """abort mode should clear last_batch and cur_batch.""" + scheduler = self._new_scheduler() + scheduler.last_batch = MagicMock() + scheduler.last_batch.forward_mode.is_extend.return_value = False + scheduler.cur_batch = MagicMock() + + scheduler.pause_generation(PauseGenerationReqInput(mode="abort")) + + self.assertTrue(scheduler._engine_paused) + self.assertIsNone(scheduler.last_batch) + self.assertIsNone(scheduler.cur_batch) + + def test_retract_clears_running_batch(self): + """retract mode should retract all requests from running_batch.""" + scheduler = self._new_scheduler() + scheduler.last_batch = None + scheduler.running_batch.reqs = [MagicMock(), MagicMock()] + scheduler.running_batch.__len__ = lambda self: len(self.reqs) + scheduler.running_batch.is_empty.return_value = False + scheduler.waiting_queue = [] + scheduler._add_request_to_queue = MagicMock() + + retracted = [MagicMock(), MagicMock()] + scheduler.running_batch.retract_all.return_value = retracted + scheduler.running_batch.filter_batch = MagicMock() + scheduler.server_args = MagicMock() + + scheduler.pause_generation(PauseGenerationReqInput(mode="retract")) + + self.assertTrue(scheduler._engine_paused) + scheduler.running_batch.retract_all.assert_called_once() + self.assertEqual(scheduler._add_request_to_queue.call_count, 2) + self.assertIsNone(scheduler.chunked_req) + + def test_abort_drains_overlap_queue(self): + """abort with overlap enabled should drain the result_queue.""" + scheduler = self._new_scheduler() + scheduler.enable_overlap = True + mock_batch = MagicMock() + mock_batch.forward_mode.is_extend.return_value = False + scheduler.last_batch = mock_batch + scheduler.result_queue = deque([(MagicMock(), MagicMock())]) + scheduler.process_batch_result = MagicMock() + + scheduler.pause_generation(PauseGenerationReqInput(mode="abort")) + + scheduler.process_batch_result.assert_called_once() + self.assertEqual(len(scheduler.result_queue), 0) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/registered/unit/mem_cache/test_evict_policy.py b/test/registered/unit/mem_cache/test_evict_policy.py index 0dd2eed5d9b5..c59af743f9e7 100644 --- a/test/registered/unit/mem_cache/test_evict_policy.py +++ b/test/registered/unit/mem_cache/test_evict_policy.py @@ -2,7 +2,7 @@ from sglang.test.ci.ci_register import register_cpu_ci -register_cpu_ci(est_time=5, suite="stage-a-test-cpu") +register_cpu_ci(est_time=2, suite="stage-a-test-cpu") import unittest from unittest.mock import MagicMock diff --git a/test/registered/unit/mem_cache/test_mamba_unittest.py b/test/registered/unit/mem_cache/test_mamba_unittest.py index 6955724991ec..c6219a46dee5 100644 --- a/test/registered/unit/mem_cache/test_mamba_unittest.py +++ b/test/registered/unit/mem_cache/test_mamba_unittest.py @@ -99,6 +99,7 @@ def test_mamba_pool(self): device=device, enable_memory_saver=False, cache_params=mamba2_cache_params, + mamba_layer_ids=mamba_layers, enable_mamba_extra_buffer=False, speculative_num_draft_tokens=3, ) @@ -340,6 +341,7 @@ def _setup_tree_and_allocator(self): device=device, enable_memory_saver=False, cache_params=mamba2_cache_params, + mamba_layer_ids=mamba_layers, enable_mamba_extra_buffer=False, speculative_num_draft_tokens=3, ) diff --git a/test/registered/unit/mem_cache/test_nsa_pool_host_unit.py b/test/registered/unit/mem_cache/test_nsa_pool_host_unit.py index f75819557e1f..64ca27493081 100644 --- a/test/registered/unit/mem_cache/test_nsa_pool_host_unit.py +++ b/test/registered/unit/mem_cache/test_nsa_pool_host_unit.py @@ -11,7 +11,7 @@ from sglang.srt.utils import is_cuda, is_hip, is_npu, is_xpu from sglang.test.ci.ci_register import register_cuda_ci -register_cuda_ci(est_time=3, suite="stage-b-test-1-gpu-small") +register_cuda_ci(est_time=10, suite="stage-b-test-1-gpu-small") class TestNSAHiCacheTransfer(unittest.TestCase): diff --git a/test/registered/unit/mem_cache/test_radix_cache_unit.py b/test/registered/unit/mem_cache/test_radix_cache_unit.py index 33dc31f41c64..f7b8a43cea6a 100644 --- a/test/registered/unit/mem_cache/test_radix_cache_unit.py +++ b/test/registered/unit/mem_cache/test_radix_cache_unit.py @@ -21,7 +21,7 @@ from sglang.test.ci.ci_register import register_amd_ci, register_cuda_ci # CPU-based unit test, runs quickly on any GPU runner -register_cuda_ci(est_time=5, suite="stage-b-test-1-gpu-small") +register_cuda_ci(est_time=20, suite="stage-b-test-1-gpu-small") register_amd_ci(est_time=5, suite="stage-b-test-1-gpu-small-amd") import random diff --git a/test/registered/unit/model_loader/test_modelopt_loader.py b/test/registered/unit/model_loader/test_modelopt_loader.py index 7f9652c0e5db..9ad6183a0b0a 100644 --- a/test/registered/unit/model_loader/test_modelopt_loader.py +++ b/test/registered/unit/model_loader/test_modelopt_loader.py @@ -646,7 +646,11 @@ def test_mixed_precision_override_does_not_hijack_w4afp8(self): ) def test_mixed_precision_uses_nvfp4_min_capability(self): - self.assertEqual(ModelOptMixedPrecisionConfig.get_min_capability(), 100) + """NVFP4 supports SM75+ (Turing) via Marlin fallback; min_capability must be >= 75.""" + cap = ModelOptMixedPrecisionConfig.get_min_capability() + self.assertGreaterEqual( + cap, 75, f"NVFP4 requires SM75+ (Marlin fallback); got min_capability={cap}" + ) def test_mixed_precision_quant_layer_resolution_after_mapping(self): quant_config = ModelOptMixedPrecisionConfig.from_config( diff --git a/test/registered/unit/models/test_llava.py b/test/registered/unit/models/test_llava.py new file mode 100644 index 000000000000..a929dcfbd811 --- /dev/null +++ b/test/registered/unit/models/test_llava.py @@ -0,0 +1,91 @@ +import unittest +from unittest.mock import patch + +from sglang.srt.models.llava import AutoModel, LlavaForConditionalGeneration +from sglang.test.ci.ci_register import register_cuda_ci +from sglang.test.test_utils import CustomTestCase + +register_cuda_ci(est_time=1, suite="stage-b-test-1-gpu-small") + + +class PixtralVisionConfig: + pass + + +class VoxtralRealtimeTextConfig: + pass + + +class GoodConfig: + pass + + +class PixtralVisionModel: + pass + + +class GoodArch: + pass + + +class FakeMapping: + def __init__(self, voxtral_error): + self.voxtral_error = voxtral_error + + def keys(self): + return [VoxtralRealtimeTextConfig, PixtralVisionConfig, GoodConfig] + + def get(self, config_cls, default=None): + if config_cls is VoxtralRealtimeTextConfig: + raise self.voxtral_error + if config_cls is PixtralVisionConfig: + return (PixtralVisionModel,) + if config_cls is GoodConfig: + return GoodArch + return default + + +KNOWN_VOXTRAL_ERROR = ValueError( + "Could not find VoxtralRealtimeTextModel neither in " + " nor in " + "!" +) + + +class TestLlavaForConditionalGeneration(CustomTestCase): + def setUp(self): + LlavaForConditionalGeneration._config_cls_name_to_arch_name_mapping.cache_clear() + + def _build_mapping(self, mapping): + with patch.object(AutoModel, "_model_mapping", mapping): + llava_model = object.__new__(LlavaForConditionalGeneration) + return llava_model._config_cls_name_to_arch_name_mapping(AutoModel) + + @patch("sglang.srt.models.llava.logger.warning") + def test_skip_known_broken_voxtral_automodel_mapping_entry(self, mock_warning): + mapping = self._build_mapping(FakeMapping(KNOWN_VOXTRAL_ERROR)) + + self.assertEqual(mapping[GoodConfig.__name__], GoodArch.__name__) + self.assertEqual( + mapping[PixtralVisionConfig.__name__], (PixtralVisionModel.__name__,) + ) + self.assertNotIn(VoxtralRealtimeTextConfig.__name__, mapping) + + mock_warning.assert_called_once() + self.assertEqual( + mock_warning.call_args.args, + ( + "Skipping broken %s mapping for config %s: %s", + AutoModel.__name__, + VoxtralRealtimeTextConfig.__name__, + unittest.mock.ANY, + ), + ) + + def test_other_voxtral_mapping_failures_still_raise(self): + with self.assertRaisesRegex(ValueError, "some other failure"): + self._build_mapping(FakeMapping(ValueError("some other failure"))) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/registered/unit/observability/test_request_metrics_exporter.py b/test/registered/unit/observability/test_request_metrics_exporter.py index f808f3f7c682..22a4496fa394 100644 --- a/test/registered/unit/observability/test_request_metrics_exporter.py +++ b/test/registered/unit/observability/test_request_metrics_exporter.py @@ -58,6 +58,7 @@ def __init__(self, **kwargs): import unittest from unittest.mock import MagicMock, patch +from sglang.srt.constants import HEALTH_CHECK_RID_PREFIX from sglang.srt.observability.request_metrics_exporter import ( FileRequestMetricsExporter, RequestMetricsExporter, @@ -243,7 +244,7 @@ def test_write_record(self): def test_write_record_skips_health_check(self): exporter = self._make_exporter() - obj = _GenerateReqInput(rid="HEALTH_CHECK_123", text="ping") + obj = _GenerateReqInput(rid=f"{HEALTH_CHECK_RID_PREFIX}_123", text="ping") asyncio.run(exporter.write_record(obj, {})) files = os.listdir(self.tmp_dir) diff --git a/test/registered/unit/sampling/test_sampling_batch_info.py b/test/registered/unit/sampling/test_sampling_batch_info.py index bc923932cd87..7b018915381f 100644 --- a/test/registered/unit/sampling/test_sampling_batch_info.py +++ b/test/registered/unit/sampling/test_sampling_batch_info.py @@ -142,10 +142,10 @@ def test_lhs_none_rhs_present(self): # apply_logits_bias class TestApplyLogitsBias(CustomTestCase): - def test_applies_linear_penalties(self): - """Test that pre-accumulated linear penalties are added to logits.""" + def test_applies_additive_penalties(self): + """Test that pre-accumulated additive penalties are added to logits.""" info = _make_info(batch_size=1) - info.acc_linear_penalties = torch.tensor([[-1.0] * VOCAB_SIZE]) + info.acc_additive_penalties = torch.tensor([[-1.0] * VOCAB_SIZE]) logits = torch.zeros(1, VOCAB_SIZE) info.apply_logits_bias(logits) self.assertAlmostEqual(logits[0, 0].item(), -1.0, places=5) @@ -181,7 +181,7 @@ def test_applies_penalizer_orchestrator(self): def test_no_bias_no_change(self): """Test that logits stay unchanged when no bias sources are set.""" info = _make_info(batch_size=1) - info.acc_linear_penalties = None + info.acc_additive_penalties = None info.logit_bias = None info.vocab_mask = None logits = torch.zeros(1, VOCAB_SIZE) @@ -194,20 +194,24 @@ def test_no_bias_no_change(self): class TestUpdatePenalties(CustomTestCase): def test_required_creates_penalties_tensor(self): - """Test that update_penalties allocates a zero tensor and calls orchestrator.apply.""" + """Test that update_penalties allocates a zero tensor and calls orchestrator methods.""" orch = MagicMock(is_required=True) + orch.accumulate_scaling_penalties.return_value = None info = _make_info(batch_size=2, penalizer_orchestrator=orch) info.update_penalties() - self.assertIsNotNone(info.acc_linear_penalties) - self.assertEqual(info.acc_linear_penalties.shape, (2, VOCAB_SIZE)) - orch.apply.assert_called_once() + self.assertIsNotNone(info.acc_additive_penalties) + self.assertEqual(info.acc_additive_penalties.shape, (2, VOCAB_SIZE)) + orch.accumulate_additive_penalties.assert_called_once_with( + info.acc_additive_penalties + ) + orch.accumulate_scaling_penalties.assert_called_once() def test_not_required_sets_none(self): - """Test that update_penalties sets acc_linear_penalties to None when not required.""" + """Test that update_penalties sets acc_additive_penalties to None when not required.""" orch = MagicMock(is_required=False) info = _make_info(batch_size=2, penalizer_orchestrator=orch) info.update_penalties() - self.assertIsNone(info.acc_linear_penalties) + self.assertIsNone(info.acc_additive_penalties) # update_regex_vocab_mask diff --git a/test/registered/unit/server_args/test_server_args.py b/test/registered/unit/server_args/test_server_args.py index 1276493d3a95..ff381bd21883 100644 --- a/test/registered/unit/server_args/test_server_args.py +++ b/test/registered/unit/server_args/test_server_args.py @@ -10,7 +10,7 @@ CustomTestCase, ) -register_cpu_ci(est_time=1, suite="stage-a-test-cpu") +register_cpu_ci(est_time=10, suite="stage-a-test-cpu") # Mock get_device() so all tests run on CPU-only CI runners _mock_device = patch("sglang.srt.server_args.get_device", return_value="cuda") diff --git a/test/registered/unit/test_runai_utils.py b/test/registered/unit/test_runai_utils.py new file mode 100644 index 000000000000..4735b5d7f7ec --- /dev/null +++ b/test/registered/unit/test_runai_utils.py @@ -0,0 +1,57 @@ +import unittest +from pathlib import Path + +from sglang.srt.configs.load_config import LoadFormat +from sglang.srt.utils.runai_utils import ObjectStorageModel, is_runai_obj_uri +from sglang.test.ci.ci_register import register_cpu_ci +from sglang.test.test_utils import CustomTestCase + +register_cpu_ci(est_time=5, suite="stage-a-test-cpu") + + +class TestRunaiUtils(CustomTestCase): + def test_is_runai_obj_uri_s3(self): + self.assertTrue(is_runai_obj_uri("s3://bucket/model/")) + self.assertTrue(is_runai_obj_uri("S3://Bucket/Model/")) + + def test_is_runai_obj_uri_gs(self): + self.assertTrue(is_runai_obj_uri("gs://bucket/model/")) + self.assertTrue(is_runai_obj_uri("GS://Bucket/Model/")) + + def test_is_runai_obj_uri_az(self): + self.assertTrue(is_runai_obj_uri("az://container/model/")) + self.assertTrue(is_runai_obj_uri("AZ://Container/Model/")) + + def test_is_runai_obj_uri_local_paths(self): + self.assertFalse(is_runai_obj_uri("/path/to/model")) + self.assertFalse(is_runai_obj_uri("./relative/path")) + self.assertFalse(is_runai_obj_uri("meta-llama/Llama-3.2-1B")) + + def test_is_runai_obj_uri_other_schemes(self): + self.assertFalse(is_runai_obj_uri("http://example.com/model")) + self.assertFalse(is_runai_obj_uri("https://example.com/model")) + self.assertFalse(is_runai_obj_uri("ftp://example.com/model")) + + def test_is_runai_obj_uri_pathlib(self): + self.assertFalse(is_runai_obj_uri(Path("/local/model"))) + + def test_get_path_deterministic(self): + path1 = ObjectStorageModel.get_path("s3://bucket/model/") + path2 = ObjectStorageModel.get_path("s3://bucket/model/") + self.assertEqual(path1, path2) + + def test_get_path_different_uris(self): + path1 = ObjectStorageModel.get_path("s3://bucket/model-a/") + path2 = ObjectStorageModel.get_path("s3://bucket/model-b/") + self.assertNotEqual(path1, path2) + + def test_get_path_contains_model_streamer(self): + path = ObjectStorageModel.get_path("s3://bucket/model/") + self.assertIn("model_streamer", path) + + def test_load_format_enum(self): + self.assertEqual(LoadFormat.RUNAI_STREAMER.value, "runai_streamer") + + +if __name__ == "__main__": + unittest.main() diff --git a/test/registered/unit/utils/test_json_response.py b/test/registered/unit/utils/test_json_response.py index f201ae735485..accc71fc5ba7 100644 --- a/test/registered/unit/utils/test_json_response.py +++ b/test/registered/unit/utils/test_json_response.py @@ -10,7 +10,7 @@ ) from sglang.test.ci.ci_register import register_cpu_ci -register_cpu_ci(est_time=2, suite="stage-a-test-cpu") +register_cpu_ci(est_time=5, suite="stage-a-test-cpu") class TestJSONResponseUtils(unittest.TestCase): diff --git a/test/registered/unit/utils/test_subprocess_watchdog.py b/test/registered/unit/utils/test_subprocess_watchdog.py new file mode 100644 index 000000000000..075bec79d26f --- /dev/null +++ b/test/registered/unit/utils/test_subprocess_watchdog.py @@ -0,0 +1,137 @@ +# Copyright 2023-2024 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for SubprocessWatchdog in watchdog.py""" + +import multiprocessing as mp +import os +import signal +import threading +import time +import unittest.mock + +from sglang.srt.utils.watchdog import SubprocessWatchdog +from sglang.test.ci.ci_register import register_cpu_ci +from sglang.test.test_utils import CustomTestCase + +register_cpu_ci(est_time=10, suite="stage-a-cpu-only") + + +def healthy_worker(): + time.sleep(10) + + +def crashing_worker(): + os._exit(1) + + +def slow_crash_worker(delay: float = 0.5): + time.sleep(delay) + os._exit(42) + + +class TestSubprocessWatchdog(CustomTestCase): + def setUp(self): + self.sigquit_triggered = threading.Event() + self._procs = [] + self._monitor = None + + original_kill = os.kill + + def mock_kill(pid, sig): + if sig == signal.SIGQUIT: + self.sigquit_triggered.set() + else: + original_kill(pid, sig) + + self._patcher = unittest.mock.patch("os.kill", side_effect=mock_kill) + self._patcher.start() + + def tearDown(self): + if self._monitor is not None: + self._monitor.stop() + self._patcher.stop() + for p in self._procs: + if p.is_alive(): + p.terminate() + p.join(timeout=1) + + def _spawn(self, target, args=()): + proc = mp.Process(target=target, args=args) + proc.start() + self._procs.append(proc) + return proc + + def _watch(self, procs, names=None, interval=0.1): + if not isinstance(procs, list): + procs = [procs] + self._monitor = SubprocessWatchdog( + processes=procs, + process_names=names, + interval=interval, + ) + self._monitor.start() + return self._monitor + + def test_healthy_processes_no_sigquit(self): + proc = self._spawn(healthy_worker) + self._watch(proc) + time.sleep(0.5) + self.assertFalse(self.sigquit_triggered.is_set()) + + def test_crashed_process_triggers_sigquit(self): + proc = self._spawn(slow_crash_worker, args=(0.2,)) + self._watch(proc) + self.assertTrue( + self.sigquit_triggered.wait(timeout=2.0), + "SIGQUIT was not triggered within timeout", + ) + + def test_immediate_crash_detection(self): + proc = self._spawn(crashing_worker) + self._watch(proc, interval=0.05) + self.assertTrue( + self.sigquit_triggered.wait(timeout=1.0), + "Immediate crash was not detected", + ) + + def test_multiple_processes_one_crashes(self): + healthy = self._spawn(healthy_worker) + crashing = self._spawn(slow_crash_worker, args=(0.2,)) + self._watch([healthy, crashing], names=["healthy", "crashing"]) + self.assertTrue( + self.sigquit_triggered.wait(timeout=2.0), + "Crash was not detected when one of multiple processes crashed", + ) + + def test_empty_processes_list(self): + self._watch([], interval=0.1) + time.sleep(0.3) + self.assertFalse(self.sigquit_triggered.is_set()) + + def test_normal_exit_no_sigquit(self): + proc = self._spawn(lambda: None) + proc.join(timeout=2) + self._watch(proc) + time.sleep(0.3) + self.assertFalse( + self.sigquit_triggered.is_set(), + "SIGQUIT should not be triggered for normal exit (exitcode=0)", + ) + + +if __name__ == "__main__": + mp.set_start_method("spawn", force=True) + import unittest + + unittest.main() diff --git a/test/registered/utils/test_log_utils.py b/test/registered/utils/test_log_utils.py index 810c12f36b7c..74341c13be60 100644 --- a/test/registered/utils/test_log_utils.py +++ b/test/registered/utils/test_log_utils.py @@ -9,7 +9,7 @@ from sglang.srt.utils.log_utils import create_log_targets, log_json from sglang.test.ci.ci_register import register_cpu_ci -register_cpu_ci(est_time=1, suite="stage-a-test-cpu") +register_cpu_ci(est_time=4, suite="stage-a-test-cpu") class TestLogUtils(unittest.TestCase): diff --git a/test/registered/utils/test_network_address.py b/test/registered/utils/test_network_address.py index 967345fae36c..ff05206c28b4 100644 --- a/test/registered/utils/test_network_address.py +++ b/test/registered/utils/test_network_address.py @@ -6,7 +6,7 @@ from sglang.srt.utils.network import NetworkAddress from sglang.test.ci.ci_register import register_cpu_ci -register_cpu_ci(est_time=1, suite="stage-a-test-cpu") +register_cpu_ci(est_time=7, suite="stage-a-test-cpu") # Mock get_device() so ServerArgs tests run on CPU-only CI runners _mock_device = patch("sglang.srt.server_args.get_device", return_value="cuda") diff --git a/test/registered/utils/test_numa_utils.py b/test/registered/utils/test_numa_utils.py new file mode 100644 index 000000000000..01209292ed56 --- /dev/null +++ b/test/registered/utils/test_numa_utils.py @@ -0,0 +1,311 @@ +import ctypes +import unittest +from unittest.mock import MagicMock, patch + +from sglang.srt.utils.numa_utils import ( + _is_numa_available, + _query_numa_node_for_gpu, + get_numa_node_if_available, +) +from sglang.test.ci.ci_register import register_cpu_ci, register_cuda_ci + +register_cpu_ci(est_time=1, suite="stage-a-cpu-only") +register_cuda_ci(est_time=10, suite="stage-c-test-4-gpu-gb200") +register_cuda_ci(est_time=10, suite="stage-c-test-8-gpu-b200") + + +class TestIsNumaAvailable(unittest.TestCase): + """Tests for _is_numa_available on both NUMA and non-NUMA systems.""" + + @patch("sglang.srt.utils.numa_utils._is_cuda", False) + def test_returns_false_when_not_cuda(self): + self.assertFalse(_is_numa_available()) + + @patch("sglang.srt.utils.numa_utils._is_cuda", True) + @patch("os.path.isdir", return_value=False) + def test_returns_false_when_no_numa_nodes(self, _mock_isdir): + self.assertFalse(_is_numa_available()) + + @patch("sglang.srt.utils.numa_utils._is_cuda", True) + @patch("os.path.isdir", return_value=True) + @patch("sglang.srt.utils.numa_utils.psutil") + def test_returns_false_when_affinity_constrained(self, mock_psutil, _mock_isdir): + mock_process = MagicMock() + mock_process.cpu_affinity.return_value = [0, 1] + mock_psutil.Process.return_value = mock_process + mock_psutil.cpu_count.return_value = 128 + + self.assertFalse(_is_numa_available()) + + @patch("sglang.srt.utils.numa_utils._can_set_mempolicy", return_value=True) + @patch("sglang.srt.utils.numa_utils.shutil.which", return_value="/usr/bin/numactl") + @patch("sglang.srt.utils.numa_utils._is_cuda", True) + @patch("os.path.isdir", return_value=True) + @patch("sglang.srt.utils.numa_utils.psutil") + def test_returns_true_on_numa_system_with_full_affinity( + self, mock_psutil, _mock_isdir, _mock_which, _mock_mempolicy + ): + all_cpus = list(range(128)) + mock_process = MagicMock() + mock_process.cpu_affinity.return_value = all_cpus + mock_psutil.Process.return_value = mock_process + mock_psutil.cpu_count.return_value = 128 + + self.assertTrue(_is_numa_available()) + + @patch("sglang.srt.utils.numa_utils._can_set_mempolicy", return_value=False) + @patch("sglang.srt.utils.numa_utils.shutil.which", return_value="/usr/bin/numactl") + @patch("sglang.srt.utils.numa_utils._is_cuda", True) + @patch("os.path.isdir", return_value=True) + @patch("sglang.srt.utils.numa_utils.psutil") + def test_returns_false_when_mempolicy_not_permitted( + self, mock_psutil, _mock_isdir, _mock_which, _mock_mempolicy + ): + all_cpus = list(range(128)) + mock_process = MagicMock() + mock_process.cpu_affinity.return_value = all_cpus + mock_psutil.Process.return_value = mock_process + mock_psutil.cpu_count.return_value = 128 + + self.assertFalse(_is_numa_available()) + + @patch("sglang.srt.utils.numa_utils._can_set_mempolicy", return_value=True) + @patch("sglang.srt.utils.numa_utils.shutil.which", return_value="/usr/bin/numactl") + @patch("sglang.srt.utils.numa_utils._is_cuda", True) + @patch("os.path.isdir", return_value=True) + @patch("sglang.srt.utils.numa_utils.psutil") + def test_isdir_called_with_node1_path( + self, mock_psutil, mock_isdir, _mock_which, _mock_mempolicy + ): + all_cpus = list(range(8)) + mock_process = MagicMock() + mock_process.cpu_affinity.return_value = all_cpus + mock_psutil.Process.return_value = mock_process + mock_psutil.cpu_count.return_value = 8 + + _is_numa_available() + mock_isdir.assert_called_with("/sys/devices/system/node/node1") + + +class TestQueryNumaNodeForGpu(unittest.TestCase): + """Tests for _query_numa_node_for_gpu with mocked pynvml.""" + + @patch( + "sglang.srt.utils.numa_utils.glob.glob", + return_value=[ + "/sys/devices/system/node/node0", + "/sys/devices/system/node/node1", + ], + ) + def test_single_node_affinity(self, _mock_glob): + c_ulong_bits = ctypes.sizeof(ctypes.c_ulong) * 8 + # Bitmask: bit 0 set -> node 0 + node_set = [1] + + mock_pynvml = MagicMock() + mock_pynvml.nvmlDeviceGetMemoryAffinity.return_value = node_set + mock_pynvml.NVML_AFFINITY_SCOPE_NODE = 0 + + with patch.dict("sys.modules", {"pynvml": mock_pynvml}): + result = _query_numa_node_for_gpu(0) + + self.assertEqual(result, [0]) + mock_pynvml.nvmlInit.assert_called_once() + mock_pynvml.nvmlShutdown.assert_called_once() + + @patch( + "sglang.srt.utils.numa_utils.glob.glob", + return_value=[ + "/sys/devices/system/node/node0", + "/sys/devices/system/node/node1", + ], + ) + def test_second_node_affinity(self, _mock_glob): + # Bitmask: bit 1 set -> node 1 + node_set = [2] + + mock_pynvml = MagicMock() + mock_pynvml.nvmlDeviceGetMemoryAffinity.return_value = node_set + mock_pynvml.NVML_AFFINITY_SCOPE_NODE = 0 + + with patch.dict("sys.modules", {"pynvml": mock_pynvml}): + result = _query_numa_node_for_gpu(1) + + self.assertEqual(result, [1]) + + @patch( + "sglang.srt.utils.numa_utils.glob.glob", + return_value=[ + "/sys/devices/system/node/node0", + "/sys/devices/system/node/node1", + "/sys/devices/system/node/node2", + "/sys/devices/system/node/node3", + ], + ) + def test_multiple_node_affinity(self, _mock_glob): + # Bitmask: bits 1 and 3 set -> nodes 1, 3 (binary: ...1010 = 10) + node_set = [0b1010] + + mock_pynvml = MagicMock() + mock_pynvml.nvmlDeviceGetMemoryAffinity.return_value = node_set + mock_pynvml.NVML_AFFINITY_SCOPE_NODE = 0 + + with patch.dict("sys.modules", {"pynvml": mock_pynvml}): + result = _query_numa_node_for_gpu(0) + + self.assertEqual(result, [1, 3]) + + @patch( + "sglang.srt.utils.numa_utils.glob.glob", + return_value=[ + "/sys/devices/system/node/node0", + "/sys/devices/system/node/node1", + ], + ) + def test_no_affinity(self, _mock_glob): + node_set = [0] + + mock_pynvml = MagicMock() + mock_pynvml.nvmlDeviceGetMemoryAffinity.return_value = node_set + mock_pynvml.NVML_AFFINITY_SCOPE_NODE = 0 + + with patch.dict("sys.modules", {"pynvml": mock_pynvml}): + result = _query_numa_node_for_gpu(0) + + self.assertEqual(result, []) + + @patch( + "sglang.srt.utils.numa_utils.glob.glob", + return_value=[ + "/sys/devices/system/node/node0", + "/sys/devices/system/node/node1", + ], + ) + def test_nvml_shutdown_called_on_success(self, _mock_glob): + node_set = [1] + mock_pynvml = MagicMock() + mock_pynvml.nvmlDeviceGetMemoryAffinity.return_value = node_set + mock_pynvml.NVML_AFFINITY_SCOPE_NODE = 0 + + with patch.dict("sys.modules", {"pynvml": mock_pynvml}): + _query_numa_node_for_gpu(0) + + mock_pynvml.nvmlShutdown.assert_called_once() + + +class TestGetNumaNodeIfAvailable(unittest.TestCase): + """Tests for get_numa_node_if_available combining _is_numa_available + _query_numa_node_for_gpu.""" + + def _make_server_args(self, numa_node=None): + args = MagicMock() + args.numa_node = numa_node + return args + + def test_returns_explicit_numa_node_from_server_args(self): + args = self._make_server_args(numa_node=[2, 3, 0, 1]) + self.assertEqual(get_numa_node_if_available(args, 0), 2) + self.assertEqual(get_numa_node_if_available(args, 1), 3) + self.assertEqual(get_numa_node_if_available(args, 2), 0) + self.assertEqual(get_numa_node_if_available(args, 3), 1) + + @patch("sglang.srt.utils.numa_utils._is_numa_available", return_value=False) + def test_returns_none_when_numa_not_available(self, _mock_avail): + args = self._make_server_args(numa_node=None) + self.assertIsNone(get_numa_node_if_available(args, 0)) + + @patch("sglang.srt.utils.numa_utils._query_numa_node_for_gpu", return_value=[]) + @patch("sglang.srt.utils.numa_utils._is_numa_available", return_value=True) + def test_returns_none_when_query_returns_empty(self, _mock_avail, _mock_gpu): + args = self._make_server_args(numa_node=None) + self.assertIsNone(get_numa_node_if_available(args, 0)) + + @patch("sglang.srt.utils.numa_utils._query_numa_node_for_gpu", return_value=[1]) + @patch("sglang.srt.utils.numa_utils._is_numa_available", return_value=True) + def test_returns_queried_single_node(self, _mock_avail, _mock_gpu): + args = self._make_server_args(numa_node=None) + self.assertEqual(get_numa_node_if_available(args, 0), 1) + + @patch("sglang.srt.utils.numa_utils._query_numa_node_for_gpu", return_value=[0, 2]) + @patch("sglang.srt.utils.numa_utils._is_numa_available", return_value=True) + def test_returns_first_node_when_multiple_found(self, _mock_avail, _mock_gpu): + args = self._make_server_args(numa_node=None) + self.assertEqual(get_numa_node_if_available(args, 0), 0) + + @patch("sglang.srt.utils.numa_utils._query_numa_node_for_gpu", return_value=[0, 2]) + @patch("sglang.srt.utils.numa_utils._is_numa_available", return_value=True) + def test_logs_warning_when_multiple_nodes(self, _mock_avail, _mock_gpu): + args = self._make_server_args(numa_node=None) + with self.assertLogs("sglang.srt.utils.numa_utils", level="WARNING") as cm: + get_numa_node_if_available(args, 0) + self.assertTrue(any("Multiple NUMA nodes" in msg for msg in cm.output)) + + @patch("sglang.srt.utils.numa_utils._is_numa_available", return_value=True) + @patch("sglang.srt.utils.numa_utils._query_numa_node_for_gpu", return_value=[1]) + def test_explicit_server_args_takes_precedence(self, _mock_gpu, _mock_avail): + args = self._make_server_args(numa_node=[5, 6]) + result = get_numa_node_if_available(args, 0) + self.assertEqual(result, 5) + _mock_avail.assert_not_called() + _mock_gpu.assert_not_called() + + +def _get_gpu_name(): + try: + import pynvml + + pynvml.nvmlInit() + handle = pynvml.nvmlDeviceGetHandleByIndex(0) + name = pynvml.nvmlDeviceGetName(handle) + pynvml.nvmlShutdown() + return name + except Exception: + return "" + + +_gpu_name = _get_gpu_name() + + +@unittest.skipUnless("GB200" in _gpu_name, "Requires GB200 hardware") +class TestGB200NumaTopology(unittest.TestCase): + """Hardware test validating expected NUMA topology on GB200 (2 NUMA nodes, 4 GPUs).""" + + def _make_server_args(self): + args = MagicMock() + args.numa_node = None + return args + + def test_gpu_numa_mapping(self): + expected = {0: 0, 1: 0, 2: 1, 3: 1} + args = self._make_server_args() + for gpu_id, expected_node in expected.items(): + result = get_numa_node_if_available(args, gpu_id) + self.assertEqual( + result, + expected_node, + f"GPU {gpu_id}: expected NUMA node {expected_node}, got {result}", + ) + + +@unittest.skipUnless("B200" in _gpu_name, "Requires B200 hardware") +class TestB200NumaTopology(unittest.TestCase): + """Hardware test validating expected NUMA topology on B200 (2 NUMA nodes, 8 GPUs).""" + + def _make_server_args(self): + args = MagicMock() + args.numa_node = None + return args + + def test_gpu_numa_mapping(self): + expected = {0: 0, 1: 0, 2: 0, 3: 0, 4: 1, 5: 1, 6: 1, 7: 1} + args = self._make_server_args() + for gpu_id, expected_node in expected.items(): + result = get_numa_node_if_available(args, gpu_id) + self.assertEqual( + result, + expected_node, + f"GPU {gpu_id}: expected NUMA node {expected_node}, got {result}", + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/registered/utils/test_request_logger.py b/test/registered/utils/test_request_logger.py index 28af9f9495c0..c87a88448c09 100644 --- a/test/registered/utils/test_request_logger.py +++ b/test/registered/utils/test_request_logger.py @@ -8,6 +8,7 @@ import requests +from sglang.srt.constants import HEALTH_CHECK_RID_PREFIX from sglang.srt.utils import kill_process_tree from sglang.test.ci.ci_register import register_amd_ci, register_cuda_ci from sglang.test.test_utils import ( @@ -197,7 +198,7 @@ def _verify_logs(self, content: str, source_name: str): continue rid = data.get("rid", "") - if rid.startswith("HEALTH_CHECK"): + if rid.startswith(HEALTH_CHECK_RID_PREFIX): continue if data.get("event") == "request.received": diff --git a/test/registered/utils/test_socket_utils.py b/test/registered/utils/test_socket_utils.py index 42c12fd88dc4..63a42ac39237 100644 --- a/test/registered/utils/test_socket_utils.py +++ b/test/registered/utils/test_socket_utils.py @@ -15,7 +15,7 @@ from sglang.test.test_utils import CustomTestCase from sglang.utils import normalize_base_url, release_port, reserve_port -register_cpu_ci(est_time=1, suite="stage-a-test-cpu") +register_cpu_ci(est_time=7, suite="stage-a-test-cpu") class TestTryBindSocket(CustomTestCase): diff --git a/test/registered/vlm/test_patch_embed_perf.py b/test/registered/vlm/test_patch_embed_perf.py index ce2f38ca2de6..3ef3a645750d 100644 --- a/test/registered/vlm/test_patch_embed_perf.py +++ b/test/registered/vlm/test_patch_embed_perf.py @@ -8,7 +8,7 @@ from sglang.srt.models.glm4v import Glm4vVisionPatchEmbed from sglang.test.ci.ci_register import register_cuda_ci -register_cuda_ci(est_time=120, suite="stage-b-test-1-gpu-large") +register_cuda_ci(est_time=12, suite="stage-b-test-1-gpu-large") PATCH_SIZE = 14 TEMPORAL_PATCH_SIZE = 2 diff --git a/test/registered/vlm/test_vlm_tp4.py b/test/registered/vlm/test_vlm_tp4.py new file mode 100644 index 000000000000..d62df9b0b985 --- /dev/null +++ b/test/registered/vlm/test_vlm_tp4.py @@ -0,0 +1,82 @@ +""" +VLM TP=4 per-commit test using Qwen3.5-27B with MMMU evaluation. +""" + +import unittest +from types import SimpleNamespace + +from sglang.srt.utils import kill_process_tree +from sglang.test.ci.ci_register import register_cuda_ci +from sglang.test.run_eval import run_eval +from sglang.test.test_utils import ( + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, + CustomTestCase, + popen_launch_server, +) + +register_cuda_ci(est_time=200, suite="stage-c-test-4-gpu-h100") + +QWEN35_27B_MODEL = "Qwen/Qwen3.5-27B" +MMMU_ACCURACY_THRESHOLD = 0.65 +MMMU_NUM_EXAMPLES = 32 + + +class TestVLMTP4(CustomTestCase): + @classmethod + def setUpClass(cls): + cls.model = QWEN35_27B_MODEL + cls.base_url = DEFAULT_URL_FOR_TEST + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=[ + "--tp-size", + "4", + "--cuda-graph-max-bs", + "32", + "--mem-fraction-static", + "0.8", + "--trust-remote-code", + "--mamba-scheduler-strategy", + "extra_buffer", + "--mamba-track-interval", + "128", + "--mamba-ssm-dtype", + "bfloat16", + "--chunked-prefill-size", + "2048", + "--max-running-requests", + "128", + ], + ) + + @classmethod + def tearDownClass(cls): + if hasattr(cls, "process") and cls.process: + kill_process_tree(cls.process.pid) + + def test_mmmu_accuracy(self): + args = SimpleNamespace( + model=self.model, + eval_name="mmmu", + num_examples=MMMU_NUM_EXAMPLES, + num_threads=16, + max_tokens=2048, + chat_template_kwargs={"enable_thinking": False}, + base_url=self.base_url, + host="http://127.0.0.1", + port=int(self.base_url.split(":")[-1]), + ) + metrics = run_eval(args) + print(f"MMMU score: {metrics['score']}") + self.assertGreaterEqual( + metrics["score"], + MMMU_ACCURACY_THRESHOLD, + f"MMMU accuracy {metrics['score']:.4f} below threshold {MMMU_ACCURACY_THRESHOLD}", + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/run_suite.py b/test/run_suite.py index 97b39ee2ea71..1062f7d2fbd8 100644 --- a/test/run_suite.py +++ b/test/run_suite.py @@ -84,6 +84,8 @@ "nightly-eval-vlm-2-gpu", "nightly-perf-text-2-gpu", "nightly-perf-vlm-2-gpu", + # GB300 (4x B200 NVL4) nightly suite + "nightly-4-gpu-gb300", ], HWBackend.AMD: [ "nightly-amd", @@ -103,6 +105,11 @@ "nightly-4-npu-a3", "nightly-8-npu-a3", "nightly-16-npu-a3", + "full-1-npu-a3", + "full-2-npu-a3", + "full-4-npu-a3", + "full-8-npu-a3", + "full-16-npu-a3", ], } diff --git a/test/srt/cpu/test_rope.py b/test/srt/cpu/test_rope.py index 97475a60724a..e4648da4d91c 100644 --- a/test/srt/cpu/test_rope.py +++ b/test/srt/cpu/test_rope.py @@ -9,6 +9,7 @@ ) from sglang.srt.layers.rotary_embedding.rope_variant import ( DeepseekScalingRotaryEmbedding, + apply_rotary_pos_emb_native, ) from sglang.srt.server_args import ServerArgs, set_global_server_args_for_scheduler from sglang.test.test_utils import CustomTestCase @@ -18,6 +19,7 @@ class TestROPE(CustomTestCase): def test_mrope(self): + torch.manual_seed(100) head_size = 128 seq_len = 512 num_heads = 16 @@ -254,6 +256,27 @@ def single_test( num_kv_heads, ) + def test_apply_rotary_pos_emb(self): + num_tokens = 1024 + num_heads = 8 + head_size = 72 + qkv = torch.randn(num_tokens, num_heads * head_size * 3).to(torch.bfloat16) + query, key, _ = qkv.split( + [num_heads * head_size, num_heads * head_size, num_heads * head_size], + dim=-1, + ) + query = query.view(num_tokens, num_heads, head_size) + key = key.view(num_tokens, num_heads, head_size) + for sincos_dtype in [torch.float32, torch.bfloat16]: + cos = torch.rand(num_tokens, head_size).to(sincos_dtype) + sin = torch.rand(num_tokens, head_size).to(sincos_dtype) + q_out_ref, k_out_ref = apply_rotary_pos_emb_native(query, key, cos, sin) + q_out_sgl, k_out_sgl = torch.ops.sgl_kernel.apply_rotary_pos_emb_cpu( + query, key, cos, sin + ) + torch.testing.assert_close(q_out_ref, q_out_sgl, atol=1e-2, rtol=1e-2) + torch.testing.assert_close(k_out_ref, k_out_sgl, atol=1e-2, rtol=1e-2) + if __name__ == "__main__": unittest.main() diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index a7d5b6744b3a..49077fcd6e67 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -13,7 +13,9 @@ suites = { # quantization_test suite migrated to test/registered/quant/ # All CUDA tests migrated to test/registered/ - "__not_in_ci__": [], + "__not_in_ci__": [ + TestFile("ascend/test_embed_interpolate_unittest.py"), + ], } # Add AMD tests @@ -81,44 +83,8 @@ ], } -# Add Ascend NPU tests -# TODO: Set accurate estimate time -# NOTE: please sort the test cases alphabetically by the test file name -suite_ascend = { - "per-commit-1-npu-a2": [ - TestFile("ascend/test_ascend_autoround_dense.py", 400), - TestFile("ascend/test_ascend_autoround_moe.py", 400), - TestFile("ascend/test_ascend_gptq_moe.py", 400), - TestFile("ascend/test_ascend_graph_tp1_bf16.py", 400), - TestFile("ascend/test_ascend_piecewise_graph_prefill.py", 400), - TestFile("ascend/test_ascend_hicache_mha.py", 400), - TestFile("ascend/test_ascend_sampling_backend.py", 400), - TestFile("ascend/test_ascend_tp1_bf16.py", 400), - TestFile("ascend/test_ascend_compile_graph_tp1_bf16.py", 400), - TestFile("ascend/test_ascend_w8a8_quantization.py", 400), - TestFile("ascend/test_embed_interpolate_unittest.py", 400), - ], - "per-commit-2-npu-a2": [ - TestFile("ascend/test_ascend_graph_tp2_bf16.py", 400), - TestFile("ascend/test_ascend_mla_fia_w8a8int8.py", 400), - TestFile("ascend/test_ascend_tp2_bf16.py", 400), - TestFile("ascend/test_ascend_tp2_fia_bf16.py", 400), - ], - "per-commit-4-npu-a3": [ - TestFile("ascend/test_ascend_mla_w8a8int8.py", 400), - TestFile("ascend/test_ascend_hicache_mla.py", 400), - TestFile("ascend/test_ascend_tp4_bf16.py", 400), - TestFile("ascend/test_ascend_w4a4_quantization.py", 600), - TestFile("ascend/test_llada2_mini_ascend.py", 800), - ], - "per-commit-16-npu-a3": [ - TestFile("ascend/test_ascend_deepep.py", 3600), - ], -} - suites.update(suite_amd) suites.update(suite_xeon) -suites.update(suite_ascend) suites.update(suite_xpu) From 5bf5c4a2173092285a27967cb5bbaf75f0afc7d6 Mon Sep 17 00:00:00 2001 From: khalilzhk Date: Fri, 3 Apr 2026 17:25:21 +0800 Subject: [PATCH 21/42] BugFix for MLAPO for Deepseek eagle3 on Ascend (#222) --- .../srt/hardware_backend/npu/attention/ascend_backend.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/hardware_backend/npu/attention/ascend_backend.py b/python/sglang/srt/hardware_backend/npu/attention/ascend_backend.py index d93cc5f62f1c..e7acdee86e90 100644 --- a/python/sglang/srt/hardware_backend/npu/attention/ascend_backend.py +++ b/python/sglang/srt/hardware_backend/npu/attention/ascend_backend.py @@ -860,7 +860,7 @@ def forward_extend( sinks: Optional[torch.Tensor] = None, slopes: Optional[torch.Tensor] = None, ): - if is_mla_preprocess_enabled(): + if is_mla_preprocess_enabled() and self.use_mla: # MLAPO and MLAPROLOG do save kv_cache save_kv_cache = False if self.is_dllm_model: @@ -1773,7 +1773,7 @@ def forward_decode( sinks: Optional[torch.Tensor] = None, slopes: Optional[torch.Tensor] = None, ): - if is_mla_preprocess_enabled(): + if is_mla_preprocess_enabled() and self.use_mla: # MLAPO does saving kv_cache save_kv_cache = False if topk_indices is not None: From bb5c38663ee773350f3ecbcc4fe5625689f1ff00 Mon Sep 17 00:00:00 2001 From: silencejade <13120475055@163.com> Date: Mon, 6 Apr 2026 15:18:01 +0800 Subject: [PATCH 22/42] adapt mtp + prefix for ascend gdn backend (#202) --- .../npu/attention/ascend_gdn_backend.py | 14 ++++++---- .../attention/hybrid_linear_attn_backend.py | 28 ++++++++----------- python/sglang/srt/server_args.py | 2 +- 3 files changed, 22 insertions(+), 22 deletions(-) diff --git a/python/sglang/srt/hardware_backend/npu/attention/ascend_gdn_backend.py b/python/sglang/srt/hardware_backend/npu/attention/ascend_gdn_backend.py index fcc8c6316d7a..00158319fb46 100644 --- a/python/sglang/srt/hardware_backend/npu/attention/ascend_gdn_backend.py +++ b/python/sglang/srt/hardware_backend/npu/attention/ascend_gdn_backend.py @@ -38,6 +38,10 @@ class AscendGDNAttnBackend(GDNAttnBackend): def __init__(self, model_runner: ModelRunner): super().__init__(model_runner) + # transpose last two dim for _init_npu_conv_state + self.conv_states_shape = torch.Size( + (*self.conv_states_shape[:-2], self.conv_states_shape[-1], self.conv_states_shape[-2]) + ) decode_backend = get_linear_attn_decode_backend() prefill_backend = get_linear_attn_prefill_backend() self.kernel_dispatcher = AscendGDNKernelDispatcher( @@ -60,10 +64,10 @@ def prepare_gdn_inputs( if forward_mode.is_target_verify(): seq_len = spec_info.draft_token_num self.actual_seq_lengths = self.actual_seq_lengths * seq_len - start_indices = cache_indices * seq_len - offset = torch.arange(seq_len, device=start_indices.device) - ranges = start_indices.unsqueeze(1) + offset - self.ssm_state_indices = ranges.flatten().to(torch.int32) + # indices + self.ssm_state_indices = torch.arange( + cache_indices.shape[0] * seq_len, dtype=torch.int32, device=cache_indices.device + ) else: self.ssm_state_indices = cache_indices @@ -262,7 +266,7 @@ def forward_extend( :, forward_metadata.track_conv_indices ].transpose(0, 1) mask_indices = forward_batch.mamba_track_mask.nonzero(as_tuple=True)[0] - conv_states[conv_dst[mask_indices]] = mixed_qkv_to_track + conv_states.transpose(1, 2)[conv_dst[mask_indices]] = mixed_qkv_to_track kernel_size = layer.conv_weights.shape[-1] conv_states_for_prefill = conv_states[:, -(kernel_size - 1) :, :] conv_states_tmp = conv_states_for_prefill.contiguous() diff --git a/python/sglang/srt/layers/attention/hybrid_linear_attn_backend.py b/python/sglang/srt/layers/attention/hybrid_linear_attn_backend.py index 75d0fac2b6d4..4ca421a8694e 100644 --- a/python/sglang/srt/layers/attention/hybrid_linear_attn_backend.py +++ b/python/sglang/srt/layers/attention/hybrid_linear_attn_backend.py @@ -487,12 +487,9 @@ def _capture_metadata( self.query_start_loc_list[bs - 1].copy_( self.cached_cuda_graph_verify_query_start_loc[: bs + 1] ) - start_indices = mamba_indices * spec_info.draft_token_num - offset = torch.arange( - spec_info.draft_token_num, device=start_indices.device + ssm_state_indices = torch.arange( + mamba_indices.shape[0] * spec_info.draft_token_num, dtype=torch.int32, device=mamba_indices.device ) - ranges = start_indices.unsqueeze(1) + offset - ssm_state_indices = ranges.flatten().to(torch.int32) self.state_indices_list_gdn[bs - 1][ : len(mamba_indices) * spec_info.draft_token_num ].copy_(ssm_state_indices) @@ -547,14 +544,10 @@ def _replay_metadata( bs - num_padding ) elif forward_mode.is_target_verify(): - start_indices = ( - mamba_indices[: bs - num_padding] * spec_info.draft_token_num + ssm_state_indices = torch.arange( + len(mamba_indices[:bs - num_padding]) * spec_info.draft_token_num, + dtype=torch.int32, device=mamba_indices.device ) - offset = torch.arange( - spec_info.draft_token_num, device=start_indices.device - ) - ranges = start_indices.unsqueeze(1) + offset - ssm_state_indices = ranges.flatten().to(torch.int32) self.state_indices_list_gdn[bs - 1][ : len(mamba_indices[: bs - num_padding]) * spec_info.draft_token_num ].copy_(ssm_state_indices) @@ -1001,18 +994,21 @@ def update_mamba_state_after_mtp_verify( intermediate_state_cache = mamba_caches.intermediate_ssm intermediate_conv_window_cache = mamba_caches.intermediate_conv_window[0] if is_npu(): - valid_state_indices = state_indices_tensor.to(torch.int64) # [N] + dst_indices_tensor = state_indices_tensor.to(torch.int64) # [N] + src_indices_tensor = torch.arange(dst_indices_tensor.shape[0], + device=dst_indices_tensor.device, + dtype=torch.int64) last_steps = accepted_steps.to(torch.int64) # [N] move_intermediate_cache( - ssm_states, intermediate_state_cache, valid_state_indices, last_steps + ssm_states, intermediate_state_cache, dst_indices_tensor, src_indices_tensor, last_steps ) draft_token_num = intermediate_state_cache.shape[2] - if valid_state_indices.numel() > 0: + if dst_indices_tensor.numel() > 0: conv_state_rollback( conv_states, - valid_state_indices, + dst_indices_tensor, last_steps, draft_token_num, ) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 6c5e559ce313..2a98753e875c 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -2172,7 +2172,7 @@ def _handle_mamba_radix_cache( ) assert ( - is_cuda() + is_cuda(), is_npu() ), "Mamba extra_buffer is only supported on CUDA devices with FLA backend" if self.speculative_num_draft_tokens is not None: assert ( From c4366c11fa36d4903676bd62638c7f03b8007b27 Mon Sep 17 00:00:00 2001 From: shadowxz109 Date: Tue, 7 Apr 2026 09:20:28 +0800 Subject: [PATCH 23/42] Minimax 2.5 optimization (#237) --- .../srt/hardware_backend/npu/moe/topk.py | 23 ++- python/sglang/srt/models/minimax_m2.py | 181 +++++++++++++----- 2 files changed, 159 insertions(+), 45 deletions(-) diff --git a/python/sglang/srt/hardware_backend/npu/moe/topk.py b/python/sglang/srt/hardware_backend/npu/moe/topk.py index 6447d9a67194..3ceb0eaad10b 100644 --- a/python/sglang/srt/hardware_backend/npu/moe/topk.py +++ b/python/sglang/srt/hardware_backend/npu/moe/topk.py @@ -26,6 +26,7 @@ def fused_topk_npu( renormalize = topk_config.renormalize correction_bias = topk_config.correction_bias + # Fast path: simple top-k without grouped routing and bias if not use_grouped_topk and correction_bias is None: topk_weights, topk_ids, _ = torch.ops.npu.npu_moe_gating_top_k_softmax( router_logits, @@ -40,8 +41,8 @@ def fused_topk_npu( ) topk_weights = topk_weights.to(torch.float32) + # Grouped top-k with correction bias elif use_grouped_topk and correction_bias is not None: - # Force set routed_scaling_factor = 1 to optimize renormalize topk_weights, topk_ids, _ = torch.ops.npu.npu_moe_gating_top_k( router_logits.to(torch.float32), k=topk_config.top_k, @@ -57,6 +58,26 @@ def fused_topk_npu( eps=float(1e-20), ) + # npu_moe_gating_top_k is not yet supported custom_routing_function + # torch native is not yet supported num_token_non_padded + elif ( + topk_config.custom_routing_function is None + and num_token_non_padded is not None + and correction_bias is not None + ): + topk_weights, topk_ids, _ = torch.ops.npu.npu_moe_gating_top_k( + router_logits.to(torch.float32), + k=topk_config.top_k, + bias=correction_bias.to(torch.float32), + renorm=0, + norm_type=1, + routed_scaling_factor=( + 1 if renormalize else topk_config.routed_scaling_factor + ), + eps=float(1e-20), + ) + + # Fallback to torch native implementation else: topk_config.torch_native = True return select_experts( diff --git a/python/sglang/srt/models/minimax_m2.py b/python/sglang/srt/models/minimax_m2.py index e5ef2d75c7dc..5ade336780bf 100644 --- a/python/sglang/srt/models/minimax_m2.py +++ b/python/sglang/srt/models/minimax_m2.py @@ -30,7 +30,6 @@ from sglang.srt.distributed import ( get_moe_expert_parallel_world_size, get_pp_group, - get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce, ) @@ -41,6 +40,13 @@ LayerScatterModes, ScatterMode, ) +from sglang.srt.layers.dp_attention import ( + attn_tp_all_reduce, + get_attention_tp_rank, + get_attention_tp_size, + is_dp_attention_enabled, + get_attention_tp_group, +) from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.linear import ( QKVParallelLinear, @@ -74,6 +80,12 @@ make_layers, ) from sglang.srt.utils.hf_transformers_utils import get_rope_config +from sglang.srt.utils import is_npu + +_is_npu = is_npu() + +if _is_npu: + from sgl_kernel_npu.norm.split_qkv_tp_rmsnorm_rope import split_qkv_tp_rmsnorm_rope logger = logging.getLogger(__name__) @@ -250,11 +262,11 @@ class MiniMaxM2RMSNormTP(nn.Module): def __init__(self, hidden_size: int, eps: float = 1e-6) -> None: super().__init__() - self.tp_world = get_tensor_model_parallel_world_size() - self.tp_rank = get_tensor_model_parallel_rank() + self.attn_tp_size = get_attention_tp_size() + self.attn_tp_rank = get_attention_tp_rank() # Weight parameter is sharded across TP ranks - self.weight = nn.Parameter(torch.ones(int(hidden_size / self.tp_world))) + self.weight = nn.Parameter(torch.ones(int(hidden_size / self.attn_tp_size))) self.weight.weight_loader = self.weight_loader self.variance_epsilon = eps @@ -264,11 +276,11 @@ def weight_loader( loaded_weight: torch.Tensor, ) -> None: """Custom weight loader that handles TP sharding.""" - tp_world = get_tensor_model_parallel_world_size() - tp_rank = get_tensor_model_parallel_rank() + attn_tp_size = get_attention_tp_size() + attn_tp_rank = get_attention_tp_rank() - shard_size = loaded_weight.shape[0] // tp_world - shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size) + shard_size = loaded_weight.shape[0] // attn_tp_size + shard = slice(attn_tp_rank * shard_size, (attn_tp_rank + 1) * shard_size) param.data.copy_(loaded_weight[shard]) @torch.compile(dynamic=True, backend=get_compiler_backend()) @@ -286,9 +298,9 @@ def forward( # Compute variance across the full dimension (not just local shard) variance = x.pow(2).mean(dim=-1, keepdim=True, dtype=torch.float32) - if self.tp_world > 1: + if self.attn_tp_size > 1: # All-reduce variance across TP ranks to get global variance - variance = tensor_model_parallel_all_reduce(variance) / self.tp_world + variance = attn_tp_all_reduce(variance) / self.attn_tp_size # Normalize and apply local weight shard x = x * torch.rsqrt(variance + self.variance_epsilon) @@ -304,8 +316,8 @@ def forward_qk( k: torch.Tensor, ) -> torch.Tensor: sum_sq = rms_sumsq_serial(q, k) - if q_norm.tp_world > 1: - sum_sq = tensor_model_parallel_all_reduce(sum_sq) + if q_norm.attn_tp_size > 1: + sum_sq = attn_tp_all_reduce(sum_sq) q, k = rms_apply_serial( q, @@ -313,7 +325,7 @@ def forward_qk( q_norm.weight, k_norm.weight, sum_sq, - q_norm.tp_world, + q_norm.attn_tp_size, q_norm.variance_epsilon, ) @@ -387,14 +399,28 @@ def ebias_weight_loader(param: nn.Parameter, loaded_weight: torch.Tensor) -> Non param.data.copy_(loaded_weight.to(torch.float32)) def forward( - self, hidden_states: torch.Tensor, forward_batch: ForwardBatch + self, + hidden_states: torch.Tensor, + forward_batch: Optional[ForwardBatch] = None, + should_allreduce_fusion: bool = False, + use_reduce_scatter: bool = False, ) -> torch.Tensor: - if get_moe_a2a_backend().is_deepep(): - return self.forward_deepep(hidden_states, forward_batch) + if ( + not get_moe_a2a_backend().is_deepep() + and not get_moe_a2a_backend().is_ascend_fuseep() + ): + return self.forward_normal( + hidden_states, should_allreduce_fusion, use_reduce_scatter + ) else: - return self.forward_normal(hidden_states) + return self.forward_deepep(hidden_states, forward_batch) - def forward_normal(self, hidden_states: torch.Tensor) -> torch.Tensor: + def forward_normal( + self, + hidden_states: torch.Tensor, + should_allreduce_fusion: bool = False, + use_reduce_scatter: bool = False, + ) -> torch.Tensor: num_tokens, hidden_dim = hidden_states.shape hidden_states = hidden_states.view(-1, hidden_dim) @@ -403,7 +429,7 @@ def forward_normal(self, hidden_states: torch.Tensor) -> torch.Tensor: topk_output = self.topk(hidden_states, router_logits) final_hidden_states = self.experts(hidden_states, topk_output) - if self.tp_size > 1: + if self.tp_size > 1 and not should_allreduce_fusion and not use_reduce_scatter: final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states) return final_hidden_states.view(num_tokens, hidden_dim) @@ -543,23 +569,26 @@ def __init__( ) -> None: super().__init__() self.hidden_size = config.hidden_size - tp_size = get_tensor_model_parallel_world_size() + + # Use attention TP rank/size for dp-attention support + attn_tp_rank = get_attention_tp_rank() + attn_tp_size = get_attention_tp_size() # Get dimensions from config self.total_num_heads = config.num_attention_heads - assert self.total_num_heads % tp_size == 0 - self.num_heads = self.total_num_heads // tp_size + assert self.total_num_heads % attn_tp_size == 0 + self.num_heads = self.total_num_heads // attn_tp_size self.total_num_kv_heads = config.num_key_value_heads - if self.total_num_kv_heads >= tp_size: + if self.total_num_kv_heads >= attn_tp_size: # Number of KV heads is greater than TP size, so we partition # the KV heads across multiple tensor parallel GPUs. - assert self.total_num_kv_heads % tp_size == 0 + assert self.total_num_kv_heads % attn_tp_size == 0 else: # Number of KV heads is less than TP size, so we replicate # the KV heads across multiple tensor parallel GPUs. - assert tp_size % self.total_num_kv_heads == 0 - self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + assert attn_tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // attn_tp_size) # Use head_dim from config if available, otherwise calculate self.head_dim = getattr( @@ -588,6 +617,8 @@ def __init__( self.total_num_kv_heads, bias=False, quant_config=quant_config, + tp_rank=attn_tp_rank, + tp_size=attn_tp_size, prefix=add_prefix("qkv_proj", prefix), ) @@ -597,6 +628,8 @@ def __init__( bias=False, reduce_results=False, quant_config=quant_config, + tp_rank=attn_tp_rank, + tp_size=attn_tp_size, prefix=add_prefix("o_proj", prefix), ) @@ -653,6 +686,42 @@ def forward_prepare( inner_state = q, k, v, forward_batch return None, forward_batch, inner_state + def forward_prepare_npu( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + forward_batch: ForwardBatch, + ): + qkv, _ = self.qkv_proj(hidden_states) + if self.use_qk_norm: + # q = self.q_norm(q.contiguous()) + # k = self.k_norm(k.contiguous()) + cos_sin = self.rotary_emb.cos_sin_cache.index_select( + 0, positions.flatten() + ) + cos, sin = cos_sin.chunk(2, dim=-1) + q, k, v = split_qkv_tp_rmsnorm_rope( + input=qkv, + cos=cos, + sin=sin, + q_weight=self.q_norm.weight, + k_weight=self.k_norm.weight, + q_hidden_size=self.q_size, + kv_hidden_size=self.kv_size, + head_dim=self.head_dim, + rotary_dim=self.rotary_dim, + eps=self.q_norm.variance_epsilon, + tp_world=self.q_norm.attn_tp_size, + tp_group=get_attention_tp_group().device_group, + ) + else: + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k = q.contiguous(), k.contiguous() + q, k = self.rotary_emb(positions, q, k) + + inner_state = q, k, v, forward_batch + return None, forward_batch, inner_state + def forward_core(self, intermediate_state): _, _, inner_state = intermediate_state attn_output = self.attn(*inner_state) @@ -665,11 +734,18 @@ def forward( hidden_states: torch.Tensor, forward_batch: ForwardBatch, ) -> torch.Tensor: - s = self.forward_prepare( - positions=positions, - hidden_states=hidden_states, - forward_batch=forward_batch, - ) + if _is_npu: + s = self.forward_prepare_npu( + positions=positions, + hidden_states=hidden_states, + forward_batch=forward_batch, + ) + else: + s = self.forward_prepare( + positions=positions, + hidden_states=hidden_states, + forward_batch=forward_batch, + ) return self.forward_core(s) def op_prepare(self, state): @@ -751,12 +827,12 @@ def forward( hidden_states, residual = self.layer_communicator.prepare_attn( hidden_states, residual, forward_batch ) - - hidden_states = self.self_attn( - positions=positions, - hidden_states=hidden_states, - forward_batch=forward_batch, - ) + if not forward_batch.forward_mode.is_idle(): + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + forward_batch=forward_batch, + ) # Fully Connected (MLP or MoE) @@ -764,12 +840,27 @@ def forward( hidden_states, residual, forward_batch ) - hidden_states = self.block_sparse_moe(hidden_states, forward_batch) + should_allreduce_fusion = ( + self.layer_communicator.should_fuse_mlp_allreduce_with_next_layer( + forward_batch + ) + ) - hidden_states, residual = self.layer_communicator.postprocess_layer( - hidden_states, residual, forward_batch + use_reduce_scatter = self.layer_communicator.should_use_reduce_scatter( + forward_batch ) + hidden_states = self.block_sparse_moe( + hidden_states, forward_batch, should_allreduce_fusion, use_reduce_scatter + ) + + if should_allreduce_fusion: + hidden_states._sglang_needs_allreduce_fusion = True + else: + hidden_states, residual = self.layer_communicator.postprocess_layer( + hidden_states, residual, forward_batch + ) + return hidden_states, residual # TBO Operations for MiniMax Decoder Layer @@ -851,6 +942,7 @@ def __init__( self.embed_tokens = VocabParallelEmbedding( config.vocab_size, config.hidden_size, + use_attn_tp_group=is_dp_attention_enabled(), ) def layer_fn(idx, prefix: str) -> nn.Module: @@ -932,10 +1024,11 @@ def forward( {"hidden_states": hidden_states, "residual": residual} ) - if residual is not None: - hidden_states, _ = self.norm(hidden_states, residual) - else: - hidden_states = self.norm(hidden_states) + if hidden_states.shape[0] != 0: + if residual is not None: + hidden_states, _ = self.norm(hidden_states, residual) + else: + hidden_states = self.norm(hidden_states) if len(aux_hidden_states) == 0: return hidden_states From fae90abf6e15aaffb6fd924a439253674771487d Mon Sep 17 00:00:00 2001 From: Ke Bao Date: Tue, 7 Apr 2026 21:56:39 +0800 Subject: [PATCH 24/42] Move ring test to nightly (#22267) --- test/registered/8-gpu-models/test_ring_2_5_1t.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/test/registered/8-gpu-models/test_ring_2_5_1t.py b/test/registered/8-gpu-models/test_ring_2_5_1t.py index 71b2a4f2609e..29c64160cf96 100644 --- a/test/registered/8-gpu-models/test_ring_2_5_1t.py +++ b/test/registered/8-gpu-models/test_ring_2_5_1t.py @@ -5,8 +5,7 @@ from sglang.test.run_combined_tests import run_combined_tests from sglang.test.test_utils import ModelLaunchSettings -# register_cuda_ci(est_time=1000, suite="nightly-8-gpu-common", nightly=True) -register_cuda_ci(est_time=1000, suite="stage-c-test-8-gpu-h200") +register_cuda_ci(est_time=1800, suite="nightly-8-gpu-common", nightly=True) RING_2_5_1T_MODEL_PATH = "inclusionAI/Ring-2.5-1T" @@ -25,6 +24,8 @@ def test_ring_2_5_1t(self): '{"enable_multithread_load": true, "num_threads": 64}', "--watchdog-timeout", "1800", + "--soft-watchdog-timeout", + "1800", ] variants = [ From e7bc23cdab6009a564df91efbe79a057b8642c11 Mon Sep 17 00:00:00 2001 From: Mick Date: Tue, 7 Apr 2026 23:43:18 +0800 Subject: [PATCH 25/42] [diffusion] CI: fix consistency check (#22251) --- .github/workflows/diffusion-ci-gt-gen.yml | 54 ++++++++++++++----- .../test/scripts/gen_diffusion_ci_outputs.py | 8 ++- .../test/server/consistency_threshold.json | 6 --- .../test/server/test_server_common.py | 27 ++++------ .../test/server/testcase_configs.py | 11 ---- 5 files changed, 57 insertions(+), 49 deletions(-) diff --git a/.github/workflows/diffusion-ci-gt-gen.yml b/.github/workflows/diffusion-ci-gt-gen.yml index 92844245bde4..9dad8ed00c16 100644 --- a/.github/workflows/diffusion-ci-gt-gen.yml +++ b/.github/workflows/diffusion-ci-gt-gen.yml @@ -22,6 +22,10 @@ permissions: contents: write actions: read +env: + SGLANG_IS_IN_CI: true + SGLANG_CUDA_COREDUMP: "1" + jobs: multimodal-diffusion-gen-1gpu: if: github.repository == 'sgl-project/sglang' @@ -40,6 +44,8 @@ jobs: run: bash scripts/ci/cuda/ci_install_dependency.sh diffusion - name: Generate outputs + env: + RUNAI_STREAMER_MEMORY_LIMIT: 0 run: | cd python python -m sglang.multimodal_gen.test.scripts.gen_diffusion_ci_outputs \ @@ -56,6 +62,11 @@ jobs: path: python/diffusion-ci-outputs retention-days: 7 + - name: Publish GT images to sglang-bot/sglang-ci-data + env: + GITHUB_TOKEN: ${{ secrets.GH_PAT_FOR_NIGHTLY_CI_DATA }} + run: python scripts/ci/utils/diffusion/publish_diffusion_gt.py --source-dir python/diffusion-ci-outputs + multimodal-diffusion-gen-2gpu: if: github.repository == 'sgl-project/sglang' runs-on: 2-gpu-h100 @@ -73,6 +84,8 @@ jobs: run: bash scripts/ci/cuda/ci_install_dependency.sh diffusion - name: Generate outputs + env: + RUNAI_STREAMER_MEMORY_LIMIT: 0 run: | cd python python -m sglang.multimodal_gen.test.scripts.gen_diffusion_ci_outputs \ @@ -89,27 +102,42 @@ jobs: path: python/diffusion-ci-outputs retention-days: 7 - diffusion-ci-push: - needs: [multimodal-diffusion-gen-1gpu, multimodal-diffusion-gen-2gpu] + - name: Publish GT images to sglang-bot/sglang-ci-data + env: + GITHUB_TOKEN: ${{ secrets.GH_PAT_FOR_NIGHTLY_CI_DATA }} + run: python scripts/ci/utils/diffusion/publish_diffusion_gt.py --source-dir python/diffusion-ci-outputs + + multimodal-diffusion-gen-b200: if: github.repository == 'sgl-project/sglang' - runs-on: ubuntu-latest + runs-on: 4-gpu-b200 + timeout-minutes: 240 steps: - name: Checkout code uses: actions/checkout@v4 - - - name: Download artifacts - uses: actions/download-artifact@v4 with: - pattern: diffusion-gen-* - path: combined - merge-multiple: true + ref: ${{ inputs.ref || github.ref }} + + - name: Install dependencies + run: bash scripts/ci/cuda/ci_install_dependency.sh diffusion - - name: Collect image files + - name: Generate outputs + env: + RUNAI_STREAMER_MEMORY_LIMIT: 0 run: | - mkdir -p gt_images - find combined \( -name "*.png" -o -name "*.jpg" -o -name "*.jpeg" -o -name "*.webp" \) -type f -exec cp -f {} gt_images/ \; + cd python + python -m sglang.multimodal_gen.test.scripts.gen_diffusion_ci_outputs \ + --suite 1-gpu-b200 \ + --out-dir ./diffusion-ci-outputs \ + ${{ inputs.case_ids != '' && format('--case-ids {0}', inputs.case_ids) || '' }} + + - name: Upload artifact + uses: actions/upload-artifact@v4 + with: + name: diffusion-gen-b200 + path: python/diffusion-ci-outputs + retention-days: 7 - name: Publish GT images to sglang-bot/sglang-ci-data env: GITHUB_TOKEN: ${{ secrets.GH_PAT_FOR_NIGHTLY_CI_DATA }} - run: python scripts/ci/utils/diffusion/publish_diffusion_gt.py --source-dir gt_images + run: python scripts/ci/utils/diffusion/publish_diffusion_gt.py --source-dir python/diffusion-ci-outputs diff --git a/python/sglang/multimodal_gen/test/scripts/gen_diffusion_ci_outputs.py b/python/sglang/multimodal_gen/test/scripts/gen_diffusion_ci_outputs.py index 645a9cac5486..f36e803dd11e 100755 --- a/python/sglang/multimodal_gen/test/scripts/gen_diffusion_ci_outputs.py +++ b/python/sglang/multimodal_gen/test/scripts/gen_diffusion_ci_outputs.py @@ -17,7 +17,12 @@ from pathlib import Path from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger -from sglang.multimodal_gen.test.run_suite import SUITES, collect_test_items, run_pytest +from sglang.multimodal_gen.test.run_suite import ( + SUITES, + _maybe_pin_update_weights_model_pair, + collect_test_items, + run_pytest, +) logger = init_logger(__name__) @@ -95,6 +100,7 @@ def main(): # Get files from suite (same as run_suite.py) suite_files_rel = SUITES[args.suite] + _maybe_pin_update_weights_model_pair(suite_files_rel) suite_files_abs = [] for f_rel in suite_files_rel: f_abs = target_dir / f_rel diff --git a/python/sglang/multimodal_gen/test/server/consistency_threshold.json b/python/sglang/multimodal_gen/test/server/consistency_threshold.json index 3795a9f6e28a..596e98166ee0 100644 --- a/python/sglang/multimodal_gen/test/server/consistency_threshold.json +++ b/python/sglang/multimodal_gen/test/server/consistency_threshold.json @@ -49,12 +49,6 @@ "psnr_threshold": 19.0, "mean_abs_diff_threshold": 10.0 }, - "sana_image_t2i": { - "clip_threshold": 0.91, - "ssim_threshold": 0.88, - "psnr_threshold": 21.0, - "mean_abs_diff_threshold": 8.4 - }, "qwen_image_edit_2509_ti2i": { "clip_threshold": 0.92, "ssim_threshold": 0.65, diff --git a/python/sglang/multimodal_gen/test/server/test_server_common.py b/python/sglang/multimodal_gen/test/server/test_server_common.py index dd48e7e0c7b8..f8ac02c2c761 100644 --- a/python/sglang/multimodal_gen/test/server/test_server_common.py +++ b/python/sglang/multimodal_gen/test/server/test_server_common.py @@ -51,14 +51,6 @@ logger = init_logger(__name__) -def _is_lora_case(case: DiffusionTestCase) -> bool: - return bool( - case.server_args.lora_path - or case.server_args.dynamic_lora_path - or case.server_args.second_lora_path - ) - - @pytest.fixture def diffusion_server(case: DiffusionTestCase) -> ServerContext: """Start a diffusion server for a single case and tear it down afterwards.""" @@ -81,11 +73,6 @@ def diffusion_server(case: DiffusionTestCase) -> ServerContext: sampling_params = case.sampling_params extra_args = os.environ.get("SGLANG_TEST_SERVE_ARGS", "") - # Keep LoRA GT on the normal backend path so adapter state matches CI. - if os.environ.get("SGLANG_GEN_GT", "0") == "1": - if not _is_lora_case(case) and "--backend" not in extra_args: - extra_args = "--backend diffusers " + extra_args.strip() - extra_args += f" --num-gpus {server_args.num_gpus}" if server_args.tp_size is not None: @@ -235,18 +222,21 @@ def run_and_collect( ctx: ServerContext, case_id: str, generate_fn: Callable[[str, openai.Client], tuple[str, bytes]], - ) -> tuple[RequestPerfRecord, bytes]: - """Run generation and collect performance records. + collect_perf: bool = True, + ) -> tuple[RequestPerfRecord | None, bytes]: + """Run generation and optionally collect performance records. Returns: Tuple of (performance_record, content_bytes) """ - log_path = ctx.perf_log_path - log_wait_timeout = 30 - client = self._client(ctx) rid, content = generate_fn(case_id, client) + if not collect_perf: + return None, content + + log_path = ctx.perf_log_path + log_wait_timeout = 30 req_perf_record = wait_for_req_perf_record( rid, log_path, @@ -1024,6 +1014,7 @@ def test_diffusion_generation( diffusion_server, case.id, generate_fn, + collect_perf=not is_gt_gen_mode, ) if is_gt_gen_mode: diff --git a/python/sglang/multimodal_gen/test/server/testcase_configs.py b/python/sglang/multimodal_gen/test/server/testcase_configs.py index d879adce616b..e1c837691fb9 100644 --- a/python/sglang/multimodal_gen/test/server/testcase_configs.py +++ b/python/sglang/multimodal_gen/test/server/testcase_configs.py @@ -500,15 +500,6 @@ def from_req_perf_record( run_lora_dynamic_switch_check=True, run_multi_lora_api_check=True, ), - DiffusionTestCase( - "sana_image_t2i", - DiffusionServerArgs( - model_path="Efficient-Large-Model/Sana_600M_1024px_diffusers", - modality="image", - ), - T2I_sampling_params, - run_perf_check=False, - ), # === Text and Image to Image (TI2I) === DiffusionTestCase( "qwen_image_edit_ti2i", @@ -804,7 +795,6 @@ def from_req_perf_record( modality="image", ), T2I_sampling_params, - run_consistency_check=False, ) ] @@ -945,7 +935,6 @@ def from_req_perf_record( extras=["--pipeline-class-name LTX2TwoStagePipeline"], ), T2V_sampling_params, - run_consistency_check=False, ), ] From 5ae00ecd48b11f943d8ba319f3b0e828d8a41116 Mon Sep 17 00:00:00 2001 From: YAMY <74099316+YAMY1234@users.noreply.github.com> Date: Tue, 7 Apr 2026 08:47:31 -0700 Subject: [PATCH 26/42] [Disagg][NIXL] Support Mamba state slice transfer for heterogeneous TP (Step 2/2 for Qwen3.5) (#22240) --- python/sglang/srt/disaggregation/nixl/conn.py | 145 +++++++++++++++++- 1 file changed, 143 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/disaggregation/nixl/conn.py b/python/sglang/srt/disaggregation/nixl/conn.py index f84353f6dc12..005d5b05c286 100644 --- a/python/sglang/srt/disaggregation/nixl/conn.py +++ b/python/sglang/srt/disaggregation/nixl/conn.py @@ -84,6 +84,8 @@ class KVArgsRegisterInfo: decode_tp_size: int decode_tp_rank: int dst_kv_item_len: int + dst_state_item_lens: list[int] = dataclasses.field(default_factory=list) + dst_state_dim_per_tensor: list[int] = dataclasses.field(default_factory=list) @classmethod def from_zmq(cls, msg: List[bytes]): @@ -93,6 +95,15 @@ def from_zmq(cls, msg: List[bytes]): else: dst_state_data_ptrs = [] + dst_state_item_lens = [] + dst_state_dim_per_tensor = [] + if len(msg) > 12 and len(msg[12]) > 0: + dst_state_item_lens = list(struct.unpack(f"{len(msg[12]) // 4}I", msg[12])) + if len(msg) > 13 and len(msg[13]) > 0: + dst_state_dim_per_tensor = list( + struct.unpack(f"{len(msg[13]) // 4}I", msg[13]) + ) + return cls( room=str(msg[0].decode("ascii")), endpoint=msg[1].decode("ascii"), @@ -106,6 +117,8 @@ def from_zmq(cls, msg: List[bytes]): decode_tp_size=int(msg[9].decode("ascii")), decode_tp_rank=int(msg[10].decode("ascii")), dst_kv_item_len=int(msg[11].decode("ascii")), + dst_state_item_lens=dst_state_item_lens, + dst_state_dim_per_tensor=dst_state_dim_per_tensor, ) @@ -681,6 +694,106 @@ def _send_mamba_state( raise Exception("Failed to post Mamba state transfer") return xfer_handle + def _send_mamba_state_slice( + self, + peer_name: str, + prefill_state_indices: List[int], + dst_state_data_ptrs: list[int], + dst_state_indices: List[int], + dst_gpu_id: int, + notif: str, + dst_state_item_lens: list[int], + dst_state_dim_per_tensor: list[int], + decode_tp_size: int, + decode_tp_rank: int, + ): + """Transfer Mamba states with TP slice support via RDMA. + + When prefill and decode have different attn_tp_size, we slice the + TP-sharded dimension (3rd dim) of conv_state and temporal_state + accordingly, mirroring Mooncake's _send_mamba_state_slice. + """ + logger.warning_once( + "Using Mamba state slice transfer for different TP sizes. " + f"Prefill attn_tp_size={self.attn_tp_size}, " + f"Decode attn_tp_size={decode_tp_size}." + ) + assert len(prefill_state_indices) == 1, "Mamba should have single state index" + + prefill_state_data_ptrs = self.kv_args.state_data_ptrs + prefill_state_item_lens = self.kv_args.state_item_lens + src_state_dim_per_tensor = getattr(self.kv_args, "state_dim_per_tensor", []) + + if not src_state_dim_per_tensor or not dst_state_dim_per_tensor: + return self._send_mamba_state( + peer_name, + prefill_state_indices, + dst_state_data_ptrs, + dst_state_indices, + dst_gpu_id, + notif, + ) + + local_tp_rank_in_group = self.kv_args.engine_rank % self.attn_tp_size + dst_tp_rank_in_group = decode_tp_rank % decode_tp_size + + src_addrs = [] + dst_addrs = [] + + for i, dst_state_ptr in enumerate(dst_state_data_ptrs): + src_item_len = prefill_state_item_lens[i] + dst_item_len = dst_state_item_lens[i] + src_dim = src_state_dim_per_tensor[i] + dst_dim = dst_state_dim_per_tensor[i] + + src_bytes_per_dim = src_item_len // src_dim + dst_bytes_per_dim = dst_item_len // dst_dim + + if self.attn_tp_size > decode_tp_size: + src_dim_start = 0 + num_dims_to_send = src_dim + writers_per_decode = self.attn_tp_size // decode_tp_size + local_writer_idx = local_tp_rank_in_group % writers_per_decode + dst_dim_start = local_writer_idx * src_dim + else: + src_dim_start = (dst_tp_rank_in_group * dst_dim) % src_dim + num_dims_to_send = dst_dim + dst_dim_start = 0 + + src_dim_offset = src_dim_start * src_bytes_per_dim + dst_dim_offset = dst_dim_start * dst_bytes_per_dim + bytes_to_send = num_dims_to_send * src_bytes_per_dim + + src_addr = ( + prefill_state_data_ptrs[i] + + src_item_len * int(prefill_state_indices[0]) + + src_dim_offset + ) + dst_addr = ( + dst_state_ptr + + dst_item_len * int(dst_state_indices[0]) + + dst_dim_offset + ) + src_addrs.append((src_addr, bytes_to_send, self.kv_args.gpu_id)) + dst_addrs.append((dst_addr, bytes_to_send, dst_gpu_id)) + + src_descs = self.agent.get_xfer_descs(src_addrs, "VRAM") + dst_descs = self.agent.get_xfer_descs(dst_addrs, "VRAM") + + xfer_handle = self.agent.initialize_xfer( + "WRITE", + src_descs, + dst_descs, + peer_name, + notif.encode("ascii"), + ) + if not xfer_handle: + raise Exception("Failed to create Mamba state slice transfer") + state = self.agent.transfer(xfer_handle) + if state == "ERR": + raise Exception("Failed to post Mamba state slice transfer") + return xfer_handle + def maybe_send_extra( self, peer_name: str, @@ -690,14 +803,26 @@ def maybe_send_extra( dst_gpu_id: int, notif: str, decode_tp_size: int, + decode_tp_rank: int = 0, + dst_state_item_lens: list[int] | None = None, + dst_state_dim_per_tensor: list[int] | None = None, ): """Send state or extra pool data with type-specific handling.""" state_type = getattr(self.kv_args, "state_type", "none") if state_type == "mamba": if self.attn_tp_size != decode_tp_size: - raise RuntimeError( - "PD Disaggregation does NOT support PD different TP sizes for hybrid mamba models yet." + return self._send_mamba_state_slice( + peer_name, + prefill_state_indices, + dst_state_data_ptrs, + dst_state_indices, + dst_gpu_id, + notif, + dst_state_item_lens or [], + dst_state_dim_per_tensor or [], + decode_tp_size, + decode_tp_rank, ) return self._send_mamba_state( peer_name, @@ -803,6 +928,9 @@ def add_transfer_request( dst_info.gpu_id, f"{req.room}_state_{self.kv_args.engine_rank}", decode_tp_size, + decode_tp_rank=dst_info.decode_tp_rank, + dst_state_item_lens=dst_info.dst_state_item_lens, + dst_state_dim_per_tensor=dst_info.dst_state_dim_per_tensor, ) if state_xfer_handle is not None: handles.append(state_xfer_handle) @@ -1080,6 +1208,17 @@ def _register_kv_args(self): struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.state_data_ptrs ) + packed_state_item_lens = b"".join( + struct.pack("I", item_len) + for item_len in self.kv_mgr.kv_args.state_item_lens + ) + state_dim_per_tensor = getattr( + self.kv_mgr.kv_args, "state_dim_per_tensor", [] + ) + packed_state_dim_per_tensor = b"".join( + struct.pack("I", dim) for dim in state_dim_per_tensor + ) + with lock: sock.send_multipart( [ @@ -1096,6 +1235,8 @@ def _register_kv_args(self): str(self.kv_mgr.attn_tp_size).encode("ascii"), str(self.kv_mgr.kv_args.engine_rank).encode("ascii"), str(self.kv_mgr.kv_args.kv_item_lens[0]).encode("ascii"), + packed_state_item_lens, + packed_state_dim_per_tensor, ] ) From 727a182067f05f70924f24c259345693b761c7e6 Mon Sep 17 00:00:00 2001 From: Henson-Zh-Ali Date: Wed, 8 Apr 2026 00:17:26 +0800 Subject: [PATCH 27/42] [Mamba] eliminate D2H if tracking mamba states (#20522) Co-authored-by: hzh0425 --- .../attention/hybrid_linear_attn_backend.py | 11 +++++---- .../layers/attention/linear/gdn_backend.py | 23 +++++++++++++------ .../layers/attention/mamba/mamba2_metadata.py | 4 ++++ 3 files changed, 27 insertions(+), 11 deletions(-) diff --git a/python/sglang/srt/layers/attention/hybrid_linear_attn_backend.py b/python/sglang/srt/layers/attention/hybrid_linear_attn_backend.py index 91194c494396..5be0b7dc85fa 100644 --- a/python/sglang/srt/layers/attention/hybrid_linear_attn_backend.py +++ b/python/sglang/srt/layers/attention/hybrid_linear_attn_backend.py @@ -217,6 +217,11 @@ def _forward_metadata(self, forward_batch: ForwardBatch): else: raise ValueError(f"Invalid forward mode: {forward_batch.forward_mode=}") + has_mamba_track_mask = bool( + forward_batch.mamba_track_mask is not None + and forward_batch.mamba_track_mask.any() + ) + return ForwardMetadata( query_start_loc=query_start_loc, mamba_cache_indices=mamba_cache_indices, @@ -228,6 +233,7 @@ def _forward_metadata(self, forward_batch: ForwardBatch): track_ssm_h_dst=track_ssm_h_dst, track_ssm_final_src=track_ssm_final_src, track_ssm_final_dst=track_ssm_final_dst, + has_mamba_track_mask=has_mamba_track_mask, ) def init_forward_metadata(self, forward_batch: ForwardBatch): @@ -613,10 +619,7 @@ def _track_mamba_state_extend( Note: Conv state tracking for extend is handled separately via gather operations using indices computed by `_init_track_conv_indices`. """ - if ( - forward_batch.mamba_track_mask is not None - and forward_batch.mamba_track_mask.any() - ): + if forward_metadata.has_mamba_track_mask: h = h.squeeze(0) if forward_metadata.track_ssm_h_src.numel() > 0: diff --git a/python/sglang/srt/layers/attention/linear/gdn_backend.py b/python/sglang/srt/layers/attention/linear/gdn_backend.py index 700ccfdf6aa3..4dad415b1c31 100644 --- a/python/sglang/srt/layers/attention/linear/gdn_backend.py +++ b/python/sglang/srt/layers/attention/linear/gdn_backend.py @@ -256,6 +256,18 @@ def __init__(self, model_runner: ModelRunner): prefill_backend = get_linear_attn_prefill_backend() self.kernel_dispatcher = GDNKernelDispatcher(decode_backend, prefill_backend) + def init_forward_metadata(self, forward_batch: ForwardBatch): + super().init_forward_metadata(forward_batch) + if self.forward_metadata.has_mamba_track_mask: + self.forward_metadata.mamba_track_mask_indices = ( + forward_batch.mamba_track_mask.nonzero(as_tuple=True)[0] + ) + self.forward_metadata.conv_states_mask_indices = ( + forward_batch.mamba_track_indices[ + self.forward_metadata.mamba_track_mask_indices + ] + ) + def forward_decode( self, layer: RadixLinearAttention, @@ -394,16 +406,13 @@ def forward_extend( mixed_qkv = mixed_qkv_processed.transpose(1, 2).view(seq_len, -1) else: mixed_qkv = mixed_qkv.transpose(0, 1) - if ( - forward_batch.mamba_track_mask is not None - and forward_batch.mamba_track_mask.any() - ): - conv_dst = forward_batch.mamba_track_indices + if forward_metadata.has_mamba_track_mask: mixed_qkv_to_track = mixed_qkv[ :, forward_metadata.track_conv_indices ].transpose(0, 1) - mask_indices = forward_batch.mamba_track_mask.nonzero(as_tuple=True)[0] - conv_states[conv_dst[mask_indices]] = mixed_qkv_to_track + conv_states[forward_metadata.conv_states_mask_indices] = ( + mixed_qkv_to_track + ) mixed_qkv = causal_conv1d_fn( mixed_qkv, diff --git a/python/sglang/srt/layers/attention/mamba/mamba2_metadata.py b/python/sglang/srt/layers/attention/mamba/mamba2_metadata.py index 5eeb2b65e307..3c1548e9b552 100644 --- a/python/sglang/srt/layers/attention/mamba/mamba2_metadata.py +++ b/python/sglang/srt/layers/attention/mamba/mamba2_metadata.py @@ -41,6 +41,10 @@ class ForwardMetadata: is_target_verify: bool = False draft_token_num: int = 1 + has_mamba_track_mask: bool = False + mamba_track_mask_indices: Optional[torch.Tensor] = None + conv_states_mask_indices: Optional[torch.Tensor] = None + @dataclass(kw_only=True) class Mamba2Metadata(ForwardMetadata): From ec5742f4ab463d2316a6a3324f6e4552cfc94c15 Mon Sep 17 00:00:00 2001 From: shuwenn <47200617+alphabetc1@users.noreply.github.com> Date: Wed, 8 Apr 2026 00:19:31 +0800 Subject: [PATCH 28/42] fix: Auto-correct page_size for Mamba no_buffer radix cache mode (#20538) --- python/sglang/srt/server_args.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 74445c9cd1f2..f6757a1bddba 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -2204,6 +2204,13 @@ def _handle_mamba_radix_cache( == 0 ), f"For SSM models with extra buffer, either FLA_CHUNK_SIZE or page_size must be divisible by the other, got {FLA_CHUNK_SIZE=}, {self.page_size=}" elif not self.disable_radix_cache: # no_buffer + if self.page_size is not None and self.page_size != 1: + logger.warning( + f"{model_arch} with radix cache requires page_size=1 in the current " + f"Mamba scheduling mode (no_buffer), but got {self.page_size}. " + "Automatically setting page_size=1." + ) + self.page_size = 1 if self.speculative_algorithm is None: logger.warning( "Disabling overlap schedule since mamba no_buffer is not compatible with " From be42fbbbd74122a3f01b7adb2a61d38df7f0c937 Mon Sep 17 00:00:00 2001 From: Ke Bao Date: Wed, 8 Apr 2026 00:42:52 +0800 Subject: [PATCH 29/42] Support HTTP2 server (#21700) --- python/pyproject.toml | 6 + python/sglang/srt/entrypoints/http_server.py | 106 +++++++++++++++++ python/sglang/srt/environ.py | 3 + .../srt/managers/multi_tokenizer_mixin.py | 11 +- python/sglang/srt/server_args.py | 29 +++++ .../openai_server/basic/test_http2_server.py | 112 ++++++++++++++++++ 6 files changed, 266 insertions(+), 1 deletion(-) create mode 100644 test/registered/openai_server/basic/test_http2_server.py diff --git a/python/pyproject.toml b/python/pyproject.toml index 8e96b44afe3c..99c6e1b3cd02 100755 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -129,6 +129,10 @@ tracing = [ "opentelemetry-sdk", ] +http2 = [ + "granian>=2.6.0", +] + test = [ "accelerate", "addict", @@ -146,6 +150,7 @@ test = [ "diff-cover", "sentence_transformers", "tabulate", + "granian>=2.6.0", ] dev = ["sglang[test]"] @@ -153,6 +158,7 @@ dev = ["sglang[test]"] all = [ "sglang[diffusion]", "sglang[tracing]", + "sglang[http2]", ] [tool.uv.extra-build-dependencies] diff --git a/python/sglang/srt/entrypoints/http_server.py b/python/sglang/srt/entrypoints/http_server.py index 61f7513f567f..43afc1577a56 100644 --- a/python/sglang/srt/entrypoints/http_server.py +++ b/python/sglang/srt/entrypoints/http_server.py @@ -206,6 +206,32 @@ def get_global_state() -> _GlobalState: return _global_state +async def _init_granian_worker() -> ServerArgs: + main_pid = get_main_process_id() + port_args, server_args, scheduler_info = read_from_shared_memory( + f"multi_tokenizer_args_{main_pid}" + ) + + tokenizer_manager = TokenizerManager(server_args, port_args) + template_manager = TemplateManager() + template_manager.initialize_templates( + tokenizer_manager=tokenizer_manager, + model_path=server_args.model_path, + chat_template=server_args.chat_template, + completion_template=server_args.completion_template, + ) + tokenizer_manager.max_req_input_len = scheduler_info["max_req_input_len"] + + set_global_state( + _GlobalState( + tokenizer_manager=tokenizer_manager, + template_manager=template_manager, + scheduler_info=scheduler_info, + ) + ) + return server_args + + async def init_multi_tokenizer() -> ServerArgs: """ Initialization function for multi-process tokenizer mode. @@ -263,6 +289,10 @@ async def lifespan(fast_api_app: FastAPI): server_args = fast_api_app.server_args warmup_thread_kwargs = fast_api_app.warmup_thread_kwargs thread_label = "Tokenizer" + elif envs.SGLANG_GRANIAN_PARENT_PID.get() is not None: + server_args = await _init_granian_worker() + warmup_thread_kwargs = dict(server_args=server_args) + thread_label = "Tokenizer" else: # Initialize multi-tokenizer support for worker processes server_args = await init_multi_tokenizer() @@ -2017,6 +2047,53 @@ def _wait_weights_ready(): ) +def _close_main_process_sockets(): + """Close the main process's ZMQ sockets before spawning Granian workers. + + Granian workers create their own TokenizerManager with fresh ZMQ sockets. + The main process must release its sockets first to avoid binding conflicts + on the same IPC addresses. + """ + if _global_state is None or _global_state.tokenizer_manager is None: + return + tm = _global_state.tokenizer_manager + for attr in ("recv_from_detokenizer", "send_to_scheduler"): + sock = getattr(tm, attr, None) + if sock is None: + continue + inner = getattr(sock, "socket", None) + if inner is not None: + inner.close() + elif hasattr(sock, "close"): + sock.close() + setattr(tm, attr, None) + + +def _run_granian_server(server_args: ServerArgs): + """Launch Granian with HTTP/2 support""" + from granian import Granian + from granian.constants import HTTPModes, Interfaces, Loops + + granian_kwargs = dict( + target="sglang.srt.entrypoints.http_server:app", + address=server_args.host, + port=server_args.port, + interface=Interfaces.ASGI, + http=HTTPModes.auto, + loop=Loops.uvloop, + log_level=server_args.log_level_http or server_args.log_level or "info", + workers=1, + ) + + ssl_enabled = server_args.ssl_certfile and server_args.ssl_keyfile + if ssl_enabled: + granian_kwargs["ssl_cert"] = server_args.ssl_certfile + granian_kwargs["ssl_key"] = server_args.ssl_keyfile + + server = Granian(**granian_kwargs) + server.serve() + + def _setup_and_run_http_server( server_args: ServerArgs, tokenizer_manager, @@ -2047,6 +2124,35 @@ def _setup_and_run_http_server( if server_args.enable_metrics: add_prometheus_track_response_middleware(app) + # Use Granian for HTTP/2 server + if server_args.enable_http2: + # Reuse the multi-tokenizer shared memory mechanism to pass + # init args (port_args, server_args, scheduler_info) to + # Granian workers, which are independent processes. + multi_tokenizer_args_shm = write_data_for_multi_tokenizer( + port_args, server_args, scheduler_infos[0] + ) + try: + if server_args.ssl_certfile: + logger.info( + f"SSL enabled: certfile={server_args.ssl_certfile}, " + f"keyfile={server_args.ssl_keyfile}" + ) + logger.info( + f"Starting Granian HTTP/2 server on " + f"{server_args.host}:{server_args.port}" + ) + # Propagate the main process PID via os.environ so Granian + # workers (forked or spawned) can locate the shared memory + # segment created above. + envs.SGLANG_GRANIAN_PARENT_PID.set(os.getpid()) + _close_main_process_sockets() + _run_granian_server(server_args) + finally: + if multi_tokenizer_args_shm is not None: + multi_tokenizer_args_shm.unlink() + return + # Pass additional arguments to the lifespan function. # They will be used for additional initialization setups. if server_args.tokenizer_worker_num == 1: diff --git a/python/sglang/srt/environ.py b/python/sglang/srt/environ.py index dfc5507de0ba..e6eb48212b2a 100644 --- a/python/sglang/srt/environ.py +++ b/python/sglang/srt/environ.py @@ -489,6 +489,9 @@ class Envs: # HTTP Server SGLANG_TIMEOUT_KEEP_ALIVE = EnvInt(5) + # HTTP/2 Server + SGLANG_GRANIAN_PARENT_PID = EnvInt(None) + # Health Check SGLANG_ENABLE_HEALTH_ENDPOINT_GENERATION = EnvBool(True) diff --git a/python/sglang/srt/managers/multi_tokenizer_mixin.py b/python/sglang/srt/managers/multi_tokenizer_mixin.py index e0a1669fb3e6..8da3d3b0de60 100644 --- a/python/sglang/srt/managers/multi_tokenizer_mixin.py +++ b/python/sglang/srt/managers/multi_tokenizer_mixin.py @@ -433,7 +433,16 @@ async def print_exception_wrapper(func): def get_main_process_id() -> int: - """Get the main process ID""" + """Get the main process ID. + + Supports override via SGLANG_GRANIAN_PARENT_PID for workers whose + multiprocessing parent PID differs from the shared-memory owner. + """ + from sglang.srt.environ import envs + + override = envs.SGLANG_GRANIAN_PARENT_PID.get() + if override is not None: + return override return multiprocessing.current_process()._parent_pid diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index f6757a1bddba..86de1b32713c 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -326,6 +326,7 @@ class ServerArgs: ssl_ca_certs: Optional[str] = None ssl_keyfile_password: Optional[str] = None enable_ssl_refresh: bool = False + enable_http2: bool = False # Quantization and data type dtype: str = "auto" @@ -923,6 +924,26 @@ def _handle_ssl_validation(self): "to be specified." ) + if self.enable_http2: + try: + import granian # noqa: F401 + except ImportError: + raise ValueError( + "--enable-http2 requires the 'granian' package. " + 'Install it with: pip install "sglang[http2]"' + ) + if self.enable_ssl_refresh: + raise ValueError( + "--enable-ssl-refresh is not supported with --enable-http2. " + "Granian does not support SSL certificate hot-reloading. " + "Use Uvicorn (the default) or handle certificate rotation externally." + ) + if self.tokenizer_worker_num > 1: + raise ValueError( + "--enable-http2 does not yet support --tokenizer-worker-num > 1. " + "Multi-worker HTTP/2 support will be added in a future release." + ) + def _handle_deprecated_args(self): # Handle deprecated tool call parsers deprecated_tool_call_parsers = {"qwen25": "qwen", "glm45": "glm"} @@ -3865,6 +3886,14 @@ def add_cli_args(parser: argparse.ArgumentParser): help="Enable automatic SSL certificate hot-reloading when cert/key " "files change on disk. Requires --ssl-certfile and --ssl-keyfile.", ) + parser.add_argument( + "--enable-http2", + action="store_true", + default=ServerArgs.enable_http2, + help="Use Granian instead of Uvicorn as the ASGI server, enabling HTTP/1.1 and " + "HTTP/2 auto-negotiation. Clients may use h2c (cleartext HTTP/2) or plain HTTP/1.1. " + "Requires 'pip install sglang[http2]'.", + ) # Quantization and data type parser.add_argument( diff --git a/test/registered/openai_server/basic/test_http2_server.py b/test/registered/openai_server/basic/test_http2_server.py new file mode 100644 index 000000000000..6cfc3ee7e7e0 --- /dev/null +++ b/test/registered/openai_server/basic/test_http2_server.py @@ -0,0 +1,112 @@ +""" +Test HTTP/2 server (Granian) with basic OpenAI-compatible endpoints. + +Verifies that --enable-http2 launches successfully and serves requests +via both HTTP/1.1 and HTTP/2 (h2c). +""" + +import subprocess +import unittest + +import requests + +from sglang.srt.utils import kill_process_tree +from sglang.test.ci.ci_register import register_cuda_ci +from sglang.test.test_utils import ( + DEFAULT_SMALL_MODEL_NAME_FOR_TEST, + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, + CustomTestCase, + popen_launch_server, +) + +try: + import granian # noqa: F401 + + _HAS_GRANIAN = True +except ImportError: + _HAS_GRANIAN = False + +register_cuda_ci(est_time=120, suite="stage-b-test-1-gpu-small") + + +@unittest.skipUnless(_HAS_GRANIAN, "granian not installed (pip install sglang[http2])") +class TestHTTP2Server(CustomTestCase): + @classmethod + def setUpClass(cls): + cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST + cls.base_url = DEFAULT_URL_FOR_TEST + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=["--enable-http2"], + ) + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.process.pid) + + def test_health(self): + resp = requests.get(f"{self.base_url}/health") + self.assertEqual(resp.status_code, 200) + + def test_get_model_info(self): + resp = requests.get(f"{self.base_url}/get_model_info") + self.assertEqual(resp.status_code, 200) + self.assertIn("model_path", resp.json()) + + def test_completion(self): + resp = requests.post( + f"{self.base_url}/v1/completions", + json={ + "model": self.model, + "prompt": "The capital of France is", + "max_tokens": 8, + "temperature": 0, + }, + ) + self.assertEqual(resp.status_code, 200) + data = resp.json() + self.assertIn("choices", data) + self.assertGreater(len(data["choices"][0]["text"]), 0) + + def test_chat_completion(self): + resp = requests.post( + f"{self.base_url}/v1/chat/completions", + json={ + "model": self.model, + "messages": [{"role": "user", "content": "Say hello"}], + "max_tokens": 16, + "temperature": 0, + }, + ) + self.assertEqual(resp.status_code, 200) + data = resp.json() + self.assertIn("choices", data) + self.assertGreater(len(data["choices"][0]["message"]["content"]), 0) + + def test_h2c_with_curl(self): + """Verify the server actually speaks HTTP/2 via h2c.""" + result = subprocess.run( + [ + "curl", + "--http2-prior-knowledge", + "-s", + "-o", + "/dev/null", + "-w", + "%{http_version}", + f"{self.base_url}/health", + ], + capture_output=True, + text=True, + timeout=10, + ) + self.assertEqual( + result.stdout.strip(), "2", "Server should respond with HTTP/2" + ) + + +if __name__ == "__main__": + unittest.main(verbosity=3) From 6131fb58827309d805f37dc9abe011e8bce3f741 Mon Sep 17 00:00:00 2001 From: khalilzhk Date: Wed, 8 Apr 2026 00:49:16 +0800 Subject: [PATCH 30/42] [NPU] enable mla prepare fused kernel only when being mla attn (#22024) --- .../srt/hardware_backend/npu/attention/ascend_backend.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/hardware_backend/npu/attention/ascend_backend.py b/python/sglang/srt/hardware_backend/npu/attention/ascend_backend.py index d93cc5f62f1c..e7acdee86e90 100644 --- a/python/sglang/srt/hardware_backend/npu/attention/ascend_backend.py +++ b/python/sglang/srt/hardware_backend/npu/attention/ascend_backend.py @@ -860,7 +860,7 @@ def forward_extend( sinks: Optional[torch.Tensor] = None, slopes: Optional[torch.Tensor] = None, ): - if is_mla_preprocess_enabled(): + if is_mla_preprocess_enabled() and self.use_mla: # MLAPO and MLAPROLOG do save kv_cache save_kv_cache = False if self.is_dllm_model: @@ -1773,7 +1773,7 @@ def forward_decode( sinks: Optional[torch.Tensor] = None, slopes: Optional[torch.Tensor] = None, ): - if is_mla_preprocess_enabled(): + if is_mla_preprocess_enabled() and self.use_mla: # MLAPO does saving kv_cache save_kv_cache = False if topk_indices is not None: From 0c204fbd57a0deb12e0736b9ba783ac1d24ece85 Mon Sep 17 00:00:00 2001 From: huangtingwei <141888744+huangtingwei9988@users.noreply.github.com> Date: Wed, 8 Apr 2026 01:34:58 +0800 Subject: [PATCH 31/42] [HiSparse] Optimize the scheduling of decode backup. (#21932) Co-authored-by: hzh0425 Co-authored-by: Zhiqiang Xie --- .../srt/managers/hisparse_coordinator.py | 45 +++++++++++++++---- .../sglang/srt/model_executor/model_runner.py | 6 +++ 2 files changed, 42 insertions(+), 9 deletions(-) diff --git a/python/sglang/srt/managers/hisparse_coordinator.py b/python/sglang/srt/managers/hisparse_coordinator.py index 89740f73682e..9336571976ba 100644 --- a/python/sglang/srt/managers/hisparse_coordinator.py +++ b/python/sglang/srt/managers/hisparse_coordinator.py @@ -78,8 +78,11 @@ def __init__( ) self.write_staging_stream = device_module.Stream() + self.decode_backup_stream = device_module.Stream() self.ack_staging_queue: List[HiSparseAct] = [] self.decode_producer_stream = None + self._backup_done_event = device_module.Event() + self._has_pending_backup = False self.tp_group = tp_group self.tp_world_size = torch.distributed.get_world_size(group=self.tp_group) @@ -391,9 +394,6 @@ def _eager_backup_previous_token( The only exception is the first decode step right after staging: all prefill tokens were already backed up during staging, so there is nothing new to save yet. """ - if self.decode_producer_stream is not None: - device_module.current_stream().wait_stream(self.decode_producer_stream) - # Build the list of batch positions that need a host backup. # Skip the first decode step after staging (prefill already backed up). backup_indices = [] @@ -431,12 +431,36 @@ def _eager_backup_previous_token( host_locs = host_locs.to(device=self.device) self.req_to_host_pool[backup_req_indices, actual_token_pos] = host_locs - self.mem_pool_host.backup_from_device_all_layer( - self.mem_pool_device, - host_locs, - device_locs.contiguous(), - io_backend="kernel", - ) + if self._has_pending_backup: + self._backup_done_event.wait(device_module.current_stream()) + self._has_pending_backup = False + schedule_stream = device_module.current_stream() + with device_module.stream(self.decode_backup_stream): + self.decode_backup_stream.wait_stream(schedule_stream) + if self.decode_producer_stream is not None: + self.decode_backup_stream.wait_stream(self.decode_producer_stream) + self.mem_pool_host.backup_from_device_all_layer( + self.mem_pool_device, + host_locs, + device_locs, + io_backend="kernel", + ) + self._backup_done_event.record() + if host_locs.is_cuda: + host_locs.record_stream(self.decode_backup_stream) + if backup_req_indices.is_cuda: + backup_req_indices.record_stream(self.decode_backup_stream) + if actual_token_pos.is_cuda: + actual_token_pos.record_stream(self.decode_backup_stream) + if device_locs.is_cuda: + device_locs.record_stream(self.decode_backup_stream) + self._has_pending_backup = True + + def wait_for_pending_backup(self) -> None: + if not self._has_pending_backup: + return + self._backup_done_event.wait(device_module.current_stream()) + self._has_pending_backup = False def get_front_topk_tokens( self, @@ -569,6 +593,9 @@ def request_finished(self, req: Req): # release resources only after the execution of a potential overlapped batch if self.decode_producer_stream is not None: device_module.current_stream().wait_stream(self.decode_producer_stream) + if self._has_pending_backup: + self._backup_done_event.wait(device_module.current_stream()) + self._has_pending_backup = False # release memory — only free actually-allocated buffer indices current_cap = int(self.req_device_buffer_size[req.req_pool_idx]) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 669cab133c49..26d8bd82a2f1 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -2817,6 +2817,12 @@ def _forward_raw( and self.graph_runner.can_run(forward_batch) ) + if ( + self.hisparse_coordinator is not None + and forward_batch.forward_mode.is_decode() + ): + self.hisparse_coordinator.wait_for_pending_backup() + if can_run_graph: ret = self.graph_runner.replay( forward_batch, From 1a8eb890f625c1dadf85c7376c94a7de311f0088 Mon Sep 17 00:00:00 2001 From: Rain Jiang <96632942+rainj-me@users.noreply.github.com> Date: Tue, 7 Apr 2026 12:48:44 -0700 Subject: [PATCH 32/42] Kernels community fa3 (#20796) --- .gitignore | 1 + docker/Dockerfile | 7 + docs/references/environment_variables.md | 2 + python/pyproject.toml | 4 + python/sglang/jit_kernel/flash_attention.py | 286 ++++ .../sglang/jit_kernel/flash_attention_v3.py | 222 +++ .../sglang/jit_kernel/flash_attention_v4.py | 1 - .../tests/test_flash_attention_3.py | 1373 +++++++++++++++++ .../tests/test_flash_attention_4.py | 4 +- .../layers/attention/backends/flash_attn.py | 22 +- python/sglang/srt/compilation/backend.py | 5 +- python/sglang/srt/environ.py | 6 + .../dual_chunk_flashattention_backend.py | 5 +- .../attention/flashattention_backend.py | 25 +- .../srt/layers/attention/nsa_backend.py | 5 +- python/sglang/srt/layers/attention/vision.py | 24 +- .../srt/layers/attention/xpu_backend.py | 6 +- python/sglang/srt/utils/runai_utils.py | 10 +- scripts/ci/cuda/ci_install_dependency.sh | 4 + test/srt/cpu/test_flash_attn.py | 5 +- 20 files changed, 1956 insertions(+), 61 deletions(-) create mode 100644 python/sglang/jit_kernel/flash_attention.py create mode 100644 python/sglang/jit_kernel/flash_attention_v3.py create mode 100644 python/sglang/jit_kernel/tests/test_flash_attention_3.py diff --git a/.gitignore b/.gitignore index b5917c299ecf..a8aa903e28f7 100644 --- a/.gitignore +++ b/.gitignore @@ -258,6 +258,7 @@ inputs/ # setuptools-scm generated version file python/sglang/_version.py +python/kernel.lock # MUSA section # Generated source files by torchada diff --git a/docker/Dockerfile b/docker/Dockerfile index d7f4ead4579c..57842c53564b 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -219,6 +219,9 @@ RUN --mount=type=cache,target=/root/.cache/pip \ python3 -m pip install flashinfer-jit-cache==${FLASHINFER_VERSION} --index-url https://flashinfer.ai/whl/cu${CUINDEX} ; \ fi \ && FLASHINFER_CUBIN_DOWNLOAD_THREADS=${BUILD_AND_DOWNLOAD_PARALLEL} FLASHINFER_LOGGING_LEVEL=warning python3 -m flashinfer --download-cubin + && kernels download python + && kernels lock python + && mv python/kernels.lock /root/.cache/sglang # DeepEP # We use Tom's DeepEP fork for GB200 for now; the 1fd57b0276311d035d16176bb0076426166e52f3 commit is https://github.com/fzyzcjy/DeepEP/tree/gb200_blog_part_2 @@ -561,6 +564,10 @@ COPY --from=framework /usr/local/lib/python3.12/dist-packages /usr/local/lib/pyt # Copy SGLang workspace COPY --from=framework /sgl-workspace /sgl-workspace +# Copy cache for kernels from kernels community +COPY --from=framework /root/.cache/huggingface /root/.cache/huggingface +COPY --from=framework /root/.cache/sglang /root/.cache/sglang + # Fix Triton to use system ptxas for Blackwell (sm_103a) support (CUDA 13+ only) RUN if [ "${CUDA_VERSION%%.*}" = "13" ] && [ -d /usr/local/lib/python3.12/dist-packages/triton/backends/nvidia/bin ]; then \ rm -f /usr/local/lib/python3.12/dist-packages/triton/backends/nvidia/bin/ptxas && \ diff --git a/docs/references/environment_variables.md b/docs/references/environment_variables.md index e2e93b177b9c..b7ac94a71245 100644 --- a/docs/references/environment_variables.md +++ b/docs/references/environment_variables.md @@ -19,6 +19,7 @@ SGLang supports various environment variables that can be used to configure its | `SGLANG_FORWARD_UNKNOWN_TOOLS` | Forward unknown tool calls to clients instead of dropping them | `false` (drop unknown tools) | | `SGLANG_REQ_WAITING_TIMEOUT` | Timeout (in seconds) for requests waiting in the queue before being scheduled | `-1` | | `SGLANG_REQ_RUNNING_TIMEOUT` | Timeout (in seconds) for requests running in the decode batch | `-1` | +| `SGLANG_CACHE_DIR` | Cache directory for model weights and other data | `~/.cache/sglang` | ## Performance Tuning @@ -47,6 +48,7 @@ SGLang supports various environment variables that can be used to configure its | `SGLANG_CUSTOM_ALLREDUCE_ALGO` | The algorithm of custom all-reduce. Set to `oneshot` or `1stage` to force use one-shot. Set to `twoshot` or `2stage` to force use two-shot. | `` | | `SGLANG_SKIP_SOFTMAX_PREFILL_THRESHOLD_SCALE_FACTOR` | Skip-softmax threshold scale factor for TRT-LLM prefill attention in flashinfer. `None` means standard attention. See https://arxiv.org/abs/2512.12087 | `None` | | `SGLANG_SKIP_SOFTMAX_DECODE_THRESHOLD_SCALE_FACTOR` | Skip-softmax threshold scale factor for TRT-LLM decode attention in flashinfer. `None` means standard attention. See https://arxiv.org/abs/2512.12087 | `None` | +| `SGLANG_USE_SGL_FA3_KERNEL` | Use sgl-kernel implementation for FlashAttention v3 | `true` | ## DeepGEMM Configuration (Advanced Optimization) diff --git a/python/pyproject.toml b/python/pyproject.toml index 99c6e1b3cd02..9fab1de2e1e2 100755 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -77,6 +77,7 @@ dependencies = [ "watchfiles", "xgrammar==0.1.32", "smg-grpc-servicer>=0.5.0", + "kernels", ] [[tool.uv.index]] @@ -207,3 +208,6 @@ version_file = "sglang/_version.py" git_describe_command = ["python3", "python/tools/get_version_tag.py", "--tag-only"] # Allow editable installs even when .git metadata is not available. fallback_version = "0.0.0.dev0" + +[tool.kernels.dependencies] +"kernels-community/sgl-flash-attn3" = 1 diff --git a/python/sglang/jit_kernel/flash_attention.py b/python/sglang/jit_kernel/flash_attention.py new file mode 100644 index 000000000000..633863d0a648 --- /dev/null +++ b/python/sglang/jit_kernel/flash_attention.py @@ -0,0 +1,286 @@ +from typing import Optional, Union + +import torch + +from .flash_attention_v3 import flash_attn_varlen_func as fa3_flash_attn_varlen_func +from .flash_attention_v3 import flash_attn_with_kvcache as fa3_flash_attn_with_kvcache +from .flash_attention_v4 import flash_attn_varlen_func as fa4_flash_attn_varlen_func +from .flash_attention_v4 import flash_attn_with_kvcache as fa4_flash_attn_with_kvcache + + +def flash_attn_with_kvcache( + q, + k_cache, + v_cache, + k=None, + v=None, + qv=None, + rotary_cos=None, + rotary_sin=None, + cache_seqlens: Optional[Union[int, torch.Tensor]] = None, + cache_batch_idx: Optional[torch.Tensor] = None, + cache_leftpad: Optional[torch.Tensor] = None, + page_table: Optional[torch.Tensor] = None, + cu_seqlens_q: Optional[torch.Tensor] = None, + cu_seqlens_k_new: Optional[torch.Tensor] = None, + max_seqlen_q: Optional[int] = None, + rotary_seqlens: Optional[torch.Tensor] = None, + q_descale: Optional[torch.Tensor] = None, + k_descale: Optional[torch.Tensor] = None, + v_descale: Optional[torch.Tensor] = None, + softmax_scale=None, + causal=False, + window_size=(-1, -1), # -1 means infinite context window + attention_chunk: Optional[int] = None, + softcap=0.0, # 0.0 means deactivated + rotary_interleaved=True, + scheduler_metadata=None, + num_splits=0, # Can be tuned for speed + pack_gqa=None, # Can be tuned for speed + sm_margin=0, # Can be tuned if some SMs are used for communication + return_softmax_lse=False, + sinks=None, + score_mod=None, + aux_tensors=None, + ver=3, +): + """ + If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from + k and v. This is useful for incremental decoding: you can pass in the cached keys/values from + the previous step, and update them with the new keys/values from the current step, and do + attention with the updated cache, all in 1 kernel. + + If you pass in k / v, you must make sure that the cache is large enough to hold the new values. + For example, the KV cache could be pre-allocated with the max sequence length, and you can use + cache_seqlens to keep track of the current sequence lengths of each sequence in the batch. + + Also apply rotary embedding if rotary_cos and rotary_sin are passed in. The key @k will be + rotated by rotary_cos and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc. + If causal or local (i.e., window_size != (-1, -1)), the query @q will be rotated by rotary_cos + and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc. + If not causal and not local, the query @q will be rotated by rotary_cos and rotary_sin at + indices cache_seqlens only (i.e. we consider all tokens in @q to be at position cache_seqlens). + + See tests/test_flash_attn.py::test_flash_attn_kvcache for examples of how to use this function. + + Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads + than Q. Note that the number of heads in Q must be divisible by the number of heads in KV. + For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head + 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V. + + If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix. + For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is: + 1 1 1 1 0 + 1 1 1 1 1 + If seqlen_q = 5 and seqlen_k = 2, the causal mask is: + 0 0 + 0 0 + 0 0 + 1 0 + 1 1 + If the row of the mask is all zero, the output will be zero. + + If window_size != (-1, -1), implements sliding window local attention. Query at position i + will only attend to keys between + [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive. + + Note: Does not support backward pass. + + Arguments: + q: (batch_size, seqlen, nheads, headdim) + k_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) if there's no page_table, + or (num_blocks, page_block_size, nheads_k, headdim) if there's a page_table (i.e. paged KV cache) + page_block_size must be a multiple of 256. + v_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim_v) if there's no page_table, + or (num_blocks, page_block_size, nheads_k, headdim_v) if there's a page_table (i.e. paged KV cache) + k [optional]: (batch_size, seqlen_new, nheads_k, headdim). If not None, we concatenate + k with k_cache, starting at the indices specified by cache_seqlens. + v [optional]: (batch_size, seqlen_new, nheads_k, headdim_v). Similar to k. + qv [optional]: (batch_size, seqlen, nheads, headdim_v) + rotary_cos [optional]: (seqlen_ro, rotary_dim / 2). If not None, we apply rotary embedding + to k and q. Only applicable if k and v are passed in. rotary_dim must be divisible by 16. + rotary_sin [optional]: (seqlen_ro, rotary_dim / 2). Similar to rotary_cos. + cache_seqlens: int, or (batch_size,), dtype torch.int32. The sequence lengths of the + KV cache. + cache_batch_idx: (batch_size,), dtype torch.int32. The indices used to index into the KV cache. + If None, we assume that the batch indices are [0, 1, 2, ..., batch_size - 1]. + If the indices are not distinct, and k and v are provided, the values updated in the cache + might come from any of the duplicate indices. + cache_leftpad: (batch_size,), dtype torch.int32. The index that the KV cache starts. If None, assume 0. + page_table [optional]: (batch_size, max_num_blocks_per_seq), dtype torch.int32. + softmax_scale: float. The scaling of QK^T before applying softmax. + Default to 1 / sqrt(headdim). + causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). + window_size: (left, right). If not (-1, -1), implements sliding window local attention. + attention_chunk: Optional[int]. If not None, splits the query into chunks of this size to save memory. + softcap: float. Anything > 0 activates softcapping attention. + rotary_interleaved: bool. Only applicable if rotary_cos and rotary_sin are passed in. + If True, rotary embedding will combine dimensions 0 & 1, 2 & 3, etc. If False, + rotary embedding will combine dimensions 0 & rotary_dim / 2, 1 & rotary_dim / 2 + 1 + (i.e. GPT-NeoX style). + num_splits: int. If > 1, split the key/value into this many chunks along the sequence. + If num_splits == 1, we don't split the key/value. If num_splits == 0, we use a heuristic + to automatically determine the number of splits. + Don't change this unless you know what you are doing. + return_softmax_lse: bool. Whether to return the logsumexp of the attention scores. + score_mod [optional]: A callable that takes the attention scores and applies a modification. + aux_tensors [optional]: Some score_mods will want to read from global aux_tensors. This is how we thread them through to the inner kernel. + + Return: + out: (batch_size, seqlen, nheads, headdim). + softmax_lse [optional, if return_softmax_lse=True]: (batch_size, nheads, seqlen). The + logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax + normalization factor). + """ + + if ver == 3: + return fa3_flash_attn_with_kvcache( + q, + k_cache, + v_cache, + k=k, + v=v, + qv=qv, + rotary_cos=rotary_cos, + rotary_sin=rotary_sin, + cache_seqlens=cache_seqlens, + cache_batch_idx=cache_batch_idx, + cache_leftpad=cache_leftpad, + page_table=page_table, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k_new=cu_seqlens_k_new, + max_seqlen_q=max_seqlen_q, + rotary_seqlens=rotary_seqlens, + q_descale=q_descale, + k_descale=k_descale, + v_descale=v_descale, + softmax_scale=softmax_scale, + causal=causal, + window_size=window_size, + attention_chunk=attention_chunk, + softcap=softcap, + rotary_interleaved=rotary_interleaved, + scheduler_metadata=scheduler_metadata, + num_splits=num_splits, + pack_gqa=pack_gqa, + sm_margin=sm_margin, + return_softmax_lse=return_softmax_lse, + sinks=sinks, + ) + elif ver == 4: + return fa4_flash_attn_with_kvcache( + q, + k_cache, + v_cache, + k=k, + v=v, + qv=qv, + rotary_cos=rotary_cos, + rotary_sin=rotary_sin, + cache_seqlens=cache_seqlens, + cache_batch_idx=cache_batch_idx, + cache_leftpad=cache_leftpad, + page_table=page_table, + cu_seqlens_q=cu_seqlens_q, + max_seqlen_q=max_seqlen_q, + rotary_seqlens=rotary_seqlens, + q_descale=q_descale, + k_descale=k_descale, + v_descale=v_descale, + softmax_scale=softmax_scale, + causal=causal, + window_size=window_size, + softcap=softcap, + num_splits=num_splits, + pack_gqa=pack_gqa, + sinks=sinks, + score_mod=score_mod, + aux_tensors=aux_tensors, + return_softmax_lse=return_softmax_lse, + ) + else: + raise RuntimeError(f"Unknown flash attention version {ver}") + + +def flash_attn_varlen_func( + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q=None, + max_seqlen_k=None, + seqused_q=None, + seqused_k=None, + page_table=None, + softmax_scale=None, + causal=False, + qv=None, + q_descale=None, + k_descale=None, + v_descale=None, + window_size=(-1, -1), + attention_chunk=0, + softcap=0.0, + num_splits=1, + pack_gqa=None, + sm_margin=0, + return_softmax_lse=False, + sinks=None, + score_mod=None, + aux_tensors=None, + ver=3, +): + + if ver == 3: + return fa3_flash_attn_varlen_func( + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q=max_seqlen_q, + max_seqlen_k=max_seqlen_k, + seqused_q=seqused_q, + seqused_k=seqused_k, + page_table=page_table, + softmax_scale=softmax_scale, + causal=causal, + qv=qv, + q_descale=q_descale, + k_descale=k_descale, + v_descale=v_descale, + window_size=window_size, + attention_chunk=attention_chunk, + softcap=softcap, + num_splits=num_splits, + pack_gqa=pack_gqa, + sm_margin=sm_margin, + return_softmax_lse=return_softmax_lse, + sinks=sinks, + ) + elif ver == 4: + return fa4_flash_attn_varlen_func( + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q=max_seqlen_q, + max_seqlen_k=max_seqlen_k, + seqused_q=seqused_q, + seqused_k=seqused_k, + page_table=page_table, + softmax_scale=softmax_scale, + causal=causal, + softcap=softcap, + window_size=window_size, + sinks=sinks, + num_splits=num_splits, + pack_gqa=pack_gqa, + score_mod=score_mod, + aux_tensors=aux_tensors, + return_softmax_lse=return_softmax_lse, + ) + else: + raise RuntimeError(f"Unknown flash attention version {ver}") diff --git a/python/sglang/jit_kernel/flash_attention_v3.py b/python/sglang/jit_kernel/flash_attention_v3.py new file mode 100644 index 000000000000..23018961d998 --- /dev/null +++ b/python/sglang/jit_kernel/flash_attention_v3.py @@ -0,0 +1,222 @@ +import logging +import os +from typing import Optional, Union + +import torch + +from sglang.jit_kernel.utils import cache_once +from sglang.kernel_api_logging import debug_kernel_api +from sglang.srt.environ import envs + +logger = logging.getLogger(__name__) + +SGL_FA3_KERNEL_REPO = "kernels-community/sgl-flash-attn3" +SGL_FA3_KERNEL_REVISION = "v1" +DEFAULT_FA3_KERNEL_LOCKFILE = "kernels.lock" + + +@cache_once +def _load_fa3_kernels(): + # By default, we use the implementation from sgl-kernel, + # which is expected to be more stable and compatible + if envs.SGLANG_USE_SGL_FA3_KERNEL.get(): + logger.debug( + f"SGLANG_USE_SGL_FA3_KERNEL=True, use sgl-kernel implementation for FlashAttention v3 " + ) + return _load_fa3_kernel_from_sgl() + + # Otherwise, we try to load the kernels from the kernels community cache directory or kernels community repo + lockfile_path = os.path.join( + envs.SGLANG_CACHE_DIR.get(), DEFAULT_FA3_KERNEL_LOCKFILE + ) + + try: + from kernels import get_kernel, load_kernel + + # When the lock file provided, load from the kernel cache directory, + # otherwise, load from the repo, which require download from huggingface hub + # but always works as long as the repo is accessible. + if os.path.exists(lockfile_path): + ops = load_kernel(SGL_FA3_KERNEL_REPO, lockfile_path) + else: + ops = get_kernel(SGL_FA3_KERNEL_REPO, revision=SGL_FA3_KERNEL_REVISION) + + return { + "flash_attn_with_kvcache": ops.flash_attn_with_kvcache, + "flash_attn_varlen_func": ops.flash_attn_varlen_func, + } + except Exception as e: + # When the kernels from the repo or the cache directory cannot be loaded + # we catch the exception and log a warning, and then fallback to the implementation + # from sgl-kernel, which is expected to be less efficient but more compatible. + logger.warning( + f"Rollback to implementation from sgl-kernel since loading FlashAttention v3 " + f"kernels from {SGL_FA3_KERNEL_REPO} with lockfile {lockfile_path} failed: {e}" + ) + return _load_fa3_kernel_from_sgl() + + +def _load_fa3_kernel_from_sgl(): + from sgl_kernel.flash_attn import ( + flash_attn_varlen_func, + flash_attn_with_kvcache, + ) + + return { + "flash_attn_with_kvcache": flash_attn_with_kvcache, + "flash_attn_varlen_func": flash_attn_varlen_func, + } + + +@cache_once +def _is_fa3_supported(device=None) -> bool: + # There some fa3 FYI + # FA3 can fail without a enough shared memory for a some shapes, such as higher + # hidden_dim or some special cases. + # Right now, fa3 is supported for sm80/sm87 and sm86/sm89. The main different + # Between sm80/sm87 and sm86/sm89 is the shared memory size. you can follow the link below for more information + # https://docs.nvidia.com/cuda/cuda-c-programming-guide/#shared-memory-8-x + # And for sgl-kernel right now, we can build fa3 on sm80/sm86/sm89/sm90a. + # That means if you use A100/A*0/L20/L40/L40s/4090 you can use fa3. + return (torch.version.cuda >= "12.3") and ( + torch.cuda.get_device_capability(device)[0] == 9 + or torch.cuda.get_device_capability(device)[0] == 8 + ) + + +@debug_kernel_api +def flash_attn_with_kvcache( + q, + k_cache, + v_cache, + k=None, + v=None, + qv=None, + rotary_cos=None, + rotary_sin=None, + cache_seqlens: Optional[Union[int, torch.Tensor]] = None, + cache_batch_idx: Optional[torch.Tensor] = None, + cache_leftpad: Optional[torch.Tensor] = None, + page_table: Optional[torch.Tensor] = None, + cu_seqlens_q: Optional[torch.Tensor] = None, + cu_seqlens_k_new: Optional[torch.Tensor] = None, + max_seqlen_q: Optional[int] = None, + rotary_seqlens: Optional[torch.Tensor] = None, + q_descale: Optional[torch.Tensor] = None, + k_descale: Optional[torch.Tensor] = None, + v_descale: Optional[torch.Tensor] = None, + softmax_scale=None, + causal=False, + window_size=(-1, -1), # -1 means infinite context window + attention_chunk: Optional[int] = None, + softcap=0.0, # 0.0 means deactivated + rotary_interleaved=True, + scheduler_metadata=None, + num_splits=0, # Can be tuned for speed + pack_gqa=None, # Can be tuned for speed + sm_margin=0, # Can be tuned if some SMs are used for communication + return_softmax_lse=False, + sinks=None, +): + if not _is_fa3_supported(): + raise NotImplementedError( + "flash_attn at sgl-kernel is only supported on sm90 and above" + ) + + assert k_cache.stride(-1) == 1, "k_cache must have contiguous last dimension" + assert v_cache.stride(-1) == 1, "v_cache must have contiguous last dimension" + + return _load_fa3_kernels()["flash_attn_with_kvcache"]( + q, + k_cache, + v_cache, + k, + v, + qv, + rotary_cos, + rotary_sin, + cache_seqlens, + cache_batch_idx, + cache_leftpad, + page_table, + cu_seqlens_q, + cu_seqlens_k_new, + max_seqlen_q, + rotary_seqlens, + q_descale, + k_descale, + v_descale, + softmax_scale, + causal, + window_size, + attention_chunk, + softcap, + rotary_interleaved, + scheduler_metadata, + num_splits, + pack_gqa, + sm_margin, + return_softmax_lse, + sinks, + ) + + +@debug_kernel_api +def flash_attn_varlen_func( + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q=None, + max_seqlen_k=None, + seqused_q=None, + seqused_k=None, + page_table=None, + softmax_scale=None, + causal=False, + qv=None, + q_descale=None, + k_descale=None, + v_descale=None, + window_size=(-1, -1), + attention_chunk=0, + softcap=0.0, + num_splits=1, + pack_gqa=None, + sm_margin=0, + return_softmax_lse=False, + sinks=None, +): + + if not _is_fa3_supported(): + raise NotImplementedError( + "flash_attn at sgl-kernel is only supported on sm90 and above" + ) + + return _load_fa3_kernels()["flash_attn_varlen_func"]( + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + seqused_q, + seqused_k, + page_table, + softmax_scale, + causal, + qv, + q_descale, + k_descale, + v_descale, + window_size, + attention_chunk, + softcap, + num_splits, + pack_gqa, + sm_margin, + return_softmax_lse, + sinks, + ) diff --git a/python/sglang/jit_kernel/flash_attention_v4.py b/python/sglang/jit_kernel/flash_attention_v4.py index 0a79614ee075..46b49d177388 100644 --- a/python/sglang/jit_kernel/flash_attention_v4.py +++ b/python/sglang/jit_kernel/flash_attention_v4.py @@ -42,7 +42,6 @@ def flash_attn_varlen_func( score_mod: Optional[Callable] = None, aux_tensors: Optional[list] = None, return_softmax_lse: bool = False, - **_: object, ): if _flash_attn_varlen_func is None: # pragma: no cover raise ImportError( diff --git a/python/sglang/jit_kernel/tests/test_flash_attention_3.py b/python/sglang/jit_kernel/tests/test_flash_attention_3.py new file mode 100644 index 000000000000..e4687da9c827 --- /dev/null +++ b/python/sglang/jit_kernel/tests/test_flash_attention_3.py @@ -0,0 +1,1373 @@ +# Adapted from https://github.com/Dao-AILab/flash-attention/blob/main/hopper/test_flash_attn.py +import itertools +import math +from typing import Optional + +import pytest +import torch +import torch.nn.functional as F +from einops import rearrange, repeat + +apply_rotary_emb = None + +from sglang.test.ci.ci_register import register_cuda_ci + +register_cuda_ci(est_time=120, suite="stage-b-kernel-unit-1-gpu-large") +register_cuda_ci(est_time=900, suite="nightly-kernel-1-gpu", nightly=True) + + +def is_hopper(): + # Only Hopper supports different V headdim + return torch.cuda.get_device_properties(0).major == 9 + + +def is_fa3_supported(device=None) -> bool: + # There some fa3 FYI + # FA3 can fail without a enough shared memory for a some shapes, such as higher + # hidden_dim or some special cases. + # Right now, fa3 is supported for sm80/sm87 and sm86/sm89. The main different + # Between sm80/sm87 and sm86/sm89 is the shared memory size. you can follow the link below for more information + # https://docs.nvidia.com/cuda/cuda-c-programming-guide/#shared-memory-8-x + # And for sgl-kernel right now, we can build fa3 on sm80/sm86/sm89/sm90a. + # That means if you use A100/A*0/L20/L40/L40s/4090 you can use fa3. + return (torch.version.cuda >= "12.3") and ( + torch.cuda.get_device_capability(device)[0] == 9 + or torch.cuda.get_device_capability(device)[0] == 8 + ) + + +DISABLE_BACKWARD = True +# For CI test, we close them to True. +# DISABLE_SPLIT = os.getenv("FLASH_ATTENTION_DISABLE_SPLIT", "FALSE") == "TRUE" +# DISABLE_PAGEDKV = os.getenv("FLASH_ATTENTION_DISABLE_PAGEDKV", "FALSE") == "TRUE" +# DISABLE_APPENDKV = os.getenv("FLASH_ATTENTION_DISABLE_APPENDKV", "FALSE") == "TRUE" +# DISABLE_LOCAL = os.getenv("FLASH_ATTENTION_DISABLE_LOCAL", "FALSE") == "TRUE" +# DISABLE_SOFTCAP = os.getenv("FLASH_ATTENTION_DISABLE_SOFTCAP", "FALSE") == "TRUE" +# DISABLE_PACKGQA = os.getenv("FLASH_ATTENTION_DISABLE_PACKGQA", "FALSE") == "TRUE" +# DISABLE_FP16 = os.getenv("FLASH_ATTENTION_DISABLE_FP16", "FALSE") == "TRUE" +# DISABLE_FP8 = ( +# os.getenv("FLASH_ATTENTION_DISABLE_FP8", "FALSE") == "TRUE" +# or torch.cuda.get_device_capability("cuda")[0] < 9 +# ) + +DISABLE_SPLIT = False +DISABLE_PAGEDKV = True +DISABLE_APPENDKV = False +DISABLE_LOCAL = False +DISABLE_SOFTCAP = True +DISABLE_PACKGQA = False +DISABLE_FP16 = True +DISABLE_FP8 = True + + +# Adapted from https://github.com/Dao-AILab/flash-attention/blob/main/hopper/padding.py +def unpad_input(hidden_states, attention_mask, unused_mask=None): + """ + Arguments: + hidden_states: (batch, seqlen, ...) + attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid. + unused_mask: (batch, seqlen), bool / int, 1 means the element is allocated but unused. + Return: + hidden_states: (total_nnz, ...), where total_nnz = number of tokens selected in attention_mask + unused_mask. + indices: (total_nnz), the indices of masked tokens from the flattened input sequence. + cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states. + max_seqlen_in_batch: int + seqused: (batch), returns the number of tokens selected in attention_mask + unused_mask. + """ + all_masks = ( + (attention_mask + unused_mask) if unused_mask is not None else attention_mask + ) + seqlens_in_batch = all_masks.sum(dim=-1, dtype=torch.int32) + used_seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(all_masks.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) + # TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the + # bool mask, then call nonzero to get the indices, then index with those. The indices is @dim + # times larger than it needs to be, wasting memory. It's faster and more memory-efficient to + # index with integer indices. + return ( + rearrange(hidden_states, "b s ... -> (b s) ...")[indices], + indices, + cu_seqlens, + max_seqlen_in_batch, + used_seqlens_in_batch, + ) + + +def generate_random_padding_mask( + max_seqlen, batch_size, device, mode="random", zero_lengths=False +): + assert mode in ["full", "random", "third"] + if mode == "full": + lengths = torch.full( + (batch_size, 1), max_seqlen, device=device, dtype=torch.int32 + ) + elif mode == "random": + lengths = torch.randint( + max(0 if zero_lengths else 1, max_seqlen - 20), + max_seqlen + 1, + (batch_size, 1), + device=device, + ) + elif mode == "third": + lengths = torch.randint( + max_seqlen // 3, max_seqlen + 1, (batch_size, 1), device=device + ) + + if zero_lengths: + # Generate zero-lengths every 5 batches and the last batch. + for i in range(batch_size): + if i % 5 == 0: + lengths[i] = 0 + lengths[-1] = 0 + padding_mask = ( + repeat(torch.arange(max_seqlen, device=device), "s -> b s", b=batch_size) + < lengths + ) + return padding_mask + + +def pad_input(hidden_states, indices, batch, seqlen): + """ + Arguments: + hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask. + indices: (total_nnz), the indices that represent the non-masked tokens of the original padded input sequence. + batch: int, batch size for the padded sequence. + seqlen: int, maximum sequence length for the padded sequence. + Return: + hidden_states: (batch, seqlen, ...) + """ + dim = hidden_states.shape[1:] + output = torch.zeros( + (batch * seqlen), *dim, device=hidden_states.device, dtype=hidden_states.dtype + ) + output[indices] = hidden_states + return rearrange(output, "(b s) ... -> b s ...", b=batch) + + +def construct_local_mask( + seqlen_q, + seqlen_k, + window_size=(-1, -1), # -1 means infinite window size + sink_token_length=0, + query_padding_mask=None, + key_padding_mask=None, + key_leftpad=None, + device=None, +): + row_idx = rearrange( + torch.arange(seqlen_q, device=device, dtype=torch.long), "s -> s 1" + ) + col_idx = torch.arange(seqlen_k, device=device, dtype=torch.long) + if key_leftpad is not None: + key_leftpad = rearrange(key_leftpad, "b -> b 1 1 1") + col_idx = repeat(col_idx, "s -> b 1 1 s", b=key_leftpad.shape[0]) + col_idx = torch.where(col_idx >= key_leftpad, col_idx - key_leftpad, 2**32) + sk = ( + seqlen_k + if key_padding_mask is None + else rearrange(key_padding_mask.sum(-1), "b -> b 1 1 1") + ) + sq = ( + seqlen_q + if query_padding_mask is None + else rearrange(query_padding_mask.sum(-1), "b -> b 1 1 1") + ) + if window_size[0] < 0: + return col_idx > row_idx + sk - sq + window_size[1] + else: + sk = torch.full_like(col_idx, seqlen_k) if key_padding_mask is None else sk + return torch.logical_or( + col_idx > torch.minimum(row_idx + sk - sq + window_size[1], sk), + torch.logical_and( + col_idx < row_idx + sk - sq - window_size[0], + col_idx >= sink_token_length, + ), + ) + + +def attention_ref( + q, + k, + v, + query_padding_mask=None, + key_padding_mask=None, + key_leftpad=None, + attn_bias=None, + dropout_p=0.0, + dropout_mask=None, + causal=False, + qv=None, + q_descale=None, + k_descale=None, + v_descale=None, + window_size=(-1, -1), # -1 means infinite window size + sink_token_length=0, + sinks: Optional[torch.Tensor] = None, + softcap=0.0, + upcast=True, + reorder_ops=False, + intermediate_dtype=None, +): + """ + Arguments: + q: (batch_size, seqlen_q, nheads, head_dim) + k: (batch_size, seqlen_k, nheads, head_dim) + v: (batch_size, seqlen_k, nheads, head_dim_v) + qv: (batch_size, seqlen_q, nheads, head_dim_v) + query_padding_mask: (batch_size, seqlen_q) + key_padding_mask: (batch_size, seqlen_k) + attn_bias: broadcastable to (batch_size, nheads, seqlen_q, seqlen_k) + dropout_p: float + dropout_mask: (batch_size, nheads, seqlen_q, seqlen_k) + causal: whether to apply causal masking + upcast: whether to cast all inputs to fp32, do all computation in fp32, then cast + output back to fp16/bf16. + reorder_ops: whether to change the order of operations (scaling k instead of scaling k, etc.) + without changing the math. This is to estimate the numerical error from operation + reordering. + Output: + output: (batch_size, seqlen_q, nheads, head_dim_v) + attention: (batch_size, nheads, seqlen_q, seqlen_k), softmax after dropout + """ + if causal: + window_size = (window_size[0], 0) + dtype_og = q.dtype + if upcast: + q, k, v = q.float(), k.float(), v.float() + qv = qv.float() if qv is not None else None + if q_descale is not None: + q_descale = repeat(q_descale, "b h -> b 1 (h g) 1", g=q.shape[2] // k.shape[2]) + q = (q.float() * q_descale).to(q.dtype) + qv = (qv.float() * q_descale).to(qv.dtype) if qv is not None else None + if k_descale is not None: + k = (k.float() * rearrange(k_descale, "b h -> b 1 h 1")).to(dtype=k.dtype) + if v_descale is not None: + v = (v.float() * rearrange(v_descale, "b h -> b 1 h 1")).to(dtype=v.dtype) + seqlen_q, seqlen_k = q.shape[1], k.shape[1] + k = repeat(k, "b s h d -> b s (h g) d", g=q.shape[2] // k.shape[2]) + v = repeat(v, "b s h d -> b s (h g) d", g=q.shape[2] // v.shape[2]) + d = q.shape[-1] + dv = v.shape[-1] + softmax_scale = 1.0 / math.sqrt(d if qv is None else d + dv) + if not reorder_ops: + scores = torch.einsum("bthd,bshd->bhts", q * softmax_scale, k) + else: + scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale) + if qv is not None: + scores = scores + torch.einsum("bthd,bshd->bhts", qv * softmax_scale, v) + if softcap > 0: + scores = torch.tanh(scores / softcap) * softcap + if key_padding_mask is not None: + scores.masked_fill_( + rearrange(~key_padding_mask, "b s -> b 1 1 s"), float("-inf") + ) + if window_size[0] >= 0 or window_size[1] >= 0: + local_mask = construct_local_mask( + seqlen_q, + seqlen_k, + window_size, + sink_token_length, + query_padding_mask, + key_padding_mask, + key_leftpad=key_leftpad, + device=q.device, + ) + scores.masked_fill_(local_mask, float("-inf")) + if attn_bias is not None: + scores = scores + attn_bias + if sinks is None: + attention = torch.softmax(scores, dim=-1).to(v.dtype) + else: + scores_fp32 = scores.to(torch.float32) + logits_max = torch.amax(scores_fp32, dim=-1, keepdim=True) + sinks = rearrange(sinks, "h -> h 1 1") + logits_or_sinks_max = torch.maximum(sinks, logits_max) + unnormalized_scores = torch.exp(scores_fp32 - logits_or_sinks_max) + normalizer = unnormalized_scores.sum(dim=-1, keepdim=True) + torch.exp( + sinks - logits_or_sinks_max + ) + attention = (unnormalized_scores / normalizer).to(v.dtype) + # We want to mask here so that the attention matrix doesn't have any NaNs + # Otherwise we'll get NaN in dV + if query_padding_mask is not None: + attention = attention.masked_fill( + rearrange(~query_padding_mask, "b s -> b 1 s 1"), 0.0 + ) + # Without this we might get NaN in dv + if key_padding_mask is not None: + attention = attention.masked_fill( + rearrange(~key_padding_mask, "b s -> b 1 1 s"), 0.0 + ) + # Some rows might be completely masked out so we fill them with zero instead of NaN + if window_size[0] >= 0 or window_size[1] >= 0: + attention = attention.masked_fill( + torch.all(local_mask, dim=-1, keepdim=True), 0.0 + ) + dropout_scaling = 1.0 / (1 - dropout_p) + # attention_drop = attention.masked_fill(~dropout_mask, 0.0) * dropout_scaling + # output = torch.einsum('bhts,bshd->bthd', attention_drop , v) + if dropout_mask is not None: + attention_drop = attention.masked_fill(~dropout_mask, 0.0) + else: + attention_drop = attention + if intermediate_dtype is not None: + attention_drop = attention_drop.to(intermediate_dtype).to(attention_drop.dtype) + output = torch.einsum("bhts,bshd->bthd", attention_drop, v * dropout_scaling) + if query_padding_mask is not None: + output.masked_fill_(rearrange(~query_padding_mask, "b s -> b s 1 1"), 0.0) + return output.to(dtype=dtype_og), attention.to(dtype=dtype_og) + + +def generate_qkv( + q, + k, + v, + query_padding_mask=None, + key_padding_mask=None, + kvpacked=False, + qkvpacked=False, + add_unused_qkv=False, + query_unused_mask=None, + key_unused_mask=None, +): + """ + Arguments: + q: (batch_size, seqlen_q, nheads, d) + k: (batch_size, seqlen_k, nheads_k, d) + v: (batch_size, seqlen_k, nheads_k, d) + query_padding_mask: (batch_size, seqlen), bool + key_padding_mask: (batch_size, seqlen), bool + """ + assert not (kvpacked and qkvpacked) + batch_size, seqlen_q, nheads, d = q.shape + _, seqlen_k, nheads_k, _ = k.shape + assert k.shape == (batch_size, seqlen_k, nheads_k, d) + assert v.shape == (batch_size, seqlen_k, nheads_k, d) + if query_unused_mask is not None or key_unused_mask is not None: + assert not kvpacked + assert not qkvpacked + + if query_padding_mask is not None: + q_unpad, indices_q, cu_seqlens_q, max_seqlen_q, seqused_q = unpad_input( + q, + query_padding_mask, + query_unused_mask, + ) + output_pad_fn = lambda output_unpad: pad_input( + output_unpad, indices_q, batch_size, seqlen_q + ) + else: + q_unpad = rearrange(q, "b s h d -> (b s) h d") + cu_seqlens_q = torch.arange( + 0, + (batch_size + 1) * seqlen_q, + step=seqlen_q, + dtype=torch.int32, + device=q_unpad.device, + ) + seqused_q = None + max_seqlen_q = seqlen_q + output_pad_fn = lambda output_unpad: rearrange( + output_unpad, "(b s) h d -> b s h d", b=batch_size + ) + + if key_padding_mask is not None: + k_unpad, indices_k, cu_seqlens_k, max_seqlen_k, seqused_k = unpad_input( + k, key_padding_mask, key_unused_mask + ) + v_unpad, _, _, _, _ = unpad_input(v, key_padding_mask, key_unused_mask) + else: + k_unpad = rearrange(k, "b s h d -> (b s) h d") + v_unpad = rearrange(v, "b s h d -> (b s) h d") + cu_seqlens_k = torch.arange( + 0, + (batch_size + 1) * seqlen_k, + step=seqlen_k, + dtype=torch.int32, + device=k_unpad.device, + ) + seqused_k = None + max_seqlen_k = seqlen_k + + if qkvpacked: + assert (query_padding_mask == key_padding_mask).all() + assert nheads == nheads_k + qkv_unpad = torch.stack([q_unpad, k_unpad, v_unpad], dim=1) + qkv = torch.stack([q, k, v], dim=2) + if query_padding_mask is not None: + dqkv_pad_fn = lambda dqkv_unpad: pad_input( + dqkv_unpad, indices_q, batch_size, seqlen_q + ) + else: + dqkv_pad_fn = lambda dqkv_unpad: rearrange( + dqkv_unpad, "(b s) t h d -> b s t h d", b=batch_size + ) + return ( + qkv_unpad.detach().requires_grad_(), + cu_seqlens_q, + max_seqlen_q, + qkv.detach().requires_grad_(), + output_pad_fn, + dqkv_pad_fn, + ) + elif kvpacked: + kv_unpad = torch.stack([k_unpad, v_unpad], dim=1) + kv = torch.stack([k, v], dim=2) + dq_pad_fn = output_pad_fn + if key_padding_mask is not None: + dkv_pad_fn = lambda dkv_unpad: pad_input( + dkv_unpad, indices_k, batch_size, seqlen_k + ) + else: + dkv_pad_fn = lambda dkv_unpad: rearrange( + dkv_unpad, "(b s) t h d -> b s t h d", b=batch_size + ) + return ( + q_unpad.detach().requires_grad_(), + kv_unpad.detach().requires_grad_(), + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + q.detach().requires_grad_(), + kv.detach().requires_grad_(), + output_pad_fn, + dq_pad_fn, + dkv_pad_fn, + ) + else: + dq_pad_fn = output_pad_fn + if key_padding_mask is not None: + dk_pad_fn = lambda dk_unpad: pad_input( + dk_unpad, indices_k, batch_size, seqlen_k + ) + else: + dk_pad_fn = lambda dk_unpad: rearrange( + dk_unpad, "(b s) h d -> b s h d", b=batch_size + ) + return ( + q_unpad.detach().requires_grad_(), + k_unpad.detach().requires_grad_(), + v_unpad.detach().requires_grad_(), + cu_seqlens_q, + cu_seqlens_k, + seqused_q, + seqused_k, + max_seqlen_q, + max_seqlen_k, + q.detach().requires_grad_(), + k.detach().requires_grad_(), + v.detach().requires_grad_(), + output_pad_fn, + dq_pad_fn, + dk_pad_fn, + ) + + +@pytest.mark.skipif( + not is_fa3_supported(), + reason="flash_attn at sgl-kernel is only supported on sm90 or sm80", +) +# @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float8_e4m3fn]) +@pytest.mark.parametrize( + "dtype", [torch.bfloat16] + ([torch.float8_e4m3fn] if not DISABLE_FP8 else []) +) +# @pytest.mark.parametrize("dtype", [torch.bfloat16]) +# @pytest.mark.parametrize("dtype", [torch.float8_e4m3fn]) +@pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) +# @pytest.mark.parametrize("mha_type", ["mha"]) +@pytest.mark.parametrize("has_sink", [False, True]) +# @pytest.mark.parametrize("has_sink", [False]) +@pytest.mark.parametrize("new_kv", [False] + ([True] if not DISABLE_APPENDKV else [])) +# @pytest.mark.parametrize("new_kv", [True]) +# @pytest.mark.parametrize( +# "causal,local", +# [(False, False), (True, False)] + ([(False, True)] if not DISABLE_LOCAL else []), +# ) +# @pytest.mark.parametrize("causal,local", [(False, False), (True, False)]) +@pytest.mark.parametrize("causal,local", [(False, False)]) +@pytest.mark.parametrize( + "seqlen_new_eq_seqlen_q", [True, False] if not DISABLE_APPENDKV else [True] +) +# @pytest.mark.parametrize("seqlen_new_eq_seqlen_q", [True]) +# @pytest.mark.parametrize("has_rotary_seqlens", [False, True]) +@pytest.mark.parametrize("has_rotary_seqlens", [False]) +@pytest.mark.parametrize( + "rotary_interleaved", [False, True] if not DISABLE_APPENDKV else [False] +) +# @pytest.mark.parametrize("rotary_interleaved", [True]) +@pytest.mark.parametrize( + "rotary_fraction", + ( + [0.0, 0.5, 1.0] + if (not DISABLE_APPENDKV) and (apply_rotary_emb is not None) + else [0.0] + ), +) +# @pytest.mark.parametrize("rotary_fraction", [0.0]) +@pytest.mark.parametrize( + "page_size", [None] + ([1, 4, 128] if not DISABLE_PAGEDKV else []) +) +# @pytest.mark.parametrize("page_size", [None]) +# @pytest.mark.parametrize("has_leftpad", [False, True]) +@pytest.mark.parametrize("has_leftpad", [False]) +# @pytest.mark.parametrize("has_batch_idx", [False, True]) +@pytest.mark.parametrize("has_batch_idx", [False]) +# @pytest.mark.parametrize("varlen_q", [False, True]) +@pytest.mark.parametrize("varlen_q", [False]) +# @pytest.mark.parametrize("d", [32, 59, 64, 80, 128, 256]) +# @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256]) +# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192]) +# @pytest.mark.parametrize('d', [56, 80]) +@pytest.mark.parametrize("d", [64]) +# @pytest.mark.parametrize("d", [192]) +@pytest.mark.parametrize( + "seqlen_q,seqlen_k", + [ + (1, 128), + (1, 339), + (3, 1024), + (64, 800), + (64, 256), + (3, 799), + (64, 2048), + (16, 20000), + # (1, 128 * 1024), + # (16, 128 * 1024), + (128, 128), + (256, 512), # To test appending KV with more than 1 block + (2048, 3577), # Enough tile to test persistent scheduler + ], +) +# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)]) +def test_flash_attn_kvcache( + seqlen_q, + seqlen_k, + d, + varlen_q, + has_batch_idx, + has_leftpad, + page_size, + rotary_fraction, + rotary_interleaved, + has_rotary_seqlens, + seqlen_new_eq_seqlen_q, + causal, + local, + new_kv, + mha_type, + dtype, + has_sink, +): + from sgl_kernel.flash_attn import flash_attn_with_kvcache + + if page_size is not None and seqlen_k % page_size != 0: + pytest.skip() + if seqlen_q > seqlen_k and new_kv: + pytest.skip() + if not new_kv and rotary_fraction > 0.0: + pytest.skip() + if rotary_fraction == 0.0 and has_rotary_seqlens: + pytest.skip() + device = "cuda" + # set seed + torch.random.manual_seed(0) + batch_size = 5 + # batch_size = 1 + batch_size_cache = batch_size if not has_batch_idx else batch_size * 2 + nheads = 6 + # nheads = 1 + # rotary_dim must be a multiple of 16, and must be <= d + rotary_dim = math.floor(int(rotary_fraction * d) / 16) * 16 + nheads_k = nheads if mha_type == "mha" else (1 if mha_type == "mqa" else 3) + assert nheads % nheads_k == 0 + dtype_ref = torch.bfloat16 if dtype == torch.float8_e4m3fn else dtype + dv_vals = [128, d] if d > 128 and d <= 192 else ([256, 512, d] if d <= 64 else [d]) + + if has_sink: + sinks = torch.randn(nheads, dtype=torch.bfloat16, device=device) + else: + sinks = None + + if dtype == torch.float8_e4m3fn or not is_hopper(): + # for fp8 and ampere arch, we not support v head dim != qk head dim + dv_vals = [d] + for dv in dv_vals: + has_qv = d == 64 and dv >= 256 + q = ( + torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype_ref) + .to(dtype) + .to(dtype_ref) + ) + if has_qv: + qv = ( + torch.randn( + batch_size, seqlen_q, nheads, dv, device=device, dtype=dtype_ref + ) + .to(dtype) + .to(dtype_ref) + ) + else: + qv = None + if varlen_q: + query_padding_mask = generate_random_padding_mask( + seqlen_q, batch_size, device, mode="random" + ) + q_unpad, indices_q, cu_seqlens_q, max_seqlen_q, *rest = unpad_input( + q, query_padding_mask + ) + output_pad_fn = lambda output_unpad: pad_input( + output_unpad, indices_q, batch_size, seqlen_q + ) + qv_unpad = ( + rearrange(qv, "b s ... -> (b s) ...")[indices_q] if has_qv else None + ) + else: + query_padding_mask = None + q_unpad = q + qv_unpad = qv + cu_seqlens_q, max_seqlen_q = None, None + # Put window_size after QKV randn so that window_size changes from test to test + window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,)) + + seqlen_new = ( + seqlen_q + if seqlen_new_eq_seqlen_q + else torch.randint(1, seqlen_q + 1, (1,)).item() + ) + cu_seqlens_k_new = None + key_new_padding_mask = None + if new_kv: + k = ( + torch.randn( + batch_size, seqlen_new, nheads_k, d, device=device, dtype=dtype_ref + ) + .to(dtype) + .to(dtype_ref) + ) + v = ( + torch.randn( + batch_size, seqlen_new, nheads_k, dv, device=device, dtype=dtype_ref + ) + .to(dtype) + .to(dtype_ref) + ) + if varlen_q: # k & v are also varlen + key_new_padding_mask = generate_random_padding_mask( + seqlen_new, batch_size, device, mode="random" + ) + k_unpad, indices_k, cu_seqlens_k_new, *rest = unpad_input( + k, key_new_padding_mask + ) + v_unpad, *rest = unpad_input(v, key_new_padding_mask) + else: + k_unpad, v_unpad = k, v + else: + k, v, k_unpad, v_unpad = None, None, None, None + if page_size is None: + k_cache = ( + torch.randn( + batch_size_cache, + seqlen_k, + nheads_k, + d, + device=device, + dtype=dtype_ref, + ) + .to(dtype) + .to(dtype_ref) + ) + v_cache = ( + torch.randn( + batch_size_cache, + seqlen_k, + nheads_k, + dv, + device=device, + dtype=dtype_ref, + ) + .to(dtype) + .to(dtype_ref) + ) + page_table = None + else: + ( + k_cache, + v_cache, + page_table, + k_cache_paged, + v_cache_paged, + num_blocks, + ) = _generate_block_kvcache( + seqlen_k, + page_size, + batch_size_cache, + nheads_k, + d, + dv, + device, + dtype, + dtype_ref, + ) + cache_seqlens = torch.randint( + 0 if new_kv else 1, + # If we don't use seqlen_q in the case of causal and rotary, cos/sin won't be long enough + ( + ( + seqlen_k + - (seqlen_q if (causal or local) and rotary_dim > 1 else seqlen_new) + + 1 + ) + if new_kv + else (seqlen_k + 1) + ), + (batch_size,), + dtype=torch.int32, + device=device, + ) + if has_leftpad: + cache_leftpad = torch.cat( + [ + ( + torch.randint( + 0, + cache_seqlens[i].item(), + (1,), + dtype=torch.int32, + device=device, + ) + if cache_seqlens[i].item() > 0 + else torch.zeros(1, dtype=torch.int32, device=device) + ) + for i in range(batch_size) + ] + ) + else: + cache_leftpad = None + if has_batch_idx: + cache_batch_idx = torch.randperm( + batch_size_cache, dtype=torch.int32, device=device + )[:batch_size] + else: + cache_batch_idx = None + arange = rearrange(torch.arange(seqlen_k, device=device), "s -> 1 s") + cache_seqlens_expanded = rearrange(cache_seqlens, "b -> b 1") + if not new_kv: + key_padding_mask = arange < cache_seqlens_expanded + else: + k_new_seqlens = ( + key_new_padding_mask.sum(-1, keepdims=True) if varlen_q else seqlen_new + ) + key_padding_mask = arange < cache_seqlens_expanded + k_new_seqlens + if has_leftpad: + key_padding_mask = torch.logical_and( + key_padding_mask, + arange >= cache_leftpad.unsqueeze(-1).expand(-1, seqlen_k), + ) + # cache_seqlens = torch.tensor([64], dtype=torch.int32, device=device) + rotary_seqlens = cache_seqlens if not has_rotary_seqlens else cache_seqlens // 2 + if rotary_dim > 0: + angle = ( + torch.rand( + seqlen_k if page_size is None else num_blocks * page_size, + rotary_dim // 2, + device=device, + ) + * 2 + * math.pi + ) + cos = torch.cos(angle).to(dtype=dtype_ref).to(dtype).to(dtype_ref) + sin = torch.sin(angle).to(dtype=dtype_ref).to(dtype).to(dtype_ref) + if causal or local: + q_ro = apply_rotary_emb( + q, + cos, + sin, + seqlen_offsets=rotary_seqlens, + interleaved=rotary_interleaved, + ) + else: + q_ro = rearrange( + apply_rotary_emb( + rearrange(q, "b s h d -> b 1 (s h) d"), + cos, + sin, + seqlen_offsets=rotary_seqlens, + interleaved=rotary_interleaved, + ), + "b 1 (s h) d -> b s h d", + s=seqlen_q, + ) + # q_ro = q + k_ro = apply_rotary_emb( + k, + cos, + sin, + seqlen_offsets=rotary_seqlens, + interleaved=rotary_interleaved, + ) + else: + cos, sin = None, None + q_ro, k_ro = q, k + # k_cache[:, 64:] = -1 + k_cache_ref = ( + k_cache if not has_batch_idx else k_cache[cache_batch_idx] + ).clone() + v_cache_ref = ( + v_cache if not has_batch_idx else v_cache[cache_batch_idx] + ).clone() + if new_kv: + update_mask = torch.logical_and( + cache_seqlens_expanded <= arange, + arange < cache_seqlens_expanded + k_new_seqlens, + ) + k_to_update = rearrange(k_ro, "b s ... -> (b s) ...") + v_to_update = rearrange(v, "b s ... -> (b s) ...") + if varlen_q: + k_to_update = k_to_update[indices_k] + v_to_update = v_to_update[indices_k] + k_cache_ref[update_mask] = k_to_update + v_cache_ref[update_mask] = v_to_update + k_cache_rep = repeat( + k_cache_ref, "b s h d -> b s (h g) d", g=nheads // nheads_k + ) + v_cache_rep = repeat( + v_cache_ref, "b s h d -> b s (h g) d", g=nheads // nheads_k + ) + out_ref, _ = attention_ref( + q_ro, + k_cache_rep, + v_cache_rep, + query_padding_mask, + key_padding_mask, + causal=causal, + qv=qv, + window_size=window_size, + key_leftpad=cache_leftpad, + sinks=sinks, + ) + out_pt, _ = attention_ref( + q_ro, + k_cache_rep, + v_cache_rep, + query_padding_mask, + key_padding_mask, + causal=causal, + qv=qv, + window_size=window_size, + upcast=False, + reorder_ops=True, + key_leftpad=cache_leftpad, + intermediate_dtype=dtype if dtype == torch.float8_e4m3fn else None, + sinks=sinks, + ) + q = q.to(dtype) + q_unpad = q_unpad.to(dtype) if varlen_q else None + k_cache = k_cache.to(dtype) + v_cache = v_cache.to(dtype) + k_cache_paged = k_cache_paged.to(dtype) if page_size is not None else None + v_cache_paged = v_cache_paged.to(dtype) if page_size is not None else None + k = k.to(dtype) if k is not None else None + v = v.to(dtype) if v is not None else None + k_unpad = k_unpad.to(dtype) if k_unpad is not None else None + v_unpad = v_unpad.to(dtype) if v_unpad is not None else None + qv = qv.to(dtype) if qv is not None else None + qv_unpad = qv_unpad.to(dtype) if (varlen_q and qv is not None) else None + cos = cos.to(dtype) if cos is not None else None + sin = sin.to(dtype) if sin is not None else None + k_cache_saved = k_cache.clone() if page_size is None else k_cache_paged.clone() + v_cache_saved = v_cache.clone() if page_size is None else v_cache_paged.clone() + num_splits_vals = [1, 0] if not DISABLE_SPLIT else [1] + precompute_metadata_vals = [False] + for num_splits, precompute_metadata in itertools.product( + num_splits_vals, precompute_metadata_vals + ): + scheduler_metadata = None + # Repeat to test metadata reuse + for _ in range(1 if not precompute_metadata else 2): + if page_size is None: + k_cache.copy_(k_cache_saved) + v_cache.copy_(v_cache_saved) + else: + k_cache_paged.copy_(k_cache_saved) + v_cache_paged.copy_(v_cache_saved) + out, lse, *rest = flash_attn_with_kvcache( + q if not varlen_q else q_unpad, + k_cache if page_size is None else k_cache_paged, + v_cache if page_size is None else v_cache_paged, + k if not new_kv or not varlen_q else k_unpad, + v if not new_kv or not varlen_q else v_unpad, + qv=qv if not varlen_q else qv_unpad, + rotary_cos=cos, + rotary_sin=sin, + cache_seqlens=cache_seqlens, + cache_batch_idx=cache_batch_idx, + cache_leftpad=cache_leftpad, + page_table=page_table, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k_new=cu_seqlens_k_new, + max_seqlen_q=max_seqlen_q, + rotary_seqlens=rotary_seqlens, + causal=causal, + window_size=window_size, + rotary_interleaved=rotary_interleaved, + scheduler_metadata=scheduler_metadata, + num_splits=num_splits, + return_softmax_lse=True, + sinks=sinks, + ) + if varlen_q: + out = output_pad_fn(out) + # out = flash_attn_with_kvcache( + # q, k_cache, v_cache, cache_seqlens=cache_seqlens, causal=causal, window_size=window_size + # ) + # out = flash_attn_with_kvcache(q, k_cache, v_cache, causal=causal, window_size=window_size) + # qk = torch.einsum("bqhd,bkhd->bhqk", q, k_cache_ref) + # m = qk.amax(-1, keepdim=True) + # s_tmp = torch.exp((qk - m) / math.sqrt(d)) + # o1 = torch.einsum('bhst,bthd->bshd', s_tmp, v_cache_ref) + # lse_ref = torch.logsumexp(qk / math.sqrt(d), -1) + # probs = torch.softmax(qk, dim=-1) + print(f"Output max diff: {(out - out_ref).abs().max().item()}") + print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") + print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") + print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") + # breakpoint() + + # Check that FlashAttention's numerical error is at most twice the numerical error + # of a Pytorch implementation. + if new_kv: + if page_size is None: + k_cache_select = ( + k_cache.to(dtype_ref) + if not has_batch_idx + else k_cache.to(dtype_ref)[cache_batch_idx] + ) + v_cache_select = ( + v_cache.to(dtype_ref) + if not has_batch_idx + else v_cache.to(dtype_ref)[cache_batch_idx] + ) + else: + k_cache_select = rearrange( + k_cache_paged.to(dtype_ref)[ + ( + page_table + if not has_batch_idx + else page_table[cache_batch_idx] + ).flatten() + ], + "(b nblocks) block_size ... -> b (nblocks block_size) ...", + b=batch_size, + )[:, :seqlen_k].to(dtype_ref) + v_cache_select = rearrange( + v_cache_paged.to(dtype_ref)[ + ( + page_table + if not has_batch_idx + else page_table[cache_batch_idx] + ).flatten() + ], + "(b nblocks) block_size ... -> b (nblocks block_size) ...", + b=batch_size, + )[:, :seqlen_k].to(dtype_ref) + k_cache_ref = k_cache_ref.to(dtype).to(dtype_ref) + v_cache_ref = v_cache_ref.to(dtype).to(dtype_ref) + if dtype is not torch.float8_e4m3fn: + assert torch.equal(v_cache_select, v_cache_ref) + else: + assert torch.allclose( + v_cache_select, v_cache_ref, rtol=1e-3, atol=1e-3 + ) + # breakpoint() + # if rotary_dim == 0 and dtype is not torch.float8_e4m3fn: + if rotary_dim == 0: + assert torch.equal(k_cache_select, k_cache_ref) + else: + # if not torch.allclose(k_cache_select, k_cache_ref, rtol=1e-3, atol=1e-3): + # breakpoint() + if dtype is not torch.float8_e4m3fn: + assert torch.allclose( + k_cache_select, k_cache_ref, rtol=1e-3, atol=1e-3 + ) + else: + assert torch.allclose( + k_cache_select, k_cache_ref, rtol=1e-1, atol=1e-1 + ) + mult = 4 if dtype == torch.float8_e4m3fn else 2 + assert (out - out_ref).abs().max().item() <= mult * ( + out_pt - out_ref + ).abs().max().item() + 1e-5 + mult_mean = 3 if dtype == torch.float8_e4m3fn else 1.5 + assert (out - out_ref).abs().mean().item() <= mult_mean * ( + out_pt - out_ref + ).abs().mean().item() + + +def _generate_block_kvcache( + seqlen_k, page_size, batch_size, nheads_k, d, dv, device, dtype, dtype_ref +): + num_blocks = math.ceil(seqlen_k / page_size) * batch_size * 3 + k_cache_paged = ( + torch.randn(num_blocks, page_size, nheads_k, d, device=device, dtype=dtype_ref) + .to(dtype) + .to(dtype_ref) + ) + v_cache_paged = ( + torch.randn(num_blocks, page_size, nheads_k, dv, device=device, dtype=dtype_ref) + .to(dtype) + .to(dtype_ref) + ) + page_table = rearrange( + torch.randperm(num_blocks, dtype=torch.int32, device=device), + "(b nblocks) -> b nblocks", + b=batch_size, + ) + k_cache = rearrange( + k_cache_paged[page_table.flatten()], + "(b nblocks) block_size ... -> b (nblocks block_size) ...", + b=batch_size, + )[:, :seqlen_k] + v_cache = rearrange( + v_cache_paged[page_table.flatten()], + "(b nblocks) block_size ... -> b (nblocks block_size) ...", + b=batch_size, + )[:, :seqlen_k] + return k_cache, v_cache, page_table, k_cache_paged, v_cache_paged, num_blocks + + +@pytest.mark.skipif( + not is_fa3_supported(), + reason="flash_attn at sgl-kernel is only supported on sm90 or sm80", +) +# @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float8_e4m3fn]) +@pytest.mark.parametrize( + "dtype", [torch.bfloat16] + ([torch.float8_e4m3fn] if not DISABLE_FP8 else []) +) +# @pytest.mark.parametrize("dtype", [torch.bfloat16]) +# @pytest.mark.parametrize("dtype", [torch.float8_e4m3fn]) +@pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) +# @pytest.mark.parametrize("mha_type", ["mha"]) +@pytest.mark.parametrize("has_sink", [False, True]) +# @pytest.mark.parametrize("has_sink", [False]) +# @pytest.mark.parametrize("has_qv", [False, True]) +@pytest.mark.parametrize("has_qv", [False]) +# @pytest.mark.parametrize("deterministic", [False, True]) +@pytest.mark.parametrize("deterministic", [False]) +@pytest.mark.parametrize("softcap", [0.0] + ([15.0] if not DISABLE_SOFTCAP else [])) +# @pytest.mark.parametrize("softcap", [0.0]) +@pytest.mark.parametrize("local", [False]) +# @pytest.mark.parametrize("local", [False]) +@pytest.mark.parametrize("causal", [False, True]) +# @pytest.mark.parametrize("causal", [False]) +@pytest.mark.parametrize("add_unused_qkv", [False, True]) +# @pytest.mark.parametrize("add_unused_qkv", [True]) +# @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256]) +# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192, 256]) +# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192]) +# @pytest.mark.parametrize('d', [56, 80]) +# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128]) +# @pytest.mark.parametrize("d", [64, 96, 128]) +# @pytest.mark.parametrize("d", COMPILED_HDIMS) +@pytest.mark.parametrize("d", [128]) +@pytest.mark.parametrize( + "seqlen_q,seqlen_k", + [ + (1, 1), + (1, 3), + (2, 1), + (511, 1), + (3, 513), + (64, 128), + (128, 128), + (256, 256), + (113, 203), + (128, 217), + (113, 211), + (108, 256), + (256, 512), + (307, 256), + (640, 128), + (512, 256), + (1024, 1024), + (1023, 1024), + (1024, 1023), + (2048, 2048), + ], +) +def test_flash_attn_varlen_output( + seqlen_q, + seqlen_k, + d, + add_unused_qkv, + causal, + local, + softcap, + deterministic, + has_qv, + mha_type, + dtype, + has_sink, +): + from sglang.jit_kernel.flash_attention import flash_attn_varlen_func + + device = "cuda" + # set seed + torch.random.manual_seed(seqlen_q + seqlen_k + d + int(causal) * 2 + int(local)) + # batch_size = 40 + # nheads = 16 + batch_size = 9 if seqlen_q <= 2048 else 2 + nheads = 6 + # batch_size = 2 + # nheads = 1 + nheads_kv = nheads if mha_type == "mha" else (2 if mha_type == "gqa" else 1) + dtype_ref = torch.bfloat16 if dtype == torch.float8_e4m3fn else dtype + dv_vals = [128, d] if d > 128 and d <= 192 else ([256, 512, d] if d <= 64 else [d]) + if dtype == torch.float8_e4m3fn: + dv_vals = [d] + for dv in dv_vals: + q_ref = torch.randn( + batch_size, seqlen_q, nheads, d, device=device, dtype=dtype_ref + ) + if softcap > 0.0: + # Ensure the values of qk are at least within softcap range. + q_ref = (q_ref * softcap / 4).detach().requires_grad_() + q_ref = q_ref.to(dtype).to(dtype_ref).requires_grad_() + k_ref = ( + torch.randn( + batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype_ref + ) + .to(dtype) + .to(dtype_ref) + .requires_grad_() + ) + v_ref = ( + torch.randn( + batch_size, seqlen_k, nheads_kv, dv, device=device, dtype=dtype_ref + ) + .to(dtype) + .to(dtype_ref) + .requires_grad_() + ) + if has_qv: + qv_ref = ( + torch.randn( + batch_size, seqlen_q, nheads, dv, device=device, dtype=dtype_ref + ) + .to(dtype) + .to(dtype_ref) + ) + else: + qv_ref = None + # Put window_size after QKV randn so that window_size changes from test to test + window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,)) + + if has_sink: + sinks = torch.randn(nheads, dtype=torch.bfloat16, device=device) + else: + sinks = None + + if dtype == torch.float8_e4m3fn: + q_descale, k_descale, v_descale = [ + torch.rand(batch_size, nheads_kv, device=device, dtype=torch.float32) + * 2 + for _ in range(3) + ] + else: + q_descale, k_descale, v_descale = None, None, None + q, k, v = [x.detach().requires_grad_() for x in (q_ref, k_ref, v_ref)] + qv = qv_ref.detach() if has_qv else None + query_padding_mask = generate_random_padding_mask( + seqlen_q, batch_size, device, mode="random", zero_lengths=False + ) + key_padding_mask = generate_random_padding_mask( + seqlen_k, batch_size, device, mode="random", zero_lengths=True + ) + + def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): + if add_unused: + another_mask = generate_random_padding_mask(max_seq_len, bs, device) + attn_mask = torch.logical_and(padding_mask, another_mask) + unused_mask = torch.logical_xor( + torch.logical_or(padding_mask, another_mask), attn_mask + ) + else: + attn_mask = padding_mask + unused_mask = None + return attn_mask, unused_mask + + query_padding_mask, query_unused_mask = _gen_unused_masks( + query_padding_mask, add_unused_qkv, seqlen_q, batch_size, q.device + ) + key_padding_mask, key_unused_mask = _gen_unused_masks( + key_padding_mask, add_unused_qkv, seqlen_k, batch_size, k.device + ) + + ( + q_unpad, + k_unpad, + v_unpad, + cu_seqlens_q, + cu_seqlens_k, + seqused_q, + seqused_k, + max_seqlen_q, + max_seqlen_k, + q, + k, + v, + output_pad_fn, + dq_pad_fn, + dk_pad_fn, + ) = generate_qkv( + q, + k, + v, + query_padding_mask, + key_padding_mask, + kvpacked=False, + query_unused_mask=query_unused_mask, + key_unused_mask=key_unused_mask, + ) + q_unpad, k_unpad, v_unpad = [ + x.detach().to(dtype).requires_grad_() for x in (q_unpad, k_unpad, v_unpad) + ] + out_ref, attn_ref = attention_ref( + q_ref, + k_ref, + v_ref, + query_padding_mask, + key_padding_mask, + causal=causal, + qv=qv_ref, + q_descale=q_descale, + k_descale=k_descale, + v_descale=v_descale, + window_size=window_size, + softcap=softcap, + sinks=sinks, + ) + out_pt, attn_pt = attention_ref( + q_ref, + k_ref, + v_ref, + query_padding_mask, + key_padding_mask, + causal=causal, + qv=qv_ref, + q_descale=q_descale, + k_descale=k_descale, + v_descale=v_descale, + window_size=window_size, + softcap=softcap, + upcast=False, + reorder_ops=True, + intermediate_dtype=dtype if dtype == torch.float8_e4m3fn else None, + sinks=sinks, + ) + + print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") + print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") + + if query_unused_mask is not None: + q_zero_masking = rearrange(query_unused_mask, "b s -> b s 1 1") + + # Numerical error if we just do any arithmetic on out_ref + fwd_atol = 2 * (out_ref + 0.3 - 0.3 - out_ref).abs().max().item() + rtol = 2 if softcap == 0.0 else 3 + + pack_gqa_vals = [False, True] if not DISABLE_PACKGQA else [False] + num_splits_vals = [1, 3] if not DISABLE_SPLIT else [1] + for pack_gqa, num_splits in itertools.product(pack_gqa_vals, num_splits_vals): + out_unpad, lse, *rest = flash_attn_varlen_func( + q_unpad, + k_unpad, + v_unpad, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + seqused_q=seqused_q, + seqused_k=seqused_k, + causal=causal, + q_descale=q_descale, + k_descale=k_descale, + v_descale=v_descale, + window_size=window_size, + softcap=softcap, + return_softmax_lse=True, + sinks=sinks, + ) + out = output_pad_fn(out_unpad) + if query_unused_mask is not None: + out.masked_fill_(q_zero_masking, 0.0) + print(f"Output max diff: {(out - out_ref).abs().max().item()}") + print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") + + # Check that FlashAttention's numerical error is at most 3x the numerical error + # of a Pytorch implementation. + assert (out - out_ref).abs().max().item() <= rtol * ( + out_pt - out_ref + ).abs().max().item() + fwd_atol + + if not DISABLE_BACKWARD and dtype != torch.float8_e4m3fn and not has_qv: + g_unpad = torch.randn_like(out_unpad) + do_o = ((g_unpad.float() * out_unpad.float()).sum(-1)).transpose(-1, -2) + dq_unpad, dk_unpad, dv_unpad = torch.autograd.grad( + out_unpad, (q_unpad, k_unpad, v_unpad), g_unpad + ) + dq = dq_pad_fn(dq_unpad) + dk = dk_pad_fn(dk_unpad) + dv = dk_pad_fn(dv_unpad) + if key_unused_mask is not None: + k_zero_masking = rearrange(key_unused_mask, "b s -> b s 1 1") + dk.masked_fill_(k_zero_masking, 0.0) + dv.masked_fill_(k_zero_masking, 0.0) + if query_unused_mask is not None: + dq.masked_fill_(q_zero_masking, 0.0) + # print(f"dO_O max diff: {(softmax_d - do_o).abs().max().item()}") + # assert (softmax_d - do_o).abs().max().item() <= 1e-5 + # assert dq_accum.abs().max().item() == 0.0 + g = output_pad_fn(g_unpad) + + # dq, dk, dv = torch.autograd.grad(out, (q, k, v), g) + dq_ref, dk_ref, dv_ref = torch.autograd.grad(out_ref, (q_ref, k_ref, v_ref), g) + dq_pt, dk_pt, dv_pt = torch.autograd.grad(out_pt, (q_ref, k_ref, v_ref), g) + print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}") + print(f"dK max diff: {(dk - dk_ref).abs().max().item()}") + print(f"dV max diff: {(dv - dv_ref).abs().max().item()}") + print(f"dQ mean diff: {(dq - dq_ref).abs().mean().item()}") + print(f"dK mean diff: {(dk - dk_ref).abs().mean().item()}") + print(f"dV mean diff: {(dv - dv_ref).abs().mean().item()}") + print(f"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}") + print(f"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}") + print(f"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}") + print(f"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}") + print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}") + print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}") + + if not DISABLE_BACKWARD and dtype != torch.float8_e4m3fn and not has_qv: + dq_atol = 2 * (dq_ref + 0.3 - 0.3 - dq_ref).abs().max().item() + ( + 0 if softcap == 0 else 3e-4 + ) + assert (dq - dq_ref).abs().max().item() <= rtol * ( + dq_pt - dq_ref + ).abs().max().item() + dq_atol + dk_atol = 2 * (dk_ref + 0.3 - 0.3 - dk_ref).abs().max().item() + ( + 0 if softcap == 0 else 3e-4 + ) + assert (dk - dk_ref).abs().max().item() <= rtol * ( + dk_pt - dk_ref + ).abs().max().item() + dk_atol + dv_atol = 2 * (dv_ref + 0.3 - 0.3 - dv_ref).abs().max().item() + ( + 0 if softcap == 0 else 3e-4 + ) + assert (dv - dv_ref).abs().max().item() <= rtol * ( + dv_pt - dv_ref + ).abs().max().item() + dv_atol + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/python/sglang/jit_kernel/tests/test_flash_attention_4.py b/python/sglang/jit_kernel/tests/test_flash_attention_4.py index e1453b8f2323..81b0f0b23d62 100644 --- a/python/sglang/jit_kernel/tests/test_flash_attention_4.py +++ b/python/sglang/jit_kernel/tests/test_flash_attention_4.py @@ -11,7 +11,7 @@ import torch.nn.functional as F from einops import rearrange, repeat -from sglang.jit_kernel.flash_attention_v4 import flash_attn_varlen_func +from sglang.jit_kernel.flash_attention import flash_attn_varlen_func from sglang.test.ci.ci_register import register_cuda_ci register_cuda_ci(est_time=120, suite="stage-b-kernel-unit-1-gpu-large") @@ -826,6 +826,7 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): sinks=learnable_sink, # FA4 uses learnable_sink, not sinks pack_gqa=pack_gqa, return_softmax_lse=True, + ver=4, ) out = output_pad_fn(out_unpad) if query_unused_mask is not None: @@ -1384,6 +1385,7 @@ def test_flash_attn_kvcache( softcap=0.0, pack_gqa=None, return_softmax_lse=True, + ver=4, ) if varlen_q: out = output_pad_fn(out) diff --git a/python/sglang/multimodal_gen/runtime/layers/attention/backends/flash_attn.py b/python/sglang/multimodal_gen/runtime/layers/attention/backends/flash_attn.py index 9c30a9798283..31372e2e16ce 100644 --- a/python/sglang/multimodal_gen/runtime/layers/attention/backends/flash_attn.py +++ b/python/sglang/multimodal_gen/runtime/layers/attention/backends/flash_attn.py @@ -5,27 +5,13 @@ import torch +from sglang.jit_kernel.flash_attention import flash_attn_varlen_func from sglang.multimodal_gen.runtime.layers.utils import register_custom_op from sglang.multimodal_gen.runtime.managers.forward_context import get_forward_context from sglang.multimodal_gen.runtime.platforms import ( AttentionBackendEnum, ) -try: - from sgl_kernel.flash_attn import flash_attn_varlen_func - - from sglang.jit_kernel.flash_attention_v4 import ( - flash_attn_varlen_func as flash_attn_varlen_func_fa4, - ) - - def flash_attn_func(*args, ver: int = 3, **kwargs): - if ver == 4: - return flash_attn_varlen_func_fa4(*args, **kwargs) - return flash_attn_varlen_func(*args, **kwargs) - -except ImportError as e: - raise e - def maybe_contiguous(x: Optional[torch.Tensor]) -> Optional[torch.Tensor]: return x.contiguous() if x is not None and x.stride(-1) != 1 else x @@ -207,7 +193,7 @@ def flash_attn_varlen_func_op( "flash_attn_varlen_func_op is out-only op; return_softmax_lse must be False. " "Use flash_attn_varlen_func_op_lse for (out, lse)." ) - return flash_attn_func( + return flash_attn_varlen_func( q, k, v, @@ -271,7 +257,7 @@ def flash_attn_varlen_func_op_lse( "flash_attn_varlen_func_op_lse is out+lse op; return_softmax_lse must be True. " "Use flash_attn_varlen_func_op for out-only." ) - return flash_attn_func( + return flash_attn_varlen_func( q, k, v, @@ -409,7 +395,7 @@ def forward( # - fa_ver == 3: call python function (can return Tensor or (Tensor, Tensor) depending on flag) # - fa_ver == 4: call custom ops with FIXED return schema if fa_ver == 3: - flash_attn_op = flash_attn_func + flash_attn_op = flash_attn_varlen_func output = flash_attn_op( q=query, k=key, diff --git a/python/sglang/srt/compilation/backend.py b/python/sglang/srt/compilation/backend.py index f9d376e959be..201123324068 100644 --- a/python/sglang/srt/compilation/backend.py +++ b/python/sglang/srt/compilation/backend.py @@ -21,6 +21,7 @@ from sglang.srt.compilation.cuda_piecewise_backend import CUDAPiecewiseBackend from sglang.srt.compilation.npu_piecewise_backend import NPUPiecewiseBackend from sglang.srt.compilation.pass_manager import PostGradPassManager +from sglang.srt.environ import envs from sglang.srt.utils.common import is_npu logger = logging.getLogger(__name__) @@ -393,9 +394,7 @@ def configure_post_pass(self): self.inductor_config["post_grad_custom_post_pass"] = self.post_grad_pass_manager def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable: - base_cache_dir = os.path.expanduser( - os.getenv("SGLANG_CACHE_DIR", "~/.cache/sglang/") - ) + base_cache_dir = envs.SGLANG_CACHE_DIR.get() cache_hash = self.compiler_manager.compute_hash() cache_dir = os.path.join( diff --git a/python/sglang/srt/environ.py b/python/sglang/srt/environ.py index e6eb48212b2a..c13d58575f6a 100644 --- a/python/sglang/srt/environ.py +++ b/python/sglang/srt/environ.py @@ -406,6 +406,9 @@ class Envs: # sgl-kernel SGLANG_SKIP_SGL_KERNEL_VERSION_CHECK = EnvBool(False) + # Flash Attention + SGLANG_USE_SGL_FA3_KERNEL = EnvBool(True) + # vLLM dependencies (TODO: they have been deprecated, we can remove them safely) USE_VLLM_CUTLASS_W8A8_FP8_KERNEL = EnvBool(False) @@ -534,6 +537,9 @@ class Envs: # Elastic EP Backup Port SGLANG_BACKUP_PORT_BASE = EnvInt(10000) + # Sglang Cache Dir + SGLANG_CACHE_DIR = EnvStr(os.path.expanduser("~/.cache/sglang")) + envs = Envs() EnvField._allow_set_name = False diff --git a/python/sglang/srt/layers/attention/dual_chunk_flashattention_backend.py b/python/sglang/srt/layers/attention/dual_chunk_flashattention_backend.py index e522fbe4a934..a84015a803f8 100644 --- a/python/sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +++ b/python/sglang/srt/layers/attention/dual_chunk_flashattention_backend.py @@ -9,13 +9,16 @@ import torch import torch.nn.functional as F -from sgl_kernel.flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache from sgl_kernel.sparse_flash_attn import ( convert_vertical_slash_indexes, convert_vertical_slash_indexes_mergehead, sparse_attn_func, ) +from sglang.jit_kernel.flash_attention import ( + flash_attn_varlen_func, + flash_attn_with_kvcache, +) from sglang.srt.distributed.parallel_state import get_tensor_model_parallel_rank from sglang.srt.layers.attention.base_attn_backend import AttentionBackend from sglang.srt.layers.attention.flashattention_backend import FlashAttentionMetadata diff --git a/python/sglang/srt/layers/attention/flashattention_backend.py b/python/sglang/srt/layers/attention/flashattention_backend.py index ff170c390838..ad7f59c0d539 100644 --- a/python/sglang/srt/layers/attention/flashattention_backend.py +++ b/python/sglang/srt/layers/attention/flashattention_backend.py @@ -27,6 +27,11 @@ from sgl_kernel import merge_state_v2 +from sglang.jit_kernel.flash_attention import ( + flash_attn_varlen_func, + flash_attn_with_kvcache, +) + @dataclass class FlashAttentionMetadata: @@ -616,9 +621,6 @@ def forward_extend( and not is_swa_layer ) - flash_attn_varlen_func = self.flash_attn_varlen_func - flash_attn_with_kvcache = self.flash_attn_with_kvcache - kwargs = {} if sinks is not None: kwargs["sinks"] = sinks @@ -696,6 +698,7 @@ def _fa_cp_attn( v_descale=v_descale, return_softmax_lse=use_cascade_attn, num_splits=self.num_splits, + ver=self.fa_impl_ver, **kwargs, ) @@ -723,6 +726,7 @@ def _fa_cp_attn( v_descale=v_descale, return_softmax_lse=use_cascade_attn, num_splits=self.num_splits, + ver=self.fa_impl_ver, **kwargs, ) @@ -750,6 +754,7 @@ def _fa_cp_attn( v_descale=v_descale, return_softmax_lse=True, num_splits=self.num_splits, + ver=self.fa_impl_ver, **kwargs, ) o, _ = merge_state_v2_wrapper( @@ -789,6 +794,7 @@ def _fa_cp_attn( softmax_scale=layer.scaling, causal=False, return_softmax_lse=True, + ver=self.fa_impl_ver, **kwargs, ) else: @@ -814,6 +820,7 @@ def _fa_cp_attn( softmax_scale=layer.scaling, causal=True, return_softmax_lse=forward_batch.mha_return_lse, + ver=self.fa_impl_ver, **kwargs, ) if forward_batch.mha_return_lse: @@ -822,7 +829,7 @@ def _fa_cp_attn( return output, lse return output else: - assert self.fa_impl_ver in [3], "Only FA3 support here" + assert self.fa_impl_ver == 3, "Only FA3 support here" # Do absorbed multi-latent attention kv_cache = forward_batch.token_to_kv_pool.get_key_buffer( layer.layer_id @@ -865,6 +872,7 @@ def _fa_cp_attn( v_descale=v_descale, return_softmax_lse=use_cascade_attn, num_splits=self.num_splits, + ver=self.fa_impl_ver, ) if use_cascade_attn: o, softmax_lse, *rest = result @@ -887,6 +895,7 @@ def _fa_cp_attn( v_descale=v_descale, return_softmax_lse=True, num_splits=self.num_splits, + ver=self.fa_impl_ver, ) ) o, _ = merge_state_v2_wrapper( @@ -964,8 +973,6 @@ def forward_decode( if sinks is not None: kwargs["sinks"] = sinks - flash_attn_with_kvcache = self.flash_attn_with_kvcache - k_descale, v_descale = None, None # only use kv scaling if: 1) fp8 kv is explicitly enabled, 2) RadixAttention # has corresponding quantization method so that layer.k_scale is not None, @@ -1009,6 +1016,7 @@ def forward_decode( k_descale=k_descale, v_descale=v_descale, num_splits=self.num_splits, + ver=self.fa_impl_ver, **kwargs, ) elif use_local_attn: @@ -1029,6 +1037,7 @@ def forward_decode( k_descale=k_descale, v_descale=v_descale, num_splits=self.num_splits, + ver=self.fa_impl_ver, **kwargs, ) else: @@ -1066,6 +1075,7 @@ def forward_decode( v_descale=v_descale, return_softmax_lse=use_cascade_attn, num_splits=self.num_splits, + ver=self.fa_impl_ver, **kwargs, ) if use_cascade_attn: @@ -1088,6 +1098,7 @@ def forward_decode( v_descale=v_descale, return_softmax_lse=True, num_splits=self.num_splits, + ver=self.fa_impl_ver, **kwargs, ) ) @@ -1144,6 +1155,7 @@ def forward_decode( v_descale=v_descale, return_softmax_lse=use_cascade_attn, # softmax_lse is needed for merge states num_splits=self.num_splits, + ver=self.fa_impl_ver, ) if use_cascade_attn: o, softmax_lse, *rest = result @@ -1165,6 +1177,7 @@ def forward_decode( v_descale=v_descale, return_softmax_lse=True, num_splits=self.num_splits, + ver=self.fa_impl_ver, ) o, _ = merge_state_v2( o, diff --git a/python/sglang/srt/layers/attention/nsa_backend.py b/python/sglang/srt/layers/attention/nsa_backend.py index 862488e5f918..314c897ab313 100644 --- a/python/sglang/srt/layers/attention/nsa_backend.py +++ b/python/sglang/srt/layers/attention/nsa_backend.py @@ -61,7 +61,10 @@ "aiter is AMD specific kernel library. Please make sure aiter is installed on your AMD device." ) else: - from sgl_kernel.flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache + from sglang.jit_kernel.flash_attention import ( + flash_attn_varlen_func, + flash_attn_with_kvcache, + ) # Reuse this workspace buffer across all NSA backend instances diff --git a/python/sglang/srt/layers/attention/vision.py b/python/sglang/srt/layers/attention/vision.py index 23dba24584e9..a624ad06e022 100644 --- a/python/sglang/srt/layers/attention/vision.py +++ b/python/sglang/srt/layers/attention/vision.py @@ -38,21 +38,9 @@ if _is_cuda: from flashinfer.prefill import cudnn_batch_prefill_with_kv_cache - try: - from sgl_kernel.flash_attn import flash_attn_varlen_func - - def flash_attn_func(*args, ver: int = 3, **kwargs): - if ver == 4: - from sglang.jit_kernel.flash_attention_v4 import ( - flash_attn_varlen_func as flash_attn_varlen_func_fa4, - ) - - return flash_attn_varlen_func_fa4(*args, **kwargs) - return flash_attn_varlen_func(*args, **kwargs) - - except ImportError as e: - raise e - + from sglang.jit_kernel.flash_attention import ( + flash_attn_varlen_func, + ) if _is_npu: import torch_npu @@ -420,7 +408,7 @@ def forward( """ if envs.SGLANG_VIT_ENABLE_CUDA_GRAPH.get(): max_seqlen = cu_seqlens[1] - output = flash_attn_func( + output = flash_attn_varlen_func( q, k, v, @@ -436,7 +424,7 @@ def forward( seq_lens = cu_seqlens[1:] - cu_seqlens[:-1] max_seqlen = seq_lens.max().item() - output = flash_attn_func( + output = flash_attn_varlen_func( q, k, v, @@ -489,7 +477,7 @@ def forward( seq_lens = cu_seqlens[1:] - cu_seqlens[:-1] max_seqlen = seq_lens.max().item() - output = flash_attn_func( + output = flash_attn_varlen_func( q, k, v, diff --git a/python/sglang/srt/layers/attention/xpu_backend.py b/python/sglang/srt/layers/attention/xpu_backend.py index 4a40d25ee8c9..77e773d88d0c 100644 --- a/python/sglang/srt/layers/attention/xpu_backend.py +++ b/python/sglang/srt/layers/attention/xpu_backend.py @@ -20,7 +20,11 @@ from sglang.srt.model_executor.model_runner import ModelRunner from sgl_kernel import merge_state_v2 -from sgl_kernel.flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache + +from sglang.jit_kernel.flash_attention import ( + flash_attn_varlen_func, + flash_attn_with_kvcache, +) class XPUAttentionBackend(AttentionBackend): diff --git a/python/sglang/srt/utils/runai_utils.py b/python/sglang/srt/utils/runai_utils.py index 0424a6371bde..dd74efb6626d 100644 --- a/python/sglang/srt/utils/runai_utils.py +++ b/python/sglang/srt/utils/runai_utils.py @@ -5,6 +5,8 @@ import os from pathlib import Path +from sglang.srt.environ import envs + logger = logging.getLogger(__name__) SUPPORTED_SCHEMES = ["s3://", "gs://", "az://"] @@ -26,12 +28,6 @@ # This avoids file locks, race conditions, and duplicate downloads -def get_cache_dir() -> str: - # Expand user path (~) to ensure absolute paths for locking - path = os.getenv("SGLANG_CACHE_DIR", "~/.cache/sglang/") - return os.path.expanduser(path) - - def list_safetensors(path: str = "") -> list[str]: """ List full file names from object path and filter by allow pattern. @@ -122,7 +118,7 @@ def get_path(cls, model_path: str) -> str: Returns the local directory path. """ model_hash = hashlib.sha256(str(model_path).encode()).hexdigest()[:16] - base_dir = get_cache_dir() + base_dir = envs.SGLANG_CACHE_DIR.get() # Ensure base cache dir exists os.makedirs(os.path.join(base_dir, "model_streamer"), exist_ok=True) diff --git a/scripts/ci/cuda/ci_install_dependency.sh b/scripts/ci/cuda/ci_install_dependency.sh index 5bfbea04ffeb..c10a79e62222 100755 --- a/scripts/ci/cuda/ci_install_dependency.sh +++ b/scripts/ci/cuda/ci_install_dependency.sh @@ -358,6 +358,10 @@ mark_step_done "Fix other dependencies" # can delete the .pth file without reliably recreating it (pip race condition). $PIP_CMD install "nvidia-cutlass-dsl>=4.4.1" "nvidia-cutlass-dsl-libs-base>=4.4.1" --no-deps --force-reinstall $PIP_INSTALL_SUFFIX || true +# Download kernels from kernels community +kernels download python || true +kernels lock python || true +mv python/kernels.lock ${HOME}/.cache/sglang || true # Install human-eval pip install "setuptools==70.0.0" diff --git a/test/srt/cpu/test_flash_attn.py b/test/srt/cpu/test_flash_attn.py index 8b1faa98b5cb..4e1968fa06e7 100644 --- a/test/srt/cpu/test_flash_attn.py +++ b/test/srt/cpu/test_flash_attn.py @@ -1,15 +1,12 @@ import unittest -import sgl_kernel # noqa: F401 import torch import torch.nn.functional as F from utils import parametrize, precision +from sglang.jit_kernel.flash_attention import flash_attn_varlen_func from sglang.test.test_utils import CustomTestCase -flash_attn_varlen_func = torch.ops.sgl_kernel.flash_attn_varlen_func - - torch.manual_seed(1234) From cc35714b034dba3724d2ca2977f607c998fb2ed9 Mon Sep 17 00:00:00 2001 From: Liangsheng Yin Date: Tue, 7 Apr 2026 13:08:35 -0700 Subject: [PATCH 33/42] [tiny] migrate /get_server_info; print accept length in accuracy tests (#22282) --- python/sglang/test/kits/eval_accuracy_kit.py | 26 ++++++++++++------- .../spec/test_ngram_speculative_decoding.py | 2 +- 2 files changed, 17 insertions(+), 11 deletions(-) diff --git a/python/sglang/test/kits/eval_accuracy_kit.py b/python/sglang/test/kits/eval_accuracy_kit.py index 25bf58151c9f..9757dc01523e 100644 --- a/python/sglang/test/kits/eval_accuracy_kit.py +++ b/python/sglang/test/kits/eval_accuracy_kit.py @@ -9,12 +9,16 @@ _THRESHOLD_NOT_SET = float("nan") -def _check_accept_length(test_case, base_url, threshold): - """Check speculative decoding accept length from server info.""" - server_info = requests.get(base_url + "/get_server_info").json() - avg_spec_accept_length = server_info["internal_states"][0]["avg_spec_accept_length"] - print(f"{avg_spec_accept_length=}") - test_case.assertGreater(avg_spec_accept_length, threshold) +def _check_accept_length(test_case, base_url, threshold=None): + """Print accept length; optionally assert it exceeds threshold.""" + try: + server_info = requests.get(base_url + "/server_info").json() + val = server_info["internal_states"][0]["avg_spec_accept_length"] + except (KeyError, IndexError, requests.RequestException): + return + print(f"avg_spec_accept_length={val:.4f}") + if threshold is not None: + test_case.assertGreater(val, threshold) class GSM8KMixin: @@ -57,8 +61,7 @@ def test_gsm8k(self): self.assertGreaterEqual(metrics["score"], self.gsm8k_accuracy_thres) - if self.gsm8k_accept_length_thres is not None: - _check_accept_length(self, self.base_url, self.gsm8k_accept_length_thres) + _check_accept_length(self, self.base_url, self.gsm8k_accept_length_thres) class MMLUMixin: @@ -95,8 +98,7 @@ def test_mmlu(self): self.assertGreaterEqual(metrics["score"], self.mmlu_score_threshold) - if self.mmlu_accept_length_thres is not None: - _check_accept_length(self, self.base_url, self.mmlu_accept_length_thres) + _check_accept_length(self, self.base_url, self.mmlu_accept_length_thres) class HumanEvalMixin: @@ -136,6 +138,8 @@ def test_human_eval(self): self.assertGreaterEqual(metrics["score"], threshold) + _check_accept_length(self, self.base_url) + class MGSMEnMixin: """Mixin for MGSM English evaluation. @@ -169,3 +173,5 @@ def test_mgsm_en(self): write_github_step_summary(f"### test_mgsm_en\n{metrics['score']=:.4f}\n") self.assertGreaterEqual(metrics["score"], self.mgsm_en_score_threshold) + + _check_accept_length(self, self.base_url) diff --git a/test/registered/spec/test_ngram_speculative_decoding.py b/test/registered/spec/test_ngram_speculative_decoding.py index f80b1e646dea..d8e0c467b6b4 100644 --- a/test/registered/spec/test_ngram_speculative_decoding.py +++ b/test/registered/spec/test_ngram_speculative_decoding.py @@ -111,7 +111,7 @@ def generate_batch(): return outputs def get_accept_length(): - info = requests.get(self.base_url + "/get_server_info").json() + info = requests.get(self.base_url + "/server_info").json() return info["internal_states"][0]["avg_spec_accept_length"] # Phase 1: baseline — no SAM corpus loaded, only trie From e14876742a08842681bfe4a13d3e719ba58cc319 Mon Sep 17 00:00:00 2001 From: YC Yen-Ching Tseng Date: Wed, 8 Apr 2026 04:48:37 +0800 Subject: [PATCH 34/42] [AMD] Fix test_kimi_k25_mxfp4.py : stage-c-test-large-8-gpu-amd-mi35x (linux-mi35x-gpu-8, 1) (#22188) --- test/registered/amd/test_kimi_k25_mxfp4.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/test/registered/amd/test_kimi_k25_mxfp4.py b/test/registered/amd/test_kimi_k25_mxfp4.py index a4ef774304c8..1ce83f8eb928 100644 --- a/test/registered/amd/test_kimi_k25_mxfp4.py +++ b/test/registered/amd/test_kimi_k25_mxfp4.py @@ -27,6 +27,7 @@ register_amd_ci(est_time=3600, suite="stage-c-test-large-8-gpu-amd-mi35x") KIMI_K25_MXFP4_MODEL_PATH = "amd/Kimi-K2.5-MXFP4" +KIMI_K25_MXFP4_REVISION = "b071bc6f8eb042e093e14f3b8bdbad71c18e09d3" SERVER_LAUNCH_TIMEOUT = 3600 @@ -36,6 +37,8 @@ def setUpClass(cls): cls.model = KIMI_K25_MXFP4_MODEL_PATH cls.base_url = DEFAULT_URL_FOR_TEST other_args = [ + "--revision", + KIMI_K25_MXFP4_REVISION, "--tp", "8", "--attention-backend", From f08726fd56c7ff6d8bd258f1545f98148fa4ef58 Mon Sep 17 00:00:00 2001 From: David Wang <21328423+dcw02@users.noreply.github.com> Date: Tue, 7 Apr 2026 17:48:51 -0400 Subject: [PATCH 35/42] [Feature] Add DFLASH speculative decoding support (#22077) Co-authored-by: Jian Chen <141193260+jianc99@users.noreply.github.com> Co-authored-by: Zhijian Liu <5782437+zhijian-liu@users.noreply.github.com> Co-authored-by: Richard Gong <8001209+gongy@users.noreply.github.com> Co-authored-by: David Wang <21328423+dcw02@users.noreply.github.com> Co-authored-by: yilian49 <43861414+yilian49@users.noreply.github.com> Co-authored-by: xm:D <38322020+xiaomin-d@users.noreply.github.com> --- .../layers/attention/flashinfer_backend.py | 30 +- python/sglang/srt/managers/scheduler.py | 26 + .../srt/model_executor/cuda_graph_runner.py | 92 +- .../sglang/srt/model_executor/model_runner.py | 94 +- .../model_runner_kv_cache_mixin.py | 16 + python/sglang/srt/models/dflash.py | 399 ++++++ python/sglang/srt/models/llama.py | 12 + python/sglang/srt/server_args.py | 147 +- python/sglang/srt/speculative/dflash_info.py | 501 +++++++ python/sglang/srt/speculative/dflash_utils.py | 637 +++++++++ .../sglang/srt/speculative/dflash_worker.py | 1245 +++++++++++++++++ python/sglang/srt/speculative/spec_info.py | 25 +- .../srt/speculative/triton_ops/__init__.py | 20 + .../triton_ops/fused_kv_materialize.py | 303 ++++ python/sglang/test/test_utils.py | 4 + test/registered/spec/dflash/test_dflash.py | 152 ++ 16 files changed, 3666 insertions(+), 37 deletions(-) create mode 100644 python/sglang/srt/models/dflash.py create mode 100644 python/sglang/srt/speculative/dflash_info.py create mode 100644 python/sglang/srt/speculative/dflash_utils.py create mode 100644 python/sglang/srt/speculative/dflash_worker.py create mode 100644 python/sglang/srt/speculative/triton_ops/__init__.py create mode 100644 python/sglang/srt/speculative/triton_ops/fused_kv_materialize.py create mode 100644 test/registered/spec/dflash/test_dflash.py diff --git a/python/sglang/srt/layers/attention/flashinfer_backend.py b/python/sglang/srt/layers/attention/flashinfer_backend.py index 4fe8aec31301..c1e2ea4fcdab 100644 --- a/python/sglang/srt/layers/attention/flashinfer_backend.py +++ b/python/sglang/srt/layers/attention/flashinfer_backend.py @@ -596,8 +596,24 @@ def init_forward_metadata_capture_cuda_graph( fast_decode_plan, decode_wrappers[i] ) elif forward_mode.is_target_verify(): + # FlashInfer's prefill wrapper decides mask mode based on whether + # `custom_mask_buf` is initialized (not whether a custom mask is provided). + # For cases like DFLASH draft (ENCODER_ONLY / non-causal) we do NOT use a + # custom mask, so we must avoid initializing `custom_mask_buf`, otherwise + # FlashInfer will treat the (zero) buffer as a real mask and block attention. + use_custom_mask = ( + spec_info is not None + and getattr(spec_info, "custom_mask", None) is not None + ) prefill_wrappers = [] for i in range(self.num_wrappers): + wrapper_kwargs = {} + if use_custom_mask: + wrapper_kwargs = { + "custom_mask_buf": self.cuda_graph_custom_mask, + "mask_indptr_buf": self.cuda_graph_qk_indptr[i][: bs + 1], + } + prefill_wrappers.append( BatchPrefillWithPagedKVCacheWrapper( self.workspace_buffer, @@ -608,8 +624,7 @@ def init_forward_metadata_capture_cuda_graph( paged_kv_indptr_buf=self.kv_indptr[i][: bs + 1], paged_kv_indices_buf=self.cuda_graph_kv_indices[i], paged_kv_last_page_len_buf=self.kv_last_page_len[:bs], - custom_mask_buf=self.cuda_graph_custom_mask, - mask_indptr_buf=self.cuda_graph_qk_indptr[i][: bs + 1], + **wrapper_kwargs, ) ) seq_lens_sum = seq_lens.sum().item() @@ -783,10 +798,14 @@ def forward_extend( layer, cache_loc, k, v, layer.k_scale, layer.v_scale ) + causal = ( + not layer.is_cross_attention + and layer.attn_type != AttentionType.ENCODER_ONLY + ) o = prefill_wrapper_paged.forward( q.view(-1, layer.tp_q_head_num, layer.head_dim), forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id), - causal=not layer.is_cross_attention, + causal=causal, sm_scale=layer.scaling, # Disable sliding window attention for multi-item scoring: # - Sliding window could cut across item boundaries, breaking semantic coherence @@ -838,11 +857,6 @@ def forward_extend( ) else: - if not self.is_dllm_model: - # TODO: design a better interface - # For other models, use causal attention for the ragged part as previously - causal = True - o1, s1 = self.prefill_wrapper_ragged.forward_return_lse( q.view(-1, layer.tp_q_head_num, layer.head_dim), k.view(-1, layer.tp_k_head_num, layer.head_dim), diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 36c55826d821..377a7ec749fb 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -276,6 +276,24 @@ def copy_to_cpu(self): self.copy_done.record() +def validate_dflash_request(req: Req) -> Optional[str]: + if req.return_logprob: + return "DFLASH speculative decoding does not support return_logprob yet." + + if ( + req.sampling_params.json_schema is not None + or req.sampling_params.regex is not None + or req.sampling_params.ebnf is not None + or req.sampling_params.structural_tag is not None + ): + return ( + "DFLASH speculative decoding does not support " + "grammar-constrained decoding yet." + ) + + return None + + class Scheduler( SchedulerOutputProcessorMixin, SchedulerUpdateWeightsMixin, @@ -1861,6 +1879,14 @@ def handle_generate_request( self._add_request_to_queue(req) return + if self.spec_algorithm.is_dflash(): + error_msg = validate_dflash_request(req) + if error_msg is not None: + req.set_finish_with_abort(error_msg) + self.init_req_max_new_tokens(req) + self._add_request_to_queue(req) + return + # Handle multimodal inputs if recv_req.mm_inputs is not None: image_inputs = self._get_multimodal_inputs(recv_req.mm_inputs) diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py index c7c7d6b5ec0b..69cb176efbdc 100644 --- a/python/sglang/srt/model_executor/cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/cuda_graph_runner.py @@ -547,18 +547,15 @@ def __init__(self, model_runner: ModelRunner): self.capture_forward_mode = ForwardMode.DECODE self.capture_hidden_mode = CaptureHiddenMode.NULL self.num_tokens_per_bs = 1 - if ( - model_runner.spec_algorithm.is_eagle() - or model_runner.spec_algorithm.is_standalone() - or model_runner.spec_algorithm.is_ngram() - ): + if model_runner.spec_algorithm.is_speculative(): if self.model_runner.is_draft_worker: - raise RuntimeError("This should not happen") - else: - self.capture_forward_mode = ForwardMode.TARGET_VERIFY - self.num_tokens_per_bs = ( - self.model_runner.server_args.speculative_num_draft_tokens - ) + # DFLASH draft workers reuse this runner for TARGET_VERIFY mode. + if not self.model_runner.spec_algorithm.is_dflash(): + raise RuntimeError("This should not happen") + self.capture_forward_mode = ForwardMode.TARGET_VERIFY + self.num_tokens_per_bs = ( + self.model_runner.server_args.speculative_num_draft_tokens + ) elif self.is_dllm: self.capture_forward_mode = ForwardMode.DLLM_EXTEND self.num_tokens_per_bs = self.dllm_config.block_size @@ -646,6 +643,18 @@ def __init__(self, model_runner: ModelRunner): and model_runner.eagle_use_aux_hidden_state ): self.model_runner.model.set_eagle3_layers_to_capture() + if ( + model_runner.spec_algorithm.is_dflash() + and model_runner.dflash_use_aux_hidden_state + ): + if not hasattr(self.model_runner.model, "set_dflash_layers_to_capture"): + raise ValueError( + f"Model {self.model_runner.model.__class__.__name__} does not implement set_dflash_layers_to_capture, " + "which is required for DFLASH aux hidden capture." + ) + self.model_runner.model.set_dflash_layers_to_capture( + self.model_runner.dflash_target_layer_ids + ) # Capture try: @@ -671,6 +680,7 @@ def can_run(self, forward_batch: ForwardBatch): max(forward_batch.global_num_tokens_cpu) // self.num_tokens_per_bs if self.model_runner.spec_algorithm.is_eagle() or self.model_runner.spec_algorithm.is_standalone() + or self.model_runner.spec_algorithm.is_dflash() else max(forward_batch.global_num_tokens_cpu) ) else: @@ -1007,6 +1017,12 @@ def run_once(): kwargs["pp_proxy_tensors"] = PPProxyTensors( {k: v.clone() for k, v in pp_proxy_tensors.tensors.items()} ) + if ( + self.model_runner.spec_algorithm.is_dflash() + and self.model_runner.is_draft_worker + and "input_embeds" in inspect.signature(forward).parameters + ): + kwargs["input_embeds"] = buffers.input_embeds[:num_tokens] logits_output_or_pp_proxy_tensors = forward( input_ids, @@ -1083,6 +1099,7 @@ def replay_prepare( max_num_tokens / self.num_tokens_per_bs if self.model_runner.spec_algorithm.is_eagle() or self.model_runner.spec_algorithm.is_standalone() + or self.model_runner.spec_algorithm.is_dflash() else max_num_tokens ) index = bisect.bisect_left(self.capture_bs, max_batch_size) @@ -1104,6 +1121,13 @@ def replay_prepare( ), pp_proxy_tensors=pp_proxy_tensors, ) + if ( + self.model_runner.spec_algorithm.is_dflash() + and self.model_runner.is_draft_worker + and forward_batch.input_embeds is not None + ): + buffers.input_embeds[:raw_num_token].copy_(forward_batch.input_embeds) + # Padded tokens aren't read, so skip zeroing them. if self.enable_two_batch_overlap: self.tbo_plugin.replay_prepare( forward_mode=self.capture_forward_mode, @@ -1152,6 +1176,14 @@ def replay( # In speculative decoding, these two fields are still needed. self.buffers.input_ids[: self.raw_num_token].copy_(forward_batch.input_ids) self.buffers.positions[: self.raw_num_token].copy_(forward_batch.positions) + if ( + self.model_runner.spec_algorithm.is_dflash() + and self.model_runner.is_draft_worker + and forward_batch.input_embeds is not None + ): + self.buffers.input_embeds[: self.raw_num_token].copy_( + forward_batch.input_embeds + ) # Replay if self.enable_pdmux: @@ -1164,10 +1196,18 @@ def replay( if isinstance(output, LogitsProcessorOutput): if self.is_dllm: next_token_logits = None - full_logits = output.full_logits[: self.raw_num_token] + full_logits = ( + output.full_logits[: self.raw_num_token] + if output.full_logits is not None + else None + ) else: full_logits = None - next_token_logits = output.next_token_logits[: self.raw_num_token] + next_token_logits = ( + output.next_token_logits[: self.raw_num_token] + if output.next_token_logits is not None + else None + ) return LogitsProcessorOutput( next_token_logits=next_token_logits, @@ -1209,6 +1249,32 @@ def get_spec_info(self, num_tokens: int): seq_lens_sum=None, seq_lens_cpu=None, ) + elif self.model_runner.spec_algorithm.is_dflash(): + from sglang.srt.speculative.dflash_info import DFlashVerifyInput + from sglang.srt.speculative.dflash_utils import ( + resolve_dflash_verify_mask_policy, + ) + + # Avoid enabling custom-mask modes during graph capture for backends that + # can express DFLASH verify via their built-in causal path. + _, build_custom_mask = resolve_dflash_verify_mask_policy( + self.model_runner.attn_backend + ) + spec_info = DFlashVerifyInput( + draft_token=None, + positions=None, + draft_token_num=self.model_runner.server_args.speculative_num_draft_tokens, + custom_mask=( + None + if (self.model_runner.is_draft_worker or not build_custom_mask) + else self.buffers.custom_mask + ), + capture_hidden_mode=( + CaptureHiddenMode.NULL + if self.model_runner.is_draft_worker + else CaptureHiddenMode.FULL + ), + ) elif self.model_runner.spec_algorithm.is_ngram(): from sglang.srt.speculative.ngram_info import NgramVerifyInput diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 26d8bd82a2f1..e2dadafe62c3 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -354,6 +354,9 @@ def __init__( self.remote_instance_transfer_engine_weight_info = None # auxiliary hidden capture mode. TODO: expose this to server args? self.eagle_use_aux_hidden_state = False + self.dflash_use_aux_hidden_state = False + self.dflash_target_layer_ids = None + self.dflash_draft_num_layers = None if self.spec_algorithm.is_eagle3() and not self.is_draft_worker: # load draft config draft_model_config = ModelConfig.from_server_args( @@ -379,6 +382,52 @@ def __init__( # if there is no aux layer, set to None self.eagle_aux_hidden_state_layer_ids = None + if self.spec_algorithm.is_dflash() and not self.is_draft_worker: + from sglang.srt.speculative.dflash_utils import ( + parse_dflash_draft_config, + ) + + # Select target layers to capture for building DFlash context features. + draft_model_config = ModelConfig.from_server_args( + server_args, + model_path=(server_args.speculative_draft_model_path), + model_revision=server_args.speculative_draft_model_revision, + is_draft_model=True, + ) + dflash_draft_config = parse_dflash_draft_config( + draft_hf_config=draft_model_config.hf_config + ) + draft_num_layers = dflash_draft_config.require_num_layers() + trained_target_layers = dflash_draft_config.num_target_layers + + target_num_layers = getattr( + self.model_config.hf_text_config, "num_hidden_layers", None + ) + if target_num_layers is None: + raise ValueError( + "DFLASH requires target num_hidden_layers in config. " + f"Got target={target_num_layers}." + ) + target_num_layers = int(target_num_layers) + + if ( + trained_target_layers is not None + and trained_target_layers != target_num_layers + ): + logger.warning( + "DFLASH draft config num_target_layers=%s differs from runtime target num_hidden_layers=%s; " + "selecting capture layers based on the runtime target model.", + trained_target_layers, + target_num_layers, + ) + + self.dflash_use_aux_hidden_state = True + self.dflash_draft_num_layers = int(draft_num_layers) + self.dflash_target_layer_ids = dflash_draft_config.resolve_target_layer_ids( + target_num_layers=int(target_num_layers), + draft_num_layers=int(draft_num_layers), + ) + # Apply the rank zero filter to logger if server_args.show_time_cost: enable_show_time_cost() @@ -670,6 +719,14 @@ def initialize(self, pre_model_load_memory: float): self.eagle_aux_hidden_state_layer_ids ) + if self.dflash_use_aux_hidden_state: + if not hasattr(self.model, "set_dflash_layers_to_capture"): + raise ValueError( + f"Model {self.model.__class__.__name__} does not implement set_dflash_layers_to_capture, " + "which is required for DFLASH." + ) + self.model.set_dflash_layers_to_capture(self.dflash_target_layer_ids) + # Initialize piecewise CUDA graph self.init_piecewise_cuda_graphs() @@ -2100,11 +2157,7 @@ def _should_run_flashinfer_autotune(self) -> bool: if major < 9: return False - if ( - self.spec_algorithm.is_eagle() - or self.spec_algorithm.is_standalone() - or self.spec_algorithm.is_ngram() - ): + if self.spec_algorithm.is_speculative(): return not self.is_draft_worker return True @@ -2134,16 +2187,12 @@ def _dummy_run(self, batch_size: int, run_ctx=None): capture_forward_mode = ForwardMode.EXTEND capture_hidden_mode = CaptureHiddenMode.NULL num_tokens_per_bs = 1 - if ( - self.spec_algorithm.is_eagle() - or self.spec_algorithm.is_standalone() - or self.spec_algorithm.is_ngram() - ): + if self.spec_algorithm.is_speculative(): if self.is_draft_worker: - raise RuntimeError("This should not happen") - else: - capture_forward_mode = ForwardMode.TARGET_VERIFY - num_tokens_per_bs = self.server_args.speculative_num_draft_tokens + if not self.spec_algorithm.is_dflash(): + raise RuntimeError("This should not happen") + capture_forward_mode = ForwardMode.TARGET_VERIFY + num_tokens_per_bs = self.server_args.speculative_num_draft_tokens if self.server_args.enable_return_hidden_states: capture_hidden_mode = CaptureHiddenMode.FULL @@ -2173,6 +2222,8 @@ def _dummy_run(self, batch_size: int, run_ctx=None): if self.eagle_use_aux_hidden_state: self.model.set_eagle3_layers_to_capture() + if self.dflash_use_aux_hidden_state: + self.model.set_dflash_layers_to_capture(self.dflash_target_layer_ids) require_mlp_tp_gather_ = require_mlp_tp_gather(self.server_args) if require_gathered_buffer(self.server_args): @@ -2286,6 +2337,21 @@ def get_spec_info(): seq_lens_sum=None, seq_lens_cpu=None, ) + elif self.spec_algorithm.is_dflash(): + from sglang.srt.speculative.dflash_info import DFlashVerifyInput + + # Dummy warmup only needs shape metadata; avoid forcing custom-mask mode. + spec_info = DFlashVerifyInput( + draft_token=None, + positions=None, + draft_token_num=self.server_args.speculative_num_draft_tokens, + custom_mask=None, + capture_hidden_mode=( + CaptureHiddenMode.NULL + if self.is_draft_worker + else CaptureHiddenMode.FULL + ), + ) elif self.spec_algorithm.is_ngram(): from sglang.srt.speculative.ngram_info import NgramVerifyInput diff --git a/python/sglang/srt/model_executor/model_runner_kv_cache_mixin.py b/python/sglang/srt/model_executor/model_runner_kv_cache_mixin.py index a6baa4817ace..bca2baca64f9 100644 --- a/python/sglang/srt/model_executor/model_runner_kv_cache_mixin.py +++ b/python/sglang/srt/model_executor/model_runner_kv_cache_mixin.py @@ -167,6 +167,22 @@ def profile_max_num_token(self: ModelRunner, pre_model_load_memory: int): num_layers = self.num_effective_layers cell_size = self.get_cell_size_per_token(num_layers) + if self.spec_algorithm.is_dflash() and not self.is_draft_worker: + from sglang.srt.speculative.dflash_utils import ( + scale_kv_cell_size_per_token_for_dflash, + ) + + draft_num_layers = getattr(self, "dflash_draft_num_layers", None) + if ( + draft_num_layers is not None + and int(draft_num_layers) > 0 + and int(num_layers) > 0 + ): + cell_size = scale_kv_cell_size_per_token_for_dflash( + target_cell_size_per_token=cell_size, + target_num_layers=int(num_layers), + draft_num_layers=int(draft_num_layers), + ) rest_memory = post_model_load_memory - pre_model_load_memory * ( 1 - self.mem_fraction_static diff --git a/python/sglang/srt/models/dflash.py b/python/sglang/srt/models/dflash.py new file mode 100644 index 000000000000..27f5cdbf539d --- /dev/null +++ b/python/sglang/srt/models/dflash.py @@ -0,0 +1,399 @@ +# Adapted from the DFlash reference implementation (HF) but implemented with +# SGLang primitives (RadixAttention + SGLang KV cache). This model intentionally +# does not include token embeddings or an LM head; DFlash uses the target model's +# embedding/lm_head. + +from __future__ import annotations + +import logging +from typing import Iterable, Optional, Tuple + +import torch +import torch.nn.functional as F +from torch import nn + +from sglang.srt.distributed import get_tensor_model_parallel_world_size +from sglang.srt.layers.activation import SiluAndMul +from sglang.srt.layers.layernorm import RMSNorm +from sglang.srt.layers.linear import ( + MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) +from sglang.srt.layers.logits_processor import LogitsProcessorOutput +from sglang.srt.layers.radix_attention import AttentionType, RadixAttention +from sglang.srt.layers.rotary_embedding import get_rope +from sglang.srt.model_executor.forward_batch_info import ForwardBatch +from sglang.srt.model_loader.weight_utils import default_weight_loader +from sglang.srt.models.utils import apply_qk_norm +from sglang.srt.speculative.dflash_utils import ( + can_dflash_slice_qkv_weight, + parse_dflash_draft_config, +) + +logger = logging.getLogger(__name__) + + +class DFlashAttention(nn.Module): + def __init__(self, config, layer_id: int) -> None: + super().__init__() + hidden_size = int(config.hidden_size) + tp_size = int(get_tensor_model_parallel_world_size()) + total_num_heads = int(config.num_attention_heads) + total_num_kv_heads = int( + getattr(config, "num_key_value_heads", total_num_heads) + ) + head_dim = int(getattr(config, "head_dim", hidden_size // total_num_heads)) + + self.hidden_size = hidden_size + self.total_num_heads = total_num_heads + self.total_num_kv_heads = total_num_kv_heads + assert self.total_num_heads % tp_size == 0, ( + f"DFlashAttention requires total_num_heads divisible by tp_size. " + f"total_num_heads={self.total_num_heads}, tp_size={tp_size}." + ) + self.num_heads = self.total_num_heads // tp_size + if self.total_num_kv_heads >= tp_size: + assert self.total_num_kv_heads % tp_size == 0, ( + f"DFlashAttention requires total_num_kv_heads divisible by tp_size when >= tp_size. " + f"total_num_kv_heads={self.total_num_kv_heads}, tp_size={tp_size}." + ) + else: + assert tp_size % self.total_num_kv_heads == 0, ( + f"DFlashAttention requires tp_size divisible by total_num_kv_heads when total_num_kv_heads < tp_size. " + f"total_num_kv_heads={self.total_num_kv_heads}, tp_size={tp_size}." + ) + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + self.head_dim = head_dim + self.q_size = self.num_heads * head_dim + self.kv_size = self.num_kv_heads * head_dim + + attention_bias = bool(getattr(config, "attention_bias", False)) + rms_norm_eps = float(getattr(config, "rms_norm_eps", 1e-6)) + + self.qkv_proj = QKVParallelLinear( + hidden_size=hidden_size, + head_size=head_dim, + total_num_heads=self.total_num_heads, + total_num_kv_heads=self.total_num_kv_heads, + bias=attention_bias, + prefix="qkv_proj", + ) + self.o_proj = RowParallelLinear( + self.total_num_heads * head_dim, + hidden_size, + bias=attention_bias, + prefix="o_proj", + ) + + # Per-head Q/K RMSNorm, matching HF Qwen3. + self.q_norm = RMSNorm(head_dim, eps=rms_norm_eps) + self.k_norm = RMSNorm(head_dim, eps=rms_norm_eps) + + rope_theta = float(getattr(config, "rope_theta", 1000000)) + rope_scaling = getattr(config, "rope_scaling", None) + rope_is_neox_style = bool( + getattr( + config, "rope_is_neox_style", getattr(config, "is_neox_style", True) + ) + ) + max_position_embeddings = int(getattr(config, "max_position_embeddings", 32768)) + self.rotary_emb = get_rope( + head_dim, + rotary_dim=head_dim, + max_position=max_position_embeddings, + base=rope_theta, + rope_scaling=rope_scaling, + is_neox_style=rope_is_neox_style, + ) + + self.scaling = head_dim**-0.5 + # DFlash uses non-causal attention over the draft block. + self.attn = RadixAttention( + num_heads=self.num_heads, + head_dim=head_dim, + scaling=self.scaling, + num_kv_heads=self.num_kv_heads, + layer_id=layer_id, + attn_type=AttentionType.ENCODER_ONLY, + ) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + forward_batch: ForwardBatch, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k = apply_qk_norm(q, k, self.q_norm, self.k_norm, self.head_dim) + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v, forward_batch) + output, _ = self.o_proj(attn_output) + return output + + def kv_proj_only( + self, hidden_states: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Project hidden_states to K/V only (skip Q). + + This is used by DFlash to materialize ctx tokens into the draft KV cache: + we only need K/V for the cached tokens; Q is never consumed. + """ + # Fast path for unquantized weights: slice the fused QKV weight and run one GEMM. + can_slice_qkv_weight, _ = can_dflash_slice_qkv_weight(self.qkv_proj) + if can_slice_qkv_weight: + kv_slice = slice(self.q_size, self.q_size + 2 * self.kv_size) + weight = self.qkv_proj.weight[kv_slice] + bias = ( + self.qkv_proj.bias[kv_slice] if self.qkv_proj.bias is not None else None + ) + kv = F.linear(hidden_states, weight, bias) + k, v = kv.split([self.kv_size, self.kv_size], dim=-1) + return k, v + + # Fallback: compute full QKV and discard Q (keeps compatibility with quantized weights). + qkv, _ = self.qkv_proj(hidden_states) + _, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + return k, v + + def apply_k_norm(self, k: torch.Tensor) -> torch.Tensor: + k_by_head = k.reshape(-1, self.head_dim) + k_by_head = self.k_norm(k_by_head) + return k_by_head.view_as(k) + + def apply_k_rope(self, positions: torch.Tensor, k: torch.Tensor) -> torch.Tensor: + # Use a minimal dummy query (1 head) to avoid doing full-Q work. + dummy_q = k.new_empty((k.shape[0], self.head_dim)) + _, k = self.rotary_emb(positions, dummy_q, k) + return k + + +class DFlashMLP(nn.Module): + def __init__(self, config, quant_config=None, prefix: str = "") -> None: + super().__init__() + hidden_size = int(config.hidden_size) + intermediate_size = int(getattr(config, "intermediate_size", 0)) + if intermediate_size <= 0: + raise ValueError( + f"Invalid intermediate_size={intermediate_size} for DFlash MLP." + ) + + self.gate_up_proj = MergedColumnParallelLinear( + hidden_size, + [intermediate_size] * 2, + bias=False, + quant_config=quant_config, + prefix="gate_up_proj" if not prefix else f"{prefix}.gate_up_proj", + ) + self.down_proj = RowParallelLinear( + intermediate_size, + hidden_size, + bias=False, + quant_config=quant_config, + prefix="down_proj" if not prefix else f"{prefix}.down_proj", + ) + hidden_act = getattr(config, "hidden_act", "silu") + if hidden_act != "silu": + raise ValueError( + f"Unsupported DFlash activation: {hidden_act}. Only silu is supported for now." + ) + self.act_fn = SiluAndMul() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.down_proj(x) + return x + + +class DFlashDecoderLayer(nn.Module): + def __init__(self, config, layer_id: int) -> None: + super().__init__() + hidden_size = int(config.hidden_size) + rms_norm_eps = float(getattr(config, "rms_norm_eps", 1e-6)) + + self.input_layernorm = RMSNorm(hidden_size, eps=rms_norm_eps) + self.self_attn = DFlashAttention(config=config, layer_id=layer_id) + self.post_attention_layernorm = RMSNorm(hidden_size, eps=rms_norm_eps) + self.mlp = DFlashMLP(config=config) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + forward_batch: ForwardBatch, + residual: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: + if hidden_states.numel() == 0: + # Keep return types consistent for upstream callers. + if residual is None: + residual = hidden_states + return hidden_states, residual + + # Pre-norm attention with fused residual+norm when possible (Qwen3-style). + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm(hidden_states, residual) + + attn_out = self.self_attn( + positions=positions, + hidden_states=hidden_states, + forward_batch=forward_batch, + ) + hidden_states, residual = self.post_attention_layernorm(attn_out, residual) + hidden_states = self.mlp(hidden_states) + return hidden_states, residual + + +class DFlashDraftModel(nn.Module): + """SGLang DFlash draft model (no embedding / lm_head weights). + + The checkpoint provides: + - transformer weights for `layers.*` + - `fc.weight`, `hidden_norm.weight` for projecting target context features + - `norm.weight` for final normalization + """ + + def __init__(self, config, quant_config=None, prefix: str = "") -> None: + super().__init__() + self.config = config + + hidden_size = int(config.hidden_size) + num_layers = int(config.num_hidden_layers) + rms_norm_eps = float(getattr(config, "rms_norm_eps", 1e-6)) + + self.layers = nn.ModuleList( + [DFlashDecoderLayer(config=config, layer_id=i) for i in range(num_layers)] + ) + self.norm = RMSNorm(hidden_size, eps=rms_norm_eps) + + # Project per-token target context features: + # concat(K * hidden_size) -> hidden_size, where K is the number of target-layer + # feature tensors concatenated per token (not necessarily equal to num_layers). + draft_config = parse_dflash_draft_config(draft_hf_config=config) + target_num_layers = ( + int(draft_config.num_target_layers) + if draft_config.num_target_layers is not None + else num_layers + ) + target_layer_ids = draft_config.resolve_target_layer_ids( + target_num_layers=target_num_layers, draft_num_layers=num_layers + ) + num_context_features = len(target_layer_ids) + + self.num_context_features = int(num_context_features) + self.fc = nn.Linear( + self.num_context_features * hidden_size, hidden_size, bias=False + ) + self.hidden_norm = RMSNorm(hidden_size, eps=rms_norm_eps) + + self.block_size = draft_config.resolve_block_size(default=16) + + def project_target_hidden(self, target_hidden: torch.Tensor) -> torch.Tensor: + """Project concatenated target-layer hidden states into draft hidden_size.""" + expected = int(self.fc.in_features) + if target_hidden.ndim != 2 or int(target_hidden.shape[-1]) != expected: + raise ValueError( + "DFLASH target_hidden feature dim mismatch. " + f"Expected shape [N, {expected}] " + f"(num_context_features={self.num_context_features}, hidden_size={int(self.config.hidden_size)}), " + f"but got shape={tuple(target_hidden.shape)}. " + "This usually means the target model is capturing a different number of layer features than " + "the draft checkpoint/config expects." + ) + return self.hidden_norm(self.fc(target_hidden)) + + @torch.no_grad() + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + forward_batch: ForwardBatch, + input_embeds: Optional[torch.Tensor] = None, + get_embedding: bool = False, + pp_proxy_tensors=None, + ) -> LogitsProcessorOutput: + if input_embeds is None: + raise ValueError( + "DFlashDraftModel requires `input_embeds` (use the target embedding)." + ) + hidden_states = input_embeds + residual: Optional[torch.Tensor] = None + + for layer in self.layers: + hidden_states, residual = layer( + positions, hidden_states, forward_batch, residual + ) + + if hidden_states.numel() != 0: + if residual is None: + hidden_states = self.norm(hidden_states) + else: + hidden_states, _ = self.norm(hidden_states, residual) + + return LogitsProcessorOutput( + next_token_logits=None, + hidden_states=hidden_states, + ) + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + stacked_params_mapping = [ + # (param_name, weight_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + + params_dict = dict(self.named_parameters()) + + def resolve_param_name(name: str) -> Optional[str]: + if name in params_dict: + return name + if name.startswith("model."): + stripped_name = name[len("model.") :] + if stripped_name in params_dict: + return stripped_name + else: + prefixed_name = f"model.{name}" + if prefixed_name in params_dict: + return prefixed_name + return None + + for name, loaded_weight in weights: + for param_name, weight_name, shard_id in stacked_params_mapping: + if f".{weight_name}." not in name: + continue + mapped_name = name.replace(weight_name, param_name) + resolved_name = resolve_param_name(mapped_name) + if resolved_name is None: + continue + param = params_dict[resolved_name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight, shard_id) + break + else: + resolved_name = resolve_param_name(name) + if resolved_name is None: + # Ignore unexpected weights (e.g., HF rotary caches). + continue + param = params_dict[resolved_name] + if resolved_name.endswith("fc.weight") and tuple( + loaded_weight.shape + ) != tuple(param.shape): + raise ValueError( + "DFLASH fc.weight shape mismatch. This usually means the draft checkpoint's " + "number of context features (K) does not match this config. " + f"Expected fc.weight.shape={tuple(param.shape)} " + f"(num_context_features={self.num_context_features}, hidden_size={int(self.config.hidden_size)}), " + f"but got {tuple(loaded_weight.shape)} for weight '{name}'." + ) + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) + + +EntryClass = DFlashDraftModel diff --git a/python/sglang/srt/models/llama.py b/python/sglang/srt/models/llama.py index f955ac750d34..b8ad74015c6e 100644 --- a/python/sglang/srt/models/llama.py +++ b/python/sglang/srt/models/llama.py @@ -794,6 +794,18 @@ def set_eagle3_layers_to_capture(self, layer_ids: Optional[List[int]] = None): # of the (i-1)th layer as aux hidden state self.model.layers_to_capture = [val + 1 for val in layer_ids] + def set_dflash_layers_to_capture(self, layer_ids: List[int]): + if not self.pp_group.is_last_rank: + return + + if layer_ids is None: + raise ValueError( + "DFLASH requires explicit layer_ids for aux hidden capture." + ) + + self.capture_aux_hidden_states = True + self.model.layers_to_capture = [val + 1 for val in layer_ids] + class Phi3ForCausalLM(LlamaForCausalLM): pass diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 86de1b32713c..9ec274d2a75c 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -499,6 +499,8 @@ class ServerArgs: speculative_num_steps: Optional[int] = None speculative_eagle_topk: Optional[int] = None speculative_num_draft_tokens: Optional[int] = None + speculative_dflash_block_size: Optional[int] = None + speculative_dflash_draft_window_size: Optional[int] = None speculative_accept_threshold_single: float = 1.0 speculative_accept_threshold_acc: float = 1.0 speculative_token_map: Optional[str] = None @@ -3027,6 +3029,134 @@ def _handle_speculative_decoding(self): if self.speculative_algorithm == "NEXTN": self.speculative_algorithm = "EAGLE" + if self.speculative_algorithm == "DFLASH": + if self.enable_dp_attention: + raise ValueError( + "Currently DFLASH speculative decoding does not support dp attention." + ) + + if self.pp_size != 1: + raise ValueError( + "Currently DFLASH speculative decoding only supports pp_size == 1." + ) + + if self.speculative_draft_model_path is None: + raise ValueError( + "DFLASH speculative decoding requires setting --speculative-draft-model-path." + ) + + # DFLASH does not use EAGLE-style `num_steps`/`topk`, but those fields still + # affect generic scheduler/KV-cache accounting (buffer sizing, KV freeing, + # RoPE reservation). Force them to 1 to avoid surprising memory behavior. + # + # For DFlash, the natural unit is `block_size` (verify window length). + if self.speculative_num_steps is None: + self.speculative_num_steps = 1 + elif int(self.speculative_num_steps) != 1: + logger.warning( + "DFLASH only supports speculative_num_steps == 1; overriding speculative_num_steps=%s to 1.", + self.speculative_num_steps, + ) + self.speculative_num_steps = 1 + + if self.speculative_eagle_topk is None: + self.speculative_eagle_topk = 1 + elif int(self.speculative_eagle_topk) != 1: + logger.warning( + "DFLASH only supports speculative_eagle_topk == 1; overriding speculative_eagle_topk=%s to 1.", + self.speculative_eagle_topk, + ) + self.speculative_eagle_topk = 1 + + if self.speculative_dflash_block_size is not None: + if int(self.speculative_dflash_block_size) <= 0: + raise ValueError( + "DFLASH requires --speculative-dflash-block-size to be positive, " + f"got {self.speculative_dflash_block_size}." + ) + if self.speculative_num_draft_tokens is not None and int( + self.speculative_num_draft_tokens + ) != int(self.speculative_dflash_block_size): + raise ValueError( + "Both --speculative-num-draft-tokens and --speculative-dflash-block-size are set " + "but they differ. For DFLASH they must match. " + f"speculative_num_draft_tokens={self.speculative_num_draft_tokens}, " + f"speculative_dflash_block_size={self.speculative_dflash_block_size}." + ) + self.speculative_num_draft_tokens = int( + self.speculative_dflash_block_size + ) + + window_size = None + if self.speculative_dflash_draft_window_size is not None: + window_size = int(self.speculative_dflash_draft_window_size) + if window_size <= 0: + raise ValueError( + "DFLASH requires --speculative-dflash-draft-window-size " + f"to be positive, got {window_size}." + ) + self.speculative_dflash_draft_window_size = window_size + + if self.speculative_num_draft_tokens is None: + from sglang.srt.speculative.dflash_utils import ( + parse_dflash_draft_config, + ) + + model_override_args = json.loads(self.json_model_override_args) + inferred_block_size = None + try: + from sglang.srt.utils.hf_transformers_utils import get_config + + draft_hf_config = get_config( + self.speculative_draft_model_path, + trust_remote_code=self.trust_remote_code, + revision=self.speculative_draft_model_revision, + model_override_args=model_override_args, + ) + inferred_block_size = parse_dflash_draft_config( + draft_hf_config=draft_hf_config + ).resolve_block_size(default=None) + except Exception as e: + logger.warning( + "Failed to infer DFLASH block_size from draft model config; " + "defaulting speculative_num_draft_tokens to 16. Error: %s", + e, + ) + + if inferred_block_size is None: + inferred_block_size = 16 + logger.warning( + "speculative_num_draft_tokens is not set; defaulting to %d for DFLASH.", + inferred_block_size, + ) + self.speculative_num_draft_tokens = inferred_block_size + + if window_size is not None: + draft_tokens = int(self.speculative_num_draft_tokens) + if window_size < draft_tokens: + raise ValueError( + "DFLASH --speculative-dflash-draft-window-size must be >= " + "--speculative-num-draft-tokens (block_size). " + f"window_size={window_size}, block_size={draft_tokens}." + ) + + if self.max_running_requests is None: + self.max_running_requests = 48 + logger.warning( + "Max running requests is reset to 48 for speculative decoding. You can override this by explicitly setting --max-running-requests." + ) + + self.disable_overlap_schedule = True + logger.warning( + "Overlap scheduler is disabled when using DFLASH speculative decoding (spec v2 is not supported yet)." + ) + + if self.enable_mixed_chunk: + self.enable_mixed_chunk = False + logger.warning( + "Mixed chunked prefill is disabled because of using dflash speculative decoding." + ) + if self.speculative_algorithm in ("EAGLE", "EAGLE3", "STANDALONE"): if self.speculative_algorithm == "STANDALONE" and self.enable_dp_attention: # TODO: support dp attention for standalone speculative decoding @@ -4832,7 +4962,7 @@ def add_cli_args(parser: argparse.ArgumentParser): parser.add_argument( "--speculative-algorithm", type=str, - choices=["EAGLE", "EAGLE3", "NEXTN", "STANDALONE", "NGRAM"], + choices=["DFLASH", "EAGLE", "EAGLE3", "NEXTN", "STANDALONE", "NGRAM"], help="Speculative algorithm.", ) parser.add_argument( @@ -4876,6 +5006,21 @@ def add_cli_args(parser: argparse.ArgumentParser): help="The number of tokens sampled from the draft model in Speculative Decoding.", default=ServerArgs.speculative_num_draft_tokens, ) + parser.add_argument( + "--speculative-dflash-block-size", + type=int, + help="DFLASH only. Block size (verify window length). Alias of --speculative-num-draft-tokens for DFLASH.", + default=ServerArgs.speculative_dflash_block_size, + ) + parser.add_argument( + "--speculative-dflash-draft-window-size", + type=int, + help="DFLASH only. Sliding window size for the draft-model KV cache. " + "When set, the draft worker keeps a recent target-token window in its " + "local cache (paged backends may retain up to one extra page on the left " + "for alignment). Default is full context.", + default=ServerArgs.speculative_dflash_draft_window_size, + ) parser.add_argument( "--speculative-accept-threshold-single", type=float, diff --git a/python/sglang/srt/speculative/dflash_info.py b/python/sglang/srt/speculative/dflash_info.py new file mode 100644 index 000000000000..fbb06cc70ee1 --- /dev/null +++ b/python/sglang/srt/speculative/dflash_info.py @@ -0,0 +1,501 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import List, Tuple + +import torch + +from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton +from sglang.srt.layers.logits_processor import LogitsProcessorOutput +from sglang.srt.layers.sampler import apply_custom_logit_processor +from sglang.srt.managers.schedule_batch import ScheduleBatch +from sglang.srt.mem_cache.common import ( + alloc_paged_token_slots_extend, + alloc_token_slots, + get_last_loc, +) +from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode +from sglang.srt.speculative.dflash_utils import ( + compute_dflash_accept_len_and_bonus, + compute_dflash_sampling_accept_len_and_bonus, + is_dflash_sampling_verify_available, +) +from sglang.srt.speculative.spec_info import SpecInput, SpecInputType +from sglang.srt.speculative.spec_utils import assign_req_to_token_pool_func + + +def _compute_paged_keep_slots( + *, + prefix_lens: torch.Tensor, + commit_lens: torch.Tensor, + draft_token_num: int, + page_size: int, +) -> torch.Tensor: + """Compute how many draft slots per request must remain allocated. + + The allocator frees at page granularity for paged mode, so we can only release + full pages from the tail after verify. + """ + + if page_size <= 1: + raise ValueError(f"Expected page_size > 1, got {page_size}.") + + seq_dtype = prefix_lens.dtype + extended_lens = prefix_lens + int(draft_token_num) + new_lens = prefix_lens + commit_lens.to(seq_dtype) + aligned_new_lens = ((new_lens + page_size - 1) // page_size) * page_size + keep_lens = torch.minimum(aligned_new_lens, extended_lens) + keep_slots = (keep_lens - prefix_lens).to(torch.int64) + keep_slots.clamp_(min=0, max=int(draft_token_num)) + return keep_slots + + +@dataclass +class DFlashDraftInput(SpecInput): + """Per-batch DFlash draft state for spec-v1 (non-overlap) scheduling. + + This object is stored on `ScheduleBatch.spec_info` between decode iterations. + It is NOT sent to model attention backends; the DFlash worker uses it to run + the draft model and to track draft-side cache progress. + + When draft windowing is disabled, `draft_seq_lens` matches the committed target + prefix length already materialized in the draft KV cache. When windowing is + enabled, `draft_seq_lens` is the logical resident length in the draft worker's + compact req-to-token mapping. In paged mode this may exceed the requested + window by up to `page_size - 1` so the local page table remains valid. `ctx_lens` + tracks newly committed target tokens that still need draft KV materialization. + """ + + # Current token to start the next DFlash block (one per request). + verified_id: torch.Tensor + + # Flattened context features for tokens that need to be appended into the draft cache. + # Shape: [sum(ctx_lens), K * hidden_size], where K is the number of target-layer + # hidden-state features concatenated per token (len(dflash_config.target_layer_ids), + # or default K == draft_num_layers for existing checkpoints). + target_hidden: torch.Tensor + + # Context lengths per request, used to slice `target_hidden`. Device tensor (int32). + ctx_lens: torch.Tensor + + # How many committed tokens are visible to the draft worker per request. + draft_seq_lens: torch.Tensor + + def __post_init__(self): + super().__init__(spec_input_type=SpecInputType.DFLASH_DRAFT) + + def get_spec_adjust_token_coefficient(self) -> Tuple[int, int]: + # Draft state does not change token accounting. + return (1, 1) + + def filter_batch(self, new_indices: torch.Tensor, has_been_filtered: bool = True): + old_ctx_lens = self.ctx_lens + old_target_hidden = self.target_hidden + + self.verified_id = self.verified_id[new_indices] + self.ctx_lens = old_ctx_lens[new_indices] + self.draft_seq_lens = self.draft_seq_lens[new_indices] + + if old_target_hidden is None or old_target_hidden.numel() == 0: + self.target_hidden = old_target_hidden + return + + # Rebuild target_hidden for the filtered batch using vectorized indexing. + old_bs = int(old_ctx_lens.shape[0]) + offsets = torch.zeros( + (old_bs + 1,), dtype=torch.int64, device=old_ctx_lens.device + ) + offsets[1:].copy_(old_ctx_lens.to(torch.int64).cumsum(0)) + + start = offsets[:-1] + seg_start = start[new_indices] + seg_lens = old_ctx_lens[new_indices].to(torch.int64) + + max_len = int(seg_lens.max().item()) if seg_lens.numel() > 0 else 0 + if max_len <= 0: + self.target_hidden = old_target_hidden[:0] + return + + r = torch.arange(max_len, device=old_ctx_lens.device, dtype=torch.int64)[ + None, : + ] + pos2d = seg_start[:, None] + r + mask = r < seg_lens[:, None] + flat_pos = pos2d[mask] + self.target_hidden = ( + old_target_hidden.index_select(0, flat_pos) + if flat_pos.numel() > 0 + else old_target_hidden[:0] + ) + + def merge_batch(self, spec_info: "DFlashDraftInput"): + self.verified_id = torch.cat([self.verified_id, spec_info.verified_id], dim=0) + self.ctx_lens = torch.cat([self.ctx_lens, spec_info.ctx_lens], dim=0) + self.draft_seq_lens = torch.cat( + [self.draft_seq_lens, spec_info.draft_seq_lens], dim=0 + ) + if self.target_hidden is None or self.target_hidden.numel() == 0: + self.target_hidden = spec_info.target_hidden + elif ( + spec_info.target_hidden is not None and spec_info.target_hidden.numel() > 0 + ): + self.target_hidden = torch.cat( + [self.target_hidden, spec_info.target_hidden], dim=0 + ) + + +@dataclass +class DFlashVerifyInput(SpecInput): + """Inputs for a target-model verify forward in DFlash (spec-v1). + + The verify forward is run with `ForwardMode.TARGET_VERIFY` so that the target + model returns logits for all tokens in the block, enabling accept-length + computation. + """ + + draft_token: torch.Tensor + positions: torch.Tensor + draft_token_num: int + # Kept for compatibility with attention backends that gate tree metadata by `topk > 1`. + # DFLASH verify is linear (non-tree), so this is always 1. + topk: int = 1 + # Custom attention "allow mask" for TARGET_VERIFY in backends that require it (e.g. triton). + # Semantics follow SGLang speculative conventions: True means the (q, k) pair is allowed. + custom_mask: torch.Tensor | None = None + capture_hidden_mode: CaptureHiddenMode = CaptureHiddenMode.FULL + + # Shape info for padding (e.g., DP attention / CUDA graph). + num_tokens_per_batch: int = -1 + + def __post_init__(self): + super().__init__(spec_input_type=SpecInputType.DFLASH_VERIFY) + if self.num_tokens_per_batch == -1: + self.num_tokens_per_batch = int(self.draft_token_num) + + def get_spec_adjust_token_coefficient(self) -> Tuple[int, int]: + return self.draft_token_num, self.draft_token_num + + def prepare_for_verify( + self, + batch: ScheduleBatch, + page_size: int, + *, + build_custom_mask: bool = True, + ): + if batch.forward_mode.is_idle(): + return + + batch.input_ids = self.draft_token + + if page_size == 1: + batch.out_cache_loc = alloc_token_slots( + batch.tree_cache, len(batch.input_ids) + ) + end_offset = batch.seq_lens + self.draft_token_num + else: + prefix_lens = batch.seq_lens + prefix_lens_cpu = batch.seq_lens_cpu + end_offset = prefix_lens + self.draft_token_num + end_offset_cpu = prefix_lens_cpu + self.draft_token_num + last_loc = get_last_loc( + batch.req_to_token_pool.req_to_token, + batch.req_pool_indices, + prefix_lens, + ) + batch.out_cache_loc = alloc_paged_token_slots_extend( + batch.tree_cache, + prefix_lens, + prefix_lens_cpu, + end_offset, + end_offset_cpu, + last_loc, + len(batch.input_ids), + ) + self.last_loc = last_loc + + bs = batch.batch_size() + assign_req_to_token_pool_func( + batch.req_pool_indices, + batch.req_to_token_pool.req_to_token, + batch.seq_lens, + end_offset, + batch.out_cache_loc, + bs, + ) + + if not build_custom_mask: + self.custom_mask = None + return + + if self.draft_token_num <= 0: + raise ValueError( + f"DFLASH draft_token_num must be positive, got {self.draft_token_num}." + ) + mask_chunks: List[torch.Tensor] = [] + q_len = int(self.draft_token_num) + q_idx = torch.arange(q_len, device=batch.device, dtype=torch.int32).unsqueeze(1) + for prefix_len in batch.seq_lens_cpu.tolist(): + prefix_len_i = int(prefix_len) + kv_len = prefix_len_i + q_len + k_idx = torch.arange( + kv_len, device=batch.device, dtype=torch.int32 + ).unsqueeze(0) + # Allow attending to the full prefix and to tokens up to (and including) the + # current query position within the verify block (standard causal masking). + allow = k_idx <= (prefix_len_i + q_idx) + mask_chunks.append(allow.flatten()) + self.custom_mask = ( + torch.cat(mask_chunks, dim=0) + if mask_chunks + else torch.empty((0,), dtype=torch.bool, device=batch.device) + ) + + def generate_attn_arg_prefill( + self, + req_pool_indices: torch.Tensor, + paged_kernel_lens: torch.Tensor, + paged_kernel_lens_sum: int, + req_to_token: torch.Tensor, + ): + device = req_pool_indices.device + bs = len(req_pool_indices) + + qo_indptr = torch.arange( + 0, + (bs + 1) * self.draft_token_num, + step=self.draft_token_num, + dtype=torch.int32, + device=device, + ) + + cum_kv_seq_len = torch.zeros((bs + 1,), dtype=torch.int32, device=device) + paged_kernel_lens = paged_kernel_lens + self.draft_token_num + cum_kv_seq_len[1:] = torch.cumsum(paged_kernel_lens, dim=0) + + kv_indices = torch.empty( + paged_kernel_lens_sum + self.draft_token_num * bs, + dtype=torch.int32, + device=device, + ) + create_flashinfer_kv_indices_triton[(bs,)]( + req_to_token, + req_pool_indices, + paged_kernel_lens, + cum_kv_seq_len, + None, + kv_indices, + req_to_token.size(1), + ) + mask = self.custom_mask + if mask is not None: + mask_numel = ( + paged_kernel_lens_sum * self.draft_token_num + + (self.draft_token_num**2) * bs + ) + if mask.numel() < mask_numel: + # FIXME(attn): temporary fix for custom mask padding with cuda graph + mask = torch.cat( + [ + mask, + torch.full( + (mask_numel - mask.numel(),), + True, + dtype=torch.bool, + device=device, + ), + ], + dim=0, + ) + self.custom_mask = mask + return kv_indices, cum_kv_seq_len, qo_indptr, mask + + def verify( + self, + *, + batch: ScheduleBatch, + logits_output: LogitsProcessorOutput, + page_size: int, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, List[int]]: + """DFlash verification for greedy and non-greedy sampling. + + Returns: + new_verified_id: int64 tensor [bs] (the new current token per request) + commit_lens: int32 tensor [bs] (how many verify-input tokens are committed) + next_target_hidden: tensor [sum(commit_lens), feature_dim] + accept_length_per_req_cpu: list[int] (accepted draft tokens per request) + """ + if batch.forward_mode.is_idle(): + empty = torch.empty((0,), dtype=torch.int64, device=batch.device) + return empty, empty.to(torch.int32), empty, [] + + bs = batch.batch_size() + device = logits_output.next_token_logits.device + + sampling_info = batch.sampling_info + if sampling_info is not None: + if len(sampling_info) != bs: + raise RuntimeError( + "DFLASH verify sampling_info size mismatch: " + f"len(sampling_info)={len(sampling_info)}, bs={bs}." + ) + + # Keep speculative verify semantics consistent with normal sampling path. + if sampling_info.has_custom_logit_processor: + apply_custom_logit_processor( + logits_output.next_token_logits, + sampling_info, + num_tokens_in_batch=self.draft_token_num, + ) + + if ( + sampling_info.penalizer_orchestrator.is_required + or sampling_info.logit_bias is not None + ): + linear_penalty = torch.zeros( + (bs, logits_output.next_token_logits.shape[1]), + dtype=torch.float32, + device=device, + ) + sampling_info.apply_logits_bias(linear_penalty) + logits_output.next_token_logits.add_( + torch.repeat_interleave(linear_penalty, self.draft_token_num, dim=0) + ) + + candidates = self.draft_token.view(bs, self.draft_token_num) + if ( + sampling_info is not None + and not sampling_info.is_all_greedy + and is_dflash_sampling_verify_available() + ): + accept_len, bonus = compute_dflash_sampling_accept_len_and_bonus( + candidates=candidates, + next_token_logits=logits_output.next_token_logits, + sampling_info=sampling_info, + ) + else: + target_predict = torch.argmax(logits_output.next_token_logits, dim=-1).view( + bs, self.draft_token_num + ) + accept_len, bonus = compute_dflash_accept_len_and_bonus( + candidates=candidates, + target_predict=target_predict, + ) + + # Single D2H transfer: candidates[1:] + accept_len + bonus + packed = torch.cat( + [candidates[:, 1:], accept_len.unsqueeze(1), bonus.unsqueeze(1)], dim=1 + ).cpu() + + max_acc = self.draft_token_num - 1 + accept_length_per_req_cpu: List[int] = [] + commit_lens_cpu: List[int] = [] + new_verified_list: List[int] = [] + + for i, req in enumerate(batch.reqs): + acc_len = int(packed[i, max_acc].item()) + proposed = packed[i, :acc_len].tolist() + [ + int(packed[i, max_acc + 1].item()) + ] + + appended = 0 + for token_id in proposed: + token_id = int(token_id) + req.output_ids.append(token_id) + appended += 1 + req.check_finished() + if req.finished(): + break + if req.grammar is not None: + req.grammar.accept_token(token_id) + + if req.output_ids: + new_verified_token = int(req.output_ids[-1]) + elif req.origin_input_ids: + # If no token was appended in this verify step, keep the current token unchanged. + new_verified_token = int(req.origin_input_ids[-1]) + else: + raise RuntimeError( + "DFLASH verify cannot determine current token: both output_ids and origin_input_ids are empty." + ) + + commit_lens_cpu.append(appended) + new_verified_list.append(new_verified_token) + accept_length_per_req_cpu.append(max(0, appended - 1)) + req.spec_verify_ct += 1 + req.spec_accepted_tokens += accept_length_per_req_cpu[-1] + + commit_lens = torch.tensor(commit_lens_cpu, dtype=torch.int32, device=device) + new_verified_id = torch.tensor( + new_verified_list, dtype=torch.int64, device=device + ) + + # Free uncommitted KV cache slots and compact out_cache_loc. + if page_size == 1: + out_cache_loc = batch.out_cache_loc.view(bs, self.draft_token_num) + keep_mask = ( + torch.arange(self.draft_token_num, device=device)[None, :] + < commit_lens[:, None] + ) + batch.token_to_kv_pool_allocator.free(out_cache_loc[~keep_mask]) + batch.out_cache_loc = out_cache_loc[keep_mask] + else: + out_cache_loc = batch.out_cache_loc.view(bs, self.draft_token_num) + row_offsets = torch.arange(self.draft_token_num, device=device)[None, :] + keep_slots = _compute_paged_keep_slots( + prefix_lens=batch.seq_lens, + commit_lens=commit_lens, + draft_token_num=self.draft_token_num, + page_size=page_size, + ) + free_mask = row_offsets >= keep_slots[:, None] + batch.token_to_kv_pool_allocator.free(out_cache_loc[free_mask]) + + keep_mask = row_offsets < commit_lens[:, None] + batch.out_cache_loc = out_cache_loc[keep_mask] + + # Update req-level KV cache accounting. + for req, commit_len in zip(batch.reqs, commit_lens_cpu, strict=True): + req.kv_committed_len += commit_len + req.kv_allocated_len = req.kv_committed_len + + # Update req_to_token pool mapping for newly committed tokens. + end_offset = batch.seq_lens + commit_lens.to(batch.seq_lens.dtype) + assign_req_to_token_pool_func( + batch.req_pool_indices, + batch.req_to_token_pool.req_to_token, + batch.seq_lens, + end_offset, + batch.out_cache_loc, + bs, + ) + + # Update batch seq lens. + batch.seq_lens.add_(commit_lens.to(batch.seq_lens.dtype)) + batch.seq_lens_cpu.add_( + torch.tensor(commit_lens_cpu, dtype=batch.seq_lens_cpu.dtype) + ) + # Keep seq_lens_sum in sync; flashinfer indices updaters rely on this for buffer sizing. + batch.seq_lens_sum += sum(commit_lens_cpu) + + # Build next-step context features from the committed verify-input tokens. + hidden = logits_output.hidden_states + if hidden is None: + raise RuntimeError( + "DFLASH verify requires target hidden states, but got None." + ) + hidden = hidden.view(bs, self.draft_token_num, -1) + segments: List[torch.Tensor] = [] + for i, ln in enumerate(commit_lens_cpu): + if ln > 0: + segments.append(hidden[i, :ln, :]) + next_target_hidden = torch.cat(segments, dim=0) if segments else hidden[:0] + + # Avoid confusing downstream consumers (spec-v1 decode doesn't use this). + logits_output.hidden_states = None + + return ( + new_verified_id, + commit_lens, + next_target_hidden, + accept_length_per_req_cpu, + ) diff --git a/python/sglang/srt/speculative/dflash_utils.py b/python/sglang/srt/speculative/dflash_utils.py new file mode 100644 index 000000000000..ddec049e0a24 --- /dev/null +++ b/python/sglang/srt/speculative/dflash_utils.py @@ -0,0 +1,637 @@ +from __future__ import annotations + +from dataclasses import dataclass +from numbers import Integral +from typing import Any, List, Optional, Tuple + +import torch +import torch.nn.functional as F + +from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod +from sglang.srt.utils import is_cuda + +DEFAULT_DFLASH_MASK_TOKEN = "<|MASK|>" + +_DFLASH_SAMPLING_VERIFY_AVAILABLE = False +_DFLASH_CHAIN_VERIFY_BUFFERS: dict[tuple[Optional[int], int], dict[str, Any]] = {} +_DFLASH_VERIFY_SKIP_CUSTOM_MASK_BACKENDS = frozenset( + { + "FlashInferAttnBackend", + "FlashInferMLAAttnBackend", + "FlashAttentionBackend", + "TRTLLMHAAttnBackend", + "TRTLLMMLABackend", + } +) + + +if is_cuda(): + try: + from sgl_kernel import ( + top_k_renorm_prob, + top_p_renorm_prob, + tree_speculative_sampling_target_only, + ) + + _DFLASH_SAMPLING_VERIFY_AVAILABLE = True + except Exception: + top_k_renorm_prob = None + top_p_renorm_prob = None + tree_speculative_sampling_target_only = None +else: + top_k_renorm_prob = None + top_p_renorm_prob = None + tree_speculative_sampling_target_only = None + + +def is_dflash_sampling_verify_available() -> bool: + return _DFLASH_SAMPLING_VERIFY_AVAILABLE + + +def scale_kv_cell_size_per_token_for_dflash( + *, + target_cell_size_per_token: int, + target_num_layers: int, + draft_num_layers: int, + draft_cell_size_per_token: Optional[int] = None, +) -> int: + """Compute bytes/token budget for combined target+draft KV pools (DFLASH). + + DFLASH runs a separate draft runner with its own KV pool. The target runner's + token capacity must fit both pools in aggregate. + + Returns: + Approximate per-token bytes for (target KV + draft KV), expressed as a + scaled version of `target_cell_size_per_token`, unless an explicit + `draft_cell_size_per_token` is provided (in which case we sum them). + """ + if target_cell_size_per_token <= 0: + raise ValueError( + "target_cell_size_per_token must be positive, " + f"got {target_cell_size_per_token}." + ) + + if draft_cell_size_per_token is not None: + draft_cell_size_per_token = int(draft_cell_size_per_token) + if draft_cell_size_per_token <= 0: + raise ValueError( + "draft_cell_size_per_token must be positive when provided, " + f"got {draft_cell_size_per_token}." + ) + return int(target_cell_size_per_token) + int(draft_cell_size_per_token) + + if target_num_layers <= 0 or draft_num_layers <= 0: + return int(target_cell_size_per_token) + + total_layers = int(target_num_layers) + int(draft_num_layers) + return ( + int(target_cell_size_per_token) * int(total_layers) + int(target_num_layers) - 1 + ) // int(target_num_layers) + + +def resolve_dflash_verify_mask_policy(attn_backend: Any) -> tuple[str, bool]: + backend = attn_backend + for _ in range(4): + full_backend = getattr(backend, "full_attn_backend", None) + if full_backend is None: + break + backend = full_backend + backend_name = type(backend).__name__ + return backend_name, (backend_name not in _DFLASH_VERIFY_SKIP_CUSTOM_MASK_BACKENDS) + + +def _get_or_create_chain_verify_buffers( + *, + bs: int, + draft_token_num: int, + device: torch.device, +) -> tuple[ + torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor +]: + key = (device.index, int(draft_token_num)) + cached = _DFLASH_CHAIN_VERIFY_BUFFERS.get(key) + cap_bs = 0 if cached is None else int(cached["cap_bs"]) + if cap_bs < bs: + new_cap = max(int(bs), cap_bs * 2 if cap_bs > 0 else int(bs)) + retrieve_index = torch.arange( + new_cap * draft_token_num, dtype=torch.int64, device=device + ).view(new_cap, draft_token_num) + row_next = torch.arange( + 1, draft_token_num + 1, dtype=torch.int64, device=device + ) + row_next[-1] = -1 + retrieve_next_token = row_next.unsqueeze(0).expand(new_cap, -1).clone() + retrieve_next_sibling = torch.full( + (new_cap, draft_token_num), -1, dtype=torch.int64, device=device + ) + predicts = torch.empty( + (new_cap * draft_token_num,), dtype=torch.int32, device=device + ) + accept_index = torch.empty( + (new_cap, draft_token_num), dtype=torch.int32, device=device + ) + accept_token_num = torch.empty((new_cap,), dtype=torch.int32, device=device) + cached = { + "cap_bs": int(new_cap), + "retrieve_index": retrieve_index, + "retrieve_next_token": retrieve_next_token, + "retrieve_next_sibling": retrieve_next_sibling, + "predicts": predicts, + "accept_index": accept_index, + "accept_token_num": accept_token_num, + } + _DFLASH_CHAIN_VERIFY_BUFFERS[key] = cached + + assert cached is not None + retrieve_index = cached["retrieve_index"][:bs] + retrieve_next_token = cached["retrieve_next_token"][:bs] + retrieve_next_sibling = cached["retrieve_next_sibling"][:bs] + predicts = cached["predicts"][: bs * draft_token_num] + accept_index = cached["accept_index"][:bs] + accept_token_num = cached["accept_token_num"][:bs] + return ( + retrieve_index, + retrieve_next_token, + retrieve_next_sibling, + predicts, + accept_index, + accept_token_num, + ) + + +def build_target_layer_ids(num_target_layers: int, num_draft_layers: int) -> List[int]: + """Select target layer indices used to build DFlash context features. + + Args: + num_target_layers: Number of transformer layers in the runtime target model. + num_draft_layers: Number of layers in the DFlash draft model. + + Returns: + A list of 0-based target layer indices of length `num_draft_layers`. + + Notes: + - DFlash uses hidden states after each selected target layer (HF-style). + - SGLang captures "before layer i", so the model hook will typically add +1 + when mapping to capture points. + """ + if num_target_layers <= 0: + raise ValueError( + f"num_target_layers must be positive, got {num_target_layers}." + ) + if num_draft_layers <= 0: + raise ValueError(f"num_draft_layers must be positive, got {num_draft_layers}.") + + if num_draft_layers == 1: + return [num_target_layers // 2] + + start = 1 + end = num_target_layers - 3 + if end < start: + raise ValueError( + "DFlash layer selection requires num_target_layers >= 4. " + f"Got num_target_layers={num_target_layers}." + ) + + span = end - start + return [ + int(round(start + (i * span) / (num_draft_layers - 1))) + for i in range(num_draft_layers) + ] + + +def _cfg_get(config: Any, key: str, default: Any = None) -> Any: + if isinstance(config, dict): + return config.get(key, default) + return getattr(config, key, default) + + +def _get_text_config(config: Any) -> Any: + if config is None: + return None + if isinstance(config, dict): + return config.get("text_config", config) + text_config = getattr(config, "text_config", None) + if text_config is not None: + return text_config + get_text_config = getattr(config, "get_text_config", None) + if callable(get_text_config): + try: + resolved = get_text_config() + if resolved is not None: + return resolved + except TypeError: + pass + return config + + +def _get_dflash_config(config: Any) -> dict: + if isinstance(config, dict): + cfg = config.get("dflash_config", None) + else: + cfg = getattr(config, "dflash_config", None) + if cfg is None: + return {} + if isinstance(cfg, dict): + return cfg + + try: + return dict(cfg) + except Exception: + return {} + + +def _parse_optional_int( + value: Any, + *, + field_name: str, + min_value: Optional[int] = None, +) -> Optional[int]: + if value is None: + return None + try: + parsed = int(value) + except Exception as e: + raise ValueError(f"Invalid {field_name}={value!r}.") from e + if min_value is not None and parsed < int(min_value): + comparator = "positive" if int(min_value) == 1 else f">= {int(min_value)}" + raise ValueError(f"{field_name} must be {comparator}, got {parsed}.") + return parsed + + +@dataclass(frozen=True) +class DFlashDraftConfig: + num_hidden_layers: Optional[int] + num_target_layers: Optional[int] + block_size: Optional[int] + target_layer_ids: Optional[List[int]] + mask_token: str + mask_token_id: Optional[int] + + def require_num_layers(self) -> int: + if self.num_hidden_layers is None: + raise ValueError( + "DFLASH requires draft num_hidden_layers in config. " + "Got config without num_hidden_layers." + ) + return int(self.num_hidden_layers) + + def resolve_block_size(self, *, default: Optional[int] = None) -> Optional[int]: + return self.block_size if self.block_size is not None else default + + def resolve_target_layer_ids( + self, + *, + target_num_layers: int, + draft_num_layers: Optional[int] = None, + ) -> List[int]: + target_num_layers = int(target_num_layers) + if target_num_layers <= 0: + raise ValueError( + f"target_num_layers must be positive, got {target_num_layers}." + ) + + if self.target_layer_ids is None: + if draft_num_layers is None: + draft_num_layers = self.require_num_layers() + return build_target_layer_ids(target_num_layers, int(draft_num_layers)) + + resolved = list(self.target_layer_ids) + if len(resolved) <= 0: + raise ValueError( + "DFLASH dflash_config.target_layer_ids must be non-empty. " + f"Got len(target_layer_ids)={len(resolved)}." + ) + for idx, val in enumerate(resolved): + if val < 0 or val >= target_num_layers: + raise ValueError( + "DFLASH target_layer_ids contains an out-of-range layer id. " + f"target_layer_ids[{idx}]={val}, target_num_layers={target_num_layers}." + ) + return resolved + + +def parse_dflash_draft_config(*, draft_hf_config: Any) -> DFlashDraftConfig: + """Parse and validate DFLASH draft config fields from HF config/dict.""" + dflash_cfg = _get_dflash_config(draft_hf_config) + draft_text_config = _get_text_config(draft_hf_config) + + num_hidden_layers = _parse_optional_int( + _cfg_get(draft_text_config, "num_hidden_layers", None), + field_name="DFLASH draft num_hidden_layers", + min_value=1, + ) + raw_num_target_layers = dflash_cfg.get( + "num_target_layers", + _cfg_get(draft_hf_config, "num_target_layers", None), + ) + num_target_layers = _parse_optional_int( + raw_num_target_layers, + field_name="DFLASH draft num_target_layers", + min_value=1, + ) + + # Keep support for current checkpoints where block_size is top-level. + raw_block_size = dflash_cfg.get( + "block_size", + _cfg_get(draft_hf_config, "block_size", None), + ) + block_size = _parse_optional_int( + raw_block_size, + field_name="DFLASH block_size", + min_value=1, + ) + + layer_ids = dflash_cfg.get( + "target_layer_ids", + _cfg_get(draft_hf_config, "target_layer_ids", None), + ) + parsed_target_layer_ids: Optional[List[int]] + if layer_ids is None: + parsed_target_layer_ids = None + else: + if not isinstance(layer_ids, (list, tuple)): + raise ValueError( + "DFLASH dflash_config.target_layer_ids must be a list of ints, " + f"got type={type(layer_ids).__name__}." + ) + parsed_target_layer_ids = [int(x) for x in layer_ids] + if len(parsed_target_layer_ids) <= 0: + raise ValueError( + "DFLASH dflash_config.target_layer_ids must be non-empty. " + f"Got len(target_layer_ids)={len(parsed_target_layer_ids)}." + ) + + mask_token = dflash_cfg.get("mask_token", None) + if mask_token is None: + mask_token = DEFAULT_DFLASH_MASK_TOKEN + if not isinstance(mask_token, str) or not mask_token: + raise ValueError( + "DFLASH dflash_config.mask_token must be a non-empty string, " + f"got {mask_token!r}." + ) + + mask_token_id = dflash_cfg.get("mask_token_id", None) + if mask_token_id is not None: + if not isinstance(mask_token_id, Integral) or isinstance(mask_token_id, bool): + raise ValueError( + "DFLASH dflash_config.mask_token_id must be an integer, " + f"got {mask_token_id!r} (type={type(mask_token_id).__name__})." + ) + mask_token_id = int(mask_token_id) + if mask_token_id < 0: + raise ValueError( + "DFLASH dflash_config.mask_token_id must be non-negative, " + f"got {mask_token_id}." + ) + + return DFlashDraftConfig( + num_hidden_layers=num_hidden_layers, + num_target_layers=num_target_layers, + block_size=block_size, + target_layer_ids=parsed_target_layer_ids, + mask_token=mask_token, + mask_token_id=mask_token_id, + ) + + +def can_dflash_slice_qkv_weight(qkv_proj: Any) -> Tuple[bool, str]: + """Validate whether DFlash can slice KV weights from a fused QKV linear layer.""" + quant_method = getattr(qkv_proj, "quant_method", None) + if not isinstance(quant_method, UnquantizedLinearMethod): + return ( + False, + "quantized qkv_proj is not supported for this path " + f"(quant_method={type(quant_method).__name__})", + ) + if not hasattr(qkv_proj, "weight"): + return False, "qkv weight tensor is missing" + return True, "" + + +def can_dflash_use_fused_qkv_proj(qkv_proj: Any) -> Tuple[bool, str]: + """Validate whether a QKV layer is eligible for DFlash fused KV materialization.""" + eligible, reason = can_dflash_slice_qkv_weight(qkv_proj) + if not eligible: + return False, reason + if getattr(qkv_proj, "bias", None) is not None: + return False, "qkv bias is not supported for fused KV path" + return True, "" + + +def compute_dflash_accept_len_and_bonus( + *, + candidates: torch.Tensor, + target_predict: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Compute DFlash accept lengths and bonus tokens (greedy verify rule). + + Args: + candidates: Token ids proposed by the DFlash draft, including the current token. + Shape: [bs, block_size]. candidates[:, 0] is the current token. + target_predict: Token ids predicted by the target model for each position in the block. + Shape: [bs, block_size]. target_predict[:, t] corresponds to argmax at position t. + + Returns: + accept_len: int32 tensor [bs], number of accepted *draft* tokens (excluding current token and bonus token). + bonus: int64 tensor [bs], the target-predicted token at index accept_len (the "bonus" token to append). + + Notes: + Matches the reference implementation rule: + accept while candidates[:, 1:] == target_predict[:, :-1] consecutively. + """ + if candidates.ndim != 2: + raise ValueError(f"candidates must be 2D, got shape={tuple(candidates.shape)}") + if target_predict.shape != candidates.shape: + raise ValueError( + "target_predict must have the same shape as candidates. " + f"candidates.shape={tuple(candidates.shape)}, target_predict.shape={tuple(target_predict.shape)}" + ) + + bs, block_size = candidates.shape + if bs <= 0: + raise ValueError(f"batch size must be positive, got {bs}.") + if block_size <= 0: + raise ValueError(f"block_size must be positive, got {block_size}.") + + matches = candidates[:, 1:] == target_predict[:, :-1] + accept_len = matches.to(torch.int32).cumprod(dim=1).sum(dim=1) + bonus = target_predict[torch.arange(bs, device=target_predict.device), accept_len] + return accept_len, bonus.to(torch.int64) + + +def compute_dflash_sampling_accept_len_and_bonus( + *, + candidates: torch.Tensor, + next_token_logits: torch.Tensor, + sampling_info: Any, + threshold_single: Optional[float] = None, + threshold_acc: Optional[float] = None, + uniform_samples: Optional[torch.Tensor] = None, + uniform_samples_for_final_sampling: Optional[torch.Tensor] = None, + use_sparse_topk: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Compute DFlash accept lengths and bonus tokens for non-greedy sampling. + + This is a chain-specialized variant of speculative target-only verification: + - DFlash proposals are linear (topk == 1), so each verify level has at most one candidate. + - When a candidate is rejected at a level, the final token is sampled from + `relu(q - p)` where `p` has only the rejected candidate mass. + """ + if not _DFLASH_SAMPLING_VERIFY_AVAILABLE: + raise RuntimeError( + "DFLASH non-greedy verification is unavailable on this build/device." + ) + if candidates.ndim != 2: + raise ValueError(f"candidates must be 2D, got shape={tuple(candidates.shape)}") + if next_token_logits.ndim != 2: + raise ValueError( + "next_token_logits must be 2D, " + f"got shape={tuple(next_token_logits.shape)}." + ) + + bs, draft_token_num = candidates.shape + if bs <= 0: + raise ValueError(f"batch size must be positive, got {bs}.") + if draft_token_num <= 0: + raise ValueError(f"draft_token_num must be positive, got {draft_token_num}.") + if next_token_logits.shape[0] != bs * draft_token_num: + raise ValueError( + "next_token_logits row count mismatch. " + f"Expected {bs * draft_token_num}, got {next_token_logits.shape[0]}." + ) + if candidates.device != next_token_logits.device: + raise ValueError( + "candidates and next_token_logits must be on the same device, " + f"got {candidates.device} and {next_token_logits.device}." + ) + + if threshold_single is None: + from sglang.srt.server_args import get_global_server_args + + threshold_single = get_global_server_args().speculative_accept_threshold_single + if threshold_acc is None: + from sglang.srt.server_args import get_global_server_args + + threshold_acc = get_global_server_args().speculative_accept_threshold_acc + threshold_single = float(threshold_single) + threshold_acc = max(float(threshold_acc), 1e-9) + + device = next_token_logits.device + + if uniform_samples is None: + uniform_samples = torch.rand( + (bs, draft_token_num), dtype=torch.float32, device=device + ) + else: + if uniform_samples.shape != (bs, draft_token_num): + raise ValueError( + "uniform_samples shape mismatch. " + f"Expected {(bs, draft_token_num)}, got {tuple(uniform_samples.shape)}." + ) + uniform_samples = uniform_samples.to(device=device, dtype=torch.float32) + + if uniform_samples_for_final_sampling is None: + uniform_samples_for_final_sampling = torch.rand( + (bs,), dtype=torch.float32, device=device + ) + else: + if uniform_samples_for_final_sampling.shape != (bs,): + raise ValueError( + "uniform_samples_for_final_sampling shape mismatch. " + f"Expected {(bs,)}, got {tuple(uniform_samples_for_final_sampling.shape)}." + ) + uniform_samples_for_final_sampling = uniform_samples_for_final_sampling.to( + device=device, + dtype=torch.float32, + ) + + need_top_k = bool(getattr(sampling_info, "need_top_k_sampling", True)) + need_top_p = bool(getattr(sampling_info, "need_top_p_sampling", False)) + # Build target distribution once over all verify rows. + expanded_temperature = torch.repeat_interleave( + sampling_info.temperatures, draft_token_num, dim=0 + ) + scaled_logits = next_token_logits / expanded_temperature + sparse_topk_applied = False + + if use_sparse_topk and need_top_k: + repeated_top_ks = torch.repeat_interleave( + sampling_info.top_ks, draft_token_num, dim=0 + ).to(dtype=torch.int64) + vocab_size = int(scaled_logits.shape[-1]) + repeated_top_ks.clamp_(min=1, max=vocab_size) + max_top_k = int(repeated_top_ks.max().item()) + + # Sparse exact path for top-k/top-p (top-k-first semantics), then scatter to dense. + if 0 < max_top_k < vocab_size: + topk_logits, topk_indices = torch.topk(scaled_logits, k=max_top_k, dim=-1) + if not torch.all(repeated_top_ks == max_top_k): + ranks = torch.arange(max_top_k, device=device, dtype=torch.int64)[ + None, : + ] + valid = ranks < repeated_top_ks.unsqueeze(1) + topk_logits = topk_logits.masked_fill(~valid, float("-inf")) + + topk_probs = F.softmax(topk_logits, dim=-1) + if need_top_p: + repeated_top_ps = torch.repeat_interleave( + sampling_info.top_ps, draft_token_num, dim=0 + ) + topk_probs = top_p_renorm_prob(topk_probs, repeated_top_ps) + + target_probs = torch.zeros_like(scaled_logits, dtype=topk_probs.dtype) + target_probs.scatter_(1, topk_indices, topk_probs) + sparse_topk_applied = True + + if not sparse_topk_applied: + target_probs = F.softmax(scaled_logits, dim=-1) + if need_top_k: + target_probs = top_k_renorm_prob( + target_probs, + torch.repeat_interleave(sampling_info.top_ks, draft_token_num, dim=0), + ) + if need_top_p: + target_probs = top_p_renorm_prob( + target_probs, + torch.repeat_interleave(sampling_info.top_ps, draft_token_num, dim=0), + ) + target_probs = target_probs.view(bs, draft_token_num, -1).contiguous() + draft_probs = torch.zeros_like(target_probs) + + ( + retrieve_index, + retrieve_next_token, + retrieve_next_sibling, + predicts, + accept_index, + accept_token_num, + ) = _get_or_create_chain_verify_buffers( + bs=bs, + draft_token_num=draft_token_num, + device=device, + ) + candidates_i64 = ( + candidates if candidates.dtype == torch.int64 else candidates.to(torch.int64) + ) + tree_speculative_sampling_target_only( + predicts=predicts, + accept_index=accept_index, + accept_token_num=accept_token_num, + candidates=candidates_i64, + retrive_index=retrieve_index, + retrive_next_token=retrieve_next_token, + retrive_next_sibling=retrieve_next_sibling, + uniform_samples=uniform_samples, + uniform_samples_for_final_sampling=uniform_samples_for_final_sampling, + target_probs=target_probs, + draft_probs=draft_probs, + threshold_single=threshold_single, + threshold_acc=threshold_acc, + deterministic=True, + ) + + accept_len = accept_token_num + row_ids = torch.arange(bs, dtype=torch.long, device=device) + accept_pos = accept_index[row_ids, accept_len.to(torch.long)].to(torch.long) + bonus = predicts[accept_pos].to(torch.int64) + return accept_len, bonus diff --git a/python/sglang/srt/speculative/dflash_worker.py b/python/sglang/srt/speculative/dflash_worker.py new file mode 100644 index 000000000000..030aa21e5b35 --- /dev/null +++ b/python/sglang/srt/speculative/dflash_worker.py @@ -0,0 +1,1245 @@ +import logging +import math +from copy import deepcopy +from typing import Optional, Union + +import torch + +from sglang.srt.distributed import get_tp_group +from sglang.srt.managers.schedule_batch import ModelWorkerBatch, ScheduleBatch +from sglang.srt.managers.scheduler import GenerationBatchResult +from sglang.srt.managers.tp_worker import TpModelWorker +from sglang.srt.mem_cache.common import get_last_loc +from sglang.srt.model_executor.forward_batch_info import ( + CaptureHiddenMode, + ForwardBatch, + ForwardMode, +) +from sglang.srt.server_args import ( + ServerArgs, + get_global_server_args, + set_global_server_args_for_scheduler, +) +from sglang.srt.speculative.dflash_info import DFlashDraftInput, DFlashVerifyInput +from sglang.srt.speculative.dflash_utils import ( + can_dflash_use_fused_qkv_proj, + is_dflash_sampling_verify_available, + parse_dflash_draft_config, + resolve_dflash_verify_mask_policy, +) +from sglang.srt.speculative.spec_info import SpeculativeAlgorithm +from sglang.srt.speculative.spec_utils import assign_req_to_token_pool_func +from sglang.srt.utils import is_cuda + +logger = logging.getLogger(__name__) + +_FusedKVMaterializeHelper = None + + +def _get_fused_kv_materialize_helper(): + global _FusedKVMaterializeHelper + if _FusedKVMaterializeHelper is None: + from sglang.srt.speculative.triton_ops.fused_kv_materialize import ( + FusedKVMaterializeHelper, + ) + + _FusedKVMaterializeHelper = FusedKVMaterializeHelper + return _FusedKVMaterializeHelper + + +class DFlashWorker: + """DFlash speculative decoding worker (spec-v1, tp>=1/pp=1).""" + + def __init__( + self, + server_args: ServerArgs, + gpu_id: int, + tp_rank: int, + dp_rank: Optional[int], + moe_ep_rank: int, + attn_cp_rank: int, + moe_dp_rank: int, + nccl_port: int, + target_worker: TpModelWorker, + ): + self.server_args = server_args + self.gpu_id = gpu_id + self.tp_rank = tp_rank + self.dp_rank = dp_rank + self.moe_ep_rank = moe_ep_rank + self.attn_cp_rank = attn_cp_rank + self.moe_dp_rank = moe_dp_rank + self.nccl_port = nccl_port + self.target_worker = target_worker + self.model_runner = target_worker.model_runner + self.page_size = server_args.page_size + self.draft_window_size: Optional[int] = ( + int(server_args.speculative_dflash_draft_window_size) + if server_args.speculative_dflash_draft_window_size is not None + else None + ) + self.use_compact_draft_cache = self.draft_window_size is not None + self.device = target_worker.device + + self._warned_sampling_fallback = False + self._logged_first_verify = False + + # Draft runner (separate KV cache + attention backend). + # Without draft windowing, the draft worker aliases the target request->token + # mapping and allocation state. With draft windowing enabled, the draft worker + # keeps a private compact req->token table over the same global KV index space, + # so radix-cache/prefix-hit KV remains reusable while draft attention sees only + # the recent window. + target_req_to_token_pool, target_token_to_kv_pool_allocator = ( + target_worker.get_memory_pool() + ) + shared_req_to_token_pool = ( + None if self.use_compact_draft_cache else target_req_to_token_pool + ) + draft_server_args = deepcopy(server_args) + draft_server_args.skip_tokenizer_init = True + draft_backend = draft_server_args.speculative_draft_attention_backend + supported_draft_backends = ("flashinfer", "fa3", "fa4") + if draft_backend is None: + draft_backend, _ = draft_server_args.get_attention_backends() + if draft_backend is None: + draft_backend = "flashinfer" + elif draft_backend == "trtllm_mha": + logger.warning( + "DFLASH draft worker does not support 'trtllm_mha' because the " + "draft path requires non-causal attention. Falling back to " + "'flashinfer'." + ) + draft_backend = "flashinfer" + elif draft_backend not in supported_draft_backends: + logger.warning( + "DFLASH draft worker only supports attention_backend in %s for now, " + "but got %r. Falling back to 'flashinfer'.", + supported_draft_backends, + draft_backend, + ) + draft_backend = "flashinfer" + # Make the draft worker backend explicit and self-contained (no further overrides). + draft_server_args.speculative_draft_attention_backend = None + draft_server_args.prefill_attention_backend = None + draft_server_args.decode_attention_backend = None + draft_server_args.attention_backend = draft_backend + # Keep draft context length aligned with the target. + draft_server_args.context_length = ( + target_worker.model_runner.model_config.context_len + ) + saved_server_args = get_global_server_args() + self.draft_worker = TpModelWorker( + server_args=draft_server_args, + gpu_id=gpu_id, + tp_rank=tp_rank, + moe_ep_rank=moe_ep_rank, + pp_rank=0, + attn_cp_rank=attn_cp_rank, + moe_dp_rank=moe_dp_rank, + dp_rank=dp_rank, + nccl_port=nccl_port, + is_draft_worker=True, + req_to_token_pool=shared_req_to_token_pool, + token_to_kv_pool_allocator=target_token_to_kv_pool_allocator, + memory_pool_config=target_worker.model_runner.memory_pool_config, + ) + set_global_server_args_for_scheduler(saved_server_args) + self.draft_model_runner = self.draft_worker.model_runner + self.draft_model = self.draft_model_runner.model + draft_config = parse_dflash_draft_config( + draft_hf_config=self.draft_model_runner.model_config.hf_config + ) + if server_args.speculative_num_draft_tokens is None: + # Should not happen (ServerArgs should have inferred it), but keep a fallback. + self.block_size = int(draft_config.resolve_block_size(default=16)) + else: + self.block_size = int(server_args.speculative_num_draft_tokens) + model_block_size = draft_config.block_size + if model_block_size is None: + model_block_size = getattr(self.draft_model, "block_size", None) + if model_block_size is not None and int(model_block_size) != int( + self.block_size + ): + logger.warning( + "DFLASH block size mismatch: using speculative_num_draft_tokens=%s but draft config block_size=%s.", + self.block_size, + model_block_size, + ) + + self._mask_token = draft_config.mask_token + self._mask_token_id_override = draft_config.mask_token_id + self._mask_token_id = self._resolve_mask_token_id( + mask_token=self._mask_token, + mask_token_id=self._mask_token_id_override, + ) + if self.tp_rank == 0: + logger.info( + "Initialized DFLASH draft runner. attention_backend=%s, model=%s, block_size=%s, draft_window_size=%s, compact_cache=%s", + getattr(draft_server_args, "attention_backend", None), + self.draft_model.__class__.__name__, + self.block_size, + self.draft_window_size, + self.use_compact_draft_cache, + ) + logger.info( + "DFLASH draft runner ready. mask_token=%s, mask_token_id=%s, mask_token_id_override=%s", + self._mask_token, + self._mask_token_id, + self._mask_token_id_override, + ) + + self._block_pos_offsets = torch.arange( + self.block_size, device=self.device, dtype=torch.int64 + ) + self._draft_block_ids_buf: Optional[torch.Tensor] = None # [cap_bs, block_size] + self._draft_block_positions_buf: Optional[torch.Tensor] = ( + None # [cap_bs, block_size] + ) + self._draft_block_tokens_buf: Optional[torch.Tensor] = ( + None # [cap_bs, block_size] + ) + self._draft_block_end_buf: Optional[torch.Tensor] = None # [cap_bs] + self._draft_seq_lens_cpu_buf: Optional[torch.Tensor] = None # [cap_bs] on CPU + self._draft_block_spec_info = DFlashVerifyInput( + draft_token=torch.empty((0,), dtype=torch.long, device=self.device), + positions=torch.empty((0,), dtype=torch.int64, device=self.device), + draft_token_num=int(self.block_size), + custom_mask=None, + capture_hidden_mode=CaptureHiddenMode.NULL, + ) + self._draft_greedy_gathered_max_buf: Optional[torch.Tensor] = None + self._draft_greedy_gathered_ids_buf: Optional[torch.Tensor] = None + self._draft_greedy_gather_cap: int = 0 + self._draft_greedy_best_rank_buf: Optional[torch.Tensor] = None + self._draft_greedy_rank_index_buf: Optional[torch.Tensor] = None + self._draft_greedy_selected_ids_buf: Optional[torch.Tensor] = None + self._draft_greedy_index_cap: int = 0 + + self._use_fused_kv_materialize = is_cuda() + self._fused_kv_helper: Optional[object] = None + if self._use_fused_kv_materialize: + self._init_fused_kv_helper() + + def _init_fused_kv_helper(self) -> None: + """Initialize the fused KV materialization helper with pre-stacked weights.""" + try: + layers = self.draft_model.layers + fused_disable_reason: Optional[str] = None + + if len(layers) == 0: + fused_disable_reason = "no layers found" + + for layer_idx, layer in enumerate(layers): + attn = layer.self_attn + eligible, reason = can_dflash_use_fused_qkv_proj(attn.qkv_proj) + if not eligible: + fused_disable_reason = f"{reason}: layer={layer_idx}" + break + + # Keep semantics aligned with set_kv_buffer scaling behavior. + k_scale = getattr(attn.attn, "k_scale", None) + v_scale = getattr(attn.attn, "v_scale", None) + if k_scale is not None and not math.isclose(float(k_scale), 1.0): + fused_disable_reason = ( + "non-unit k_scale is not supported for fused KV path: " + f"layer={layer_idx}, k_scale={k_scale}" + ) + break + if v_scale is not None and not math.isclose(float(v_scale), 1.0): + fused_disable_reason = ( + "non-unit v_scale is not supported for fused KV path: " + f"layer={layer_idx}, v_scale={v_scale}" + ) + break + + rope_is_neox_style = bool( + getattr(attn.rotary_emb, "is_neox_style", True) + ) + if not rope_is_neox_style: + fused_disable_reason = ( + "non-neox RoPE is not supported for fused KV path: " + f"layer={layer_idx}, rope_is_neox_style={rope_is_neox_style}" + ) + break + + if fused_disable_reason is not None: + if self.tp_rank == 0: + logger.info( + "DFLASH fused KV materialization disabled: %s", + fused_disable_reason, + ) + self._use_fused_kv_materialize = False + self._fused_kv_helper = None + return + + FusedKVMaterializeHelper = _get_fused_kv_materialize_helper() + first_attn = layers[0].self_attn + rotary_emb = first_attn.rotary_emb + + self._fused_kv_helper = FusedKVMaterializeHelper( + layers=layers, + rotary_emb=rotary_emb, + num_kv_heads=first_attn.num_kv_heads, + head_dim=first_attn.head_dim, + device=self.device, + ) + if self.tp_rank == 0: + logger.info( + "DFLASH fused KV materialization enabled. " + "n_layers=%d, num_kv_heads=%d, head_dim=%d", + len(layers), + first_attn.num_kv_heads, + first_attn.head_dim, + ) + except Exception as e: + logger.warning( + "DFLASH fused KV initialization failed, falling back to sequential path: %s", + e, + ) + self._use_fused_kv_materialize = False + self._fused_kv_helper = None + + def _ensure_draft_block_buffers(self, bs: int) -> None: + cap = ( + 0 + if self._draft_block_ids_buf is None + else int(self._draft_block_ids_buf.shape[0]) + ) + if cap >= int(bs): + return + + new_cap = max(int(bs), cap * 2 if cap > 0 else int(bs)) + device = self.device + block_size = int(self.block_size) + self._draft_block_ids_buf = torch.empty( + (new_cap, block_size), dtype=torch.long, device=device + ) + self._draft_block_positions_buf = torch.empty( + (new_cap, block_size), dtype=torch.int64, device=device + ) + self._draft_block_tokens_buf = torch.empty( + (new_cap, block_size), dtype=torch.long, device=device + ) + self._draft_block_end_buf = torch.empty( + (new_cap,), dtype=torch.int32, device=device + ) + self._draft_seq_lens_cpu_buf = torch.empty( + (new_cap,), dtype=torch.int32, device="cpu" + ) + + def __getattr__(self, name): + # Delegate anything not implemented yet to the target worker. + return getattr(self.target_worker, name) + + def clear_cache_pool(self): + # The target worker owns the shared KV allocator/cache. For the compact + # sliding-window path, the draft req->token view is rebuilt from committed + # target state before each draft forward, so there is nothing persistent + # to flush here. + pass + + def _gather_req_to_token_masked( + self, + *, + req_to_token: torch.Tensor, + req_pool_indices: torch.Tensor, + pos2d: torch.Tensor, + mask: torch.Tensor, + context: str, + ) -> torch.Tensor: + if pos2d.ndim != 2: + raise RuntimeError( + f"{context} expected 2D positions, got shape={tuple(pos2d.shape)}." + ) + if mask.shape != pos2d.shape: + raise RuntimeError( + f"{context} mask/position shape mismatch: {tuple(mask.shape)} vs {tuple(pos2d.shape)}." + ) + + if req_pool_indices.dtype != torch.int64: + req_pool_indices = req_pool_indices.to(torch.int64) + if mask.dtype != torch.bool: + mask = mask.to(torch.bool) + + table_width = int(req_to_token.shape[1]) + if table_width <= 0: + if bool(mask.any().item()): + raise RuntimeError( + f"{context} req_to_token table is empty but gather mask is non-empty." + ) + return torch.empty((0,), dtype=torch.int64, device=self.device) + + # Only the masked-off rectangular padding can be out of range in the normal + # ragged-batch case. Replace those don't-care columns with a valid in-range + # position before the gather so the kernel only sees real positions. + safe_pos2d = pos2d.masked_fill(~mask, 0) + return req_to_token[req_pool_indices[:, None], safe_pos2d][mask].to(torch.int64) + + def _gather_req_to_token_segments( + self, + *, + req_to_token: torch.Tensor, + req_pool_indices: torch.Tensor, + start: torch.Tensor | None, + lengths: torch.Tensor, + ) -> torch.Tensor: + lengths = lengths.to(torch.int64) + if lengths.numel() == 0: + return torch.empty((0,), dtype=torch.int64, device=self.device) + max_len = int(lengths.max().item()) + if max_len <= 0: + return torch.empty((0,), dtype=torch.int64, device=self.device) + + if req_pool_indices.dtype != torch.int64: + req_pool_indices = req_pool_indices.to(torch.int64) + offsets = torch.arange( + max_len, device=self.device, dtype=torch.int64 + ).unsqueeze(0) + if start is None: + pos2d = offsets.expand(req_pool_indices.shape[0], -1) + else: + pos2d = start.to(torch.int64).unsqueeze(1) + offsets + mask = offsets < lengths.unsqueeze(1) + return self._gather_req_to_token_masked( + req_to_token=req_to_token, + req_pool_indices=req_pool_indices, + pos2d=pos2d, + mask=mask, + context="DFLASH req_to_token segment gather", + ) + + def _compute_compact_draft_seq_lens(self, seq_lens: torch.Tensor) -> torch.Tensor: + assert self.draft_window_size is not None + visible_lens = torch.clamp( + seq_lens.to(dtype=torch.int32, device=self.device), + max=int(self.draft_window_size), + ) + if self.page_size <= 1: + return visible_lens + + # Paged FA backends derive the page table from local token positions, so the + # compact suffix must start on a page boundary. Keep up to page_size - 1 extra + # tokens on the left to preserve valid local page structure. + seq_lens_i64 = seq_lens.to(torch.int64) + visible_lens_i64 = visible_lens.to(torch.int64) + visible_start = seq_lens_i64 - visible_lens_i64 + aligned_start = visible_start - torch.remainder(visible_start, self.page_size) + return (seq_lens_i64 - aligned_start).to(torch.int32) + + def _resolve_mask_token_id( + self, *, mask_token: str, mask_token_id: Optional[int] = None + ) -> int: + if not isinstance(mask_token, str) or not mask_token: + raise ValueError( + f"DFLASH mask_token must be a non-empty string, got {mask_token!r}." + ) + + vocab_size = int(self.target_worker.model_runner.model_config.vocab_size) + if mask_token_id is not None: + resolved_id = int(mask_token_id) + if resolved_id >= vocab_size: + raise ValueError( + "DFLASH mask_token_id is outside the target vocab size. " + f"mask_token_id={resolved_id}, vocab_size={vocab_size}. " + f"This likely means mask_token={mask_token!r} requires vocab expansion beyond the model's embedding size. " + "SGLang does not support resizing target embeddings for DFLASH yet." + ) + + tokenizer = getattr(self.target_worker, "tokenizer", None) + if tokenizer is not None: + token_id_from_vocab = tokenizer.get_vocab().get(mask_token, None) + if ( + token_id_from_vocab is not None + and int(token_id_from_vocab) != resolved_id + ): + raise ValueError( + "DFLASH config mismatch: dflash_config.mask_token_id conflicts with tokenizer vocab id " + f"for dflash_config.mask_token. mask_token={mask_token!r}, " + f"mask_token_id={resolved_id}, tokenizer_vocab_id={int(token_id_from_vocab)}." + ) + return resolved_id + + tokenizer = getattr(self.target_worker, "tokenizer", None) + if tokenizer is None: + raise RuntimeError( + "DFLASH requires tokenizer initialization when dflash_config.mask_token_id is not set " + "(skip_tokenizer_init is not supported in this mode)." + ) + + resolved_id = None + if getattr(tokenizer, "mask_token", None) == mask_token: + resolved_id = getattr(tokenizer, "mask_token_id", None) + + if resolved_id is None: + # Prefer checking the explicit vocab mapping first. + vocab = tokenizer.get_vocab() + resolved_id = vocab.get(mask_token, None) + + if resolved_id is None: + # Mirror the reference DFlash HF demo by adding the mask token to the tokenizer. + # This is safe only when the resulting id stays within the target model vocab size. + added = tokenizer.add_special_tokens({"mask_token": mask_token}) + resolved_id = getattr(tokenizer, "mask_token_id", None) + if resolved_id is None: + resolved_id = tokenizer.convert_tokens_to_ids(mask_token) + + if added and self.tp_rank == 0: + logger.info( + "Added DFLASH mask token to tokenizer. token=%s, mask_token_id=%s, tokenizer_len=%s, model_vocab_size=%s", + mask_token, + resolved_id, + len(tokenizer), + vocab_size, + ) + + if resolved_id is None or int(resolved_id) < 0: + raise ValueError( + "DFLASH requires resolving a mask token id, but it could not be resolved. " + f"mask_token={mask_token!r}." + ) + + if resolved_id >= vocab_size: + raise ValueError( + "DFLASH mask_token_id is outside the target vocab size. " + f"mask_token_id={resolved_id}, vocab_size={vocab_size}. " + f"This likely means mask_token={mask_token!r} requires vocab expansion beyond the model's embedding size. " + "SGLang does not support resizing target embeddings for DFLASH yet." + ) + + return int(resolved_id) + + def _prepare_for_speculative_decoding( + self, batch: ScheduleBatch, draft_input: DFlashDraftInput + ): + if batch.forward_mode.is_extend() or batch.forward_mode.is_idle(): + return + + if batch.has_grammar: + raise RuntimeError( + "Invariant broken: DFLASH batch has grammar constraints, but scheduler should have rejected this request." + ) + if batch.sampling_info is not None and not batch.sampling_info.is_all_greedy: + if ( + not is_dflash_sampling_verify_available() + and not self._warned_sampling_fallback + and self.tp_rank == 0 + ): + logger.warning( + "DFLASH non-greedy verification is unavailable on this build/device; " + "falling back to greedy argmax verification." + ) + self._warned_sampling_fallback = True + + bs = batch.batch_size() + + # --- 1) Append any newly committed tokens into the draft KV cache. + self._append_target_hidden_to_draft_kv(batch, draft_input) + + target_model = self.target_worker.model_runner.model + embed_module = target_model.get_input_embeddings() + lm_head = getattr(target_model, "lm_head", None) + if ( + lm_head is None + or not hasattr(lm_head, "weight") + or not hasattr(lm_head, "shard_indices") + ): + raise RuntimeError( + "DFLASH requires the target model to expose a vocab-parallel `lm_head` with `weight` and " + "`shard_indices` attributes." + ) + + # --- 2) Draft a non-causal block with the draft model. + self._ensure_draft_block_buffers(bs) + assert self._draft_block_ids_buf is not None + assert self._draft_block_positions_buf is not None + assert self._draft_block_tokens_buf is not None + assert self._draft_block_end_buf is not None + assert self._draft_seq_lens_cpu_buf is not None + + block_ids = self._draft_block_ids_buf[:bs] + block_ids.fill_(int(self._mask_token_id)) + block_ids[:, 0].copy_(draft_input.verified_id.to(torch.long)) + + noise_embedding = embed_module(block_ids) + input_embeds = noise_embedding.view(-1, noise_embedding.shape[-1]) + + # For spec-v1, the draft KV cache is always materialized before drafting the + # next block. `target_prefix_lens` stay absolute for RoPE; `draft_prefix_lens` + # are the logical resident lengths in the draft-local cache. + target_prefix_lens = batch.seq_lens # int32, device + draft_prefix_lens = draft_input.draft_seq_lens + if draft_prefix_lens.dtype != torch.int32: + draft_prefix_lens = draft_prefix_lens.to(torch.int32) + if draft_prefix_lens.device != self.device: + draft_prefix_lens = draft_prefix_lens.to(self.device, non_blocking=True) + + positions_2d = self._draft_block_positions_buf[:bs] + torch.add( + target_prefix_lens.unsqueeze(1), self._block_pos_offsets, out=positions_2d + ) + positions = positions_2d.reshape(-1) + + block_start = draft_prefix_lens + block_end = self._draft_block_end_buf[:bs] + torch.add(block_start, int(self.block_size), out=block_end) + + seq_lens_cpu = self._draft_seq_lens_cpu_buf[:bs] + seq_lens_cpu.copy_(draft_prefix_lens.to(device="cpu", dtype=torch.int32)) + allocator = self.draft_model_runner.token_to_kv_pool_allocator + token_to_kv_pool_state_backup = allocator.backup_state() + try: + if self.page_size == 1: + block_cache_loc = allocator.alloc(bs * self.block_size) + else: + block_end_cpu = seq_lens_cpu + int(self.block_size) + last_loc = get_last_loc( + self.draft_model_runner.req_to_token_pool.req_to_token, + batch.req_pool_indices, + block_start, + ) + block_cache_loc = allocator.alloc_extend( + block_start, + seq_lens_cpu, + block_end, + block_end_cpu, + last_loc, + bs * self.block_size, + ) + if block_cache_loc is None: + raise RuntimeError( + f"DFLASH draft OOM when allocating {bs * self.block_size} block tokens." + ) + + assign_req_to_token_pool_func( + batch.req_pool_indices, + self.draft_model_runner.req_to_token_pool.req_to_token, + block_start, + block_end, + block_cache_loc, + bs, + ) + + # Use TARGET_VERIFY mode (cuda-graphable) to run a fixed-size draft block. + # In this mode, `seq_lens` stores the prefix lengths; attention backends + # derive kv_len by adding `draft_token_num`. + draft_spec_info = self._draft_block_spec_info + seq_lens = draft_prefix_lens + seq_lens_sum = int(draft_prefix_lens.sum().item()) + forward_batch = ForwardBatch( + forward_mode=ForwardMode.TARGET_VERIFY, + batch_size=bs, + input_ids=block_ids.flatten(), + req_pool_indices=batch.req_pool_indices, + seq_lens=seq_lens, + out_cache_loc=block_cache_loc, + seq_lens_sum=seq_lens_sum, + seq_lens_cpu=seq_lens_cpu, + positions=positions, + req_to_token_pool=self.draft_model_runner.req_to_token_pool, + token_to_kv_pool=self.draft_model_runner.token_to_kv_pool, + attn_backend=self.draft_model_runner.attn_backend, + input_embeds=input_embeds, + spec_algorithm=SpeculativeAlgorithm.DFLASH, + spec_info=draft_spec_info, + capture_hidden_mode=CaptureHiddenMode.NULL, + ) + + with torch.inference_mode(): + draft_logits_output = self.draft_model_runner.forward( + forward_batch + ).logits_output + finally: + # Drop the speculative block from the shared allocator (EAGLE3-style). + allocator.restore_state(token_to_kv_pool_state_backup) + + draft_hidden = draft_logits_output.hidden_states + if draft_hidden is None: + raise RuntimeError("DFLASH draft model returned no hidden states.") + draft_hidden = draft_hidden.view(bs, self.block_size, -1) + draft_next = self._greedy_sample_from_vocab_parallel_head( + hidden_states=draft_hidden[:, 1:, :].reshape(-1, draft_hidden.shape[-1]), + lm_head=lm_head, + ).view(bs, self.block_size - 1) + draft_tokens = self._draft_block_tokens_buf[:bs] + draft_tokens[:, 0].copy_(block_ids[:, 0]) + draft_tokens[:, 1:].copy_(draft_next) + positions = positions_2d.reshape(-1) + + verify_input = DFlashVerifyInput( + draft_token=draft_tokens.reshape(-1), + positions=positions, + draft_token_num=self.block_size, + ) + _, build_custom_mask = resolve_dflash_verify_mask_policy( + self.model_runner.attn_backend + ) + verify_input.prepare_for_verify( + batch, + self.page_size, + build_custom_mask=build_custom_mask, + ) + + batch.forward_mode = ( + ForwardMode.TARGET_VERIFY + if not batch.forward_mode.is_idle() + else ForwardMode.IDLE + ) + batch.spec_info = verify_input + batch.return_hidden_states = False + + def _greedy_sample_from_vocab_parallel_head( + self, + *, + hidden_states: torch.Tensor, + lm_head, + chunk_size: int = 256, + ) -> torch.Tensor: + """Greedy argmax over the target LM head in a TP-safe way. + + We cannot materialize full logits for large vocabularies efficiently, and with + TP>1 each rank only owns a shard of the LM head weight. This computes the + per-rank max, gathers candidates across TP ranks, and selects the global max. + """ + + if hidden_states.numel() == 0: + return torch.empty((0,), dtype=torch.long, device=hidden_states.device) + + tp_group = get_tp_group() + tp_size = int(tp_group.world_size) + + if not hasattr(lm_head, "weight") or not hasattr(lm_head, "shard_indices"): + raise RuntimeError( + "DFLASH greedy sampling requires a vocab-parallel head with `weight` and `shard_indices`." + ) + + shard = lm_head.shard_indices + weight = lm_head.weight # [local_vocab_padded, hidden] + weight_dtype = weight.dtype + + # Valid ranges in the local shard (excluding padding): + # base vocab: [0, num_org) + # added vocab: [num_org_padded, num_org_padded + num_added) + num_org = int(shard.num_org_elements) + num_org_padded = int(shard.num_org_elements_padded) + num_added = int(shard.num_added_elements) + org_vocab_start = int(shard.org_vocab_start_index) + added_vocab_start = int(shard.added_vocab_start_index) + + num_tokens = int(hidden_states.shape[0]) + out_token_ids = torch.empty( + (num_tokens,), dtype=torch.long, device=hidden_states.device + ) + + def _cast_hs(x: torch.Tensor) -> torch.Tensor: + return x if x.dtype == weight_dtype else x.to(weight_dtype) + + # Fast path (common): single-rank greedy sampling over the base vocab shard. + # Avoids extra max/id bookkeeping that is only needed for TP sync or added vocab. + if tp_size == 1 and num_added == 0: + for start in range(0, num_tokens, int(chunk_size)): + end = min(num_tokens, start + int(chunk_size)) + hs = _cast_hs(hidden_states[start:end]) + if num_org > 0: + base_logits = torch.matmul(hs, weight[:num_org].T) + out_token_ids[start:end] = ( + torch.argmax(base_logits, dim=-1).to(torch.long) + + org_vocab_start + ) + else: + out_token_ids[start:end] = 0 + return out_token_ids + + for start in range(0, num_tokens, int(chunk_size)): + end = min(num_tokens, start + int(chunk_size)) + hs = _cast_hs(hidden_states[start:end]) + chunk_len = int(hs.shape[0]) + + # Base vocab logits. + if num_org > 0: + base_logits = torch.matmul(hs, weight[:num_org].T) + local_max, local_arg = torch.max(base_logits, dim=-1) + else: + local_max = torch.full( + (chunk_len,), + torch.finfo(weight_dtype).min, + dtype=weight_dtype, + device=hs.device, + ) + local_arg = torch.zeros( + (chunk_len,), dtype=torch.int64, device=hs.device + ) + + # Added vocab logits (e.g., LoRA-added embeddings), if present. + if num_added > 0: + added_slice_start = num_org_padded + added_slice_end = num_org_padded + num_added + added_logits = torch.matmul( + hs, weight[added_slice_start:added_slice_end].T + ) + added_max, added_arg = torch.max(added_logits, dim=-1) + use_added = added_max > local_max + local_max = torch.where(use_added, added_max, local_max) + # For base/added conversion below, keep local_arg expressed in the full local + # weight index space (base + padding + added), matching `lm_head.weight`. + local_arg = torch.where( + use_added, added_arg.to(local_arg.dtype) + num_org_padded, local_arg + ) + + # Convert local argmax indices to global token ids. + if num_added == 0: + local_arg.add_(org_vocab_start) + global_ids = local_arg + else: + global_ids = torch.empty( + (chunk_len,), dtype=torch.int64, device=hs.device + ) + is_base = local_arg < num_org + global_ids[is_base] = org_vocab_start + local_arg[is_base] + global_ids[~is_base] = added_vocab_start + ( + local_arg[~is_base] - num_org_padded + ) + + if tp_size == 1: + out_token_ids[start:end] = global_ids.to(torch.long) + continue + + # Gather per-rank maxima and associated global ids, then select the global max. + needed = tp_size * chunk_len + chunk_cap = int(chunk_size) + if ( + self._draft_greedy_gather_cap < needed + or self._draft_greedy_gathered_max_buf is None + or self._draft_greedy_gathered_ids_buf is None + or self._draft_greedy_gathered_max_buf.dtype != local_max.dtype + or self._draft_greedy_gathered_max_buf.device != hs.device + ): + # Allocate enough space for the max chunk size to avoid reallocations. + cap = tp_size * chunk_cap + self._draft_greedy_gathered_max_buf = torch.empty( + (cap,), dtype=local_max.dtype, device=hs.device + ) + self._draft_greedy_gathered_ids_buf = torch.empty( + (cap,), dtype=global_ids.dtype, device=hs.device + ) + self._draft_greedy_gather_cap = cap + + if ( + self._draft_greedy_index_cap < chunk_len + or self._draft_greedy_best_rank_buf is None + or self._draft_greedy_rank_index_buf is None + or self._draft_greedy_selected_ids_buf is None + or self._draft_greedy_best_rank_buf.device != hs.device + or self._draft_greedy_selected_ids_buf.device != hs.device + ): + self._draft_greedy_best_rank_buf = torch.empty( + (chunk_cap,), dtype=torch.int64, device=hs.device + ) + self._draft_greedy_rank_index_buf = torch.empty( + (1, chunk_cap), dtype=torch.int64, device=hs.device + ) + self._draft_greedy_selected_ids_buf = torch.empty( + (1, chunk_cap), dtype=torch.int64, device=hs.device + ) + self._draft_greedy_index_cap = chunk_cap + + gathered_max = self._draft_greedy_gathered_max_buf[:needed] + gathered_ids = self._draft_greedy_gathered_ids_buf[:needed] + + tp_group.all_gather_into_tensor(gathered_max, local_max.contiguous()) + tp_group.all_gather_into_tensor(gathered_ids, global_ids.contiguous()) + gathered_max = gathered_max.view(tp_size, chunk_len) + gathered_ids = gathered_ids.view(tp_size, chunk_len) + + best_rank = self._draft_greedy_best_rank_buf[:chunk_len] + torch.argmax(gathered_max, dim=0, out=best_rank) + + rank_index = self._draft_greedy_rank_index_buf[:, :chunk_len] + rank_index[0].copy_(best_rank) + selected_ids = self._draft_greedy_selected_ids_buf[:, :chunk_len] + torch.gather(gathered_ids, 0, rank_index, out=selected_ids) + out_token_ids[start:end].copy_(selected_ids.view(-1)) + + return out_token_ids + + def _append_target_hidden_to_draft_kv( + self, + batch: ScheduleBatch, + draft_input: DFlashDraftInput, + ) -> None: + """Materialize the target hidden-state features into the draft KV cache. + + This must be run before exposing new tokens to radix cache (prefix hits), otherwise + another request could reuse target KV indices without having draft KV values. + """ + + bs = batch.batch_size() + device = self.model_runner.device + + if draft_input.target_hidden is None: + raise RuntimeError( + "DFLASH draft state missing target_hidden context features." + ) + if draft_input.ctx_lens.numel() != bs: + raise RuntimeError( + f"DFLASH ctx_lens length mismatch: got {draft_input.ctx_lens.numel()} for bs={bs}." + ) + if draft_input.draft_seq_lens.numel() != bs: + raise RuntimeError( + f"DFLASH draft_seq_lens length mismatch: got {draft_input.draft_seq_lens.numel()} for bs={bs}." + ) + + total_ctx = int(draft_input.target_hidden.shape[0]) + if total_ctx <= 0: + draft_input.ctx_lens = torch.zeros_like(draft_input.ctx_lens) + draft_input.target_hidden = draft_input.target_hidden[:0] + return + + target_req_to_token = batch.req_to_token_pool.req_to_token + draft_req_to_token = self.draft_model_runner.req_to_token_pool.req_to_token + + req_pool_indices = batch.req_pool_indices + if req_pool_indices.dtype != torch.int64: + req_pool_indices = req_pool_indices.to(torch.int64) + + ctx_lens = draft_input.ctx_lens + if ctx_lens.dtype != torch.int32: + ctx_lens = ctx_lens.to(torch.int32) + if ctx_lens.device != device: + ctx_lens = ctx_lens.to(device, non_blocking=True) + ctx_start = batch.seq_lens.to(torch.int64) - ctx_lens.to(torch.int64) + + if bs == 1: + # Fast path for single request. + max_ctx = int(total_ctx) + if max_ctx <= self._block_pos_offsets.numel(): + r = self._block_pos_offsets[:max_ctx] + else: + r = torch.arange(max_ctx, device=device, dtype=torch.int64) + pos2d = ctx_start[:, None] + r[None, :] # [1, ctx] + cache2d = target_req_to_token[req_pool_indices[:, None], pos2d] # [1, ctx] + ctx_cache_loc = cache2d.reshape(-1).to(torch.int64) # [ctx] + ctx_positions = pos2d.reshape(-1) # [ctx] + else: + # In decode mode, ctx_lens <= block_size so we can skip the .item() sync. + if batch.forward_mode.is_extend() or batch.is_extend_in_batch: + max_ctx = int(ctx_lens.max().item()) + else: + max_ctx = int(self.block_size) + if max_ctx <= 0: + raise RuntimeError(f"DFLASH invalid max_ctx={max_ctx} for KV append.") + + if max_ctx <= self._block_pos_offsets.numel(): + r = self._block_pos_offsets[:max_ctx] + else: + r = torch.arange(max_ctx, device=device, dtype=torch.int64) + r = r[None, :] # [1, max_ctx] + pos2d = ctx_start[:, None] + r # [bs, max_ctx] + mask = r < ctx_lens[:, None] + + # Batched gather of cache locations and positions. + ctx_cache_loc = self._gather_req_to_token_masked( + req_to_token=target_req_to_token, + req_pool_indices=req_pool_indices, + pos2d=pos2d, + mask=mask, + context="DFLASH target hidden KV append", + ) # [sum(ctx_lens)] + ctx_positions = pos2d[mask] # [sum(ctx_lens)] + + with torch.inference_mode(): + ctx_hidden = self.draft_model.project_target_hidden( + draft_input.target_hidden + ) # [sum(ctx), hidden] + if ctx_hidden.shape[0] != ctx_cache_loc.numel(): + raise RuntimeError( + f"DFLASH ctx_hidden/cache_loc mismatch: {ctx_hidden.shape[0]} vs {ctx_cache_loc.numel()}." + ) + + if self._use_fused_kv_materialize and self._fused_kv_helper is not None: + try: + self._append_target_hidden_fused( + ctx_hidden, ctx_positions, ctx_cache_loc + ) + except Exception as e: + logger.warning( + "DFLASH fused KV append failed; falling back to sequential path: %s", + e, + ) + self._use_fused_kv_materialize = False + self._fused_kv_helper = None + self._append_target_hidden_sequential( + ctx_hidden, ctx_positions, ctx_cache_loc + ) + else: + self._append_target_hidden_sequential( + ctx_hidden, ctx_positions, ctx_cache_loc + ) + + if self.use_compact_draft_cache: + new_draft_seq_lens = self._compute_compact_draft_seq_lens(batch.seq_lens) + suffix_start = batch.seq_lens.to(torch.int64) - new_draft_seq_lens.to( + torch.int64 + ) + suffix_cache_loc = self._gather_req_to_token_segments( + req_to_token=target_req_to_token, + req_pool_indices=req_pool_indices, + start=suffix_start, + lengths=new_draft_seq_lens, + ) + assign_req_to_token_pool_func( + batch.req_pool_indices, + draft_req_to_token, + torch.zeros_like(new_draft_seq_lens), + new_draft_seq_lens, + suffix_cache_loc, + bs, + ) + draft_input.draft_seq_lens = new_draft_seq_lens + else: + draft_input.draft_seq_lens = batch.seq_lens.to(dtype=torch.int32) + draft_input.ctx_lens = torch.zeros_like(ctx_lens) + draft_input.target_hidden = draft_input.target_hidden[:0] + + def _append_target_hidden_sequential( + self, + ctx_hidden: torch.Tensor, + ctx_positions: torch.Tensor, + ctx_cache_loc: torch.Tensor, + ) -> None: + for layer in self.draft_model.layers: + attn = layer.self_attn + k, v = attn.kv_proj_only(ctx_hidden) + k = attn.apply_k_norm(k) + k = attn.apply_k_rope(ctx_positions, k) + k = k.view(-1, attn.num_kv_heads, attn.head_dim) + v = v.view(-1, attn.num_kv_heads, attn.head_dim) + self.draft_model_runner.token_to_kv_pool.set_kv_buffer( + attn.attn, + ctx_cache_loc, + k, + v, + attn.attn.k_scale, + attn.attn.v_scale, + ) + + def _append_target_hidden_fused( + self, + ctx_hidden: torch.Tensor, + ctx_positions: torch.Tensor, + ctx_cache_loc: torch.Tensor, + ) -> None: + """Fused KV materialization using batched projection + Triton kernel.""" + token_to_kv_pool = self.draft_model_runner.token_to_kv_pool + layers = self.draft_model.layers + + def _write_layer_kv( + layer_idx: int, cache_k: torch.Tensor, cache_v: torch.Tensor + ) -> None: + attn = layers[layer_idx].self_attn.attn + token_to_kv_pool.set_kv_buffer( + attn, + ctx_cache_loc, + cache_k, + cache_v, + attn.k_scale, + attn.v_scale, + ) + + self._fused_kv_helper.materialize( + ctx_hidden=ctx_hidden, + positions=ctx_positions, + write_layer_kv=_write_layer_kv, + ) + + def _update_target_mamba_state_after_verify( + self, + *, + batch: ScheduleBatch, + seq_lens_pre_verify: torch.Tensor, + commit_lens: torch.Tensor, + ) -> None: + """Commit Mamba intermediate states for accepted verify steps. + + During TARGET_VERIFY, Mamba kernels run with `disable_state_update=True` and + cache per-step intermediate states. After acceptance, we need to commit the + state corresponding to each request's last accepted step. + """ + attn_backend = self.target_worker.model_runner.attn_backend + if not hasattr(attn_backend, "update_mamba_state_after_mtp_verify"): + return + + accepted_steps = commit_lens.to(torch.int64) - 1 + mamba_steps_to_track = None + + if batch.mamba_track_indices is not None: + mamba_track_interval = self.server_args.mamba_track_interval + to_track_mask = ( + seq_lens_pre_verify // mamba_track_interval + != batch.seq_lens // mamba_track_interval + ) + tracking_point = ( + batch.seq_lens // mamba_track_interval * mamba_track_interval + ) + to_track_ith = torch.clamp(tracking_point - seq_lens_pre_verify - 1, min=0) + can_track_mask = to_track_mask & ( + to_track_ith < commit_lens.to(to_track_ith.dtype) + ) + mamba_steps_to_track = torch.where( + can_track_mask, + to_track_ith.to(torch.int64), + torch.full_like(to_track_ith, -1, dtype=torch.int64), + ) + + attn_backend.update_mamba_state_after_mtp_verify( + accepted_steps=accepted_steps, + mamba_track_indices=batch.mamba_track_indices, + mamba_steps_to_track=mamba_steps_to_track, + model=self.target_worker.model_runner.model, + ) + + def forward_batch_generation( + self, + batch: Union[ScheduleBatch, ModelWorkerBatch], + **kwargs, + ) -> GenerationBatchResult: + if getattr(batch, "return_logprob", False): + raise RuntimeError( + "Invariant broken: DFLASH batch requested return_logprob, but scheduler should have rejected this request." + ) + + if isinstance(batch, ModelWorkerBatch): + # Should not happen for spec-v1 (non-overlap) scheduling, but keep a sane fallback. + return self.target_worker.forward_batch_generation(batch, **kwargs) + + if batch.forward_mode.is_extend() or batch.is_extend_in_batch: + model_worker_batch = batch.get_model_worker_batch() + model_worker_batch.capture_hidden_mode = CaptureHiddenMode.FULL + + batch_result = self.target_worker.forward_batch_generation( + model_worker_batch, **kwargs + ) + logits_output, next_token_ids = ( + batch_result.logits_output, + batch_result.next_token_ids, + ) + if logits_output.hidden_states is None: + raise RuntimeError( + "DFLASH requires target aux hidden capture for prefill, but got None. " + "Make sure the target model has DFlash layers-to-capture configured." + ) + + if ( + model_worker_batch.extend_seq_lens is None + or model_worker_batch.extend_prefix_lens is None + ): + raise RuntimeError( + "DFLASH expected extend_seq_lens / extend_prefix_lens to be populated in extend mode, but got None." + ) + + # Materialize the prompt tokens into the draft KV cache immediately. This is required + # for radix cache support, since the scheduler may update radix after prefill returns. + device = next_token_ids.device + + def _to_int32_device_tensor(x, *, device=device): + if isinstance(x, torch.Tensor): + if x.device != device: + x = x.to(device, non_blocking=True) + return x if x.dtype == torch.int32 else x.to(torch.int32) + return torch.tensor(x, dtype=torch.int32, device=device) + + extend_seq_lens = _to_int32_device_tensor( + model_worker_batch.extend_seq_lens + ) + draft_input = DFlashDraftInput( + verified_id=next_token_ids.to(torch.int64), + target_hidden=logits_output.hidden_states, + ctx_lens=extend_seq_lens, + draft_seq_lens=( + torch.zeros_like(extend_seq_lens) + if self.use_compact_draft_cache + else _to_int32_device_tensor(model_worker_batch.extend_prefix_lens) + ), + ) + self._append_target_hidden_to_draft_kv(batch, draft_input) + batch.spec_info = draft_input + + return GenerationBatchResult( + logits_output=logits_output, + next_token_ids=next_token_ids, + num_accepted_tokens=0, + can_run_cuda_graph=batch_result.can_run_cuda_graph, + ) + + # Decode / target-verify stage. + draft_input = batch.spec_info + if not isinstance(draft_input, DFlashDraftInput): + raise RuntimeError( + "DFLASH decode requires DFlashDraftInput state on the running batch. " + "This usually means the request did not complete the prefill stage." + ) + + self._prepare_for_speculative_decoding(batch, draft_input) + + model_worker_batch = batch.get_model_worker_batch() + assert model_worker_batch.forward_mode.is_target_verify() + verify_input = model_worker_batch.spec_info + assert isinstance(verify_input, DFlashVerifyInput) + need_mamba_verify_commit = hasattr( + self.target_worker.model_runner.attn_backend, + "update_mamba_state_after_mtp_verify", + ) + seq_lens_pre_verify = ( + batch.seq_lens.clone() if need_mamba_verify_commit else None + ) + + batch_result = self.target_worker.forward_batch_generation( + model_worker_batch, is_verify=True, **kwargs + ) + logits_output, can_run_cuda_graph = ( + batch_result.logits_output, + batch_result.can_run_cuda_graph, + ) + + ( + new_verified_id, + commit_lens, + next_target_hidden, + accept_length_per_req_cpu, + ) = verify_input.verify( + batch=batch, + logits_output=logits_output, + page_size=self.page_size, + ) + if need_mamba_verify_commit: + assert seq_lens_pre_verify is not None + self._update_target_mamba_state_after_verify( + batch=batch, + seq_lens_pre_verify=seq_lens_pre_verify, + commit_lens=commit_lens, + ) + + # Update draft state for the next iteration. Also materialize the committed verify tokens + # into the draft KV cache immediately so radix cache entries are safe to reuse. + draft_input.verified_id = new_verified_id + draft_input.target_hidden = next_target_hidden + draft_input.ctx_lens = commit_lens + self._append_target_hidden_to_draft_kv(batch, draft_input) + batch.spec_info = draft_input + batch.forward_mode = ForwardMode.DECODE + + num_accepted_tokens = sum(accept_length_per_req_cpu) + if not self._logged_first_verify and self.tp_rank == 0: + logger.info( + "DFLASH verify completed. accept_length_per_req=%s", + accept_length_per_req_cpu, + ) + self._logged_first_verify = True + + return GenerationBatchResult( + logits_output=logits_output, + next_token_ids=new_verified_id, + num_accepted_tokens=num_accepted_tokens, + accept_length_per_req_cpu=accept_length_per_req_cpu, + can_run_cuda_graph=can_run_cuda_graph, + ) diff --git a/python/sglang/srt/speculative/spec_info.py b/python/sglang/srt/speculative/spec_info.py index a40a8aa0dc33..3e5727187572 100644 --- a/python/sglang/srt/speculative/spec_info.py +++ b/python/sglang/srt/speculative/spec_info.py @@ -15,6 +15,7 @@ class SpeculativeAlgorithm(Enum): """Enumeration of speculative decoding algorithms.""" + DFLASH = auto() EAGLE = auto() EAGLE3 = auto() STANDALONE = auto() @@ -33,6 +34,9 @@ def from_string(cls, name: Optional[str]) -> SpeculativeAlgorithm: def is_none(self) -> bool: return self == SpeculativeAlgorithm.NONE + def is_speculative(self) -> bool: + return self != SpeculativeAlgorithm.NONE + def is_eagle(self) -> bool: # NOTE: EAGLE3 is a variant of EAGLE return self == SpeculativeAlgorithm.EAGLE or self == SpeculativeAlgorithm.EAGLE3 @@ -40,6 +44,9 @@ def is_eagle(self) -> bool: def is_eagle3(self) -> bool: return self == SpeculativeAlgorithm.EAGLE3 + def is_dflash(self) -> bool: + return self == SpeculativeAlgorithm.DFLASH + def is_standalone(self) -> bool: return self == SpeculativeAlgorithm.STANDALONE @@ -57,6 +64,16 @@ def create_worker( ), "Cannot create worker for NONE speculative algorithm." enable_overlap = not server_args.disable_overlap_schedule + + if self.is_dflash(): + if enable_overlap: + raise ValueError( + "DFLASH does not support overlap scheduling (spec v2)." + ) + from sglang.srt.speculative.dflash_worker import DFlashWorker + + return DFlashWorker + if self.is_eagle() and server_args.enable_multi_layer_eagle: # FIXME: migrate to EagleWorker if enable_overlap: @@ -110,6 +127,8 @@ class SpecInputType(IntEnum): # If all algorithms can share the same datastrucutre of draft_input and verify_input, consider simplify it EAGLE_DRAFT = auto() EAGLE_VERIFY = auto() + DFLASH_DRAFT = auto() + DFLASH_VERIFY = auto() NGRAM_VERIFY = auto() @@ -120,11 +139,15 @@ def __init__(self, spec_input_type: SpecInputType): def is_draft_input(self) -> bool: # FIXME: remove this function which is only used for assertion # or use another variable name like `draft_input` to substitute `spec_info` - return self.spec_input_type == SpecInputType.EAGLE_DRAFT + return self.spec_input_type in { + SpecInputType.EAGLE_DRAFT, + SpecInputType.DFLASH_DRAFT, + } def is_verify_input(self) -> bool: return self.spec_input_type in { SpecInputType.EAGLE_VERIFY, + SpecInputType.DFLASH_VERIFY, SpecInputType.NGRAM_VERIFY, } diff --git a/python/sglang/srt/speculative/triton_ops/__init__.py b/python/sglang/srt/speculative/triton_ops/__init__.py new file mode 100644 index 000000000000..a8ea8f4c704b --- /dev/null +++ b/python/sglang/srt/speculative/triton_ops/__init__.py @@ -0,0 +1,20 @@ +# Copyright 2023-2024 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Triton kernels for speculative decoding.""" + +from sglang.srt.speculative.triton_ops.fused_kv_materialize import ( + FusedKVMaterializeHelper, +) + +__all__ = ["FusedKVMaterializeHelper"] diff --git a/python/sglang/srt/speculative/triton_ops/fused_kv_materialize.py b/python/sglang/srt/speculative/triton_ops/fused_kv_materialize.py new file mode 100644 index 000000000000..e7dc4c05ddfc --- /dev/null +++ b/python/sglang/srt/speculative/triton_ops/fused_kv_materialize.py @@ -0,0 +1,303 @@ +# Copyright 2023-2024 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Fused Triton kernel for DFlash KV materialization. + +Combines: KV projection (cuBLAS) + RMSNorm + RoPE (Triton), then pool-managed KV writes. +""" + +from typing import Callable, List + +import torch +import triton +import triton.language as tl + + +@triton.jit +def _fused_norm_rope_kernel( + kv_ptr, # [total_ctx, kv_size * 2] + k_norm_weight_ptr, # [head_dim] + cos_sin_cache_ptr, # [max_pos, rotary_dim] + positions_ptr, # [total_ctx] + k_out_ptr, # [total_ctx, num_kv_heads, head_dim] + v_out_ptr, # [total_ctx, num_kv_heads, head_dim] + kv_stride_ctx, + cos_sin_stride_pos, + k_out_stride_ctx, + k_out_stride_head, + v_out_stride_ctx, + v_out_stride_head, + total_ctx, + num_kv_heads: tl.constexpr, + head_dim: tl.constexpr, + kv_size: tl.constexpr, + rotary_dim: tl.constexpr, + half_rotary_dim: tl.constexpr, + eps: tl.constexpr, + BLOCK_HD: tl.constexpr, +): + """Fused RMSNorm(K) + RoPE(K) materialization. Grid: (total_ctx, num_kv_heads).""" + ctx_id = tl.program_id(0) + head_id = tl.program_id(1) + if ctx_id >= total_ctx: + return + + # Load metadata + position = tl.load(positions_ptr + ctx_id) + + # Compute base pointers + kv_base = kv_ptr + ctx_id * kv_stride_ctx + k_base = kv_base + head_id * head_dim + v_base = kv_base + kv_size + head_id * head_dim + k_write = k_out_ptr + ctx_id * k_out_stride_ctx + head_id * k_out_stride_head + v_write = v_out_ptr + ctx_id * v_out_stride_ctx + head_id * v_out_stride_head + + # Load K and V + offs = tl.arange(0, BLOCK_HD) + mask_hd = offs < head_dim + mask_half = offs < half_rotary_dim + + k_raw = tl.load(k_base + offs, mask=mask_hd, other=0.0).to(tl.float32) + v_raw = tl.load(v_base + offs, mask=mask_hd, other=0.0) + + # RMSNorm on K + inv_rms = tl.rsqrt(tl.sum(k_raw * k_raw) / head_dim + eps) + norm_w = tl.load(k_norm_weight_ptr + offs, mask=mask_hd, other=1.0).to(tl.float32) + k_normed = k_raw * inv_rms * norm_w + + # RoPE (neox style): k_first, k_second -> rotated + cos_sin_base = cos_sin_cache_ptr + position * cos_sin_stride_pos + cos_v = tl.load(cos_sin_base + offs, mask=mask_half, other=1.0).to(tl.float32) + sin_v = tl.load( + cos_sin_base + half_rotary_dim + offs, mask=mask_half, other=0.0 + ).to(tl.float32) + + # Extract first/second halves of K for rotation + k_first = tl.where(mask_half, k_normed, 0.0) + k_second_raw = tl.load( + k_base + half_rotary_dim + offs, mask=mask_half, other=0.0 + ).to(tl.float32) + norm_w_second = tl.load( + k_norm_weight_ptr + half_rotary_dim + offs, mask=mask_half, other=1.0 + ).to(tl.float32) + k_second = k_second_raw * inv_rms * norm_w_second + + # Apply rotation + k_rot_first = k_first * cos_v - k_second * sin_v + k_rot_second = k_second * cos_v + k_first * sin_v + + # Store V (no transform) + tl.store(v_write + offs, v_raw, mask=mask_hd) + + # Store K: rotated halves + pass-through + tl.store(k_write + offs, k_rot_first.to(v_raw.dtype), mask=mask_half) + tl.store( + k_write + half_rotary_dim + offs, k_rot_second.to(v_raw.dtype), mask=mask_half + ) + mask_pass = (offs >= rotary_dim) & (offs < head_dim) + tl.store(k_write + offs, k_normed.to(v_raw.dtype), mask=mask_pass) + + +def _fused_norm_rope( + kv: torch.Tensor, # [total_ctx, kv_size*2] + k_norm_weight: torch.Tensor, # [head_dim] + cos_sin_cache: torch.Tensor, # [max_pos, rotary_dim] + positions: torch.Tensor, # [total_ctx] + num_kv_heads: int, + head_dim: int, + rotary_dim: int, + eps: float = 1e-6, +) -> tuple[torch.Tensor, torch.Tensor]: + """Fused RMSNorm + RoPE materialization for a single layer.""" + total_ctx = kv.shape[0] + if total_ctx == 0: + empty = torch.empty( + (0, num_kv_heads, head_dim), dtype=kv.dtype, device=kv.device + ) + return empty, empty + + kv_size = num_kv_heads * head_dim + if kv.shape[1] != kv_size * 2: + raise ValueError( + "Invalid fused KV projection shape: " + f"got {tuple(kv.shape)}, expected second dim {kv_size * 2}." + ) + if rotary_dim <= 0 or rotary_dim > head_dim or rotary_dim % 2 != 0: + raise ValueError( + "Invalid fused KV rotary/head dim pair: " + f"rotary_dim={rotary_dim}, head_dim={head_dim}." + ) + + half_rotary_dim = rotary_dim // 2 + BLOCK_HD = triton.next_power_of_2(head_dim) + + # Ensure int64 for indexing + if positions.device != kv.device: + positions = positions.to(device=kv.device, dtype=torch.int64) + elif positions.dtype != torch.int64: + positions = positions.to(torch.int64) + + k_out = torch.empty( + (total_ctx, num_kv_heads, head_dim), dtype=kv.dtype, device=kv.device + ) + v_out = torch.empty_like(k_out) + + _fused_norm_rope_kernel[(total_ctx, num_kv_heads)]( + kv, + k_norm_weight, + cos_sin_cache, + positions, + k_out, + v_out, + kv.stride(0), + cos_sin_cache.stride(0), + k_out.stride(0), + k_out.stride(1), + v_out.stride(0), + v_out.stride(1), + total_ctx, + num_kv_heads, + head_dim, + kv_size, + rotary_dim, + half_rotary_dim, + eps, + BLOCK_HD, + ) + return k_out, v_out + + +class FusedKVMaterializeHelper: + """Fused KV materialization helper using batched projection. + + Uses torch.einsum for batched KV projection across all layers, + then a Triton kernel for fused RMSNorm + RoPE materialization per layer. + """ + + def __init__( + self, + layers: List, + rotary_emb, + num_kv_heads: int, + head_dim: int, + device: torch.device, + ): + self.num_kv_heads = num_kv_heads + self.head_dim = head_dim + self.rotary_emb = rotary_emb + self.n_layers = len(layers) + self.device = device + + self.rotary_dim = int(getattr(rotary_emb, "rotary_dim", head_dim)) + self.is_neox_style = bool(getattr(rotary_emb, "is_neox_style", True)) + + if not self.is_neox_style: + raise NotImplementedError("Only neox-style RoPE is supported.") + if self.rotary_dim <= 0 or self.rotary_dim > self.head_dim: + raise ValueError( + "Invalid fused KV rotary/head dim pair: " + f"rotary_dim={self.rotary_dim}, head_dim={self.head_dim}." + ) + + # Pre-extract and stack weights for batched projection. + kv_weights = [] + self.k_norm_weights = [] + self.eps_values = [] + + for layer_id, layer in enumerate(layers): + attn = layer.self_attn + if int(attn.num_kv_heads) != self.num_kv_heads: + raise ValueError( + "num_kv_heads mismatch across layers for fused KV path: " + f"expected {self.num_kv_heads}, got {int(attn.num_kv_heads)} at layer {layer_id}." + ) + if int(attn.head_dim) != self.head_dim: + raise ValueError( + "head_dim mismatch across layers for fused KV path: " + f"expected {self.head_dim}, got {int(attn.head_dim)} at layer {layer_id}." + ) + layer_rotary_dim = int( + getattr(attn.rotary_emb, "rotary_dim", self.head_dim) + ) + layer_is_neox = bool(getattr(attn.rotary_emb, "is_neox_style", True)) + if ( + layer_rotary_dim != self.rotary_dim + or layer_is_neox != self.is_neox_style + ): + raise ValueError( + "RoPE config mismatch across layers for fused KV path: " + f"expected (rotary_dim={self.rotary_dim}, neox={self.is_neox_style}), " + f"got (rotary_dim={layer_rotary_dim}, neox={layer_is_neox}) at layer {layer_id}." + ) + + # Extract KV portion of QKV weight + qkv_w = attn.qkv_proj.weight + kv_weight = qkv_w[attn.q_size : attn.q_size + 2 * attn.kv_size] + kv_weights.append(kv_weight) + self.k_norm_weights.append(attn.k_norm.weight) + self.eps_values.append(attn.k_norm.variance_epsilon) + + # Stack for batched einsum: [n_layers, kv_size*2, hidden_size] + self.batched_kv_weight = torch.stack(kv_weights) + + def materialize( + self, + ctx_hidden: torch.Tensor, + positions: torch.Tensor, + write_layer_kv: Callable[[int, torch.Tensor, torch.Tensor], None], + ) -> None: + """Materialize KV cache for all layers using batched projection.""" + total_ctx = ctx_hidden.shape[0] + if total_ctx == 0: + return + + if positions.ndim != 1: + positions = positions.reshape(-1) + if positions.numel() != total_ctx: + raise ValueError( + "positions must match ctx_hidden token count for fused KV materialization: " + f"positions={positions.numel()}, total_ctx={total_ctx}." + ) + + max_position = int(positions.max().item()) + ensure_cos_sin_cache_length = getattr( + self.rotary_emb, "_ensure_cos_sin_cache_length", None + ) + if callable(ensure_cos_sin_cache_length): + ensure_cos_sin_cache_length(max_position) + + cos_sin_cache = self.rotary_emb.cos_sin_cache + if max_position >= int(cos_sin_cache.shape[0]): + raise RuntimeError( + "RoPE cos/sin cache is too short for fused KV materialization: " + f"max_position={max_position}, cache_len={int(cos_sin_cache.shape[0])}." + ) + if cos_sin_cache.device != ctx_hidden.device: + cos_sin_cache = cos_sin_cache.to(ctx_hidden.device) + + # Batched KV projection: [n_layers, total_ctx, kv_size*2] + kv_all = torch.einsum("th,loh->lto", ctx_hidden, self.batched_kv_weight) + + # Per-layer fused norm/RoPE/materialize, then delegate writes to the KV pool. + for layer_id in range(self.n_layers): + cache_k, cache_v = _fused_norm_rope( + kv_all[layer_id], + self.k_norm_weights[layer_id], + cos_sin_cache, + positions, + self.num_kv_heads, + self.head_dim, + self.rotary_dim, + self.eps_values[layer_id], + ) + write_layer_kv(layer_id, cache_k, cache_v) diff --git a/python/sglang/test/test_utils.py b/python/sglang/test/test_utils.py index 9cbd2e59dc90..6022f602c3f4 100644 --- a/python/sglang/test/test_utils.py +++ b/python/sglang/test/test_utils.py @@ -107,6 +107,10 @@ DEFAULT_TARGET_MODEL_EAGLE3 = "meta-llama/Llama-3.1-8B-Instruct" DEFAULT_DRAFT_MODEL_EAGLE3 = "lmsys/sglang-EAGLE3-LLaMA3.1-Instruct-8B" +# DFLASH model +DEFAULT_TARGET_MODEL_DFLASH = "meta-llama/Llama-3.1-8B-Instruct" +DEFAULT_DRAFT_MODEL_DFLASH = "z-lab/LLaMA3.1-8B-Instruct-DFlash-UltraChat" + # EAGLE2 with DP-Attention models DEFAULT_TARGET_MODEL_EAGLE_DP_ATTN = "Qwen/Qwen3-30B-A3B" DEFAULT_DRAFT_MODEL_EAGLE_DP_ATTN = "Tengyunw/qwen3_30b_moe_eagle3" diff --git a/test/registered/spec/dflash/test_dflash.py b/test/registered/spec/dflash/test_dflash.py new file mode 100644 index 000000000000..aa9ee2327d21 --- /dev/null +++ b/test/registered/spec/dflash/test_dflash.py @@ -0,0 +1,152 @@ +import os +import unittest + +import openai + +from sglang.srt.environ import envs +from sglang.srt.utils import kill_process_tree +from sglang.test.ci.ci_register import register_cuda_ci +from sglang.test.kits.eval_accuracy_kit import GSM8KMixin +from sglang.test.kits.matched_stop_kit import MatchedStopMixin +from sglang.test.kits.radix_cache_server_kit import gen_radix_tree +from sglang.test.test_utils import ( + DEFAULT_DRAFT_MODEL_DFLASH, + DEFAULT_TARGET_MODEL_DFLASH, + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, + CustomTestCase, + popen_launch_server, +) + +register_cuda_ci(est_time=300, suite="stage-b-test-1-gpu-small") + + +class TestDFlashServerBase(CustomTestCase, MatchedStopMixin, GSM8KMixin): + max_running_requests = 64 + attention_backend = "flashinfer" + page_size = 1 + other_launch_args = [] + model = DEFAULT_TARGET_MODEL_DFLASH + draft_model = DEFAULT_DRAFT_MODEL_DFLASH + gsm8k_accuracy_thres = 0.75 + gsm8k_accept_length_thres = 2.8 + + @classmethod + def setUpClass(cls): + cls.base_url = DEFAULT_URL_FOR_TEST + launch_args = [ + "--trust-remote-code", + "--attention-backend", + cls.attention_backend, + "--speculative-algorithm", + "DFLASH", + "--speculative-draft-model-path", + cls.draft_model, + "--page-size", + str(cls.page_size), + "--max-running-requests", + str(cls.max_running_requests), + "--cuda-graph-bs", + *[str(i) for i in range(1, cls.max_running_requests + 1)], + ] + launch_args.extend(cls.other_launch_args) + old_value = os.environ.get("SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN") + os.environ["SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN"] = "1" + try: + with envs.SGLANG_ENABLE_STRICT_MEM_CHECK_DURING_BUSY.override( + 1 + ), envs.SGLANG_SPEC_NAN_DETECTION.override( + True + ), envs.SGLANG_SPEC_OOB_DETECTION.override( + True + ): + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=launch_args, + ) + finally: + if old_value is None: + del os.environ["SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN"] + else: + os.environ["SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN"] = old_value + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.process.pid) + + def test_early_stop(self): + client = openai.Client(base_url=self.base_url + "/v1", api_key="EMPTY") + for i in range(8): + max_tokens = (i % 3) + 1 + response = client.completions.create( + model=self.model, + prompt=f"There are {i} apples on the table. How to divide them equally?", + max_tokens=max_tokens, + temperature=0, + ) + text = response.choices[0].text + print(f"early_stop: max_tokens={max_tokens}, text={text!r}") + assert self.process.poll() is None + + def test_eos_handling(self): + client = openai.Client(base_url=self.base_url + "/v1", api_key="EMPTY") + response = client.chat.completions.create( + model=self.model, + messages=[{"role": "user", "content": "Today is a sunny day and I like"}], + max_tokens=256, + temperature=0.1, + ) + text = response.choices[0].message.content + print(f"eos_handling: text={text!r}") + self.assertNotIn("<|eot_id|>", text) + self.assertNotIn("<|end_of_text|>", text) + assert self.process.poll() is None + + def test_greedy_determinism(self): + client = openai.Client(base_url=self.base_url + "/v1", api_key="EMPTY") + prompt = "The capital of France is" + outputs = [] + for _ in range(2): + response = client.completions.create( + model=self.model, + prompt=prompt, + max_tokens=32, + temperature=0, + ) + outputs.append(response.choices[0].text) + print(f"determinism: {outputs=}") + self.assertEqual(outputs[0], outputs[1]) + assert self.process.poll() is None + + +class TestDFlashServerPage256(TestDFlashServerBase): + page_size = 256 + + def test_radix_attention(self): + import requests + + nodes = gen_radix_tree(num_nodes=50) + data = { + "input_ids": [node["input_ids"] for node in nodes], + "sampling_params": [ + {"max_new_tokens": node["decode_len"], "temperature": 0} + for node in nodes + ], + } + res = requests.post(self.base_url + "/generate", json=data) + assert res.status_code == 200 + assert self.process.poll() is None + + +class TestDFlashServerChunkedPrefill(TestDFlashServerBase): + other_launch_args = ["--chunked-prefill-size", "4"] + + +class TestDFlashServerNoCudaGraph(TestDFlashServerBase): + other_launch_args = ["--disable-cuda-graph"] + + +if __name__ == "__main__": + unittest.main() From 671fe73961b7b535fea0415616791506d3f1e0b3 Mon Sep 17 00:00:00 2001 From: Thomas Wang <1am9trash@gmail.com> Date: Wed, 8 Apr 2026 06:37:08 +0800 Subject: [PATCH 36/42] Reduce unnecessary kernels and copies in the NSA indexer (#22232) --- .../srt/layers/attention/nsa/nsa_indexer.py | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/python/sglang/srt/layers/attention/nsa/nsa_indexer.py b/python/sglang/srt/layers/attention/nsa/nsa_indexer.py index 02ef4e2440cd..6bfcb3f66852 100644 --- a/python/sglang/srt/layers/attention/nsa/nsa_indexer.py +++ b/python/sglang/srt/layers/attention/nsa/nsa_indexer.py @@ -257,13 +257,13 @@ def _weights_proj_bf16_in_fp32_out(self, x: torch.Tensor) -> torch.Tensor: weights, _ = self.weights_proj(x) return weights.float() - @torch.compile(dynamic=True) if not _is_hip else lambda f: f + @torch.compile(dynamic=True) def _project_and_scale_head_gates(self, x: torch.Tensor): weights = self._weights_proj_bf16_in_fp32_out(x) weights = weights * self.n_heads**-0.5 return weights - @torch.compile(dynamic=True) if not _is_hip else lambda f: f + @torch.compile(dynamic=True) def _get_logits_head_gate(self, x: torch.Tensor, q_scale: torch.Tensor): weights = self._weights_proj_bf16_in_fp32_out(x) weights = weights * self.n_heads**-0.5 @@ -318,8 +318,8 @@ def _get_q_k_bf16( q_rope, k_rope = self.rotary_emb(positions, q_rope, k_rope) - query[..., : self.rope_head_dim] = q_rope.clone() - key[..., : self.rope_head_dim] = k_rope.clone() + self._update_rope_guarded(query[..., : self.rope_head_dim], q_rope) + self._update_rope_guarded(key[..., : self.rope_head_dim], k_rope) if enable_dual_stream: current_stream = torch.cuda.current_stream() @@ -376,11 +376,19 @@ def _get_k_bf16( ) _, k_rope = self.rotary_emb(positions, k_rope, k_rope) - key[..., : self.rope_head_dim] = k_rope.clone() + self._update_rope_guarded(key[..., : self.rope_head_dim], k_rope) key = rotate_activation(key) return key + @staticmethod + def _update_rope_guarded(dst: torch.Tensor, src: torch.Tensor) -> None: + # On AMD with in-place RoPE kernels, self-aliasing can occur; + # skip write-back when src/dst tensors point to a single memory. + if src.data_ptr() == dst.data_ptr(): + return + dst.copy_(src) + def _get_topk_paged( self, forward_batch: ForwardBatch, From e6652309c48b64053409924464b50d7a84f8addf Mon Sep 17 00:00:00 2001 From: Kangyan-Zhou Date: Tue, 7 Apr 2026 15:44:52 -0700 Subject: [PATCH 37/42] [CI] Update nightly test models for H200/B200 (#22288) Co-authored-by: Claude Opus 4.6 (1M context) --- .../8-gpu-models/test_deepseek_v31.py | 5 +- test/registered/8-gpu-models/test_glm_46.py | 52 ------------------- .../8-gpu-models/test_glm_46_fp8.py | 5 +- test/registered/8-gpu-models/test_qwen35.py | 13 +++-- .../8-gpu-models/test_qwen3_235b.py | 5 +- 5 files changed, 13 insertions(+), 67 deletions(-) delete mode 100644 test/registered/8-gpu-models/test_glm_46.py diff --git a/test/registered/8-gpu-models/test_deepseek_v31.py b/test/registered/8-gpu-models/test_deepseek_v31.py index b6cc7e5c471e..bacca09f5136 100644 --- a/test/registered/8-gpu-models/test_deepseek_v31.py +++ b/test/registered/8-gpu-models/test_deepseek_v31.py @@ -1,14 +1,11 @@ import unittest from sglang.test.accuracy_test_runner import AccuracyTestParams -from sglang.test.ci.ci_register import register_cuda_ci from sglang.test.performance_test_runner import PerformanceTestParams from sglang.test.run_combined_tests import run_combined_tests from sglang.test.test_utils import ModelLaunchSettings -# Runs on both H200 and B200 via nightly-8-gpu-common suite -register_cuda_ci(est_time=5400, suite="nightly-8-gpu-common", nightly=True) - +# Manual-only: not registered in any CI suite DEEPSEEK_V31_MODEL_PATH = "deepseek-ai/DeepSeek-V3.1" diff --git a/test/registered/8-gpu-models/test_glm_46.py b/test/registered/8-gpu-models/test_glm_46.py deleted file mode 100644 index dc22744b46a5..000000000000 --- a/test/registered/8-gpu-models/test_glm_46.py +++ /dev/null @@ -1,52 +0,0 @@ -import unittest - -from sglang.test.accuracy_test_runner import AccuracyTestParams -from sglang.test.ci.ci_register import register_cuda_ci -from sglang.test.performance_test_runner import PerformanceTestParams -from sglang.test.run_combined_tests import run_combined_tests -from sglang.test.test_utils import ModelLaunchSettings - -# Runs on both H200 and B200 via nightly-8-gpu-common suite -register_cuda_ci(est_time=1800, suite="nightly-8-gpu-common", nightly=True) - -GLM_4_6_MODEL_PATH = "zai-org/GLM-4.6" - - -class TestGLM46(unittest.TestCase): - """Unified test class for GLM-4.6 performance and accuracy. - - Single variant with simple TP=8 configuration. - GLM-4.6 is a 357B MoE model. - Runs BOTH: - - Performance test (using NightlyBenchmarkRunner) - - Accuracy test (using run_eval with mgsm_en) - """ - - def test_glm_46(self): - """Run performance and accuracy for GLM-4.6.""" - base_args = [ - "--tp=8", - "--trust-remote-code", - ] - - variants = [ - ModelLaunchSettings( - GLM_4_6_MODEL_PATH, - tp_size=8, - extra_args=base_args, - variant="TP8", - ), - ] - - run_combined_tests( - models=variants, - test_name="GLM-4.6", - accuracy_params=AccuracyTestParams(dataset="gsm8k", baseline_accuracy=0.80), - performance_params=PerformanceTestParams( - profile_dir="performance_profiles_glm_4_6", - ), - ) - - -if __name__ == "__main__": - unittest.main() diff --git a/test/registered/8-gpu-models/test_glm_46_fp8.py b/test/registered/8-gpu-models/test_glm_46_fp8.py index 435763e01447..fd564fbb8a70 100644 --- a/test/registered/8-gpu-models/test_glm_46_fp8.py +++ b/test/registered/8-gpu-models/test_glm_46_fp8.py @@ -1,14 +1,11 @@ import unittest from sglang.test.accuracy_test_runner import AccuracyTestParams -from sglang.test.ci.ci_register import register_cuda_ci from sglang.test.performance_test_runner import PerformanceTestParams from sglang.test.run_combined_tests import run_combined_tests from sglang.test.test_utils import ModelLaunchSettings -# Runs on both H200 and B200 via nightly-8-gpu-common suite -register_cuda_ci(est_time=1800, suite="nightly-8-gpu-common", nightly=True) - +# Manual-only: not registered in any CI suite GLM_4_6_FP8_MODEL_PATH = "zai-org/GLM-4.6-FP8" diff --git a/test/registered/8-gpu-models/test_qwen35.py b/test/registered/8-gpu-models/test_qwen35.py index 7386b4a8cf46..813552b83421 100644 --- a/test/registered/8-gpu-models/test_qwen35.py +++ b/test/registered/8-gpu-models/test_qwen35.py @@ -9,7 +9,7 @@ # Runs on both H200 and B200 via nightly-8-gpu-common suite register_cuda_ci(est_time=1800, suite="nightly-8-gpu-common", nightly=True) -QWEN35_MODEL_PATH = "Qwen/Qwen3.5-397B-A17B" +QWEN35_MODEL_PATH = "Qwen/Qwen3.5-397B-A17B-FP8" class TestQwen35(unittest.TestCase): @@ -30,6 +30,7 @@ def test_qwen35(self): "--tool-call-parser=qwen3_coder", "--mem-fraction-static=0.8", ] + dp_args = ["--dp=8", "--enable-dp-attention"] mtp_args = [ "--speculative-algorithm=EAGLE", "--speculative-num-steps=3", @@ -48,8 +49,14 @@ def test_qwen35(self): ModelLaunchSettings( QWEN35_MODEL_PATH, tp_size=8, - extra_args=base_args + mtp_args, - variant="TP8+MTP", + extra_args=base_args + dp_args, + variant="TP8+DP8", + ), + ModelLaunchSettings( + QWEN35_MODEL_PATH, + tp_size=8, + extra_args=base_args + dp_args + mtp_args, + variant="TP8+DP8+MTP", env={"SGLANG_ENABLE_SPEC_V2": "1"}, ), ] diff --git a/test/registered/8-gpu-models/test_qwen3_235b.py b/test/registered/8-gpu-models/test_qwen3_235b.py index 70420bbed64a..b3f521bcd752 100644 --- a/test/registered/8-gpu-models/test_qwen3_235b.py +++ b/test/registered/8-gpu-models/test_qwen3_235b.py @@ -1,14 +1,11 @@ import unittest from sglang.test.accuracy_test_runner import AccuracyTestParams -from sglang.test.ci.ci_register import register_cuda_ci from sglang.test.performance_test_runner import PerformanceTestParams from sglang.test.run_combined_tests import run_combined_tests from sglang.test.test_utils import ModelLaunchSettings, is_blackwell_system -# Runs on both H200 and B200 via nightly-8-gpu-common suite -register_cuda_ci(est_time=1800, suite="nightly-8-gpu-common", nightly=True) - +# Manual-only: not registered in any CI suite QWEN3_235B_FP8_MODEL_PATH = "Qwen/Qwen3-235B-A22B-Instruct-2507-FP8" QWEN3_235B_EAGLE3_MODEL_PATH = ( "lmsys/SGLang-EAGLE3-Qwen3-235B-A22B-Instruct-2507-SpecForge-Meituan" From 0e2a0260a150aff50b05bbac3c88fa935e264235 Mon Sep 17 00:00:00 2001 From: Liangsheng Yin Date: Tue, 7 Apr 2026 15:56:12 -0700 Subject: [PATCH 38/42] Add fast-fail to multimodal-gen CI (#22284) --- .github/workflows/pr-test-multimodal-gen.yml | 1 + python/sglang/multimodal_gen/test/run_suite.py | 7 +++++-- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/.github/workflows/pr-test-multimodal-gen.yml b/.github/workflows/pr-test-multimodal-gen.yml index a91b6c2e927a..4b560ed5da39 100644 --- a/.github/workflows/pr-test-multimodal-gen.yml +++ b/.github/workflows/pr-test-multimodal-gen.yml @@ -175,6 +175,7 @@ jobs: with: ref: ${{ inputs.pr_head_sha || inputs.git_ref || github.sha }} + - uses: ./.github/actions/check-stage-health - uses: ./.github/actions/check-maintenance diff --git a/python/sglang/multimodal_gen/test/run_suite.py b/python/sglang/multimodal_gen/test/run_suite.py index 700d4d6b8b18..a6fef42e8a69 100644 --- a/python/sglang/multimodal_gen/test/run_suite.py +++ b/python/sglang/multimodal_gen/test/run_suite.py @@ -174,12 +174,14 @@ def collect_test_items(files, filter_expr=None): return test_items -def run_pytest(files, filter_expr=None): +def run_pytest(files, filter_expr=None, exitfirst=False): if not files: print("No files to run.") return 0 base_cmd = [sys.executable, "-m", "pytest", "-s", "-v"] + if exitfirst: + base_cmd.append("-x") # Add pytest -k filter if provided if filter_expr: @@ -349,7 +351,8 @@ def main(): print(f"Running {len(my_items)} items in this shard: {', '.join(my_items)}") # 4. execute with the specific test items - exit_code = run_pytest(my_items) + # Fast-fail: stop on first failure unless --continue-on-error is set + exit_code = run_pytest(my_items, exitfirst=not args.continue_on_error) # Print tests again at the end for visibility msg = "\n" + tabulate.tabulate(rows, headers=headers, tablefmt="psql") + "\n" From 7546d04c81e3ad7370e0e621e3c6376610f6cfb4 Mon Sep 17 00:00:00 2001 From: Trevor Morris Date: Tue, 7 Apr 2026 16:16:29 -0700 Subject: [PATCH 39/42] [NVIDIA] Enable FP4 flashinfer trtllm routed moe (#21240) --- .../moe/moe_runner/flashinfer_trtllm.py | 165 ++++++++++++------ .../srt/layers/quantization/modelopt_quant.py | 5 + 2 files changed, 114 insertions(+), 56 deletions(-) diff --git a/python/sglang/srt/layers/moe/moe_runner/flashinfer_trtllm.py b/python/sglang/srt/layers/moe/moe_runner/flashinfer_trtllm.py index e568579bb7e8..f4add35b391d 100644 --- a/python/sglang/srt/layers/moe/moe_runner/flashinfer_trtllm.py +++ b/python/sglang/srt/layers/moe/moe_runner/flashinfer_trtllm.py @@ -616,13 +616,17 @@ def fused_experts_none_to_flashinfer_trtllm_fp4( dispatch_output: StandardDispatchOutput, quant_info: FlashInferTrtllmFp4MoeQuantInfo, runner_config: MoeRunnerConfig, + use_routed_topk: bool = False, ) -> StandardCombineInput: """FlashInfer TRTLLM FP4 MoE forward pass. This function handles the FP4 TRTLLM MoE path that was previously in ModelOptNvFp4FusedMoEMethod.apply. """ - from flashinfer.fused_moe import trtllm_fp4_block_scale_moe + from flashinfer.fused_moe import ( + trtllm_fp4_block_scale_moe, + trtllm_fp4_block_scale_routed_moe, + ) from sglang.srt.layers.moe.token_dispatcher.standard import StandardCombineInput from sglang.srt.layers.moe.topk import TopKOutputChecker @@ -633,25 +637,13 @@ def fused_experts_none_to_flashinfer_trtllm_fp4( hidden_states = dispatch_output.hidden_states topk_output = dispatch_output.topk_output - assert TopKOutputChecker.format_is_bypassed(topk_output) - - router_logits = topk_output.router_logits - topk_config = topk_output.topk_config - routing_method_type = quant_info.routing_method_type # Quantize hidden states to FP4 hs_fp4, hs_scale_linear = quantize_hidden_states_fp4( hidden_states, quant_info.w13_input_scale_quant ) - - # DeepSeekV3 style routing requires float32 router logits - if routing_method_type == RoutingMethodType.DeepSeekV3: - router_logits = router_logits.to(torch.float32) - - correction_bias = ( - None - if topk_config.correction_bias is None - else topk_config.correction_bias.to(hidden_states.dtype) + hs_scale = hs_scale_linear.view(torch.float8_e4m3fn).reshape( + *hs_scale_linear.shape[:-1], -1 ) with use_symmetric_memory(get_tp_group(), disabled=not is_allocation_symmetric()): @@ -660,49 +652,103 @@ def fused_experts_none_to_flashinfer_trtllm_fp4( hs_fp4.shape[-1] * 2 if hs_fp4.dtype == torch.uint8 else hs_fp4.shape[-1] ) symm_output = torch.empty( - num_tokens, hidden_size, dtype=torch.bfloat16, device=hs_fp4.device + num_tokens, hidden_size, dtype=hidden_states.dtype, device=hs_fp4.device ) - result = trtllm_fp4_block_scale_moe( - routing_logits=router_logits, - routing_bias=correction_bias, - hidden_states=hs_fp4, - hidden_states_scale=hs_scale_linear.view(torch.float8_e4m3fn).reshape( - *hs_scale_linear.shape[:-1], -1 - ), - gemm1_weights=quant_info.gemm1_weights_fp4_shuffled, - gemm1_weights_scale=quant_info.gemm1_scales_fp4_shuffled.view( - torch.float8_e4m3fn - ), - gemm1_bias=None, - gemm1_alpha=None, - gemm1_beta=None, - gemm1_clamp_limit=None, - gemm2_weights=quant_info.gemm2_weights_fp4_shuffled, - gemm2_weights_scale=quant_info.gemm2_scales_fp4_shuffled.view( - torch.float8_e4m3fn - ), - gemm2_bias=None, - output1_scale_scalar=quant_info.g1_scale_c, - output1_scale_gate_scalar=quant_info.g1_alphas, - output2_scale_scalar=quant_info.g2_alphas, - num_experts=quant_info.global_num_experts, - top_k=topk_config.top_k, - n_group=topk_config.num_expert_group, - topk_group=topk_config.topk_group, - intermediate_size=quant_info.intermediate_size_per_partition, - local_expert_offset=quant_info.local_expert_offset, - local_num_experts=quant_info.local_num_experts, - routed_scaling_factor=runner_config.routed_scaling_factor, - routing_method_type=( - routing_method_type - if routing_method_type is not None - else RoutingMethodType.Default - ), - do_finalize=True, - tune_max_num_tokens=next_power_of_2(hs_fp4.shape[0]), - output=symm_output, - )[0] + if use_routed_topk: + assert TopKOutputChecker.format_is_standard(topk_output) + + packed_topk_ids = _pack_topk_for_flashinfer_routed( + topk_output.topk_ids, topk_output.topk_weights + ) + result = trtllm_fp4_block_scale_routed_moe( + topk_ids=packed_topk_ids, + routing_bias=None, + hidden_states=hs_fp4, + hidden_states_scale=hs_scale, + gemm1_weights=quant_info.gemm1_weights_fp4_shuffled, + gemm1_weights_scale=quant_info.gemm1_scales_fp4_shuffled.view( + torch.float8_e4m3fn + ), + gemm1_bias=None, + gemm1_alpha=None, + gemm1_beta=None, + gemm1_clamp_limit=None, + gemm2_weights=quant_info.gemm2_weights_fp4_shuffled, + gemm2_weights_scale=quant_info.gemm2_scales_fp4_shuffled.view( + torch.float8_e4m3fn + ), + gemm2_bias=None, + output1_scale_scalar=quant_info.g1_scale_c, + output1_scale_gate_scalar=quant_info.g1_alphas, + output2_scale_scalar=quant_info.g2_alphas, + num_experts=quant_info.global_num_experts, + top_k=topk_output.topk_ids.shape[1], + n_group=0, + topk_group=0, + intermediate_size=quant_info.intermediate_size_per_partition, + local_expert_offset=quant_info.local_expert_offset, + local_num_experts=quant_info.local_num_experts, + routed_scaling_factor=None, + routing_method_type=1, # Unused, but must be 1 to pass validation. + do_finalize=True, + tune_max_num_tokens=next_power_of_2(hs_fp4.shape[0]), + output=symm_output, + )[0] + else: + assert TopKOutputChecker.format_is_bypassed(topk_output) + + router_logits = topk_output.router_logits + topk_config = topk_output.topk_config + routing_method_type = quant_info.routing_method_type + + # DeepSeekV3 style routing requires float32 router logits + if routing_method_type == RoutingMethodType.DeepSeekV3: + router_logits = router_logits.to(torch.float32) + + correction_bias = ( + None + if topk_config.correction_bias is None + else topk_config.correction_bias.to(hidden_states.dtype) + ) + result = trtllm_fp4_block_scale_moe( + routing_logits=router_logits, + routing_bias=correction_bias, + hidden_states=hs_fp4, + hidden_states_scale=hs_scale, + gemm1_weights=quant_info.gemm1_weights_fp4_shuffled, + gemm1_weights_scale=quant_info.gemm1_scales_fp4_shuffled.view( + torch.float8_e4m3fn + ), + gemm1_bias=None, + gemm1_alpha=None, + gemm1_beta=None, + gemm1_clamp_limit=None, + gemm2_weights=quant_info.gemm2_weights_fp4_shuffled, + gemm2_weights_scale=quant_info.gemm2_scales_fp4_shuffled.view( + torch.float8_e4m3fn + ), + gemm2_bias=None, + output1_scale_scalar=quant_info.g1_scale_c, + output1_scale_gate_scalar=quant_info.g1_alphas, + output2_scale_scalar=quant_info.g2_alphas, + num_experts=quant_info.global_num_experts, + top_k=topk_config.top_k, + n_group=topk_config.num_expert_group, + topk_group=topk_config.topk_group, + intermediate_size=quant_info.intermediate_size_per_partition, + local_expert_offset=quant_info.local_expert_offset, + local_num_experts=quant_info.local_num_experts, + routed_scaling_factor=runner_config.routed_scaling_factor, + routing_method_type=( + routing_method_type + if routing_method_type is not None + else RoutingMethodType.Default + ), + do_finalize=True, + tune_max_num_tokens=next_power_of_2(hs_fp4.shape[0]), + output=symm_output, + )[0] return StandardCombineInput(hidden_states=result) @@ -858,6 +904,13 @@ def fused_experts_none_to_flashinfer_trtllm_routed( quant_info: MoeQuantInfo, runner_config: MoeRunnerConfig, ) -> StandardCombineInput: + if isinstance(quant_info, FlashInferTrtllmFp4MoeQuantInfo): + return fused_experts_none_to_flashinfer_trtllm_fp4( + dispatch_output, + quant_info, + runner_config, + use_routed_topk=True, + ) if isinstance(quant_info, FlashInferTrtllmFp8MoeQuantInfo): return fused_experts_none_to_flashinfer_trtllm_fp8( dispatch_output, diff --git a/python/sglang/srt/layers/quantization/modelopt_quant.py b/python/sglang/srt/layers/quantization/modelopt_quant.py index c0d9958e45ee..d3b278c16375 100755 --- a/python/sglang/srt/layers/quantization/modelopt_quant.py +++ b/python/sglang/srt/layers/quantization/modelopt_quant.py @@ -1534,6 +1534,7 @@ def __init__(self, quant_config: ModelOptFp4Config): ) self.enable_flashinfer_trtllm_moe = ( get_moe_runner_backend().is_flashinfer_trtllm() + or get_moe_runner_backend().is_flashinfer_trtllm_routed() ) self._cache_permute_indices = {} @@ -1904,6 +1905,10 @@ def create_moe_runner( self.runner = MoeRunner( MoeRunnerBackend.FLASHINFER_TRTLLM, moe_runner_config ) + elif get_moe_runner_backend().is_flashinfer_trtllm_routed(): + self.runner = MoeRunner( + MoeRunnerBackend.FLASHINFER_TRTLLM_ROUTED, moe_runner_config + ) def apply( self, From f6fc39569a8f4b35b9bc9b5fb60bbd1210790f22 Mon Sep 17 00:00:00 2001 From: Douglas Yang Date: Tue, 7 Apr 2026 16:29:20 -0700 Subject: [PATCH 40/42] [CI] Migrate mgsm_en eval to gsm8k to remove openaipublic dependency (#21931) Co-authored-by: Claude Sonnet 4.6 Co-authored-by: Kangyan-Zhou --- .../amd/accuracy/mi30x/test_gsm8k_eval_amd.py | 53 ++++++++++--------- .../distributed/test_dp_attention_large.py | 4 +- .../distributed/test_pp_single_node.py | 4 +- .../eval/test_text_models_gsm8k_eval.py | 43 +++++++-------- ...test_piecewise_cuda_graph_support_1_gpu.py | 32 +++++------ test/registered/quant/test_quantization.py | 13 +++-- .../scheduler/test_prefill_delayer.py | 10 ++-- 7 files changed, 82 insertions(+), 77 deletions(-) diff --git a/test/registered/amd/accuracy/mi30x/test_gsm8k_eval_amd.py b/test/registered/amd/accuracy/mi30x/test_gsm8k_eval_amd.py index aa7813ee543a..9a37ed6d5315 100644 --- a/test/registered/amd/accuracy/mi30x/test_gsm8k_eval_amd.py +++ b/test/registered/amd/accuracy/mi30x/test_gsm8k_eval_amd.py @@ -1,7 +1,7 @@ """ AMD GSM8K Evaluation Test (Migrated from test/srt/nightly/) -This test evaluates instruction-tuned models on the mgsm_en benchmark using chat completions. +This test evaluates instruction-tuned models on the gsm8k benchmark using chat completions. Models are tested with various TP configurations on AMD GPUs. Registry: nightly-amd suite (2-GPU tests) @@ -35,34 +35,35 @@ register_amd_ci(est_time=3600, suite="nightly-amd", nightly=True) MODEL_SCORE_THRESHOLDS = { + # Thresholds set at 5% below reported GSM8K (5-shot/CoT) scores # Llama 3.1 series - "meta-llama/Llama-3.1-8B-Instruct": 0.82, - "meta-llama/Llama-3.1-70B-Instruct": 0.95, + "meta-llama/Llama-3.1-8B-Instruct": 0.80, # 84.5% - 5% + "meta-llama/Llama-3.1-70B-Instruct": 0.89, # 94.1% - 5% # Llama 3.2 series (smaller models) - "meta-llama/Llama-3.2-3B-Instruct": 0.55, + "meta-llama/Llama-3.2-3B-Instruct": 0.43, # 48.2% - 5% # Mistral series - "mistralai/Mistral-7B-Instruct-v0.3": 0.55, - "mistralai/Mixtral-8x7B-Instruct-v0.1": 0.58, + "mistralai/Mistral-7B-Instruct-v0.3": 0.47, # 52.1% - 5% + "mistralai/Mixtral-8x7B-Instruct-v0.1": 0.69, # 74.4% - 5% (lower if AMD scores differently) # DeepSeek series - "deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct": 0.85, + "deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct": 0.81, # 86.4% - 5% # Qwen2 series - "Qwen/Qwen2-57B-A14B-Instruct": 0.86, - "Qwen/Qwen2.5-7B-Instruct": 0.85, + "Qwen/Qwen2-57B-A14B-Instruct": 0.76, # 80.7% - 5% (official A14B score; 88.2% was the 72B) + "Qwen/Qwen2.5-7B-Instruct": 0.82, # 86.3% - 5% # Qwen3 series - "Qwen/Qwen3-30B-A3B-Thinking-2507": 0.84, # MoE model verified on MI300X - "Qwen/Qwen3-8B": 0.77, + "Qwen/Qwen3-30B-A3B-Thinking-2507": 0.86, # 91.4% - 5% (full attention mode; ensure sufficient max_tokens) + "Qwen/Qwen3-8B": 0.76, # ~81% - 5% # Google Gemma - "google/gemma-2-27b-it": 0.91, - "google/gemma-2-9b-it": 0.72, + "google/gemma-2-27b-it": 0.86, # 90.7% - 5% + "google/gemma-2-9b-it": 0.74, # 78.5% - 5% # "neuralmagic/gemma-2-2b-it-FP8": 0.4, # Small 2B model - OOM on single GPU # FP8 quantized models - "neuralmagic/Meta-Llama-3.1-8B-Instruct-FP8": 0.8, - "neuralmagic/Mistral-7B-Instruct-v0.3-FP8": 0.54, - "neuralmagic/Meta-Llama-3.1-70B-Instruct-FP8": 0.94, - "neuralmagic/Qwen2-72B-Instruct-FP8": 0.92, - "neuralmagic/Qwen2-57B-A14B-Instruct-FP8": 0.81, - "neuralmagic/Mixtral-8x7B-Instruct-v0.1-FP8": 0.57, - "neuralmagic/DeepSeek-Coder-V2-Lite-Instruct-FP8": 0.84, + "neuralmagic/Meta-Llama-3.1-8B-Instruct-FP8": 0.80, # 84.5% - 5% + "neuralmagic/Mistral-7B-Instruct-v0.3-FP8": 0.46, # ~51% - 5% + "neuralmagic/Meta-Llama-3.1-70B-Instruct-FP8": 0.89, # 94.1% - 5% + "neuralmagic/Qwen2-72B-Instruct-FP8": 0.86, # 91.1% - 5% + "neuralmagic/Qwen2-57B-A14B-Instruct-FP8": 0.76, # 80.7% - 5% (official A14B score) + "neuralmagic/Mixtral-8x7B-Instruct-v0.1-FP8": 0.69, # 74.4% - 5% + "neuralmagic/DeepSeek-Coder-V2-Lite-Instruct-FP8": 0.81, # 86.4% - 5% } failing_models = { @@ -185,7 +186,7 @@ def check_model_scores(results): summary += line print(f"\n{'='*60}") - print("SUMMARY - TP=2 Instruction Models (mgsm_en)") + print("SUMMARY - TP=2 Instruction Models (gsm8k)") print(f"{'='*60}") print(summary) print(f"\nšŸ“Š Final Statistics:") @@ -200,7 +201,7 @@ def check_model_scores(results): raise AssertionError(f"The following models failed:\n{failure_msg}") -# Do not use `CustomTestCase` since `test_mgsm_en_all_models` does not want retry +# Do not use `CustomTestCase` since `test_gsm8k_all_models` does not want retry class TestNightlyGsm8KEval(unittest.TestCase): @classmethod def setUpClass(cls): @@ -215,7 +216,7 @@ def setUpClass(cls): ] cls.base_url = DEFAULT_URL_FOR_TEST - def test_mgsm_en_all_models(self): + def test_gsm8k_all_models(self): warnings.filterwarnings( "ignore", category=ResourceWarning, message="unclosed.*socket" ) @@ -226,7 +227,7 @@ def test_mgsm_en_all_models(self): print(f"\n{'='*60}") print("AMD GSM8K Evaluation Test (TP=2 Instruction Models)") print(f"{'='*60}") - print(f"Benchmark: mgsm_en (chat completions)") + print(f"Benchmark: gsm8k (chat completions)") print(f"{'='*60}\n") for model_group, is_fp8, is_tp2 in self.model_groups: @@ -261,13 +262,13 @@ def test_mgsm_en_all_models(self): args = SimpleNamespace( base_url=self.base_url, model=model, - eval_name="mgsm_en", + eval_name="gsm8k", num_examples=None, num_threads=1024, ) # Run eval with timing and retries - print(f"šŸ“Š Running mgsm_en evaluation...") + print(f"šŸ“Š Running gsm8k evaluation...") eval_start = time.time() threshold = MODEL_SCORE_THRESHOLDS.get(model) metrics = None diff --git a/test/registered/distributed/test_dp_attention_large.py b/test/registered/distributed/test_dp_attention_large.py index 48cdee862f8a..3e1d65f747e3 100644 --- a/test/registered/distributed/test_dp_attention_large.py +++ b/test/registered/distributed/test_dp_attention_large.py @@ -56,11 +56,11 @@ def setUpClass(cls): def tearDownClass(cls): kill_process_tree(cls.process.pid) - def test_mgsm_en(self): + def test_gsm8k(self): args = SimpleNamespace( base_url=self.base_url, model=self.model, - eval_name="mgsm_en", + eval_name="gsm8k", num_examples=None, num_threads=1024, ) diff --git a/test/registered/distributed/test_pp_single_node.py b/test/registered/distributed/test_pp_single_node.py index 76e1c068d7f1..0dd5d4fe8277 100644 --- a/test/registered/distributed/test_pp_single_node.py +++ b/test/registered/distributed/test_pp_single_node.py @@ -128,11 +128,11 @@ def setUpClass(cls): def tearDownClass(cls): kill_process_tree(cls.process.pid) - def test_mgsm_en(self): + def test_gsm8k(self): args = SimpleNamespace( base_url=self.base_url, model=self.model, - eval_name="mgsm_en", + eval_name="gsm8k", num_examples=None, num_threads=1024, ) diff --git a/test/registered/eval/test_text_models_gsm8k_eval.py b/test/registered/eval/test_text_models_gsm8k_eval.py index 9436895422b7..c2974439c797 100644 --- a/test/registered/eval/test_text_models_gsm8k_eval.py +++ b/test/registered/eval/test_text_models_gsm8k_eval.py @@ -26,28 +26,29 @@ register_cuda_ci(est_time=3600, suite="nightly-eval-text-2-gpu", nightly=True) MODEL_SCORE_THRESHOLDS = { - "meta-llama/Llama-3.1-8B-Instruct": 0.82, - "mistralai/Mistral-7B-Instruct-v0.3": 0.58, - "deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct": 0.85, - "google/gemma-2-27b-it": 0.91, - "meta-llama/Llama-3.1-70B-Instruct": 0.95, - "mistralai/Mixtral-8x7B-Instruct-v0.1": 0.616, - "Qwen/Qwen2-57B-A14B-Instruct": 0.86, - "neuralmagic/Meta-Llama-3.1-8B-Instruct-FP8": 0.83, - "neuralmagic/Mistral-7B-Instruct-v0.3-FP8": 0.54, - "neuralmagic/DeepSeek-Coder-V2-Lite-Instruct-FP8": 0.835, - "zai-org/GLM-4.5-Air-FP8": 0.75, - # The threshold of neuralmagic/gemma-2-2b-it-FP8 should be 0.6, but this model has some accuracy regression. - # The fix is tracked at https://github.com/sgl-project/sglang/issues/4324, we set it to 0.50, for now, to make CI green. - "neuralmagic/gemma-2-2b-it-FP8": 0.50, - "neuralmagic/Meta-Llama-3.1-70B-Instruct-FP8": 0.94, - "neuralmagic/Mixtral-8x7B-Instruct-v0.1-FP8": 0.65, - "neuralmagic/Qwen2-72B-Instruct-FP8": 0.94, - "neuralmagic/Qwen2-57B-A14B-Instruct-FP8": 0.82, + # Thresholds set at 5% below reported GSM8K (5-shot/CoT) scores + "meta-llama/Llama-3.1-8B-Instruct": 0.80, # 84.5% - 5% + "mistralai/Mistral-7B-Instruct-v0.3": 0.47, # 52.1% - 5% + "deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct": 0.81, # 86.4% - 5% + "google/gemma-2-27b-it": 0.86, # 90.7% - 5% + "meta-llama/Llama-3.1-70B-Instruct": 0.89, # 94.1% - 5% + "mistralai/Mixtral-8x7B-Instruct-v0.1": 0.69, # 74.4% - 5% + "Qwen/Qwen2-57B-A14B-Instruct": 0.76, # 80.7% - 5% (official A14B score; 88.2% was the 72B) + "neuralmagic/Meta-Llama-3.1-8B-Instruct-FP8": 0.80, # 84.5% - 5% + "neuralmagic/Mistral-7B-Instruct-v0.3-FP8": 0.47, # 52.1% - 5% + "neuralmagic/DeepSeek-Coder-V2-Lite-Instruct-FP8": 0.81, # 86.4% - 5% + "zai-org/GLM-4.5-Air-FP8": 0.80, # ~85% - 5% + # GSM8K baseline for gemma-2-2b is ~40-45%; threshold set at 5% below. + # (Previously 0.50 based on MGSM-EN; tracked regression: https://github.com/sgl-project/sglang/issues/4324) + "neuralmagic/gemma-2-2b-it-FP8": 0.38, # ~43% - 5% + "neuralmagic/Meta-Llama-3.1-70B-Instruct-FP8": 0.89, # 94.1% - 5% + "neuralmagic/Mixtral-8x7B-Instruct-v0.1-FP8": 0.69, # 74.4% - 5% + "neuralmagic/Qwen2-72B-Instruct-FP8": 0.86, # 91.1% - 5% + "neuralmagic/Qwen2-57B-A14B-Instruct-FP8": 0.76, # 80.7% - 5% (official A14B score) } -# Do not use `CustomTestCase` since `test_mgsm_en_all_models` does not want retry +# Do not use `CustomTestCase` since `test_gsm8k_all_models` does not want retry class TestNightlyGsm8KEval(unittest.TestCase): @classmethod def setUpClass(cls): @@ -66,7 +67,7 @@ def setUpClass(cls): cls.base_url = DEFAULT_URL_FOR_TEST - def test_mgsm_en_all_models(self): + def test_gsm8k_all_models(self): warnings.filterwarnings( "ignore", category=ResourceWarning, message="unclosed.*socket" ) @@ -91,7 +92,7 @@ def test_mgsm_en_all_models(self): args = SimpleNamespace( base_url=self.base_url, model=model_setup.model_path, - eval_name="mgsm_en", + eval_name="gsm8k", num_examples=None, num_threads=1024, ) diff --git a/test/registered/piecewise_cuda_graph/test_piecewise_cuda_graph_support_1_gpu.py b/test/registered/piecewise_cuda_graph/test_piecewise_cuda_graph_support_1_gpu.py index e38b59f5b86b..ce6fe2291828 100644 --- a/test/registered/piecewise_cuda_graph/test_piecewise_cuda_graph_support_1_gpu.py +++ b/test/registered/piecewise_cuda_graph/test_piecewise_cuda_graph_support_1_gpu.py @@ -41,21 +41,19 @@ def setUpClass(cls): def tearDownClass(cls): kill_process_tree(cls.process.pid) - def test_mgsm_accuracy(self): - num_examples = 2000 - + def test_gsm8k_accuracy(self): args = SimpleNamespace( base_url=self.base_url, model=self.model, - eval_name="mgsm_en", - num_examples=num_examples, - num_threads=min(num_examples, 1024), + eval_name="gsm8k", + num_examples=None, + num_threads=1024, ) metrics = run_eval(args) - print(f"MGSM Accuracy: {metrics['score']:.3f}") + print(f"GSM8K Accuracy: {metrics['score']:.3f}") - self.assertGreaterEqual(metrics["score"], 0.70) + self.assertGreaterEqual(metrics["score"], 0.82) class TestPiecewiseCudaGraphInternVL25(CustomTestCase): @@ -79,21 +77,23 @@ def setUpClass(cls): def tearDownClass(cls): kill_process_tree(cls.process.pid) - def test_mgsm_accuracy(self): - num_examples = 2000 - + def test_gsm8k_accuracy(self): args = SimpleNamespace( base_url=self.base_url, model=self.model, - eval_name="mgsm_en", - num_examples=num_examples, - num_threads=min(num_examples, 1024), + eval_name="gsm8k", + num_examples=None, + num_threads=1024, ) metrics = run_eval(args) - print(f"MGSM Accuracy: {metrics['score']:.3f}") + print(f"GSM8K Accuracy: {metrics['score']:.3f}") - self.assertGreaterEqual(metrics["score"], 0.70) + # Baseline (no piecewise CUDA graph): 0.571 — this eval uses 5-shot + # concatenated text via chat API, which scores lower than reported + # benchmarks (~77.8%) that use proper CoT chat format. The threshold + # is set 5% below observed to catch catastrophic regressions. + self.assertGreaterEqual(metrics["score"], 0.54) class TestPiecewiseCudaGraphQwen25VLEmbedding(CustomTestCase): diff --git a/test/registered/quant/test_quantization.py b/test/registered/quant/test_quantization.py index cdb1f0970ef4..cdf0b0e619bb 100644 --- a/test/registered/quant/test_quantization.py +++ b/test/registered/quant/test_quantization.py @@ -19,9 +19,12 @@ register_cuda_ci(est_time=370, suite="stage-b-test-1-gpu-large") MODEL_SCORE_THRESHOLDS = { - "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4": 0.825, - "hugging-quants/Meta-Llama-3.1-8B-Instruct-GPTQ-INT4": 0.825, - "hugging-quants/Mixtral-8x7B-Instruct-v0.1-AWQ-INT4": 0.615, + # Baselines observed with gsm8k 5-shot concatenated format via chat API, + # which scores lower than reported benchmarks using proper CoT format. + # Thresholds set 5% below observed to catch catastrophic regressions. + "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4": 0.74, # observed: 0.781 + "hugging-quants/Meta-Llama-3.1-8B-Instruct-GPTQ-INT4": 0.74, # observed: 0.785 + "hugging-quants/Mixtral-8x7B-Instruct-v0.1-AWQ-INT4": 0.36, # observed: 0.380 } @@ -93,7 +96,7 @@ def setUpClass(cls): ] cls.base_url = DEFAULT_URL_FOR_TEST - def test_mgsm_en_all_models(self): + def test_gsm8k_all_models(self): warnings.filterwarnings( "ignore", category=ResourceWarning, message="unclosed.*socket" ) @@ -110,7 +113,7 @@ def test_mgsm_en_all_models(self): args = SimpleNamespace( base_url=self.base_url, model=model, - eval_name="mgsm_en", + eval_name="gsm8k", num_examples=None, num_threads=1024, ) diff --git a/test/registered/scheduler/test_prefill_delayer.py b/test/registered/scheduler/test_prefill_delayer.py index 66ea497bf3c0..493346fda930 100644 --- a/test/registered/scheduler/test_prefill_delayer.py +++ b/test/registered/scheduler/test_prefill_delayer.py @@ -428,10 +428,10 @@ async def send_normal_request(dp_rank, req_idx): class TestPrefillDelayerAccuracy(CustomTestCase): - def test_1_mgsm_en_has_prefill_delayer(self): + def test_1_gsm8k_has_prefill_delayer(self): self._run_accuracy_test(prefill_delayer=True) - def test_2_mgsm_en_no_prefill_delayer(self): + def test_2_gsm8k_no_prefill_delayer(self): self._run_accuracy_test(prefill_delayer=False) def _run_accuracy_test(self, prefill_delayer: bool): @@ -454,14 +454,14 @@ def _run_accuracy_test(self, prefill_delayer: bool): args = SimpleNamespace( base_url=base_url, model=model, - eval_name="mgsm_en", + eval_name="gsm8k", num_examples=None, num_threads=1024, ) metrics = run_eval(args) - print(f"=== mgsm_en ({prefill_delayer=}) ===") + print(f"=== gsm8k ({prefill_delayer=}) ===") print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.87) + self.assertGreater(metrics["score"], 0.57) finally: kill_process_tree(process.pid) From dd73e9a62ea65c7272dae4cf8a6b8e994814a80d Mon Sep 17 00:00:00 2001 From: Kangyan-Zhou Date: Tue, 7 Apr 2026 17:04:06 -0700 Subject: [PATCH 41/42] Revert "[CI] Update nightly test models for H200/B200 (#22288)" (#22297) --- .../8-gpu-models/test_deepseek_v31.py | 5 +- test/registered/8-gpu-models/test_glm_46.py | 52 +++++++++++++++++++ .../8-gpu-models/test_glm_46_fp8.py | 5 +- test/registered/8-gpu-models/test_qwen35.py | 13 ++--- .../8-gpu-models/test_qwen3_235b.py | 5 +- 5 files changed, 67 insertions(+), 13 deletions(-) create mode 100644 test/registered/8-gpu-models/test_glm_46.py diff --git a/test/registered/8-gpu-models/test_deepseek_v31.py b/test/registered/8-gpu-models/test_deepseek_v31.py index bacca09f5136..b6cc7e5c471e 100644 --- a/test/registered/8-gpu-models/test_deepseek_v31.py +++ b/test/registered/8-gpu-models/test_deepseek_v31.py @@ -1,11 +1,14 @@ import unittest from sglang.test.accuracy_test_runner import AccuracyTestParams +from sglang.test.ci.ci_register import register_cuda_ci from sglang.test.performance_test_runner import PerformanceTestParams from sglang.test.run_combined_tests import run_combined_tests from sglang.test.test_utils import ModelLaunchSettings -# Manual-only: not registered in any CI suite +# Runs on both H200 and B200 via nightly-8-gpu-common suite +register_cuda_ci(est_time=5400, suite="nightly-8-gpu-common", nightly=True) + DEEPSEEK_V31_MODEL_PATH = "deepseek-ai/DeepSeek-V3.1" diff --git a/test/registered/8-gpu-models/test_glm_46.py b/test/registered/8-gpu-models/test_glm_46.py new file mode 100644 index 000000000000..dc22744b46a5 --- /dev/null +++ b/test/registered/8-gpu-models/test_glm_46.py @@ -0,0 +1,52 @@ +import unittest + +from sglang.test.accuracy_test_runner import AccuracyTestParams +from sglang.test.ci.ci_register import register_cuda_ci +from sglang.test.performance_test_runner import PerformanceTestParams +from sglang.test.run_combined_tests import run_combined_tests +from sglang.test.test_utils import ModelLaunchSettings + +# Runs on both H200 and B200 via nightly-8-gpu-common suite +register_cuda_ci(est_time=1800, suite="nightly-8-gpu-common", nightly=True) + +GLM_4_6_MODEL_PATH = "zai-org/GLM-4.6" + + +class TestGLM46(unittest.TestCase): + """Unified test class for GLM-4.6 performance and accuracy. + + Single variant with simple TP=8 configuration. + GLM-4.6 is a 357B MoE model. + Runs BOTH: + - Performance test (using NightlyBenchmarkRunner) + - Accuracy test (using run_eval with mgsm_en) + """ + + def test_glm_46(self): + """Run performance and accuracy for GLM-4.6.""" + base_args = [ + "--tp=8", + "--trust-remote-code", + ] + + variants = [ + ModelLaunchSettings( + GLM_4_6_MODEL_PATH, + tp_size=8, + extra_args=base_args, + variant="TP8", + ), + ] + + run_combined_tests( + models=variants, + test_name="GLM-4.6", + accuracy_params=AccuracyTestParams(dataset="gsm8k", baseline_accuracy=0.80), + performance_params=PerformanceTestParams( + profile_dir="performance_profiles_glm_4_6", + ), + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/registered/8-gpu-models/test_glm_46_fp8.py b/test/registered/8-gpu-models/test_glm_46_fp8.py index fd564fbb8a70..435763e01447 100644 --- a/test/registered/8-gpu-models/test_glm_46_fp8.py +++ b/test/registered/8-gpu-models/test_glm_46_fp8.py @@ -1,11 +1,14 @@ import unittest from sglang.test.accuracy_test_runner import AccuracyTestParams +from sglang.test.ci.ci_register import register_cuda_ci from sglang.test.performance_test_runner import PerformanceTestParams from sglang.test.run_combined_tests import run_combined_tests from sglang.test.test_utils import ModelLaunchSettings -# Manual-only: not registered in any CI suite +# Runs on both H200 and B200 via nightly-8-gpu-common suite +register_cuda_ci(est_time=1800, suite="nightly-8-gpu-common", nightly=True) + GLM_4_6_FP8_MODEL_PATH = "zai-org/GLM-4.6-FP8" diff --git a/test/registered/8-gpu-models/test_qwen35.py b/test/registered/8-gpu-models/test_qwen35.py index 813552b83421..7386b4a8cf46 100644 --- a/test/registered/8-gpu-models/test_qwen35.py +++ b/test/registered/8-gpu-models/test_qwen35.py @@ -9,7 +9,7 @@ # Runs on both H200 and B200 via nightly-8-gpu-common suite register_cuda_ci(est_time=1800, suite="nightly-8-gpu-common", nightly=True) -QWEN35_MODEL_PATH = "Qwen/Qwen3.5-397B-A17B-FP8" +QWEN35_MODEL_PATH = "Qwen/Qwen3.5-397B-A17B" class TestQwen35(unittest.TestCase): @@ -30,7 +30,6 @@ def test_qwen35(self): "--tool-call-parser=qwen3_coder", "--mem-fraction-static=0.8", ] - dp_args = ["--dp=8", "--enable-dp-attention"] mtp_args = [ "--speculative-algorithm=EAGLE", "--speculative-num-steps=3", @@ -49,14 +48,8 @@ def test_qwen35(self): ModelLaunchSettings( QWEN35_MODEL_PATH, tp_size=8, - extra_args=base_args + dp_args, - variant="TP8+DP8", - ), - ModelLaunchSettings( - QWEN35_MODEL_PATH, - tp_size=8, - extra_args=base_args + dp_args + mtp_args, - variant="TP8+DP8+MTP", + extra_args=base_args + mtp_args, + variant="TP8+MTP", env={"SGLANG_ENABLE_SPEC_V2": "1"}, ), ] diff --git a/test/registered/8-gpu-models/test_qwen3_235b.py b/test/registered/8-gpu-models/test_qwen3_235b.py index b3f521bcd752..70420bbed64a 100644 --- a/test/registered/8-gpu-models/test_qwen3_235b.py +++ b/test/registered/8-gpu-models/test_qwen3_235b.py @@ -1,11 +1,14 @@ import unittest from sglang.test.accuracy_test_runner import AccuracyTestParams +from sglang.test.ci.ci_register import register_cuda_ci from sglang.test.performance_test_runner import PerformanceTestParams from sglang.test.run_combined_tests import run_combined_tests from sglang.test.test_utils import ModelLaunchSettings, is_blackwell_system -# Manual-only: not registered in any CI suite +# Runs on both H200 and B200 via nightly-8-gpu-common suite +register_cuda_ci(est_time=1800, suite="nightly-8-gpu-common", nightly=True) + QWEN3_235B_FP8_MODEL_PATH = "Qwen/Qwen3-235B-A22B-Instruct-2507-FP8" QWEN3_235B_EAGLE3_MODEL_PATH = ( "lmsys/SGLang-EAGLE3-Qwen3-235B-A22B-Instruct-2507-SpecForge-Meituan" From 8c3d80eabec2d2a41ccff4d15b5fc81e59d24fbe Mon Sep 17 00:00:00 2001 From: Liangsheng Yin Date: Tue, 7 Apr 2026 18:07:28 -0700 Subject: [PATCH 42/42] Only upload CUDA coredumps on test failure (#22301) --- .github/workflows/nightly-test-nvidia.yml | 32 ++++++++++---------- .github/workflows/pr-test-multimodal-gen.yml | 6 ++-- .github/workflows/pr-test.yml | 24 +++++++-------- .github/workflows/rerun-test.yml | 2 +- 4 files changed, 32 insertions(+), 32 deletions(-) diff --git a/.github/workflows/nightly-test-nvidia.yml b/.github/workflows/nightly-test-nvidia.yml index f3b33b8cbc9f..f523f4c8dbbe 100644 --- a/.github/workflows/nightly-test-nvidia.yml +++ b/.github/workflows/nightly-test-nvidia.yml @@ -76,7 +76,7 @@ jobs: python3 run_suite.py --hw cuda --suite nightly-1-gpu --nightly --continue-on-error - uses: ./.github/actions/upload-cuda-coredumps - if: always() + if: failure() # JIT kernel full unit tests (expanded parameter ranges via SGLANG_JIT_KERNEL_RUN_FULL_TESTS) nightly-test-kernel-1-gpu-h100: @@ -110,7 +110,7 @@ jobs: python3 run_suite.py --hw cuda --suite nightly-kernel-1-gpu --nightly --continue-on-error - uses: ./.github/actions/upload-cuda-coredumps - if: always() + if: failure() nightly-test-kernel-8-gpu-h200: if: github.repository == 'sgl-project/sglang' && (inputs.job_filter == '' || inputs.job_filter == 'all' || inputs.job_filter == 'nightly-test-kernel-8-gpu-h200') @@ -140,7 +140,7 @@ jobs: python3 run_suite.py --hw cuda --suite nightly-kernel-8-gpu-h200 --nightly --continue-on-error - uses: ./.github/actions/upload-cuda-coredumps - if: always() + if: failure() # General tests - 4 GPU H100 nightly-test-general-4-gpu-h100: @@ -165,7 +165,7 @@ jobs: python3 run_suite.py --hw cuda --suite nightly-4-gpu --nightly --continue-on-error - uses: ./.github/actions/upload-cuda-coredumps - if: always() + if: failure() # General tests - 8 GPU H200 nightly-test-general-8-gpu-h200: @@ -249,7 +249,7 @@ jobs: if-no-files-found: ignore - uses: ./.github/actions/upload-cuda-coredumps - if: always() + if: failure() with: artifact-suffix: ${{ matrix.partition }} @@ -280,7 +280,7 @@ jobs: python3 run_suite.py --hw cuda --suite nightly-8-gpu-h20 --nightly --continue-on-error - uses: ./.github/actions/upload-cuda-coredumps - if: always() + if: failure() # General tests - 8 GPU B200 nightly-test-general-8-gpu-b200: @@ -353,7 +353,7 @@ jobs: if-no-files-found: ignore - uses: ./.github/actions/upload-cuda-coredumps - if: always() + if: failure() with: artifact-suffix: ${{ matrix.partition }} @@ -380,7 +380,7 @@ jobs: python3 run_suite.py --hw cuda --suite nightly-eval-text-2-gpu --nightly --continue-on-error --timeout-per-file 4500 - uses: ./.github/actions/upload-cuda-coredumps - if: always() + if: failure() # Text model performance tests nightly-test-text-perf-2-gpu-h100: @@ -418,7 +418,7 @@ jobs: python3 scripts/ci/utils/publish_traces.py --traces-dir test/performance_profiles_text_models - uses: ./.github/actions/upload-cuda-coredumps - if: always() + if: failure() # VLM accuracy tests nightly-test-vlm-accuracy-2-gpu-h100: @@ -443,7 +443,7 @@ jobs: python3 run_suite.py --hw cuda --suite nightly-eval-vlm-2-gpu --nightly --continue-on-error --timeout-per-file 9000 - uses: ./.github/actions/upload-cuda-coredumps - if: always() + if: failure() # VLM performance tests nightly-test-vlm-perf-2-gpu-h100: @@ -481,7 +481,7 @@ jobs: python3 scripts/ci/utils/publish_traces.py --traces-dir test/performance_profiles_vlms - uses: ./.github/actions/upload-cuda-coredumps - if: always() + if: failure() # diffusion performance tests nightly-test-multimodal-server-1-gpu: @@ -538,7 +538,7 @@ jobs: if-no-files-found: ignore - uses: ./.github/actions/upload-cuda-coredumps - if: always() + if: failure() with: artifact-suffix: ${{ matrix.part }} @@ -596,7 +596,7 @@ jobs: if-no-files-found: ignore - uses: ./.github/actions/upload-cuda-coredumps - if: always() + if: failure() with: artifact-suffix: ${{ matrix.part }} @@ -623,7 +623,7 @@ jobs: python3 run_suite.py --hw cuda --suite nightly-4-gpu-b200 --nightly --continue-on-error --timeout-per-file 12000 - uses: ./.github/actions/upload-cuda-coredumps - if: always() + if: failure() # Specialized B200 tests - 8 GPU, for specific backends and configs nightly-test-specialized-8-gpu-b200: @@ -652,7 +652,7 @@ jobs: python3 run_suite.py --hw cuda --suite nightly-8-gpu-b200 --nightly --continue-on-error --timeout-per-file 2400 - uses: ./.github/actions/upload-cuda-coredumps - if: always() + if: failure() # Diffusion cross-framework comparison nightly-test-diffusion-comparison: @@ -716,7 +716,7 @@ jobs: if-no-files-found: ignore - uses: ./.github/actions/upload-cuda-coredumps - if: always() + if: failure() # Consolidate performance metrics from all jobs consolidate-metrics: diff --git a/.github/workflows/pr-test-multimodal-gen.yml b/.github/workflows/pr-test-multimodal-gen.yml index 4b560ed5da39..1fd8ed24eb0e 100644 --- a/.github/workflows/pr-test-multimodal-gen.yml +++ b/.github/workflows/pr-test-multimodal-gen.yml @@ -100,7 +100,7 @@ jobs: $CONTINUE_ON_ERROR_FLAG - uses: ./.github/actions/upload-cuda-coredumps - if: always() + if: failure() with: artifact-suffix: ${{ matrix.part }} @@ -155,7 +155,7 @@ jobs: $CONTINUE_ON_ERROR_FLAG - uses: ./.github/actions/upload-cuda-coredumps - if: always() + if: failure() with: artifact-suffix: ${{ matrix.part }} @@ -204,7 +204,7 @@ jobs: $CONTINUE_ON_ERROR_FLAG - uses: ./.github/actions/upload-cuda-coredumps - if: always() + if: failure() multimodal-gen-unit-test: if: | diff --git a/.github/workflows/pr-test.yml b/.github/workflows/pr-test.yml index 4be332e5928e..ff64a9c3d4d2 100644 --- a/.github/workflows/pr-test.yml +++ b/.github/workflows/pr-test.yml @@ -602,7 +602,7 @@ jobs: python3 run_suite.py --hw cuda --suite stage-a-test-1-gpu-small $CONTINUE_ON_ERROR_FLAG - uses: ./.github/actions/upload-cuda-coredumps - if: always() + if: failure() stage-a-test-cpu: needs: [check-changes, call-gate] @@ -711,7 +711,7 @@ jobs: python3 run_suite.py --hw cuda --suite stage-b-test-1-gpu-small --auto-partition-id ${{ matrix.partition }} --auto-partition-size 8 $CONTINUE_ON_ERROR_FLAG - uses: ./.github/actions/upload-cuda-coredumps - if: always() + if: failure() with: artifact-suffix: ${{ matrix.partition }} @@ -767,7 +767,7 @@ jobs: python3 run_suite.py --hw cuda --suite stage-b-test-1-gpu-large --auto-partition-id ${{ matrix.partition }} --auto-partition-size 14 --timeout-per-file 1800 $CONTINUE_ON_ERROR_FLAG - uses: ./.github/actions/upload-cuda-coredumps - if: always() + if: failure() with: artifact-suffix: ${{ matrix.partition }} @@ -822,7 +822,7 @@ jobs: python3 run_suite.py --hw cuda --suite stage-b-test-2-gpu-large --auto-partition-id ${{ matrix.partition }} --auto-partition-size 4 $CONTINUE_ON_ERROR_FLAG - uses: ./.github/actions/upload-cuda-coredumps - if: always() + if: failure() with: artifact-suffix: ${{ matrix.partition }} @@ -880,7 +880,7 @@ jobs: python3 -m pytest -q python/sglang/jit_kernel/tests/test_flash_attention_4.py - uses: ./.github/actions/upload-cuda-coredumps - if: always() + if: failure() call-multimodal-gen-tests: needs: [check-changes, call-gate, sgl-kernel-build-wheels] @@ -962,7 +962,7 @@ jobs: python3 run_suite.py --hw cuda --suite stage-c-test-4-gpu-h100 --auto-partition-id ${{ matrix.part }} --auto-partition-size 3 $CONTINUE_ON_ERROR_FLAG - uses: ./.github/actions/upload-cuda-coredumps - if: always() + if: failure() with: artifact-suffix: ${{ matrix.part }} @@ -1030,7 +1030,7 @@ jobs: python3 run_suite.py --hw cuda --suite stage-c-test-8-gpu-h200 --auto-partition-id ${{ matrix.part }} --auto-partition-size 4 $CONTINUE_ON_ERROR_FLAG - uses: ./.github/actions/upload-cuda-coredumps - if: always() + if: failure() with: artifact-suffix: ${{ matrix.part }} @@ -1086,7 +1086,7 @@ jobs: python3 run_suite.py --hw cuda --suite stage-c-test-8-gpu-h20 --auto-partition-id ${{ matrix.part }} --auto-partition-size 2 $CONTINUE_ON_ERROR_FLAG - uses: ./.github/actions/upload-cuda-coredumps - if: always() + if: failure() with: artifact-suffix: ${{ matrix.part }} @@ -1148,7 +1148,7 @@ jobs: python3 run_suite.py --hw cuda --suite stage-c-test-deepep-4-gpu-h100 $CONTINUE_ON_ERROR_FLAG - uses: ./.github/actions/upload-cuda-coredumps - if: always() + if: failure() stage-c-test-deepep-8-gpu-h200: needs: [check-changes, call-gate, wait-for-stage-b] @@ -1209,7 +1209,7 @@ jobs: python3 run_suite.py --hw cuda --suite stage-c-test-deepep-8-gpu-h200 $CONTINUE_ON_ERROR_FLAG - uses: ./.github/actions/upload-cuda-coredumps - if: always() + if: failure() stage-c-test-4-gpu-b200: needs: [check-changes, call-gate, wait-for-stage-b] @@ -1262,7 +1262,7 @@ jobs: python3 run_suite.py --hw cuda --suite stage-c-test-4-gpu-b200 --auto-partition-id ${{ matrix.part }} --auto-partition-size 4 --timeout-per-file 1800 $CONTINUE_ON_ERROR_FLAG - uses: ./.github/actions/upload-cuda-coredumps - if: always() + if: failure() with: artifact-suffix: ${{ matrix.part }} @@ -1316,7 +1316,7 @@ jobs: # python3 run_suite.py --hw cuda --suite stage-c-test-4-gpu-gb200 --timeout-per-file 3600 $CONTINUE_ON_ERROR_FLAG # # - uses: ./.github/actions/upload-cuda-coredumps - # if: always() + # if: failure() pr-test-finish: needs: diff --git a/.github/workflows/rerun-test.yml b/.github/workflows/rerun-test.yml index 431b69474c1c..e5c8c98c5ba4 100644 --- a/.github/workflows/rerun-test.yml +++ b/.github/workflows/rerun-test.yml @@ -111,7 +111,7 @@ jobs: echo "All $total test(s) passed in ${total_elapsed}s" - uses: ./.github/actions/upload-cuda-coredumps - if: always() + if: failure() rerun-test-cpu: if: inputs.is_cpu == 'true'