-
-
Notifications
You must be signed in to change notification settings - Fork 15.5k
[Fix] Share decode output buffer across MLA layers to reduce memory #37805
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
8e86f6e
b1df6ec
1d2a31f
230f55e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -103,6 +103,45 @@ def get_required_kv_cache_layout(cls) -> "KVCacheLayoutType | None": | |
| ) | ||
|
|
||
|
|
||
| class FlashInferDecodeOutBuffer: | ||
| """Shared decode output buffer for all FlashInferMLAImpl instances. | ||
|
|
||
| The NaN fix (PR #37442) allocated a persistent _decode_out buffer per | ||
| attention layer. For DeepSeek-R1 (61 layers), this totals ~15 GiB of | ||
| GPU memory that is not accounted for during profiling, causing OOM. | ||
|
|
||
| This shared buffer reduces memory from ~15 GiB to ~256 MiB since all | ||
| layers write sequentially (no concurrent access in forward pass). | ||
| """ | ||
|
|
||
| def __init__(self): | ||
| self._buffer: torch.Tensor | None = None | ||
|
|
||
| def get( | ||
| self, | ||
| batch_size: int, | ||
| num_heads: int, | ||
| kv_lora_rank: int, | ||
| dtype: torch.dtype, | ||
| device: torch.device, | ||
| ): | ||
| if ( | ||
| self._buffer is None | ||
| or self._buffer.shape[0] < batch_size | ||
| or self._buffer.shape[1] < num_heads | ||
| or self._buffer.shape[2] < kv_lora_rank | ||
| or self._buffer.dtype != dtype | ||
| or self._buffer.device != device | ||
| ): | ||
| self._buffer = torch.zeros( | ||
| (batch_size, num_heads, kv_lora_rank), dtype=dtype, device=device | ||
| ) | ||
| return self._buffer[:batch_size, :num_heads, :kv_lora_rank] | ||
|
Comment on lines
+106
to
+139
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The class FlashInferDecodeOutBuffer:
"""Shared decode output buffer for all FlashInferMLAImpl instances.
The NaN fix (PR #37442) allocated a persistent _decode_out buffer per
attention layer. For DeepSeek-R1 (61 layers), this totals ~15 GiB of
GPU memory that is not accounted for during profiling, causing OOM.
This shared buffer reduces memory from ~15 GiB to ~256 MiB since all
layers write sequentially (no concurrent access in forward pass).
"""
def __init__(self):
self._buffer: torch.Tensor | None = None
self._max_batch_size = 0
self._max_num_heads = 0
self._max_kv_lora_rank = 0
def get(
self,
batch_size: int,
num_heads: int,
kv_lora_rank: int,
dtype: torch.dtype,
device: torch.device,
):
realloc = False
if (
self._buffer is None
or self._buffer.dtype != dtype
or self._buffer.device != device
):
self._max_batch_size = 0
self._max_num_heads = 0
self._max_kv_lora_rank = 0
realloc = True
if batch_size > self._max_batch_size:
self._max_batch_size = batch_size
realloc = True
if num_heads > self._max_num_heads:
self._max_num_heads = num_heads
realloc = True
if kv_lora_rank > self._max_kv_lora_rank:
self._max_kv_lora_rank = kv_lora_rank
realloc = True
if realloc:
self._buffer = torch.zeros(
(self._max_batch_size, self._max_num_heads, self._max_kv_lora_rank),
dtype=dtype,
device=device,
)
return self._buffer[:batch_size, :num_heads, :kv_lora_rank] |
||
|
|
||
|
|
||
| g_fi_decode_out = FlashInferDecodeOutBuffer() | ||
|
|
||
|
|
||
| class FlashInferMLAImpl(MLACommonImpl[MLACommonMetadata]): | ||
| def __init__( | ||
| self, | ||
|
|
@@ -152,10 +191,9 @@ def __init__( | |
| self.bmm1_scale: float | None = None | ||
| self.bmm2_scale: float | None = None | ||
|
|
||
| # Pre-allocated output buffer, lazily sized on first call. | ||
| # Zero-init once to prevent NaN in padding slots (seq_lens=0) | ||
| # from contaminating downstream per-tensor reductions. | ||
| self._decode_out: torch.Tensor | None = None | ||
| # Use shared decode output buffer across all layers to reduce memory. | ||
| # Each layer writes sequentially, so sharing is safe. | ||
| self._decode_out_buffer = g_fi_decode_out | ||
|
|
||
| def forward_mqa( | ||
| self, | ||
|
|
@@ -209,19 +247,10 @@ def forward_mqa( | |
| if is_quantized_kv_cache(self.kv_cache_dtype) | ||
| else q.dtype | ||
| ) | ||
| if ( | ||
| self._decode_out is None | ||
| or self._decode_out.shape[0] < B | ||
| or self._decode_out.dtype != dtype | ||
| ): | ||
| self._decode_out = torch.zeros( | ||
| B, | ||
| q.shape[2], | ||
| self.kv_lora_rank, | ||
| dtype=dtype, | ||
| device=q.device, | ||
| ) | ||
| out_kwargs["out"] = self._decode_out[:B] | ||
| # Use shared buffer to reduce memory footprint. | ||
| out_kwargs["out"] = self._decode_out_buffer.get( | ||
| B, q.shape[2], self.kv_lora_rank, dtype, q.device | ||
| ) | ||
|
|
||
| o = trtllm_batch_decode_with_kv_cache_mla( | ||
| query=q, | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The current implementation of
SM100DecodeOutBuffer.getcan be inefficient. When the buffer needs to be reallocated due to a change indtypeordevice, it uses the currentbatch_size, which might be small. A subsequent call with a largerbatch_sizewould trigger another reallocation. To avoid these potentially frequent reallocations, the buffer should always grow to the maximumbatch_sizeseen so far for a givendtypeanddevice.