diff --git a/python/sglang/multimodal_gen/runtime/layers/layernorm.py b/python/sglang/multimodal_gen/runtime/layers/layernorm.py index 82fbb76828fe..78ed8099d893 100644 --- a/python/sglang/multimodal_gen/runtime/layers/layernorm.py +++ b/python/sglang/multimodal_gen/runtime/layers/layernorm.py @@ -434,8 +434,7 @@ def apply_qk_norm( ) -> Tuple[torch.Tensor, torch.Tensor]: """Apply QK normalization for query and key tensors. - Minimal multimodal_gen-only implementation: only the JIT fused inplace - QK-norm kernel path is supported (no fallback). + Uses JIT fused inplace kernel when available, falls back to standard RMSNorm. """ batch_size = q.size(0) @@ -458,7 +457,15 @@ def apply_qk_norm( ) return q, k - raise RuntimeError( - "apply_qk_norm: fused inplace QK-norm is not applicable " - "(expected CUDA, contiguous q/k, matching eps, and supported head_dim)" + # Fallback for AMD/ROCm: apply RMSNorm separately to q and k + import warnings + + warnings.warn( + "Fused QK-norm not available, using RMSNorm fallback", + stacklevel=2, ) + q_shape = q.shape + k_shape = k.shape + q_out = q_norm(q.view(-1, head_dim)).view(q_shape) + k_out = k_norm(k.view(-1, head_dim)).view(k_shape) + return q_out, k_out diff --git a/python/sglang/multimodal_gen/runtime/layers/rotary_embedding.py b/python/sglang/multimodal_gen/runtime/layers/rotary_embedding.py index ac5e8ed0e091..2ef943229c6b 100644 --- a/python/sglang/multimodal_gen/runtime/layers/rotary_embedding.py +++ b/python/sglang/multimodal_gen/runtime/layers/rotary_embedding.py @@ -69,11 +69,29 @@ def apply_flashinfer_rope_qk_inplace( try: from flashinfer.rope import apply_rope_with_cos_sin_cache_inplace - except Exception as e: - raise RuntimeError( - "flashinfer is required for apply_flashinfer_rope_qk_inplace. " - "Please install flashinfer or disable this optimization." - ) from e + except ImportError: + # Triton fallback for AMD/ROCm where FlashInfer is not available + import warnings + + warnings.warn( + "FlashInfer not available, using Triton fallback for RoPE", + stacklevel=2, + ) + half_size = cos_sin_cache.shape[-1] // 2 + if positions is None: + cos = cos_sin_cache[:seqlen, :half_size].to(q.dtype) + sin = cos_sin_cache[:seqlen, half_size:].to(q.dtype) + cos = cos.unsqueeze(0).expand(bsz, -1, -1).reshape(bsz * seqlen, -1) + sin = sin.unsqueeze(0).expand(bsz, -1, -1).reshape(bsz * seqlen, -1) + else: + positions = positions.to(cos_sin_cache.device).view(-1) + cos = cos_sin_cache[positions, :half_size].to(q.dtype) + sin = cos_sin_cache[positions, half_size:].to(q.dtype) + q_flat = q.reshape(bsz * seqlen, nheads, d) + k_flat = k.reshape(bsz * seqlen, nheads, d) + q_rot = apply_rotary_embedding(q_flat, cos, sin, interleaved=not is_neox) + k_rot = apply_rotary_embedding(k_flat, cos, sin, interleaved=not is_neox) + return q_rot.view(bsz, seqlen, nheads, d), k_rot.view(bsz, seqlen, nheads, d) if positions is None: pos_1d = torch.arange(seqlen, device="cpu", dtype=torch.long) diff --git a/python/sglang/multimodal_gen/runtime/models/encoders/clip.py b/python/sglang/multimodal_gen/runtime/models/encoders/clip.py index 99db53a75ad4..9dd279d8fc7d 100644 --- a/python/sglang/multimodal_gen/runtime/models/encoders/clip.py +++ b/python/sglang/multimodal_gen/runtime/models/encoders/clip.py @@ -33,7 +33,10 @@ from sglang.multimodal_gen.runtime.models.encoders.vision import ( resolve_visual_encoder_outputs, ) -from sglang.multimodal_gen.runtime.platforms import AttentionBackendEnum +from sglang.multimodal_gen.runtime.platforms import ( + AttentionBackendEnum, + current_platform, +) from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger logger = init_logger(__name__) @@ -227,26 +230,41 @@ def forward( key_states = key_states.transpose(1, 2) value_states = value_states.transpose(1, 2) - if attention_mask is not None: - # SDPA requires [B, 1, 1, S] or [B, S, S] format mask - if attention_mask.dim() == 2: - attn_mask = attention_mask[:, None, None, :].to( - dtype=query_states.dtype - ) - attn_mask = (1.0 - attn_mask) * torch.finfo(query_states.dtype).min - else: - attn_mask = attention_mask + if current_platform.is_rocm(): + # ROCm: Using both is_causal=True and attn_mask causes NaN. + # Use is_causal=True alone (padding mask not needed for CLIP + # since pooler_output comes from EOS token before padding). + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=None, + is_causal=True, + scale=self.scale, + ) else: - attn_mask = None - - attn_output = torch.nn.functional.scaled_dot_product_attention( - query_states, - key_states, - value_states, - attn_mask=attn_mask, - is_causal=True, - scale=self.scale, - ) + if attention_mask is not None: + # SDPA requires [B, 1, 1, S] or [B, S, S] format mask + if attention_mask.dim() == 2: + attn_mask = attention_mask[:, None, None, :].to( + dtype=query_states.dtype + ) + attn_mask = (1.0 - attn_mask) * torch.finfo( + query_states.dtype + ).min + else: + attn_mask = attention_mask + else: + attn_mask = None + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=attn_mask, + is_causal=True, + scale=self.scale, + ) attn_output = attn_output.transpose(1, 2) else: # Use LocalAttention (doesn't support attention_mask, but maintains compatibility) diff --git a/python/sglang/multimodal_gen/runtime/utils/hf_diffusers_utils.py b/python/sglang/multimodal_gen/runtime/utils/hf_diffusers_utils.py index 63e6d6e730b1..a4b0dc25f133 100644 --- a/python/sglang/multimodal_gen/runtime/utils/hf_diffusers_utils.py +++ b/python/sglang/multimodal_gen/runtime/utils/hf_diffusers_utils.py @@ -41,6 +41,7 @@ from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES from sglang.multimodal_gen.runtime.loader.weight_utils import get_lock +from sglang.multimodal_gen.runtime.platforms import current_platform from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger logger = init_logger(__name__) @@ -230,6 +231,12 @@ def maybe_download_lora( return local_path weight_name = _best_guess_weight_name(local_path, file_extension=".safetensors") + # AMD workaround: PR 15813 changed from model_name_or_path to local_path, + # which can return None. Fall back to original behavior on ROCm. + if weight_name is None and current_platform.is_rocm(): + weight_name = _best_guess_weight_name( + model_name_or_path, file_extension=".safetensors" + ) return os.path.join(local_path, weight_name) diff --git a/scripts/ci/amd_ci_install_dependency.sh b/scripts/ci/amd_ci_install_dependency.sh index f5c11bc13fca..f8c1d5fc138f 100755 --- a/scripts/ci/amd_ci_install_dependency.sh +++ b/scripts/ci/amd_ci_install_dependency.sh @@ -92,6 +92,13 @@ docker cp ./dummy-grok ci_sglang:/ docker exec ci_sglang pip install --cache-dir=/sgl-data/pip-cache huggingface_hub[hf_xet] docker exec ci_sglang pip install --cache-dir=/sgl-data/pip-cache pytest +# Install tvm-ffi for JIT kernel support (QK-norm, etc.) +echo "Installing tvm-ffi for JIT kernel support..." +docker exec ci_sglang pip install --cache-dir=/sgl-data/pip-cache git+https://github.com/apache/tvm-ffi.git || echo "tvm-ffi installation failed, JIT kernels will use fallback" + +# Install cache-dit for qwen_image_t2i_cache_dit_enabled test (added in PR 16204) +docker exec ci_sglang pip install --cache-dir=/sgl-data/pip-cache cache-dit || echo "cache-dit installation failed" + # Detect AITER version ############################################# # Detect correct AITER_COMMIT for this runner