Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions python/sglang/srt/layers/attention/vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,13 +486,12 @@ def __init__(
customized_position_embedding_applier: Callable[
[torch.Tensor, torch.Tensor, Any, Any], Tuple[torch.Tensor, torch.Tensor]
] = None,
use_data_parallel: bool = False,
**kwargs,
):
super().__init__()
attn_tp_rank = get_attention_tp_rank()
attn_tp_size = get_attention_tp_size()
self.tp_size = attn_tp_size
self.tp_rank = attn_tp_rank
self.tp_size = 1 if use_data_parallel else get_attention_tp_size()
self.tp_rank = 0 if use_data_parallel else get_attention_tp_rank()
self.dropout = dropout
self.head_size = embed_dim // num_heads
self.hidden_size_per_attention_head = dist_utils.divide(
Expand Down
46 changes: 44 additions & 2 deletions python/sglang/srt/models/qwen2_5_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,10 @@
Qwen2_5_VisionRotaryEmbedding,
)

from sglang.srt.distributed import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
from sglang.srt.distributed.parallel_state import get_pp_group
from sglang.srt.layers.attention.vision import VisionAttention
from sglang.srt.layers.layernorm import RMSNorm
Expand All @@ -62,6 +66,8 @@
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.qwen2 import Qwen2Model
from sglang.srt.models.utils import permute_inv
from sglang.srt.multimodal.mm_utils import run_dp_sharded_mrope_vision_model
from sglang.srt.server_args import get_global_server_args
from sglang.srt.utils import add_prefix

logger = logging.getLogger(__name__)
Expand All @@ -76,21 +82,30 @@ def __init__(
hidden_act="silu",
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
use_data_parallel: bool = False,
):
super().__init__()
self.tp_size = (
1 if use_data_parallel else get_tensor_model_parallel_world_size()
)
self.tp_rank = 0 if use_data_parallel else get_tensor_model_parallel_rank()
self.gate_up_proj = MergedColumnParallelLinear(
input_size=in_features,
output_sizes=[hidden_features] * 2, # [gate_proj, up_proj]
bias=bias,
quant_config=quant_config,
prefix=add_prefix("gate_up_proj", prefix),
tp_size=self.tp_size,
tp_rank=self.tp_rank,
)
self.down_proj = RowParallelLinear(
hidden_features,
in_features,
bias=bias,
quant_config=quant_config,
prefix=add_prefix("down_proj", prefix),
tp_size=self.tp_size,
tp_rank=self.tp_rank,
)
self.act = ACT2FN[hidden_act]

Expand All @@ -115,6 +130,7 @@ def __init__(
prefix: str = "",
num_dummy_heads: int = 0,
rms_norm_eps: float = 1e-6,
use_data_parallel: bool = False,
) -> None:
super().__init__()
self.norm1 = RMSNorm(dim, eps=rms_norm_eps)
Expand All @@ -130,13 +146,15 @@ def __init__(
quant_config=quant_config,
prefix=add_prefix("attn", prefix),
num_dummy_heads=num_dummy_heads,
use_data_parallel=use_data_parallel,
)
self.mlp = Qwen2_5_VLMLP(
dim,
intermediate_dim,
hidden_act=hidden_act,
quant_config=quant_config,
prefix=add_prefix("mlp", prefix),
use_data_parallel=use_data_parallel,
)

def forward(
Expand Down Expand Up @@ -180,10 +198,13 @@ def __init__(
spatial_merge_size: int = 2,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
use_data_parallel: bool = False,
) -> None:
super().__init__()
self.hidden_size = context_dim * (spatial_merge_size**2)
self.ln_q = RMSNorm(context_dim, eps=1e-6)
tp_size = 1 if use_data_parallel else get_tensor_model_parallel_world_size()
tp_rank = 0 if use_data_parallel else get_tensor_model_parallel_rank()
self.mlp = nn.ModuleList(
[
ColumnParallelLinear(
Expand All @@ -192,6 +213,8 @@ def __init__(
bias=True,
quant_config=quant_config,
prefix=add_prefix("mlp.0", prefix),
tp_size=tp_size,
tp_rank=tp_rank,
),
nn.GELU(),
RowParallelLinear(
Expand All @@ -200,6 +223,8 @@ def __init__(
bias=True,
quant_config=quant_config,
prefix=add_prefix("mlp.2", prefix),
tp_size=tp_size,
tp_rank=tp_rank,
),
]
)
Expand All @@ -225,6 +250,7 @@ def __init__(
norm_eps: float = 1e-6,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
use_data_parallel: bool = False,
) -> None:
super().__init__()

Expand All @@ -241,6 +267,8 @@ def __init__(
self.window_size = vision_config.window_size
self.patch_size = vision_config.patch_size
mlp_hidden_size: int = ((vision_config.intermediate_size + 7) // 8) * 8
self.use_data_parallel = use_data_parallel
self.out_hidden_size = vision_config.out_hidden_size
self.patch_embed = Qwen2_5_VisionPatchEmbed(
patch_size=patch_size,
temporal_patch_size=temporal_patch_size,
Expand All @@ -261,6 +289,7 @@ def __init__(
norm_layer=norm_layer,
quant_config=quant_config,
prefix=add_prefix(f"blocks.{i}", prefix),
use_data_parallel=use_data_parallel,
)
for i in range(depth)
]
Expand All @@ -271,6 +300,7 @@ def __init__(
spatial_merge_size=spatial_merge_size,
quant_config=quant_config,
prefix=add_prefix("merger", prefix),
use_data_parallel=use_data_parallel,
)

def get_window_index(self, grid_thw):
Expand Down Expand Up @@ -461,13 +491,15 @@ def __init__(

self.pp_group = get_pp_group()
self.config = config
self.use_data_parallel = get_global_server_args().mm_enable_dp_encoder
self.visual = Qwen2_5_VisionTransformer(
config.vision_config,
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
# NOTE: Qwen2_5-VL vision encoder currently supports BitsAndBytes 4-bit quantization.
# Other quantization methods (e.g., GPTQ, AWQ) are untested and may not be supported.
quant_config=quant_config,
prefix=add_prefix("visual", prefix),
use_data_parallel=self.use_data_parallel,
)

self.model = Qwen2Model(
Expand Down Expand Up @@ -510,7 +542,12 @@ def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
image_grid_thw = torch.concat([item.image_grid_thw for item in items], dim=0)
assert pixel_values.dim() == 2, pixel_values.dim()
assert image_grid_thw.dim() == 2, image_grid_thw.dim()
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
if self.use_data_parallel:
return run_dp_sharded_mrope_vision_model(
self.visual, pixel_values, image_grid_thw.tolist(), rope_type="rope_3d"
)
else:
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
return image_embeds

def get_video_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
Expand All @@ -521,7 +558,12 @@ def get_video_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
video_grid_thw = torch.concat([item.video_grid_thw for item in items], dim=0)
assert pixel_values.dim() == 2, pixel_values.dim()
assert video_grid_thw.dim() == 2, video_grid_thw.dim()
video_embeds = self.visual(pixel_values, grid_thw=video_grid_thw)
if self.use_data_parallel:
return run_dp_sharded_mrope_vision_model(
self.visual, pixel_values, video_grid_thw.tolist(), rope_type="rope_3d"
)
else:
video_embeds = self.visual(pixel_values, grid_thw=video_grid_thw)
return video_embeds

def get_input_embeddings(self):
Expand Down
Loading
Loading