diff --git a/python/sglang/srt/models/qwen2_vl.py b/python/sglang/srt/models/qwen2_vl.py index 2e9ec9d8f50..199df024090 100644 --- a/python/sglang/srt/models/qwen2_vl.py +++ b/python/sglang/srt/models/qwen2_vl.py @@ -30,10 +30,12 @@ import torch.nn as nn import torch.nn.functional as F from einops import rearrange, repeat +from vllm.config import CacheConfig, MultiModalConfig from vllm.distributed import parallel_state from vllm.distributed import utils as dist_utils from vllm.logger import init_logger from vllm.model_executor.layers.activation import QuickGELU +from vllm.model_executor.model_loader.weight_utils import default_weight_loader from sglang.srt.configs import Qwen2VLConfig, Qwen2VLVisionConfig from sglang.srt.hf_transformers_utils import get_processor @@ -52,7 +54,15 @@ logger = init_logger(__name__) + # === Vision Inputs === # +class OriginalQuickGELUActivation(nn.Module): + """ + Applies GELU approximation that is fast but somewhat inaccurate. See: https://github.com/hendrycks/GELUs + """ + + def forward(self, input: torch.Tensor) -> torch.Tensor: + return input * torch.sigmoid(1.702 * input) class Qwen2VLImageInputs(TypedDict): @@ -91,7 +101,7 @@ def __init__( self, in_features: int, hidden_features: int = None, - act_layer: Type[nn.Module] = QuickGELU, + act_layer: Type[nn.Module] = OriginalQuickGELUActivation, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() @@ -201,21 +211,30 @@ def forward( q, k, v = dist_utils.split_tensor_along_last_dim(x, 3) batch_size = q.shape[1] - q, k, v = [rearrange(x, "s b ... -> b s ...").contiguous() for x in (q, k, v)] + q, k, v = (rearrange(x, "s b ... -> b s ...").contiguous() for x in (q, k, v)) if rotary_pos_emb is not None: q = apply_rotary_pos_emb_vision(q, rotary_pos_emb) k = apply_rotary_pos_emb_vision(k, rotary_pos_emb) - seq_lens = cu_seqlens[1:] - cu_seqlens[:-1] - max_seqlen = (seq_lens).max().item() - q, k, v = [rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]] - - output = torch.empty_like(q) - context_attention_fwd( - q, k, v, output, cu_seqlens, seq_lens, max_seqlen, is_causal=False + seq_length = q.size(1) + q, k, v = (rearrange(x, "b s h d -> b h s d") for x in [q, k, v]) + attention_mask = torch.zeros( + [1, seq_length, seq_length], device=q.device, dtype=torch.bool ) + for i in range(1, len(cu_seqlens)): + attention_mask[ + ..., + cu_seqlens[i - 1] : cu_seqlens[i], + cu_seqlens[i - 1] : cu_seqlens[i], + ] = True + + q = q.squeeze(0) + k = k.squeeze(0) + v = v.squeeze(0) + output = F.scaled_dot_product_attention(q, k, v, attention_mask, dropout_p=0.0) + output = output.unsqueeze(0) + context_layer = rearrange(output, "b h s d -> b s h d ") - context_layer = rearrange(output, "(b s) ... -> b s ...", b=batch_size) context_layer = rearrange(context_layer, "b s h d -> s b (h d)").contiguous() output, _ = self.proj(context_layer) @@ -229,7 +248,7 @@ def __init__( dim: int, num_heads: int, mlp_ratio: float, - act_layer: Type[nn.Module] = QuickGELU, + act_layer: Type[nn.Module] = OriginalQuickGELUActivation, norm_layer: Type[nn.Module] = None, quant_config: Optional[QuantizationConfig] = None, ) -> None: @@ -328,38 +347,19 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class Qwen2VisionRotaryEmbedding(nn.Module): - def __init__(self, dim: int, theta: float = 10000.0) -> None: super().__init__() - self.dim = dim self.theta = theta - inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim)) - self.register_buffer("inv_freq", inv_freq, persistent=False) - self._seq_len_cached = 0 - self._freqs_cached = None - - def update_freqs_cache(self, seqlen: int) -> None: - if seqlen > self._seq_len_cached: - seqlen *= 2 - self._seq_len_cached = seqlen - self.inv_freq = 1.0 / ( - self.theta - ** ( - torch.arange( - 0, self.dim, 2, dtype=torch.float, device=self.inv_freq.device - ) - / self.dim - ) - ) - seq = torch.arange( - seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype - ) - freqs = torch.outer(seq, self.inv_freq) - self._freqs_cached = freqs + self.dim = dim def forward(self, seqlen: int) -> torch.Tensor: - self.update_freqs_cache(seqlen) - return self._freqs_cached[:seqlen] + inv_freq = 1.0 / ( + self.theta ** (torch.arange(0, self.dim, 2, dtype=torch.float) / self.dim) + ) + + seq = torch.arange(seqlen, device=inv_freq.device, dtype=inv_freq.dtype) + freqs = torch.outer(seq, inv_freq) + return freqs class Qwen2VisionTransformer(nn.Module): @@ -450,7 +450,7 @@ def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor: pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)) pos_ids = torch.cat(pos_ids, dim=0) max_grid_size = grid_thw[:, 1:].max() - rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size) + rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size).to(grid_thw.device) rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1) return rotary_pos_emb @@ -499,7 +499,7 @@ def calculate_num_image_tokens(self, image_grid_thw: Tuple[int, int, int]): return num_image_tokens # Use grid_t * grid_w * grid_h to pad tokens for each image - # add replaced padding by unique image hash + # and replaced padding by unique image hash def pad_input_ids(self, input_ids: List[int], image_inputs: ImageInputs): image_grid_thws = image_inputs.image_grid_thws pad_values = image_inputs.pad_values diff --git a/test/srt/test_qwen2vl.py b/test/srt/test_qwen2vl.py new file mode 100644 index 00000000000..e3406b6fc4b --- /dev/null +++ b/test/srt/test_qwen2vl.py @@ -0,0 +1,159 @@ +import asyncio +import base64 +import json +import math +import os +import tempfile +import unittest +from io import BytesIO +from pathlib import Path +from unittest.mock import AsyncMock, patch + +import numpy as np +import requests +import torch +import torch.nn.functional as F +from PIL import Image +from transformers import AutoProcessor, AutoTokenizer, Qwen2VLForConditionalGeneration + +from sglang.srt.configs.model_config import ModelConfig +from sglang.srt.hf_transformers_utils import get_tokenizer +from sglang.srt.managers.schedule_batch import Req, ScheduleBatch +from sglang.srt.model_executor.forward_batch_info import ForwardBatch +from sglang.srt.model_executor.model_runner import ModelRunner +from sglang.srt.sampling.sampling_params import SamplingParams +from sglang.srt.server_args import PortArgs, ServerArgs + +QWEN2_VL_MODEL = "Qwen/Qwen2-VL-7B-Instruct" + + +class RawSGLangTest(unittest.IsolatedAsyncioTestCase): + def setUp(self): + self.tokenizer = AutoTokenizer.from_pretrained( + QWEN2_VL_MODEL, trust_remote_code=True + ) + self.image_token_id = self.tokenizer.encode("<|image_pad|>")[0] + + self.model = Qwen2VLForConditionalGeneration.from_pretrained( + QWEN2_VL_MODEL, torch_dtype=torch.bfloat16, trust_remote_code=True + ).eval() + self.processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct") + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + self.model.to(self.device) + + async def test_vision_encoder(self): + messages = [ + { + "role": "user", + "content": [ + { + "type": "image", + "image": "https://github.com/sgl-project/sglang/blob/main/test/lang/example_image.png?raw=true", + }, + {"type": "text", "text": "Describe this image."}, + ], + } + ] + + # Apply chat template to get the text + text = self.processor.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True + ) + + response = requests.get(messages[0]["content"][0]["image"]) + main_image = Image.open(BytesIO(response.content)) + + # Process inputs using processor + inputs = self.processor( + text=[text], + images=[main_image], + padding=True, + return_tensors="pt", + ) + + with torch.no_grad(): + hf_output = self.model.visual( + inputs["pixel_values"].to(self.device), + grid_thw=inputs["image_grid_thw"].to(self.device), + ) + + model_config = ModelConfig(QWEN2_VL_MODEL, model_override_args="{}") + + server_args = ServerArgs(model_path=QWEN2_VL_MODEL) + model_runner = ModelRunner( + model_config=model_config, + mem_fraction_static=0.8, + gpu_id=0, + tp_rank=0, + tp_size=1, + nccl_port=12435, + server_args=server_args, + ) + + with torch.no_grad(): + sglang_output = model_runner.model.visual( + inputs["pixel_values"].to(self.device), + grid_thw=inputs["image_grid_thw"].to(self.device), + ) + + # Convert to float32 for numerical stability if needed + hf = hf_output.float() + sg = sglang_output.float() + + # Basic shape and dtype comparison + print("\n=== Basic Properties ===") + print(f"Shapes match: {hf.shape == sg.shape}") + print(f"HF shape: {hf.shape}, SGLang shape: {sg.shape}") + print(f"HF dtype: {hf.dtype}, SGLang dtype: {sg.dtype}") + + # Move tensors to CPU for numpy operations + hf_np = hf.cpu().numpy() + sg_np = sg.cpu().numpy() + + # Statistical metrics + print("\n=== Statistical Metrics ===") + print(f"Mean absolute difference: {torch.mean(torch.abs(hf - sg)).item():.6f}") + print(f"Max absolute difference: {torch.max(torch.abs(hf - sg)).item():.6f}") + print(f"Mean squared error: {torch.mean((hf - sg) ** 2).item():.6f}") + print( + f"Root mean squared error: {torch.sqrt(torch.mean((hf - sg) ** 2)).item():.6f}" + ) + + # Cosine similarity (across feature dimension) + cos_sim = F.cosine_similarity(hf, sg) + print(f"Mean cosine similarity: {torch.mean(cos_sim).item():.6f}") + print(f"Min cosine similarity: {torch.min(cos_sim).item():.6f}") + + # Find largest absolute differences + print("\n=== Largest Absolute Differences ===") + diffs = torch.abs(hf - sg) + flat_diffs = diffs.flatten() + + # Get indices of top 10 differences + top_k = 10 + top_values, top_flat_indices = torch.topk(flat_diffs, top_k) + + # Convert flat indices to multidimensional indices + top_indices = np.unravel_index(top_flat_indices.cpu().numpy(), diffs.shape) + + print(f"\nTop {top_k} largest absolute differences:") + print( + "Index".ljust(30) + + "Difference".ljust(15) + + "HF Value".ljust(15) + + "SGLang Value" + ) + print("-" * 75) + + for i in range(top_k): + # Get the index tuple for this difference + idx = tuple(dim[i] for dim in top_indices) + diff_val = top_values[i].item() + hf_val = hf[idx].item() + sg_val = sg[idx].item() + + # Format the index tuple and values + idx_str = str(idx) + print(f"{idx_str:<30}{diff_val:<15.6f}{hf_val:<15.6f}{sg_val:.6f}") + + np.testing.assert_allclose(hf_np, sg_np)