[BugFix] Potential bug fix for test_async_tp_pass_correctness#33854
[BugFix] Potential bug fix for test_async_tp_pass_correctness#33854LucasWilkinson wants to merge 1 commit intovllm-project:mainfrom
test_async_tp_pass_correctness#33854Conversation
There was a problem hiding this comment.
Code Review
This pull request applies a bug fix for CUDA graph correctness in FlashAttention by adjusting the range of scheduler metadata that is zeroed out. The change ensures a semaphore is correctly reset. My review identifies a potential edge case where n could be 0, which might lead to incorrect behavior. I've suggested a more robust implementation to handle this case gracefully in both modified files.
| # forward pass because when num_splits == 1, FA3's internal | ||
| # semaphore reset uses PyTorch zero_() which isn't captured in | ||
| # CUDA graphs. | ||
| self.scheduler_metadata[n - 1 :] = 0 |
There was a problem hiding this comment.
This change seems correct, but it might introduce a bug if n can be 0. If n=0, n-1 becomes -1, and self.scheduler_metadata[-1:] = 0 will only zero out the last element of the buffer. The previous behavior for n=0 was to zero out the entire buffer (self.scheduler_metadata[0:] = 0), which seems safer for resetting state when there are no requests.
While n is likely always >= 1 (since scheduler_metadata size is batch_size * 4 + 1), it's safer to handle the n=0 case explicitly to prevent potential issues. A more robust implementation would be:
| self.scheduler_metadata[n - 1 :] = 0 | |
| self.scheduler_metadata[max(0, n - 1):] = 0 |
| # forward pass because when num_splits == 1, FA3's internal | ||
| # semaphore reset uses PyTorch zero_() which isn't captured in | ||
| # CUDA graphs. | ||
| self.scheduler_metadata[n - 1 :] = 0 |
There was a problem hiding this comment.
This change seems correct, but it might introduce a bug if n can be 0. If n=0, n-1 becomes -1, and self.scheduler_metadata[-1:] = 0 will only zero out the last element of the buffer. The previous behavior for n=0 was to zero out the entire buffer (self.scheduler_metadata[0:] = 0), which seems safer for resetting state when there are no requests.
While n is likely always >= 1 (since scheduler_metadata size is batch_size * 4 + 1), it's safer to handle the n=0 case explicitly to prevent potential issues. A more robust implementation would be:
| self.scheduler_metadata[n - 1 :] = 0 | |
| self.scheduler_metadata[max(0, n - 1):] = 0 |
|
see: #34043 |
Potential fix for: #33802