Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Parallel patch embed for faster execution #17

Merged
merged 4 commits into from
Jun 15, 2023
Merged

Conversation

rejuvyesh
Copy link
Collaborator

We can use the trick of grouped convolutions to perform patch embedding of different variables in parallel in a single GPU kernel call instead of serially in a for loop.

TODO:

  • Correctly converting the non-parallel patch embed checkpoints

@rejuvyesh rejuvyesh assigned tung-nd and unassigned tung-nd Apr 24, 2023
@rejuvyesh rejuvyesh requested a review from tung-nd April 24, 2023 02:42
@rejuvyesh
Copy link
Collaborator Author

rejuvyesh commented Apr 24, 2023

Some rough timing:

parallel_patch_embed: True, img_size: [32,64]
forward time: 0.009
---
parallel_patch_embed: False, img_size: [32,64]
forward time: 0.021
---
parallel_patch_embed: True, img_size: [128, 256]
forward time: 0.097
---
parallel_patch_embed: False, img_size: [128, 256]
forward time: 0.109
parallel_patch_embed: True, img_size: [32, 64]
backward time: 0.034
---
parallel_patch_embed: False, img_size: [32, 64]
backward time: 0.052

@brandstetter-johannes brandstetter-johannes merged commit efd6de4 into main Jun 15, 2023
@rejuvyesh rejuvyesh deleted the jkg/fast branch June 28, 2023 15:47
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants