Skip to content

[Perf] replace Conv3d with Linear in VisionPatchEmbed#132

Merged
FrankLeeeee merged 1 commit intomainfrom
optimze/conv3d
Mar 6, 2026
Merged

[Perf] replace Conv3d with Linear in VisionPatchEmbed#132
FrankLeeeee merged 1 commit intomainfrom
optimze/conv3d

Conversation

@zhou9402
Copy link
Copy Markdown
Collaborator

@zhou9402 zhou9402 commented Mar 6, 2026

Conv3d with kernel_size == stride does not slide over the input, making it equivalent to a reshape + linear projection. Replace it with nn.Linear after loading HF checkpoint weights for significantly faster patch embedding, especially on older cuDNN versions where Conv3d with large non-sliding kernels is extremely slow.

Safety checks ensure the optimization only applies when kernel_size == stride, padding == 0, dilation == 1, and groups == 1.

Reference: sgl-project/sglang#19788

Benchmark in H20, linear vs conv3d:

  size   patches   Conv3d/ms   Linear/ms   Speedup
------------------------------------------------------
 32x32           4       0.211       0.019     11.3x
 64x64          16       0.216       0.020     10.9x
 96x96          36       0.231       0.018     12.5x
128x128         64       0.253       0.019     13.4x
160x160        100       0.253       0.019     13.5x
192x192        144       0.254       0.019     13.6x
224x224        196       0.254       0.019     13.7x
256x256        256       0.254       0.019     13.7x
288x288        324       0.254       0.019     13.5x
320x320        400       0.255       0.019     13.8x
352x352        484       0.256       0.019     13.8x
384x384        576       0.362       0.020     18.2x
416x416        676       0.365       0.023     15.9x
448x448        784       0.368       0.025     15.0x
480x480        900       0.368       0.027     13.5x
512x512       1024       0.368       0.031     11.7x
544x544       1156       0.490       0.037     13.2x
576x576       1296       0.492       0.040     12.2x
608x608       1444       0.492       0.040     12.2x
640x640       1600       0.493       0.046     10.8x
672x672       1764       0.625       0.051     12.3x
704x704       1936       0.628       0.056     11.2x
736x736       2116       0.628       0.061     10.2x
768x768       2304       0.860       0.067     12.9x
800x800       2500       1.146       0.072     15.9x
832x832       2704       1.186       0.072     16.5x
864x864       2916       1.106       0.081     13.7x
896x896       3136       1.101       0.088     12.5x
928x928       3364       1.103       0.091     12.1x
960x960       3600       1.210       0.099     12.2x
992x992       3844       1.213       0.106     11.5x
1024x1024      4096       1.248       0.110     11.4x

Accuracy:

Test with uniform distribution:

device  dtype       patches     max_abs    mean_abs
----------------------------------------------------
cuda    bfloat16          1    1.56e-02    6.68e-04
cuda    bfloat16         10    1.56e-02    6.29e-04
cuda    bfloat16         50    1.56e-02    6.52e-04
cuda    bfloat16        100    1.56e-02    6.48e-04
cuda    bfloat16        500    1.56e-02    6.48e-04
cuda    bfloat16       1000    1.56e-02    6.48e-04
cuda    bfloat16       2000    1.56e-02    6.48e-04
cuda    bfloat16       4000    1.56e-02    6.48e-04
cuda    bfloat16       6000    1.56e-02    6.48e-04
cuda    bfloat16       8000    1.56e-02    6.52e-04
cuda    bfloat16      12000    1.56e-02    6.48e-04
cuda    bfloat16      16000    1.56e-02    6.52e-04
cuda    bfloat16      20000    1.56e-02    6.48e-04
cuda    bfloat16      24168    1.56e-02    6.52e-04
cuda    float16           1    1.95e-03    7.70e-05
cuda    float16          10    1.95e-03    7.74e-05
cuda    float16          50    1.95e-03    7.99e-05
cuda    float16         100    1.95e-03    7.95e-05
cuda    float16         500    1.95e-03    7.96e-05
cuda    float16        1000    1.95e-03    7.98e-05
cuda    float16        2000    1.95e-03    7.97e-05
cuda    float16        4000    1.95e-03    7.97e-05
cuda    float16        6000    1.95e-03    7.96e-05
cuda    float16        8000    1.95e-03    7.95e-05
cuda    float16       12000    1.95e-03    7.96e-05
cuda    float16       16000    1.95e-03    7.96e-05
cuda    float16       20000    1.95e-03    7.96e-05
cuda    float16       24168    1.95e-03    7.96e-05

Test with 24168 real image patches from image processor

dtype          max_abs    mean_abs    match%
--------------------------------------------
float32       7.73e-04    8.28e-05    100.0%
bfloat16      1.56e-02    4.46e-04    100.0%
float16       1.95e-03    5.48e-05    100.0%

