diff --git a/python/sglang/srt/layers/conv.py b/python/sglang/srt/layers/conv.py new file mode 100644 index 000000000000..d2885f0efc89 --- /dev/null +++ b/python/sglang/srt/layers/conv.py @@ -0,0 +1,300 @@ +""" +Conv2d/Conv3d layers with unfold+linear optimization for patch embeddings. + +When kernel_size == stride, padding == 0, dilation == 1, groups == 1, the conv +is equivalent to unfold + F.linear, which is significantly faster on CUDA and +also avoids the PyTorch 2.9.1 + CuDNN < 9.15 Conv3d bug +(https://github.com/pytorch/pytorch/issues/168167). +""" + +import math +from typing import Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from sglang.srt.layers.utils.multi_platform import MultiPlatformOp + +_VALID_PADDING_STRINGS = {"same", "valid"} +_VALID_PADDING_MODES = {"zeros", "reflect", "replicate", "circular"} + + +def _tuplify(val, n: int) -> tuple: + if isinstance(val, (list, tuple)): + assert len(val) == n + return tuple(val) + return (val,) * n + + +def _check_enable_linear( + kernel_size: tuple, + stride: tuple, + padding: tuple, + dilation: tuple, + groups: int, +) -> bool: + """Check if conv can be replaced with unfold + F.linear.""" + return ( + kernel_size == stride + and all(p == 0 for p in padding) + and all(d == 1 for d in dilation) + and groups == 1 + ) + + +def _reverse_repeat_tuple(t: tuple) -> tuple: + """(1, 2, 3) -> (3, 3, 2, 2, 1, 1). Used for F.pad with non-zeros padding_mode.""" + return tuple(x for x in reversed(t) for _ in range(2)) + + +def _compute_same_padding_for_pad(kernel_size: tuple, dilation: tuple) -> tuple: + """Compute _reversed_padding_repeated_twice for padding='same'. + + This mirrors PyTorch's nn.Conv*d behavior: pre-compute the exact pad + amounts so that F.pad can be called before F.conv*d(padding=0). + """ + pad = [] + for k, d in zip(reversed(kernel_size), reversed(dilation)): + total = d * (k - 1) + pad.append(total // 2) + pad.append(total - total // 2) + return tuple(pad) + + +def _validate_conv_args( + in_channels: int, + out_channels: int, + groups: int, + padding, + padding_mode: str, + stride: tuple, +) -> None: + if in_channels % groups != 0: + raise ValueError( + f"in_channels ({in_channels}) must be divisible by groups ({groups})" + ) + if out_channels % groups != 0: + raise ValueError( + f"out_channels ({out_channels}) must be divisible by groups ({groups})" + ) + if padding_mode not in _VALID_PADDING_MODES: + raise ValueError( + f"padding_mode must be one of {_VALID_PADDING_MODES}, got '{padding_mode}'" + ) + if isinstance(padding, str): + if padding not in _VALID_PADDING_STRINGS: + raise ValueError( + f"padding must be one of {_VALID_PADDING_STRINGS}, got '{padding}'" + ) + if padding == "same" and any(s != 1 for s in stride): + raise ValueError("padding='same' is not supported for strided convolutions") + + +class Conv2dLayer(MultiPlatformOp): + """Drop-in replacement for nn.Conv2d. Linear optimization disabled by default.""" + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: Union[int, Tuple[int, int]], + stride: Union[int, Tuple[int, int]] = 1, + padding: Union[int, Tuple[int, int], str] = 0, + dilation: Union[int, Tuple[int, int]] = 1, + groups: int = 1, + bias: bool = True, + padding_mode: str = "zeros", + disable_linear: bool = True, + ): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = _tuplify(kernel_size, 2) + self.stride = _tuplify(stride, 2) + self.dilation = _tuplify(dilation, 2) + self.groups = groups + self.padding_mode = padding_mode + + _validate_conv_args( + in_channels, out_channels, groups, padding, padding_mode, self.stride + ) + + if isinstance(padding, str): + self.padding = (0, 0) if padding == "valid" else padding + else: + self.padding = _tuplify(padding, 2) + + # Pre-compute pad tuple for padding_mode != "zeros" (mirrors nn.Conv2d). + # When padding="same", we need numeric values for F.pad; + # when padding is already numeric, _reverse_repeat_tuple handles it. + if isinstance(self.padding, str): + self._reversed_padding_repeated_twice = _compute_same_padding_for_pad( + self.kernel_size, self.dilation + ) + else: + self._reversed_padding_repeated_twice = _reverse_repeat_tuple(self.padding) + + padding_tuple = self.padding if isinstance(self.padding, tuple) else (1, 1) + self.enable_linear = not disable_linear and _check_enable_linear( + self.kernel_size, self.stride, padding_tuple, self.dilation, groups + ) + + self.weight = nn.Parameter( + torch.empty(out_channels, in_channels // groups, *self.kernel_size) + ) + if bias: + self.bias = nn.Parameter(torch.empty(out_channels)) + else: + self.register_parameter("bias", None) + + self._reset_parameters() + + def _reset_parameters(self): + nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) + if self.bias is not None: + fan_in = nn.init._calculate_correct_fan(self.weight, "fan_in") + bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 + nn.init.uniform_(self.bias, -bound, bound) + + def _forward_mulmat(self, x: torch.Tensor) -> torch.Tensor: + K1, K2 = self.kernel_size + x = x.unfold(2, K1, K1).unfold(3, K2, K2) + N, _, Hp, Wp = x.shape[:4] + x = x.permute(0, 2, 3, 1, 4, 5).reshape(N, Hp, Wp, -1) + x = F.linear(x, self.weight.reshape(self.out_channels, -1), self.bias) + return x.permute(0, 3, 1, 2) + + def _forward_conv(self, x: torch.Tensor) -> torch.Tensor: + if self.padding_mode != "zeros": + return F.conv2d( + F.pad(x, self._reversed_padding_repeated_twice, mode=self.padding_mode), + self.weight, + self.bias, + self.stride, + (0, 0), + self.dilation, + self.groups, + ) + return F.conv2d( + x, + self.weight, + self.bias, + self.stride, + self.padding, + self.dilation, + self.groups, + ) + + def forward_native(self, x: torch.Tensor) -> torch.Tensor: + if self.enable_linear: + return self._forward_mulmat(x) + return self._forward_conv(x) + + def forward_cuda(self, x: torch.Tensor) -> torch.Tensor: + if self.enable_linear: + return self._forward_mulmat(x) + return self._forward_conv(x) + + +class Conv3dLayer(MultiPlatformOp): + """Drop-in replacement for nn.Conv3d with automatic linear optimization.""" + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: Union[int, Tuple[int, int, int]], + stride: Union[int, Tuple[int, int, int]] = 1, + padding: Union[int, Tuple[int, int, int], str] = 0, + dilation: Union[int, Tuple[int, int, int]] = 1, + groups: int = 1, + bias: bool = True, + padding_mode: str = "zeros", + disable_linear: bool = False, + ): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = _tuplify(kernel_size, 3) + self.stride = _tuplify(stride, 3) + self.dilation = _tuplify(dilation, 3) + self.groups = groups + self.padding_mode = padding_mode + + _validate_conv_args( + in_channels, out_channels, groups, padding, padding_mode, self.stride + ) + + if isinstance(padding, str): + self.padding = (0, 0, 0) if padding == "valid" else padding + else: + self.padding = _tuplify(padding, 3) + + if isinstance(self.padding, str): + self._reversed_padding_repeated_twice = _compute_same_padding_for_pad( + self.kernel_size, self.dilation + ) + else: + self._reversed_padding_repeated_twice = _reverse_repeat_tuple(self.padding) + + padding_tuple = self.padding if isinstance(self.padding, tuple) else (1, 1, 1) + self.enable_linear = not disable_linear and _check_enable_linear( + self.kernel_size, self.stride, padding_tuple, self.dilation, groups + ) + + self.weight = nn.Parameter( + torch.empty(out_channels, in_channels // groups, *self.kernel_size) + ) + if bias: + self.bias = nn.Parameter(torch.empty(out_channels)) + else: + self.register_parameter("bias", None) + + self._reset_parameters() + + def _reset_parameters(self): + nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) + if self.bias is not None: + fan_in = nn.init._calculate_correct_fan(self.weight, "fan_in") + bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 + nn.init.uniform_(self.bias, -bound, bound) + + def _forward_mulmat(self, x: torch.Tensor) -> torch.Tensor: + K1, K2, K3 = self.kernel_size + x = x.unfold(2, K1, K1).unfold(3, K2, K2).unfold(4, K3, K3) + N, Dp, Hp, Wp = x.shape[0], x.shape[2], x.shape[3], x.shape[4] + x = x.permute(0, 2, 3, 4, 1, 5, 6, 7).reshape(N, Dp, Hp, Wp, -1) + x = F.linear(x, self.weight.reshape(self.out_channels, -1), self.bias) + return x.permute(0, 4, 1, 2, 3) + + def _forward_conv(self, x: torch.Tensor) -> torch.Tensor: + if self.padding_mode != "zeros": + return F.conv3d( + F.pad(x, self._reversed_padding_repeated_twice, mode=self.padding_mode), + self.weight, + self.bias, + self.stride, + (0, 0, 0), + self.dilation, + self.groups, + ) + return F.conv3d( + x, + self.weight, + self.bias, + self.stride, + self.padding, + self.dilation, + self.groups, + ) + + def forward_native(self, x: torch.Tensor) -> torch.Tensor: + if self.enable_linear: + return self._forward_mulmat(x) + return self._forward_conv(x) + + def forward_cuda(self, x: torch.Tensor) -> torch.Tensor: + if self.enable_linear: + return self._forward_mulmat(x) + return self._forward_conv(x) diff --git a/python/sglang/srt/models/clip.py b/python/sglang/srt/models/clip.py index 9294e6f8807f..6aa7b792a70e 100644 --- a/python/sglang/srt/models/clip.py +++ b/python/sglang/srt/models/clip.py @@ -11,6 +11,7 @@ from sglang.srt.layers.activation import QuickGELU from sglang.srt.layers.attention.vision import VisionAttention +from sglang.srt.layers.conv import Conv2dLayer from sglang.srt.layers.linear import ColumnParallelLinear, RowParallelLinear from sglang.srt.layers.pooler import EmbeddingPoolerOutput, Pooler, PoolingType from sglang.srt.layers.quantization.base_config import QuantizationConfig @@ -32,7 +33,7 @@ def __init__(self, config: CLIPVisionConfig): self.class_embedding = nn.Parameter(torch.randn(self.embed_dim)) - self.patch_embedding = nn.Conv2d( + self.patch_embedding = Conv2dLayer( in_channels=config.num_channels, out_channels=self.embed_dim, kernel_size=self.patch_size, diff --git a/python/sglang/srt/models/dots_vlm_vit.py b/python/sglang/srt/models/dots_vlm_vit.py index 873994e0b769..caf6e38b1f50 100644 --- a/python/sglang/srt/models/dots_vlm_vit.py +++ b/python/sglang/srt/models/dots_vlm_vit.py @@ -11,6 +11,7 @@ from sglang.srt.configs.dots_vlm import DotsVisionConfig from sglang.srt.distributed import parallel_state from sglang.srt.layers.attention.vision import VisionAttention +from sglang.srt.layers.conv import Conv2dLayer from sglang.srt.layers.quantization import QuantizationConfig from sglang.srt.utils import add_prefix, is_npu @@ -113,7 +114,7 @@ def __init__(self, config, quant_config: Optional[QuantizationConfig] = None): self.temporal_patch_size = config.temporal_patch_size self.embed_dim = config.embed_dim self.config = config - self.proj = nn.Conv2d( + self.proj = Conv2dLayer( config.num_channels, config.embed_dim, kernel_size=(config.patch_size, config.patch_size), diff --git a/python/sglang/srt/models/glm4v.py b/python/sglang/srt/models/glm4v.py index 8941efeada84..7cfa1e71c1d7 100644 --- a/python/sglang/srt/models/glm4v.py +++ b/python/sglang/srt/models/glm4v.py @@ -35,6 +35,7 @@ from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.attention import vision_utils from sglang.srt.layers.attention.vision import VisionAttention +from sglang.srt.layers.conv import Conv3dLayer from sglang.srt.layers.layernorm import LayerNorm, RMSNorm from sglang.srt.layers.linear import ( MergedColumnParallelLinear, @@ -203,7 +204,7 @@ def __init__( self.in_channels = in_channels kernel_size = (temporal_patch_size, patch_size, patch_size) - self.proj = nn.Conv3d( + self.proj = Conv3dLayer( in_channels, hidden_size, kernel_size=kernel_size, @@ -211,26 +212,17 @@ def __init__( bias=True, ) - k = self.in_channels * self.temporal_patch_size * self.patch_size**2 - self.linear = nn.Linear( - in_features=k, - out_features=self.hidden_size, - bias=True, - dtype=self.proj.weight.dtype, - ) - - def copy_conv3d_weight_to_linear(self): - # Call this after weight loading - with torch.no_grad(): - self.linear.weight.copy_(self.proj.weight.view(self.hidden_size, -1)) - self.linear.bias.copy_(self.proj.bias) - del self.proj - def forward(self, x: torch.Tensor) -> torch.Tensor: - # After copy_conv3d_weight_to_linear(), self.linear exists and - # self.proj has been deleted. Input x is already 2-D: - # (num_patches, C * T * P * P) - return self.linear(x) + # Input x is 2-D: (num_patches, C * T * P * P) + # Reshape to 5-D for Conv3dLayer, then flatten back. + x = x.view( + -1, + self.in_channels, + self.temporal_patch_size, + self.patch_size, + self.patch_size, + ) + return self.proj(x).view(-1, self.hidden_size) class Glm4vPatchMerger(nn.Module): @@ -456,16 +448,10 @@ def __init__( @property def dtype(self) -> torch.dtype: - # After Conv3d to Linear conversion, self.proj is deleted and - # self.linear takes its place. - if hasattr(self.patch_embed, "linear"): - return self.patch_embed.linear.weight.dtype return self.patch_embed.proj.weight.dtype @property def device(self) -> torch.device: - if hasattr(self.patch_embed, "linear"): - return self.patch_embed.linear.weight.device return self.patch_embed.proj.weight.device def rot_pos_emb( @@ -815,7 +801,6 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): self.config, name, loaded_weight ) weight_loader(param, loaded_weight) - self.visual.patch_embed.copy_conv3d_weight_to_linear() def get_embed_and_head(self): return self.model.embed_tokens.weight, self.lm_head.weight diff --git a/python/sglang/srt/models/idefics2.py b/python/sglang/srt/models/idefics2.py index c16c86d1073a..7288cf4f3f9f 100644 --- a/python/sglang/srt/models/idefics2.py +++ b/python/sglang/srt/models/idefics2.py @@ -26,6 +26,7 @@ from sglang.srt.layers.activation import get_act_fn from sglang.srt.layers.attention.vision import VisionAttention +from sglang.srt.layers.conv import Conv2dLayer from sglang.srt.layers.linear import ColumnParallelLinear, RowParallelLinear from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.utils import add_prefix, is_npu @@ -193,7 +194,7 @@ def __init__(self, config: PretrainedConfig): self.embed_dim = config.hidden_size self.image_size = config.image_size self.patch_size = config.patch_size - self.patch_embedding = nn.Conv2d( + self.patch_embedding = Conv2dLayer( in_channels=config.num_channels, out_channels=self.embed_dim, kernel_size=self.patch_size, diff --git a/python/sglang/srt/models/internvl.py b/python/sglang/srt/models/internvl.py index e5d6a5b70a78..e5a71a37d767 100644 --- a/python/sglang/srt/models/internvl.py +++ b/python/sglang/srt/models/internvl.py @@ -17,6 +17,7 @@ from sglang.srt.layers.activation import get_act_fn from sglang.srt.layers.attention import vision_utils from sglang.srt.layers.attention.vision import SingletonCache, VisionAttention +from sglang.srt.layers.conv import Conv2dLayer from sglang.srt.layers.linear import ColumnParallelLinear, RowParallelLinear from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE from sglang.srt.layers.quantization.base_config import QuantizationConfig @@ -113,7 +114,7 @@ def __init__(self, config: PretrainedConfig): torch.randn(1, 1, self.embed_dim), ) - self.patch_embedding = nn.Conv2d( + self.patch_embedding = Conv2dLayer( in_channels=3, out_channels=self.embed_dim, kernel_size=self.patch_size, diff --git a/python/sglang/srt/models/kimi_k25.py b/python/sglang/srt/models/kimi_k25.py index 7825389fba28..9e7688719e3d 100644 --- a/python/sglang/srt/models/kimi_k25.py +++ b/python/sglang/srt/models/kimi_k25.py @@ -10,6 +10,7 @@ from sglang.srt.configs.kimi_k25 import KimiK25Config, KimiK25VisionConfig from sglang.srt.eplb.expert_location import ModelConfigForExpertLocation +from sglang.srt.layers.conv import Conv2dLayer from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.managers.mm_utils import ( MultiModalityDataPaddingPatternMultimodalTokens, @@ -401,7 +402,7 @@ def __init__( ), f"Expected patch_size to be a tuple of 2, got {patch_size}" self.patch_size = patch_size - self.proj = nn.Conv2d( + self.proj = Conv2dLayer( in_dim, out_dim, kernel_size=patch_size, stride=patch_size ) diff --git a/python/sglang/srt/models/kimi_vl_moonvit.py b/python/sglang/srt/models/kimi_vl_moonvit.py index 9a0e5c4059f6..e41f85695233 100644 --- a/python/sglang/srt/models/kimi_vl_moonvit.py +++ b/python/sglang/srt/models/kimi_vl_moonvit.py @@ -58,6 +58,7 @@ flash_attn_varlen_func = None from sglang.srt.configs import MoonViTConfig +from sglang.srt.layers.conv import Conv2dLayer from sglang.srt.layers.linear import ReplicatedLinear from sglang.srt.layers.quantization import QuantizationConfig from sglang.srt.layers.quantization.modelslim.modelslim import ModelSlimConfig @@ -250,7 +251,7 @@ def __init__( ), f"Expected patch_size to be a tuple of 2, got {patch_size}" self.patch_size = patch_size - self.proj = nn.Conv2d( + self.proj = Conv2dLayer( in_dim, out_dim, kernel_size=patch_size, stride=patch_size ) diff --git a/python/sglang/srt/models/midashenglm.py b/python/sglang/srt/models/midashenglm.py index 2698fd724edc..bc758a2c3086 100644 --- a/python/sglang/srt/models/midashenglm.py +++ b/python/sglang/srt/models/midashenglm.py @@ -10,6 +10,7 @@ from transformers import PretrainedConfig from sglang.srt.layers.attention.vision import VisionAttention +from sglang.srt.layers.conv import Conv2dLayer from sglang.srt.layers.linear import ColumnParallelLinear, RowParallelLinear from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.managers.mm_utils import ( @@ -79,7 +80,7 @@ def __init__( ) self.num_patches = self.grid_size[0] * self.grid_size[1] self.flatten = flatten - self.proj = nn.Conv2d( + self.proj = Conv2dLayer( in_chans, embed_dim, kernel_size=self.patch_size, diff --git a/python/sglang/srt/models/paddleocr_vl.py b/python/sglang/srt/models/paddleocr_vl.py index 456fb19ab378..53cb8d741ca0 100644 --- a/python/sglang/srt/models/paddleocr_vl.py +++ b/python/sglang/srt/models/paddleocr_vl.py @@ -26,6 +26,7 @@ from sglang.srt.layers.activation import get_act_fn from sglang.srt.layers.attention.vision import VisionAttention +from sglang.srt.layers.conv import Conv2dLayer from sglang.srt.layers.linear import ColumnParallelLinear, RowParallelLinear from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.managers.mm_utils import ( @@ -113,7 +114,7 @@ def __init__(self, config): self.image_size = config.image_size self.patch_size = config.patch_size - self.patch_embedding = nn.Conv2d( + self.patch_embedding = Conv2dLayer( in_channels=config.num_channels, out_channels=self.embed_dim, kernel_size=self.patch_size, diff --git a/python/sglang/srt/models/pixtral.py b/python/sglang/srt/models/pixtral.py index 265801901421..2ce96da00c97 100644 --- a/python/sglang/srt/models/pixtral.py +++ b/python/sglang/srt/models/pixtral.py @@ -35,6 +35,7 @@ from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.attention.vision import VisionAttention +from sglang.srt.layers.conv import Conv2dLayer from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.linear import MergedColumnParallelLinear, RowParallelLinear from sglang.srt.layers.quantization.base_config import QuantizationConfig @@ -328,7 +329,7 @@ class VisionTransformer(nn.Module): def __init__(self, args: VisionEncoderArgs): super().__init__() self.args = args - self.patch_conv = nn.Conv2d( + self.patch_conv = Conv2dLayer( in_channels=args.num_channels, out_channels=args.hidden_size, kernel_size=args.patch_size, @@ -850,7 +851,7 @@ def __init__( self.image_size = config.image_size self.patch_size = config.patch_size - self.patch_conv = nn.Conv2d( + self.patch_conv = Conv2dLayer( in_channels=config.num_channels, out_channels=config.hidden_size, kernel_size=config.patch_size, diff --git a/python/sglang/srt/models/qwen2_vl.py b/python/sglang/srt/models/qwen2_vl.py index 94e6a48bd76a..3bf0367102d3 100644 --- a/python/sglang/srt/models/qwen2_vl.py +++ b/python/sglang/srt/models/qwen2_vl.py @@ -35,6 +35,7 @@ from sglang.srt.layers.activation import QuickGELU from sglang.srt.layers.attention.vision import VisionAttention +from sglang.srt.layers.conv import Conv3dLayer from sglang.srt.layers.linear import ColumnParallelLinear, RowParallelLinear from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.pooler import Pooler, PoolingType @@ -190,7 +191,7 @@ def __init__( self.embed_dim = embed_dim kernel_size = [temporal_patch_size, patch_size, patch_size] - self.proj = nn.Conv3d( + self.proj = Conv3dLayer( in_chans, embed_dim, kernel_size=kernel_size, stride=kernel_size, bias=False ) diff --git a/python/sglang/srt/models/qwen3_vl.py b/python/sglang/srt/models/qwen3_vl.py index 38a4767b0d7f..2d0298d8017b 100644 --- a/python/sglang/srt/models/qwen3_vl.py +++ b/python/sglang/srt/models/qwen3_vl.py @@ -37,6 +37,7 @@ FLASHINFER_WORKSPACE_SIZE_BYTES, VisionAttention, ) +from sglang.srt.layers.conv import Conv3dLayer from sglang.srt.layers.dp_attention import ( get_attention_tp_rank, get_attention_tp_size, @@ -139,7 +140,7 @@ def __init__(self, config) -> None: self.embed_dim = config.hidden_size kernel_size = [self.temporal_patch_size, self.patch_size, self.patch_size] - self.proj = nn.Conv3d( + self.proj = Conv3dLayer( self.in_channels, self.embed_dim, kernel_size=kernel_size, diff --git a/python/sglang/srt/models/siglip.py b/python/sglang/srt/models/siglip.py index 34afe07f8e4d..989fcd9fc722 100644 --- a/python/sglang/srt/models/siglip.py +++ b/python/sglang/srt/models/siglip.py @@ -10,6 +10,7 @@ from sglang.srt.layers.activation import QuickGELU from sglang.srt.layers.attention.vision import VisionAttention +from sglang.srt.layers.conv import Conv2dLayer from sglang.srt.layers.linear import ColumnParallelLinear, RowParallelLinear from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding @@ -26,7 +27,7 @@ def __init__(self, config: SiglipVisionConfig): self.image_size = config.image_size self.patch_size = config.patch_size - self.patch_embedding = nn.Conv2d( + self.patch_embedding = Conv2dLayer( in_channels=config.num_channels, out_channels=self.embed_dim, kernel_size=self.patch_size, diff --git a/python/sglang/srt/models/step3_vl.py b/python/sglang/srt/models/step3_vl.py index 5ac9528f94dd..3ab2354637e8 100644 --- a/python/sglang/srt/models/step3_vl.py +++ b/python/sglang/srt/models/step3_vl.py @@ -24,6 +24,7 @@ from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.attention.vision import VisionAttention from sglang.srt.layers.communicator import LayerCommunicator, LayerScatterModes +from sglang.srt.layers.conv import Conv2dLayer from sglang.srt.layers.dp_attention import ( get_attention_tp_rank, get_attention_tp_size, @@ -616,7 +617,7 @@ def __init__(self, config: Step3VisionEncoderConfig): self.class_embedding = nn.Parameter(torch.randn(1, self.embed_dim)) - self.patch_embedding = nn.Conv2d( + self.patch_embedding = Conv2dLayer( in_channels=config.num_channels, out_channels=self.embed_dim, kernel_size=self.patch_size, diff --git a/python/sglang/srt/models/step3_vl_10b.py b/python/sglang/srt/models/step3_vl_10b.py index c043191ce9ef..474f377315a9 100644 --- a/python/sglang/srt/models/step3_vl_10b.py +++ b/python/sglang/srt/models/step3_vl_10b.py @@ -13,6 +13,7 @@ from sglang.srt.configs.step3_vl import Step3VLConfig from sglang.srt.layers.attention.vision import VisionAttention +from sglang.srt.layers.conv import Conv2dLayer from sglang.srt.layers.linear import ColumnParallelLinear, RowParallelLinear from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.managers.mm_utils import ( @@ -316,7 +317,7 @@ def __init__( raise ValueError("use_rope2d must be True") self.image_size = config.image_size - self.conv1 = nn.Conv2d( + self.conv1 = Conv2dLayer( in_channels=3, out_channels=config.width, kernel_size=config.patch_size, diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 6bfa8ecb30ea..aaabe4749046 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -5666,11 +5666,6 @@ def check_server_args(self): # Check LoRA self.check_lora_server_args() - # torch 2.9.1 has compatibility issues with cuDNN 9.14 and below, - # causing extremely slow nn.Conv3d performance. - # TODO(yhyang201): Remove this check when sglang no longer uses torch 2.9.1. - self.check_torch_2_9_1_cudnn_compatibility() - # Check speculative decoding if self.speculative_algorithm is not None: assert ( @@ -5779,49 +5774,6 @@ def check_server_args(self): "When enabling two batch overlap, moe_a2a_backend cannot be 'none'." ) - def check_torch_2_9_1_cudnn_compatibility(self): - if get_bool_env_var("SGLANG_DISABLE_CUDNN_CHECK"): - return - - if self.get_model_config().is_multimodal: - import torch - - if torch_release[:3] == (2, 9, 1): - cudnn_version = None - try: - cudnn_version = torch.backends.cudnn.version() - except Exception: - cudnn_version = None - if cudnn_version is not None: - version_float = float(str(cudnn_version)[:3]) / 100 - if version_float < 9.15: - RED = "\033[91m" - BOLD = "\033[1m" - RESET = "\033[0m" - msg = ( - f"{RED}{BOLD}" - "CRITICAL WARNING: PyTorch 2.9.1 & CuDNN Compatibility Issue Detected\n" - "--------------------------------------------------------------------------------\n" - f"Current Environment: PyTorch {torch.__version__} | CuDNN {version_float:.2f}\n\n" - "Issue: There is a KNOWN BUG in PyTorch 2.9.1's `nn.Conv3d` implementation\n" - " when used with CuDNN versions older than 9.15. This can cause\n" - " SEVERE PERFORMANCE DEGRADATION and EXCESSIVE MEMORY USAGE.\n\n" - "Reference: https://github.com/pytorch/pytorch/issues/168167\n\n" - "Solution: You MUST upgrade CuDNN to version 9.15+ to ensure correctness.\n\n" - "Run the following command immediately to fix:\n" - " pip install nvidia-cudnn-cu12==9.16.0.29\n\n" - "Or you can disable this check by setting env var SGLANG_DISABLE_CUDNN_CHECK=1\n" - "--------------------------------------------------------------------------------\n" - f"{RESET}" - ) - raise RuntimeError(msg) - else: - RED = "\033[91m" - RESET = "\033[0m" - logger.warning( - f"{RED}WARNING: Could not determine CuDNN version for torch==2.9.1. Please ensure CuDNN >= 9.15 to avoid nn.Conv3d bugs.{RESET}" - ) - def check_lora_server_args(self): assert self.max_loras_per_batch > 0, "max_loras_per_batch must be positive" diff --git a/test/unit/test_conv_layer.py b/test/unit/test_conv_layer.py new file mode 100644 index 000000000000..ff4b80bda4b3 --- /dev/null +++ b/test/unit/test_conv_layer.py @@ -0,0 +1,363 @@ +import unittest + +import torch +import torch.nn as nn + +from sglang.srt.layers.conv import Conv2dLayer, Conv3dLayer + + +def _copy_weights(src, dst_nn): + """Copy weights from Conv*dLayer to nn.Conv*d for comparison.""" + with torch.no_grad(): + dst_nn.weight.copy_(src.weight) + if src.bias is not None: + dst_nn.bias.copy_(src.bias) + + +class TestConv2dLayer(unittest.TestCase): + + def test_basic_patch_embedding(self): + layer = Conv2dLayer(3, 768, kernel_size=14, stride=14, bias=False) + ref = nn.Conv2d(3, 768, kernel_size=14, stride=14, bias=False) + self.assertFalse(layer.enable_linear) + _copy_weights(layer, ref) + x = torch.randn(2, 3, 224, 224) + with torch.no_grad(): + torch.testing.assert_close( + layer.forward_native(x), ref(x), rtol=1e-4, atol=1e-4 + ) + + def test_enable_linear(self): + layer = Conv2dLayer( + 3, 768, kernel_size=14, stride=14, bias=True, disable_linear=False + ) + ref = nn.Conv2d(3, 768, kernel_size=14, stride=14, bias=True) + self.assertTrue(layer.enable_linear) + _copy_weights(layer, ref) + x = torch.randn(1, 3, 224, 224) + with torch.no_grad(): + torch.testing.assert_close( + layer.forward_native(x), ref(x), rtol=1e-4, atol=1e-4 + ) + + def test_padding_valid(self): + layer = Conv2dLayer(3, 768, kernel_size=14, stride=14, padding="valid") + self.assertFalse(layer.enable_linear) + self.assertEqual(layer.padding, (0, 0)) + + def test_padding_same_disables_linear(self): + layer = Conv2dLayer(3, 64, kernel_size=3, stride=1, padding="same") + self.assertFalse(layer.enable_linear) + + def test_non_matching_stride_disables_linear(self): + layer = Conv2dLayer(3, 64, kernel_size=3, stride=1, padding=1) + self.assertFalse(layer.enable_linear) + + def test_groups_disable_linear(self): + layer = Conv2dLayer(4, 8, kernel_size=2, stride=2, groups=2) + self.assertFalse(layer.enable_linear) + + def test_default_disables_linear(self): + layer = Conv2dLayer(3, 768, kernel_size=14, stride=14) + self.assertFalse(layer.enable_linear) + + def test_dilation_disables_linear(self): + layer = Conv2dLayer(3, 64, kernel_size=3, stride=3, dilation=2) + self.assertFalse(layer.enable_linear) + + def test_padding_mode_reflect(self): + layer = Conv2dLayer( + 3, 64, kernel_size=3, stride=1, padding=1, padding_mode="reflect", bias=True + ) + ref = nn.Conv2d( + 3, 64, kernel_size=3, stride=1, padding=1, padding_mode="reflect", bias=True + ) + self.assertFalse(layer.enable_linear) + _copy_weights(layer, ref) + x = torch.randn(1, 3, 16, 16) + with torch.no_grad(): + torch.testing.assert_close( + layer.forward_native(x), ref(x), rtol=1e-4, atol=1e-4 + ) + + def test_conv_path_with_padding(self): + layer = Conv2dLayer(3, 64, kernel_size=3, stride=1, padding=1, bias=True) + ref = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=True) + _copy_weights(layer, ref) + x = torch.randn(1, 3, 32, 32) + with torch.no_grad(): + torch.testing.assert_close( + layer.forward_native(x), ref(x), rtol=1e-4, atol=1e-4 + ) + + def test_mulmat_matches_conv(self): + layer = Conv2dLayer( + 3, 768, kernel_size=14, stride=14, bias=True, disable_linear=False + ) + self.assertTrue(layer.enable_linear) + x = torch.randn(2, 3, 224, 224) + with torch.no_grad(): + torch.testing.assert_close( + layer._forward_mulmat(x), + layer._forward_conv(x), + rtol=1e-4, + atol=1e-4, + ) + + def test_forward_cuda_uses_mulmat_when_enabled(self): + layer = Conv2dLayer( + 3, 64, kernel_size=4, stride=4, bias=False, disable_linear=False + ) + self.assertTrue(layer.enable_linear) + x = torch.randn(1, 3, 16, 16) + with torch.no_grad(): + torch.testing.assert_close(layer.forward_cuda(x), layer._forward_mulmat(x)) + + def test_forward_cuda_uses_conv_when_not_eligible(self): + layer = Conv2dLayer(3, 64, kernel_size=3, stride=1, padding=1, bias=False) + self.assertFalse(layer.enable_linear) + x = torch.randn(1, 3, 16, 16) + with torch.no_grad(): + torch.testing.assert_close(layer.forward_cuda(x), layer._forward_conv(x)) + + def test_tuple_kernel_size(self): + layer = Conv2dLayer( + 3, + 768, + kernel_size=(14, 14), + stride=(14, 14), + bias=False, + disable_linear=False, + ) + self.assertTrue(layer.enable_linear) + ref = nn.Conv2d(3, 768, kernel_size=(14, 14), stride=(14, 14), bias=False) + _copy_weights(layer, ref) + x = torch.randn(1, 3, 224, 224) + with torch.no_grad(): + torch.testing.assert_close( + layer.forward_native(x), ref(x), rtol=1e-4, atol=1e-4 + ) + + def test_output_shape(self): + layer = Conv2dLayer(3, 768, kernel_size=16, stride=16, bias=False) + x = torch.randn(4, 3, 224, 224) + out = layer.forward_native(x) + self.assertEqual(out.shape, (4, 768, 14, 14)) + + def test_no_bias_parameter(self): + layer = Conv2dLayer(3, 64, kernel_size=4, stride=4, bias=False) + self.assertIsNone(layer.bias) + + +class TestConvValidation(unittest.TestCase): + + def test_in_channels_not_divisible_by_groups(self): + with self.assertRaises(ValueError): + Conv2dLayer(3, 64, kernel_size=3, stride=1, groups=2) + + def test_out_channels_not_divisible_by_groups(self): + with self.assertRaises(ValueError): + Conv2dLayer(4, 6, kernel_size=3, stride=1, groups=4) + + def test_invalid_padding_string(self): + with self.assertRaises(ValueError): + Conv2dLayer(3, 64, kernel_size=3, stride=1, padding="full") + + def test_padding_same_with_stride(self): + with self.assertRaises(ValueError): + Conv2dLayer(3, 64, kernel_size=3, stride=2, padding="same") + + def test_padding_same_with_non_zeros_padding_mode(self): + layer = Conv2dLayer( + 3, + 64, + kernel_size=3, + stride=1, + padding="same", + padding_mode="reflect", + bias=True, + ) + ref = nn.Conv2d( + 3, + 64, + kernel_size=3, + stride=1, + padding="same", + padding_mode="reflect", + bias=True, + ) + self.assertFalse(layer.enable_linear) + _copy_weights(layer, ref) + x = torch.randn(1, 3, 16, 16) + with torch.no_grad(): + torch.testing.assert_close( + layer.forward_native(x), ref(x), rtol=1e-4, atol=1e-4 + ) + + def test_invalid_padding_mode(self): + with self.assertRaises(ValueError): + Conv3dLayer(3, 64, kernel_size=3, stride=1, padding_mode="invalid") + + def test_conv3d_in_channels_not_divisible_by_groups(self): + with self.assertRaises(ValueError): + Conv3dLayer(3, 64, kernel_size=3, stride=1, groups=2) + + +class TestConv3dLayer(unittest.TestCase): + + def test_basic_temporal_patch_embedding(self): + layer = Conv3dLayer( + 3, 1152, kernel_size=[2, 14, 14], stride=[2, 14, 14], bias=False + ) + ref = nn.Conv3d( + 3, 1152, kernel_size=[2, 14, 14], stride=[2, 14, 14], bias=False + ) + self.assertTrue(layer.enable_linear) + _copy_weights(layer, ref) + x = torch.randn(1, 3, 2, 14, 14) + with torch.no_grad(): + torch.testing.assert_close( + layer.forward_native(x), ref(x), rtol=1e-4, atol=1e-4 + ) + + def test_with_bias(self): + layer = Conv3dLayer( + 3, 1536, kernel_size=[2, 14, 14], stride=[2, 14, 14], bias=True + ) + ref = nn.Conv3d(3, 1536, kernel_size=[2, 14, 14], stride=[2, 14, 14], bias=True) + self.assertTrue(layer.enable_linear) + _copy_weights(layer, ref) + x = torch.randn(4, 3, 2, 14, 14) + with torch.no_grad(): + torch.testing.assert_close( + layer.forward_native(x), ref(x), rtol=1e-4, atol=1e-4 + ) + + def test_mulmat_matches_conv(self): + layer = Conv3dLayer( + 3, 1152, kernel_size=[2, 14, 14], stride=[2, 14, 14], bias=True + ) + self.assertTrue(layer.enable_linear) + x = torch.randn(2, 3, 2, 14, 14) + with torch.no_grad(): + torch.testing.assert_close( + layer._forward_mulmat(x), + layer._forward_conv(x), + rtol=1e-4, + atol=1e-4, + ) + + def test_non_matching_stride_disables_linear(self): + layer = Conv3dLayer(3, 64, kernel_size=3, stride=1, padding=1) + self.assertFalse(layer.enable_linear) + + def test_dilation_disables_linear(self): + layer = Conv3dLayer(3, 64, kernel_size=3, stride=3, dilation=2) + self.assertFalse(layer.enable_linear) + + def test_disable_linear(self): + layer = Conv3dLayer( + 3, + 1152, + kernel_size=[2, 14, 14], + stride=[2, 14, 14], + bias=False, + disable_linear=True, + ) + self.assertFalse(layer.enable_linear) + ref = nn.Conv3d( + 3, 1152, kernel_size=[2, 14, 14], stride=[2, 14, 14], bias=False + ) + _copy_weights(layer, ref) + x = torch.randn(1, 3, 2, 14, 14) + with torch.no_grad(): + torch.testing.assert_close( + layer.forward_native(x), ref(x), rtol=1e-4, atol=1e-4 + ) + + def test_conv_path_with_padding(self): + layer = Conv3dLayer(3, 64, kernel_size=3, stride=1, padding=1, bias=True) + ref = nn.Conv3d(3, 64, kernel_size=3, stride=1, padding=1, bias=True) + _copy_weights(layer, ref) + x = torch.randn(1, 3, 4, 8, 8) + with torch.no_grad(): + torch.testing.assert_close( + layer.forward_native(x), ref(x), rtol=1e-4, atol=1e-4 + ) + + def test_output_shape(self): + layer = Conv3dLayer( + 3, 1152, kernel_size=[2, 14, 14], stride=[2, 14, 14], bias=False + ) + x = torch.randn(1, 3, 2, 14, 14) + out = layer.forward_native(x) + self.assertEqual(out.shape, (1, 1152, 1, 1, 1)) + + def test_batch_processing(self): + layer = Conv3dLayer( + 3, 1536, kernel_size=[2, 14, 14], stride=[2, 14, 14], bias=True + ) + ref = nn.Conv3d(3, 1536, kernel_size=[2, 14, 14], stride=[2, 14, 14], bias=True) + _copy_weights(layer, ref) + x = torch.randn(8, 3, 2, 14, 14) + with torch.no_grad(): + torch.testing.assert_close( + layer.forward_native(x), ref(x), rtol=1e-4, atol=1e-4 + ) + + def test_forward_native_uses_mulmat_when_eligible(self): + layer = Conv3dLayer(3, 128, kernel_size=[2, 4, 4], stride=[2, 4, 4], bias=True) + self.assertTrue(layer.enable_linear) + x = torch.randn(1, 3, 2, 4, 4) + with torch.no_grad(): + torch.testing.assert_close( + layer.forward_native(x), layer._forward_mulmat(x) + ) + + def test_padding_valid(self): + layer = Conv3dLayer( + 3, 64, kernel_size=[2, 4, 4], stride=[2, 4, 4], padding="valid" + ) + self.assertTrue(layer.enable_linear) + self.assertEqual(layer.padding, (0, 0, 0)) + + def test_weight_shape(self): + layer = Conv3dLayer( + 3, 1152, kernel_size=[2, 14, 14], stride=[2, 14, 14], bias=False + ) + self.assertEqual(layer.weight.shape, (1152, 3, 2, 14, 14)) + + def test_glm4v_workflow(self): + """GLM4V-style: 2D input -> reshape to 5D -> Conv3dLayer -> flatten.""" + in_channels, temporal_patch_size, patch_size = 3, 2, 14 + hidden_size = 1536 + layer = Conv3dLayer( + in_channels, + hidden_size, + kernel_size=[temporal_patch_size, patch_size, patch_size], + stride=[temporal_patch_size, patch_size, patch_size], + bias=True, + ) + ref = nn.Conv3d( + in_channels, + hidden_size, + kernel_size=[temporal_patch_size, patch_size, patch_size], + stride=[temporal_patch_size, patch_size, patch_size], + bias=True, + ) + _copy_weights(layer, ref) + num_patches = 4 + flat_dim = in_channels * temporal_patch_size * patch_size * patch_size + x_2d = torch.randn(num_patches, flat_dim) + x_5d = x_2d.view(-1, in_channels, temporal_patch_size, patch_size, patch_size) + with torch.no_grad(): + torch.testing.assert_close( + layer.forward_native(x_5d).view(-1, hidden_size), + ref(x_5d).view(-1, hidden_size), + rtol=1e-4, + atol=1e-4, + ) + + +if __name__ == "__main__": + unittest.main()