Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -382,7 +382,11 @@ def capture_one_batch_size(self, num_tokens: int):
def run_once():
# Clean intermediate result cache for DP attention
forward_batch.dp_local_start_pos = forward_batch.dp_local_num_tokens = None
set_dp_buffer_len(global_dp_buffer_len, num_tokens)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

While adding the dp_max_padding argument fixes the immediate TypeError, there's a latent bug here. The global_dp_buffer_len variable is initialized to None at line 331. When data parallelism is used (dp_size > 1), this None value will be passed to set_dp_buffer_len and eventually cause a TypeError inside get_global_dp_buffer when it's used to allocate a tensor.

To make this robust, we should calculate the correct global_dp_buffer_len here, which should be num_tokens * self.dp_size. Additionally, the global_num_tokens argument should be passed for data parallel attention to work correctly during graph capture.

Suggested change
set_dp_buffer_len(global_dp_buffer_len, num_tokens)
set_dp_buffer_len(
num_tokens * self.dp_size,
num_tokens,
forward_batch.dp_padding_mode.is_max_len(),
[num_tokens] * self.dp_size,
)

set_dp_buffer_len(
global_dp_buffer_len,
num_tokens,
forward_batch.dp_padding_mode.is_max_len(),
)
# FIXME: the implementation is hacky. `is_extend_in_batch`` is for determining the deepep mode.
# It is True in this context but we need to set it to use low latency deepep mode.
set_is_extend_in_batch(False)
Expand Down
Loading