[Perf] replace Conv3d with Linear in VisionPatchEmbed#132
Conversation
Conv3d with kernel_size == stride does not slide over the input, making it equivalent to a reshape + linear projection. Replace it with nn.Linear after loading HF checkpoint weights for significantly faster patch embedding, especially on older cuDNN versions where Conv3d with large non-sliding kernels is extremely slow. Safety checks ensure the optimization only applies when kernel_size == stride, padding == 0, dilation == 1, and groups == 1. Reference: sgl-project/sglang#19788 Made-with: Cursor
Summary of ChangesHello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly enhances the performance of the Highlights
Changelog
Activity
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
The pull request successfully implements a performance optimization by replacing Conv3d with nn.Linear in VisionPatchEmbed under specific conditions, which is a clever way to leverage the equivalence of these operations for speedup. The safety checks to ensure the optimization only applies when kernel_size == stride, padding == 0, dilation == 1, and groups == 1 are well-implemented. However, there is a critical correctness issue in the _patch_embed_forward function where the input tensor is not reshaped to match the expected input dimensions of the nn.Linear layer, which will lead to a runtime error.
|
|
||
| def _patch_embed_forward(self: nn.Module, hidden_states: torch.Tensor) -> torch.Tensor: | ||
| """Optimized PatchEmbed forward using Linear instead of Conv3d.""" | ||
| return self.linear(hidden_states.to(dtype=self.linear.weight.dtype)) |
There was a problem hiding this comment.
The _patch_embed_forward function is intended to replace the original forward method of a PatchEmbed module, which likely received a 5D tensor (batch, channels, depth, height, width) for the Conv3d layer. The new self.linear layer, however, expects a 2D tensor (batch, flattened_features).
The current implementation does not reshape hidden_states before passing it to self.linear. This will cause a runtime error because the nn.Linear layer will receive an incorrectly shaped input tensor.
To fix this, hidden_states needs to be reshaped from its original 5D format to a 2D format (batch_size, in_features) where in_features is the product of the original channels, depth, height, and width, as calculated in _optimize_patch_embed.
return self.linear(hidden_states.flatten(1).to(dtype=self.linear.weight.dtype))There was a problem hiding this comment.
The caller passes hidden_states as a 2D tensor [N, 1536] (already flattened per patch). The original HF forward reshapes it to 5D [N, 3, 2, 16, 16] internally just to satisfy Conv3d, then reshapes the output back to 2D [N, 1152].
|
Could you provide accuracy test result? |
Added. |
Conv3d with kernel_size == stride does not slide over the input, making it equivalent to a reshape + linear projection. Replace it with nn.Linear after loading HF checkpoint weights for significantly faster patch embedding, especially on older cuDNN versions where Conv3d with large non-sliding kernels is extremely slow.
Safety checks ensure the optimization only applies when kernel_size == stride, padding == 0, dilation == 1, and groups == 1.
Reference: sgl-project/sglang#19788
Benchmark in H20,
linearvsconv3d:Accuracy:
Test with uniform distribution:
Test with 24168 real image patches from image processor