Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 40 additions & 40 deletions python/sglang/srt/models/qwen2_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,12 @@
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat
from vllm.config import CacheConfig, MultiModalConfig
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove unused imports

from vllm.distributed import parallel_state
from vllm.distributed import utils as dist_utils
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import QuickGELU
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove unused imports


from sglang.srt.configs import Qwen2VLConfig, Qwen2VLVisionConfig
from sglang.srt.hf_transformers_utils import get_processor
Expand All @@ -52,7 +54,15 @@

logger = init_logger(__name__)


# === Vision Inputs === #
class OriginalQuickGELUActivation(nn.Module):
"""
Applies GELU approximation that is fast but somewhat inaccurate. See: https://github.com/hendrycks/GELUs
"""

def forward(self, input: torch.Tensor) -> torch.Tensor:
return input * torch.sigmoid(1.702 * input)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

try torch.compile to fuse them?



class Qwen2VLImageInputs(TypedDict):
Expand Down Expand Up @@ -91,7 +101,7 @@ def __init__(
self,
in_features: int,
hidden_features: int = None,
act_layer: Type[nn.Module] = QuickGELU,
act_layer: Type[nn.Module] = OriginalQuickGELUActivation,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
Expand Down Expand Up @@ -201,21 +211,30 @@ def forward(
q, k, v = dist_utils.split_tensor_along_last_dim(x, 3)
batch_size = q.shape[1]

q, k, v = [rearrange(x, "s b ... -> b s ...").contiguous() for x in (q, k, v)]
q, k, v = (rearrange(x, "s b ... -> b s ...").contiguous() for x in (q, k, v))
if rotary_pos_emb is not None:
q = apply_rotary_pos_emb_vision(q, rotary_pos_emb)
k = apply_rotary_pos_emb_vision(k, rotary_pos_emb)

seq_lens = cu_seqlens[1:] - cu_seqlens[:-1]
max_seqlen = (seq_lens).max().item()
q, k, v = [rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]]

output = torch.empty_like(q)
context_attention_fwd(
q, k, v, output, cu_seqlens, seq_lens, max_seqlen, is_causal=False
seq_length = q.size(1)
q, k, v = (rearrange(x, "b s h d -> b h s d") for x in [q, k, v])
attention_mask = torch.zeros(
[1, seq_length, seq_length], device=q.device, dtype=torch.bool
)
for i in range(1, len(cu_seqlens)):
attention_mask[
...,
cu_seqlens[i - 1] : cu_seqlens[i],
cu_seqlens[i - 1] : cu_seqlens[i],
] = True

q = q.squeeze(0)
k = k.squeeze(0)
v = v.squeeze(0)
output = F.scaled_dot_product_attention(q, k, v, attention_mask, dropout_p=0.0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think pytorch SDPA should also be fast, the probably is probably how you prepare the attention mask?
Can you vectorize the code more, use less Python for-loop, or write a triton kernel for it (see example), or catch the results so we can reuse it across layers?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, I have tried caching the attention mask, but it doesn't seem to impact performance much.

The issue I see is that for one layer, the context_attention_fwd kernel in sglang matches torch's scaled_dot_product_attention pretty closely, within 1e-2 for each activation. But, in qwen2vl, there are 32 layers, and after a while, the absolute difference accumulates higher, closer to +/- 1.0 max absolute difference in the activations.

output = output.unsqueeze(0)
context_layer = rearrange(output, "b h s d -> b s h d ")

context_layer = rearrange(output, "(b s) ... -> b s ...", b=batch_size)
context_layer = rearrange(context_layer, "b s h d -> s b (h d)").contiguous()

output, _ = self.proj(context_layer)
Expand All @@ -229,7 +248,7 @@ def __init__(
dim: int,
num_heads: int,
mlp_ratio: float,
act_layer: Type[nn.Module] = QuickGELU,
act_layer: Type[nn.Module] = OriginalQuickGELUActivation,
norm_layer: Type[nn.Module] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
Expand Down Expand Up @@ -328,38 +347,19 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:


class Qwen2VisionRotaryEmbedding(nn.Module):

def __init__(self, dim: int, theta: float = 10000.0) -> None:
super().__init__()
self.dim = dim
self.theta = theta
inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
self._seq_len_cached = 0
self._freqs_cached = None

def update_freqs_cache(self, seqlen: int) -> None:
if seqlen > self._seq_len_cached:
seqlen *= 2
self._seq_len_cached = seqlen
self.inv_freq = 1.0 / (
self.theta
** (
torch.arange(
0, self.dim, 2, dtype=torch.float, device=self.inv_freq.device
)
/ self.dim
)
)
seq = torch.arange(
seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype
)
freqs = torch.outer(seq, self.inv_freq)
self._freqs_cached = freqs
self.dim = dim

def forward(self, seqlen: int) -> torch.Tensor:
self.update_freqs_cache(seqlen)
return self._freqs_cached[:seqlen]
inv_freq = 1.0 / (
self.theta ** (torch.arange(0, self.dim, 2, dtype=torch.float) / self.dim)
)

seq = torch.arange(seqlen, device=inv_freq.device, dtype=inv_freq.dtype)
freqs = torch.outer(seq, inv_freq)
return freqs


class Qwen2VisionTransformer(nn.Module):
Expand Down Expand Up @@ -450,7 +450,7 @@ def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor:
pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
pos_ids = torch.cat(pos_ids, dim=0)
max_grid_size = grid_thw[:, 1:].max()
rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size).to(grid_thw.device)
rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
return rotary_pos_emb

Expand Down Expand Up @@ -499,7 +499,7 @@ def calculate_num_image_tokens(self, image_grid_thw: Tuple[int, int, int]):
return num_image_tokens

# Use grid_t * grid_w * grid_h to pad tokens for each image
# add replaced padding by unique image hash
# and replaced padding by unique image hash
def pad_input_ids(self, input_ids: List[int], image_inputs: ImageInputs):
image_grid_thws = image_inputs.image_grid_thws
pad_values = image_inputs.pad_values
Expand Down
159 changes: 159 additions & 0 deletions test/srt/test_qwen2vl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
import asyncio
import base64
import json
import math
import os
import tempfile
import unittest
from io import BytesIO
from pathlib import Path
from unittest.mock import AsyncMock, patch

import numpy as np
import requests
import torch
import torch.nn.functional as F
from PIL import Image
from transformers import AutoProcessor, AutoTokenizer, Qwen2VLForConditionalGeneration

from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.hf_transformers_utils import get_tokenizer
from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.sampling.sampling_params import SamplingParams
from sglang.srt.server_args import PortArgs, ServerArgs

QWEN2_VL_MODEL = "Qwen/Qwen2-VL-7B-Instruct"


class RawSGLangTest(unittest.IsolatedAsyncioTestCase):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks good. Maybe give it a better name.

Now we have a very good reference script for text-only models and a very good model support guide:
https://github.com/sgl-project/sglang/blob/main/scripts/playground/reference_hf.py
https://sgl-project.github.io/references/supported_models.html#how-to-support-a-new-model

Are you willing to help here to add some scripts/docs similar to the above ones, but for vision language models?

def setUp(self):
self.tokenizer = AutoTokenizer.from_pretrained(
QWEN2_VL_MODEL, trust_remote_code=True
)
self.image_token_id = self.tokenizer.encode("<|image_pad|>")[0]

self.model = Qwen2VLForConditionalGeneration.from_pretrained(
QWEN2_VL_MODEL, torch_dtype=torch.bfloat16, trust_remote_code=True
).eval()
self.processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model.to(self.device)

async def test_vision_encoder(self):
messages = [
{
"role": "user",
"content": [
{
"type": "image",
"image": "https://github.com/sgl-project/sglang/blob/main/test/lang/example_image.png?raw=true",
},
{"type": "text", "text": "Describe this image."},
],
}
]

# Apply chat template to get the text
text = self.processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)

response = requests.get(messages[0]["content"][0]["image"])
main_image = Image.open(BytesIO(response.content))

# Process inputs using processor
inputs = self.processor(
text=[text],
images=[main_image],
padding=True,
return_tensors="pt",
)

with torch.no_grad():
hf_output = self.model.visual(
inputs["pixel_values"].to(self.device),
grid_thw=inputs["image_grid_thw"].to(self.device),
)

model_config = ModelConfig(QWEN2_VL_MODEL, model_override_args="{}")

server_args = ServerArgs(model_path=QWEN2_VL_MODEL)
model_runner = ModelRunner(
model_config=model_config,
mem_fraction_static=0.8,
gpu_id=0,
tp_rank=0,
tp_size=1,
nccl_port=12435,
server_args=server_args,
)

with torch.no_grad():
sglang_output = model_runner.model.visual(
inputs["pixel_values"].to(self.device),
grid_thw=inputs["image_grid_thw"].to(self.device),
)

# Convert to float32 for numerical stability if needed
hf = hf_output.float()
sg = sglang_output.float()

# Basic shape and dtype comparison
print("\n=== Basic Properties ===")
print(f"Shapes match: {hf.shape == sg.shape}")
print(f"HF shape: {hf.shape}, SGLang shape: {sg.shape}")
print(f"HF dtype: {hf.dtype}, SGLang dtype: {sg.dtype}")

# Move tensors to CPU for numpy operations
hf_np = hf.cpu().numpy()
sg_np = sg.cpu().numpy()

# Statistical metrics
print("\n=== Statistical Metrics ===")
print(f"Mean absolute difference: {torch.mean(torch.abs(hf - sg)).item():.6f}")
print(f"Max absolute difference: {torch.max(torch.abs(hf - sg)).item():.6f}")
print(f"Mean squared error: {torch.mean((hf - sg) ** 2).item():.6f}")
print(
f"Root mean squared error: {torch.sqrt(torch.mean((hf - sg) ** 2)).item():.6f}"
)

# Cosine similarity (across feature dimension)
cos_sim = F.cosine_similarity(hf, sg)
print(f"Mean cosine similarity: {torch.mean(cos_sim).item():.6f}")
print(f"Min cosine similarity: {torch.min(cos_sim).item():.6f}")

# Find largest absolute differences
print("\n=== Largest Absolute Differences ===")
diffs = torch.abs(hf - sg)
flat_diffs = diffs.flatten()

# Get indices of top 10 differences
top_k = 10
top_values, top_flat_indices = torch.topk(flat_diffs, top_k)

# Convert flat indices to multidimensional indices
top_indices = np.unravel_index(top_flat_indices.cpu().numpy(), diffs.shape)

print(f"\nTop {top_k} largest absolute differences:")
print(
"Index".ljust(30)
+ "Difference".ljust(15)
+ "HF Value".ljust(15)
+ "SGLang Value"
)
print("-" * 75)

for i in range(top_k):
# Get the index tuple for this difference
idx = tuple(dim[i] for dim in top_indices)
diff_val = top_values[i].item()
hf_val = hf[idx].item()
sg_val = sg[idx].item()

# Format the index tuple and values
idx_str = str(idx)
print(f"{idx_str:<30}{diff_val:<15.6f}{hf_val:<15.6f}{sg_val:.6f}")

np.testing.assert_allclose(hf_np, sg_np)
Loading