-
Notifications
You must be signed in to change notification settings - Fork 3.4k
Qwen2vl vision encoder fix #2365
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
a14c88c
a88f0fc
8ac651c
d81cc75
5402c43
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -30,10 +30,12 @@ | |
| import torch.nn as nn | ||
| import torch.nn.functional as F | ||
| from einops import rearrange, repeat | ||
| from vllm.config import CacheConfig, MultiModalConfig | ||
| from vllm.distributed import parallel_state | ||
| from vllm.distributed import utils as dist_utils | ||
| from vllm.logger import init_logger | ||
| from vllm.model_executor.layers.activation import QuickGELU | ||
| from vllm.model_executor.model_loader.weight_utils import default_weight_loader | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. remove unused imports |
||
|
|
||
| from sglang.srt.configs import Qwen2VLConfig, Qwen2VLVisionConfig | ||
| from sglang.srt.hf_transformers_utils import get_processor | ||
|
|
@@ -52,7 +54,15 @@ | |
|
|
||
| logger = init_logger(__name__) | ||
|
|
||
|
|
||
| # === Vision Inputs === # | ||
| class OriginalQuickGELUActivation(nn.Module): | ||
| """ | ||
| Applies GELU approximation that is fast but somewhat inaccurate. See: https://github.com/hendrycks/GELUs | ||
| """ | ||
|
|
||
| def forward(self, input: torch.Tensor) -> torch.Tensor: | ||
| return input * torch.sigmoid(1.702 * input) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. try torch.compile to fuse them? |
||
|
|
||
|
|
||
| class Qwen2VLImageInputs(TypedDict): | ||
|
|
@@ -91,7 +101,7 @@ def __init__( | |
| self, | ||
| in_features: int, | ||
| hidden_features: int = None, | ||
| act_layer: Type[nn.Module] = QuickGELU, | ||
| act_layer: Type[nn.Module] = OriginalQuickGELUActivation, | ||
| quant_config: Optional[QuantizationConfig] = None, | ||
| ): | ||
| super().__init__() | ||
|
|
@@ -201,21 +211,30 @@ def forward( | |
| q, k, v = dist_utils.split_tensor_along_last_dim(x, 3) | ||
| batch_size = q.shape[1] | ||
|
|
||
| q, k, v = [rearrange(x, "s b ... -> b s ...").contiguous() for x in (q, k, v)] | ||
| q, k, v = (rearrange(x, "s b ... -> b s ...").contiguous() for x in (q, k, v)) | ||
| if rotary_pos_emb is not None: | ||
| q = apply_rotary_pos_emb_vision(q, rotary_pos_emb) | ||
| k = apply_rotary_pos_emb_vision(k, rotary_pos_emb) | ||
|
|
||
| seq_lens = cu_seqlens[1:] - cu_seqlens[:-1] | ||
| max_seqlen = (seq_lens).max().item() | ||
| q, k, v = [rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]] | ||
|
|
||
| output = torch.empty_like(q) | ||
| context_attention_fwd( | ||
| q, k, v, output, cu_seqlens, seq_lens, max_seqlen, is_causal=False | ||
| seq_length = q.size(1) | ||
| q, k, v = (rearrange(x, "b s h d -> b h s d") for x in [q, k, v]) | ||
| attention_mask = torch.zeros( | ||
| [1, seq_length, seq_length], device=q.device, dtype=torch.bool | ||
| ) | ||
| for i in range(1, len(cu_seqlens)): | ||
| attention_mask[ | ||
| ..., | ||
| cu_seqlens[i - 1] : cu_seqlens[i], | ||
| cu_seqlens[i - 1] : cu_seqlens[i], | ||
| ] = True | ||
|
|
||
| q = q.squeeze(0) | ||
| k = k.squeeze(0) | ||
| v = v.squeeze(0) | ||
| output = F.scaled_dot_product_attention(q, k, v, attention_mask, dropout_p=0.0) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think pytorch SDPA should also be fast, the probably is probably how you prepare the attention mask?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hmm, I have tried caching the attention mask, but it doesn't seem to impact performance much. The issue I see is that for one layer, the |
||
| output = output.unsqueeze(0) | ||
| context_layer = rearrange(output, "b h s d -> b s h d ") | ||
|
|
||
| context_layer = rearrange(output, "(b s) ... -> b s ...", b=batch_size) | ||
| context_layer = rearrange(context_layer, "b s h d -> s b (h d)").contiguous() | ||
|
|
||
| output, _ = self.proj(context_layer) | ||
|
|
@@ -229,7 +248,7 @@ def __init__( | |
| dim: int, | ||
| num_heads: int, | ||
| mlp_ratio: float, | ||
| act_layer: Type[nn.Module] = QuickGELU, | ||
| act_layer: Type[nn.Module] = OriginalQuickGELUActivation, | ||
| norm_layer: Type[nn.Module] = None, | ||
| quant_config: Optional[QuantizationConfig] = None, | ||
| ) -> None: | ||
|
|
@@ -328,38 +347,19 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: | |
|
|
||
|
|
||
| class Qwen2VisionRotaryEmbedding(nn.Module): | ||
|
|
||
| def __init__(self, dim: int, theta: float = 10000.0) -> None: | ||
| super().__init__() | ||
| self.dim = dim | ||
| self.theta = theta | ||
| inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim)) | ||
| self.register_buffer("inv_freq", inv_freq, persistent=False) | ||
| self._seq_len_cached = 0 | ||
| self._freqs_cached = None | ||
|
|
||
| def update_freqs_cache(self, seqlen: int) -> None: | ||
| if seqlen > self._seq_len_cached: | ||
| seqlen *= 2 | ||
| self._seq_len_cached = seqlen | ||
| self.inv_freq = 1.0 / ( | ||
| self.theta | ||
| ** ( | ||
| torch.arange( | ||
| 0, self.dim, 2, dtype=torch.float, device=self.inv_freq.device | ||
| ) | ||
| / self.dim | ||
| ) | ||
| ) | ||
| seq = torch.arange( | ||
| seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype | ||
| ) | ||
| freqs = torch.outer(seq, self.inv_freq) | ||
| self._freqs_cached = freqs | ||
| self.dim = dim | ||
|
|
||
| def forward(self, seqlen: int) -> torch.Tensor: | ||
| self.update_freqs_cache(seqlen) | ||
| return self._freqs_cached[:seqlen] | ||
| inv_freq = 1.0 / ( | ||
| self.theta ** (torch.arange(0, self.dim, 2, dtype=torch.float) / self.dim) | ||
| ) | ||
|
|
||
| seq = torch.arange(seqlen, device=inv_freq.device, dtype=inv_freq.dtype) | ||
| freqs = torch.outer(seq, inv_freq) | ||
| return freqs | ||
|
|
||
|
|
||
| class Qwen2VisionTransformer(nn.Module): | ||
|
|
@@ -450,7 +450,7 @@ def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor: | |
| pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)) | ||
| pos_ids = torch.cat(pos_ids, dim=0) | ||
| max_grid_size = grid_thw[:, 1:].max() | ||
| rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size) | ||
| rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size).to(grid_thw.device) | ||
| rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1) | ||
| return rotary_pos_emb | ||
|
|
||
|
|
@@ -499,7 +499,7 @@ def calculate_num_image_tokens(self, image_grid_thw: Tuple[int, int, int]): | |
| return num_image_tokens | ||
|
|
||
| # Use grid_t * grid_w * grid_h to pad tokens for each image | ||
| # add replaced padding by unique image hash | ||
| # and replaced padding by unique image hash | ||
| def pad_input_ids(self, input_ids: List[int], image_inputs: ImageInputs): | ||
| image_grid_thws = image_inputs.image_grid_thws | ||
| pad_values = image_inputs.pad_values | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,159 @@ | ||
| import asyncio | ||
| import base64 | ||
| import json | ||
| import math | ||
| import os | ||
| import tempfile | ||
| import unittest | ||
| from io import BytesIO | ||
| from pathlib import Path | ||
| from unittest.mock import AsyncMock, patch | ||
|
|
||
| import numpy as np | ||
| import requests | ||
| import torch | ||
| import torch.nn.functional as F | ||
| from PIL import Image | ||
| from transformers import AutoProcessor, AutoTokenizer, Qwen2VLForConditionalGeneration | ||
|
|
||
| from sglang.srt.configs.model_config import ModelConfig | ||
| from sglang.srt.hf_transformers_utils import get_tokenizer | ||
| from sglang.srt.managers.schedule_batch import Req, ScheduleBatch | ||
| from sglang.srt.model_executor.forward_batch_info import ForwardBatch | ||
| from sglang.srt.model_executor.model_runner import ModelRunner | ||
| from sglang.srt.sampling.sampling_params import SamplingParams | ||
| from sglang.srt.server_args import PortArgs, ServerArgs | ||
|
|
||
| QWEN2_VL_MODEL = "Qwen/Qwen2-VL-7B-Instruct" | ||
|
|
||
|
|
||
| class RawSGLangTest(unittest.IsolatedAsyncioTestCase): | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This looks good. Maybe give it a better name. Now we have a very good reference script for text-only models and a very good model support guide: Are you willing to help here to add some scripts/docs similar to the above ones, but for vision language models? |
||
| def setUp(self): | ||
| self.tokenizer = AutoTokenizer.from_pretrained( | ||
| QWEN2_VL_MODEL, trust_remote_code=True | ||
| ) | ||
| self.image_token_id = self.tokenizer.encode("<|image_pad|>")[0] | ||
|
|
||
| self.model = Qwen2VLForConditionalGeneration.from_pretrained( | ||
| QWEN2_VL_MODEL, torch_dtype=torch.bfloat16, trust_remote_code=True | ||
| ).eval() | ||
| self.processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct") | ||
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | ||
| self.model.to(self.device) | ||
|
|
||
| async def test_vision_encoder(self): | ||
| messages = [ | ||
| { | ||
| "role": "user", | ||
| "content": [ | ||
| { | ||
| "type": "image", | ||
| "image": "https://github.com/sgl-project/sglang/blob/main/test/lang/example_image.png?raw=true", | ||
| }, | ||
| {"type": "text", "text": "Describe this image."}, | ||
| ], | ||
| } | ||
| ] | ||
|
|
||
| # Apply chat template to get the text | ||
| text = self.processor.apply_chat_template( | ||
| messages, tokenize=False, add_generation_prompt=True | ||
| ) | ||
|
|
||
| response = requests.get(messages[0]["content"][0]["image"]) | ||
| main_image = Image.open(BytesIO(response.content)) | ||
|
|
||
| # Process inputs using processor | ||
| inputs = self.processor( | ||
| text=[text], | ||
| images=[main_image], | ||
| padding=True, | ||
| return_tensors="pt", | ||
| ) | ||
|
|
||
| with torch.no_grad(): | ||
| hf_output = self.model.visual( | ||
| inputs["pixel_values"].to(self.device), | ||
| grid_thw=inputs["image_grid_thw"].to(self.device), | ||
| ) | ||
|
|
||
| model_config = ModelConfig(QWEN2_VL_MODEL, model_override_args="{}") | ||
|
|
||
| server_args = ServerArgs(model_path=QWEN2_VL_MODEL) | ||
| model_runner = ModelRunner( | ||
| model_config=model_config, | ||
| mem_fraction_static=0.8, | ||
| gpu_id=0, | ||
| tp_rank=0, | ||
| tp_size=1, | ||
| nccl_port=12435, | ||
| server_args=server_args, | ||
| ) | ||
|
|
||
| with torch.no_grad(): | ||
| sglang_output = model_runner.model.visual( | ||
| inputs["pixel_values"].to(self.device), | ||
| grid_thw=inputs["image_grid_thw"].to(self.device), | ||
| ) | ||
|
|
||
| # Convert to float32 for numerical stability if needed | ||
| hf = hf_output.float() | ||
| sg = sglang_output.float() | ||
|
|
||
| # Basic shape and dtype comparison | ||
| print("\n=== Basic Properties ===") | ||
| print(f"Shapes match: {hf.shape == sg.shape}") | ||
| print(f"HF shape: {hf.shape}, SGLang shape: {sg.shape}") | ||
| print(f"HF dtype: {hf.dtype}, SGLang dtype: {sg.dtype}") | ||
|
|
||
| # Move tensors to CPU for numpy operations | ||
| hf_np = hf.cpu().numpy() | ||
| sg_np = sg.cpu().numpy() | ||
|
|
||
| # Statistical metrics | ||
| print("\n=== Statistical Metrics ===") | ||
| print(f"Mean absolute difference: {torch.mean(torch.abs(hf - sg)).item():.6f}") | ||
| print(f"Max absolute difference: {torch.max(torch.abs(hf - sg)).item():.6f}") | ||
| print(f"Mean squared error: {torch.mean((hf - sg) ** 2).item():.6f}") | ||
| print( | ||
| f"Root mean squared error: {torch.sqrt(torch.mean((hf - sg) ** 2)).item():.6f}" | ||
| ) | ||
|
|
||
| # Cosine similarity (across feature dimension) | ||
| cos_sim = F.cosine_similarity(hf, sg) | ||
| print(f"Mean cosine similarity: {torch.mean(cos_sim).item():.6f}") | ||
| print(f"Min cosine similarity: {torch.min(cos_sim).item():.6f}") | ||
|
|
||
| # Find largest absolute differences | ||
| print("\n=== Largest Absolute Differences ===") | ||
| diffs = torch.abs(hf - sg) | ||
| flat_diffs = diffs.flatten() | ||
|
|
||
| # Get indices of top 10 differences | ||
| top_k = 10 | ||
| top_values, top_flat_indices = torch.topk(flat_diffs, top_k) | ||
|
|
||
| # Convert flat indices to multidimensional indices | ||
| top_indices = np.unravel_index(top_flat_indices.cpu().numpy(), diffs.shape) | ||
|
|
||
| print(f"\nTop {top_k} largest absolute differences:") | ||
| print( | ||
| "Index".ljust(30) | ||
| + "Difference".ljust(15) | ||
| + "HF Value".ljust(15) | ||
| + "SGLang Value" | ||
| ) | ||
| print("-" * 75) | ||
|
|
||
| for i in range(top_k): | ||
| # Get the index tuple for this difference | ||
| idx = tuple(dim[i] for dim in top_indices) | ||
| diff_val = top_values[i].item() | ||
| hf_val = hf[idx].item() | ||
| sg_val = sg[idx].item() | ||
|
|
||
| # Format the index tuple and values | ||
| idx_str = str(idx) | ||
| print(f"{idx_str:<30}{diff_val:<15.6f}{hf_val:<15.6f}{sg_val:.6f}") | ||
|
|
||
| np.testing.assert_allclose(hf_np, sg_np) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove unused imports