Conv3d with kernel_size == stride does not slide over the input, making
it equivalent to a reshape + linear projection. Replace it with
nn.Linear after loading HF checkpoint weights for significantly faster
patch embedding, especially on older cuDNN versions where Conv3d with
large non-sliding kernels is extremely slow.

Safety checks ensure the optimization only applies when kernel_size ==
stride, padding == 0, dilation == 1, and groups == 1.

Reference: sgl-project/sglang#19788
Made-with: Cursor
@gemini-code-assist
Copy link
Copy Markdown

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances the performance of the VisionPatchEmbed component by intelligently replacing Conv3d layers with nn.Linear layers where their mathematical operations are equivalent. This change targets a substantial speedup in patch embedding without altering the model's output, particularly beneficial for environments with older cuDNN versions.

Highlights

  • Performance Optimization: Replaced Conv3d with nn.Linear in the VisionPatchEmbed component, yielding a 10-15x speedup for patch embedding, especially beneficial on older cuDNN versions.
  • Conditional Application: The optimization is applied only when the Conv3d layer's kernel_size equals its stride, and padding, dilation, and groups are set to their default non-sliding values, ensuring functional equivalence.
  • Weight Migration: Existing Conv3d weights are loaded for checkpoint compatibility and then safely transferred to the newly created nn.Linear layer.
Changelog
  • sglang_omni/models/qwen3_omni/components/image_encoder.py
    • Added logging and types imports to support the new optimization logic.
    • Implemented _patch_embed_forward and _optimize_patch_embed functions to conditionally replace Conv3d with nn.Linear for performance.
    • Modified _build_visual to apply the _optimize_patch_embed function after loading the visual module.
Activity
  • The pull request was created with assistance from Cursor.
  • No other specific review comments or activities have been recorded yet.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

The pull request successfully implements a performance optimization by replacing Conv3d with nn.Linear in VisionPatchEmbed under specific conditions, which is a clever way to leverage the equivalence of these operations for speedup. The safety checks to ensure the optimization only applies when kernel_size == stride, padding == 0, dilation == 1, and groups == 1 are well-implemented. However, there is a critical correctness issue in the _patch_embed_forward function where the input tensor is not reshaped to match the expected input dimensions of the nn.Linear layer, which will lead to a runtime error.


def _patch_embed_forward(self: nn.Module, hidden_states: torch.Tensor) -> torch.Tensor:
"""Optimized PatchEmbed forward using Linear instead of Conv3d."""
return self.linear(hidden_states.to(dtype=self.linear.weight.dtype))
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The _patch_embed_forward function is intended to replace the original forward method of a PatchEmbed module, which likely received a 5D tensor (batch, channels, depth, height, width) for the Conv3d layer. The new self.linear layer, however, expects a 2D tensor (batch, flattened_features).

The current implementation does not reshape hidden_states before passing it to self.linear. This will cause a runtime error because the nn.Linear layer will receive an incorrectly shaped input tensor.

To fix this, hidden_states needs to be reshaped from its original 5D format to a 2D format (batch_size, in_features) where in_features is the product of the original channels, depth, height, and width, as calculated in _optimize_patch_embed.

        return self.linear(hidden_states.flatten(1).to(dtype=self.linear.weight.dtype))

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The caller passes hidden_states as a 2D tensor [N, 1536] (already flattened per patch). The original HF forward reshapes it to 5D [N, 3, 2, 16, 16] internally just to satisfy Conv3d, then reshapes the output back to 2D [N, 1152].

@zhou9402 zhou9402 changed the title [Perf] replace Conv3d with Linear in VisionPatchEmbed for 10-15x speedup [Perf] replace Conv3d with Linear in VisionPatchEmbed speedup Mar 6, 2026
@zhou9402 zhou9402 changed the title [Perf] replace Conv3d with Linear in VisionPatchEmbed speedup [Perf] replace Conv3d with Linear in VisionPatchEmbe Mar 6, 2026
@zhou9402 zhou9402 changed the title [Perf] replace Conv3d with Linear in VisionPatchEmbe [Perf] replace Conv3d with Linear in VisionPatchEmbed Mar 6, 2026
@zhou9402 zhou9402 requested review from FrankLeeeee and shuaills and removed request for FrankLeeeee and shuaills March 6, 2026 10:30
@yuan-luo
Copy link
Copy Markdown
Collaborator

yuan-luo commented Mar 6, 2026

Could you provide accuracy test result?

@zhou9402
Copy link
Copy Markdown
Collaborator Author

zhou9402 commented Mar 6, 2026

Could you provide accuracy test result?

Added.

@FrankLeeeee FrankLeeeee merged commit b82d434 into main Mar 6, 2026
5 checks passed
@zhaochenyang20 zhaochenyang20 deleted the optimze/conv3d branch March 25, 2026 19:21
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants