Fix missing dp_max_padding argument in set_dp_buffer_len#12812
Conversation
Summary of ChangesHello @Chen-0210, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request resolves a critical bug stemming from a recent update that introduced a new argument, Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request correctly fixes a crash by passing the missing dp_max_padding argument to set_dp_buffer_len. However, I've identified a related latent bug where global_dp_buffer_len is None, which would cause a TypeError when data parallelism is enabled. My suggested change addresses both the original issue and this latent bug by correctly calculating global_dp_buffer_len and also passing the global_num_tokens argument, which is necessary for data parallel attention.
| def run_once(): | ||
| # Clean intermediate result cache for DP attention | ||
| forward_batch.dp_local_start_pos = forward_batch.dp_local_num_tokens = None | ||
| set_dp_buffer_len(global_dp_buffer_len, num_tokens) |
There was a problem hiding this comment.
While adding the dp_max_padding argument fixes the immediate TypeError, there's a latent bug here. The global_dp_buffer_len variable is initialized to None at line 331. When data parallelism is used (dp_size > 1), this None value will be passed to set_dp_buffer_len and eventually cause a TypeError inside get_global_dp_buffer when it's used to allocate a tensor.
To make this robust, we should calculate the correct global_dp_buffer_len here, which should be num_tokens * self.dp_size. Additionally, the global_num_tokens argument should be passed for data parallel attention to work correctly during graph capture.
| set_dp_buffer_len(global_dp_buffer_len, num_tokens) | |
| set_dp_buffer_len( | |
| num_tokens * self.dp_size, | |
| num_tokens, | |
| forward_batch.dp_padding_mode.is_max_len(), | |
| [num_tokens] * self.dp_size, | |
| ) |
Motivation
PR #12572 added a new argument
dp_max_paddingtoset_dp_buffer_len, which caused #12796 .Fixes #12796.
Modifications
Pass the missing
dp_max_paddingargument toset_dp_buffer_len.Accuracy Tests
Benchmarking and Profiling
Checklist