Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 12 additions & 5 deletions python/sglang/multimodal_gen/runtime/layers/layernorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,8 +434,7 @@ def apply_qk_norm(
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Apply QK normalization for query and key tensors.

Minimal multimodal_gen-only implementation: only the JIT fused inplace
QK-norm kernel path is supported (no fallback).
Uses JIT fused inplace kernel when available, falls back to standard RMSNorm.
"""

batch_size = q.size(0)
Expand All @@ -458,7 +457,15 @@ def apply_qk_norm(
)
return q, k

raise RuntimeError(
"apply_qk_norm: fused inplace QK-norm is not applicable "
"(expected CUDA, contiguous q/k, matching eps, and supported head_dim)"
# Fallback for AMD/ROCm: apply RMSNorm separately to q and k
import warnings

warnings.warn(
"Fused QK-norm not available, using RMSNorm fallback",
stacklevel=2,
)
q_shape = q.shape
k_shape = k.shape
q_out = q_norm(q.view(-1, head_dim)).view(q_shape)
k_out = k_norm(k.view(-1, head_dim)).view(k_shape)
return q_out, k_out
28 changes: 23 additions & 5 deletions python/sglang/multimodal_gen/runtime/layers/rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,11 +69,29 @@ def apply_flashinfer_rope_qk_inplace(

try:
from flashinfer.rope import apply_rope_with_cos_sin_cache_inplace
except Exception as e:
raise RuntimeError(
"flashinfer is required for apply_flashinfer_rope_qk_inplace. "
"Please install flashinfer or disable this optimization."
) from e
except ImportError:
# Triton fallback for AMD/ROCm where FlashInfer is not available
import warnings

warnings.warn(
"FlashInfer not available, using Triton fallback for RoPE",
stacklevel=2,
)
half_size = cos_sin_cache.shape[-1] // 2
if positions is None:
cos = cos_sin_cache[:seqlen, :half_size].to(q.dtype)
sin = cos_sin_cache[:seqlen, half_size:].to(q.dtype)
cos = cos.unsqueeze(0).expand(bsz, -1, -1).reshape(bsz * seqlen, -1)
sin = sin.unsqueeze(0).expand(bsz, -1, -1).reshape(bsz * seqlen, -1)
else:
positions = positions.to(cos_sin_cache.device).view(-1)
cos = cos_sin_cache[positions, :half_size].to(q.dtype)
sin = cos_sin_cache[positions, half_size:].to(q.dtype)
q_flat = q.reshape(bsz * seqlen, nheads, d)
k_flat = k.reshape(bsz * seqlen, nheads, d)
q_rot = apply_rotary_embedding(q_flat, cos, sin, interleaved=not is_neox)
k_rot = apply_rotary_embedding(k_flat, cos, sin, interleaved=not is_neox)
return q_rot.view(bsz, seqlen, nheads, d), k_rot.view(bsz, seqlen, nheads, d)

if positions is None:
pos_1d = torch.arange(seqlen, device="cpu", dtype=torch.long)
Expand Down
58 changes: 38 additions & 20 deletions python/sglang/multimodal_gen/runtime/models/encoders/clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,10 @@
from sglang.multimodal_gen.runtime.models.encoders.vision import (
resolve_visual_encoder_outputs,
)
from sglang.multimodal_gen.runtime.platforms import AttentionBackendEnum
from sglang.multimodal_gen.runtime.platforms import (
AttentionBackendEnum,
current_platform,
)
from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger

logger = init_logger(__name__)
Expand Down Expand Up @@ -227,26 +230,41 @@ def forward(
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)

if attention_mask is not None:
# SDPA requires [B, 1, 1, S] or [B, S, S] format mask
if attention_mask.dim() == 2:
attn_mask = attention_mask[:, None, None, :].to(
dtype=query_states.dtype
)
attn_mask = (1.0 - attn_mask) * torch.finfo(query_states.dtype).min
else:
attn_mask = attention_mask
if current_platform.is_rocm():
# ROCm: Using both is_causal=True and attn_mask causes NaN.
# Use is_causal=True alone (padding mask not needed for CLIP
# since pooler_output comes from EOS token before padding).
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=None,
is_causal=True,
scale=self.scale,
)
else:
attn_mask = None

attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=attn_mask,
is_causal=True,
scale=self.scale,
)
if attention_mask is not None:
# SDPA requires [B, 1, 1, S] or [B, S, S] format mask
if attention_mask.dim() == 2:
attn_mask = attention_mask[:, None, None, :].to(
dtype=query_states.dtype
)
attn_mask = (1.0 - attn_mask) * torch.finfo(
query_states.dtype
).min
else:
attn_mask = attention_mask
else:
attn_mask = None

attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=attn_mask,
is_causal=True,
scale=self.scale,
)
attn_output = attn_output.transpose(1, 2)
else:
# Use LocalAttention (doesn't support attention_mask, but maintains compatibility)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES

from sglang.multimodal_gen.runtime.loader.weight_utils import get_lock
from sglang.multimodal_gen.runtime.platforms import current_platform
from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger

logger = init_logger(__name__)
Expand Down Expand Up @@ -230,6 +231,12 @@ def maybe_download_lora(
return local_path

weight_name = _best_guess_weight_name(local_path, file_extension=".safetensors")
# AMD workaround: PR 15813 changed from model_name_or_path to local_path,
# which can return None. Fall back to original behavior on ROCm.
if weight_name is None and current_platform.is_rocm():
weight_name = _best_guess_weight_name(
model_name_or_path, file_extension=".safetensors"
)
return os.path.join(local_path, weight_name)


Expand Down
7 changes: 7 additions & 0 deletions scripts/ci/amd_ci_install_dependency.sh
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,13 @@ docker cp ./dummy-grok ci_sglang:/
docker exec ci_sglang pip install --cache-dir=/sgl-data/pip-cache huggingface_hub[hf_xet]
docker exec ci_sglang pip install --cache-dir=/sgl-data/pip-cache pytest

# Install tvm-ffi for JIT kernel support (QK-norm, etc.)
echo "Installing tvm-ffi for JIT kernel support..."
docker exec ci_sglang pip install --cache-dir=/sgl-data/pip-cache git+https://github.com/apache/tvm-ffi.git || echo "tvm-ffi installation failed, JIT kernels will use fallback"

# Install cache-dit for qwen_image_t2i_cache_dit_enabled test (added in PR 16204)
docker exec ci_sglang pip install --cache-dir=/sgl-data/pip-cache cache-dit || echo "cache-dit installation failed"

# Detect AITER version
#############################################
# Detect correct AITER_COMMIT for this runner
Expand Down
Loading