-
-
Notifications
You must be signed in to change notification settings - Fork 14.8k
[BugFix] Potential bug fix for test_async_tp_pass_correctness
#33854
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -224,11 +224,13 @@ def _build_decode( | |||||
| f"{self.scheduler_metadata.shape[0]}" | ||||||
| ) | ||||||
| self.scheduler_metadata[:n] = scheduler_metadata | ||||||
| # NOTE(woosuk): We should zero out the rest of the scheduler | ||||||
| # metadata to guarantee the correctness. Otherwise, some thread | ||||||
| # blocks may use the invalid scheduler metadata and overwrite the | ||||||
| # output buffer. | ||||||
| self.scheduler_metadata[n:] = 0 | ||||||
| # NOTE(woosuk, lucas): Zero from n-1 onwards. Positions >= n must be | ||||||
| # zeroed to prevent invalid metadata from being used. The | ||||||
| # semaphore at position n-1 must also be zeroed before each | ||||||
| # forward pass because when num_splits == 1, FA3's internal | ||||||
| # semaphore reset uses PyTorch zero_() which isn't captured in | ||||||
| # CUDA graphs. | ||||||
| self.scheduler_metadata[n - 1 :] = 0 | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This change seems correct, but it might introduce a bug if While
Suggested change
|
||||||
| scheduler_metadata = self.scheduler_metadata[:n] | ||||||
|
|
||||||
| metadata = FlashAttnMLADecodeMetadata( | ||||||
|
|
||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This change seems correct, but it might introduce a bug if
ncan be 0. Ifn=0,n-1becomes-1, andself.scheduler_metadata[-1:] = 0will only zero out the last element of the buffer. The previous behavior forn=0was to zero out the entire buffer (self.scheduler_metadata[0:] = 0), which seems safer for resetting state when there are no requests.While
nis likely always>= 1(sincescheduler_metadatasize isbatch_size * 4 + 1), it's safer to handle then=0case explicitly to prevent potential issues. A more robust implementation would be: