Skip to content

Commit 4f33d84

Browse files
committed
[https://nvbugs/5451280][fix] Reduce memory fraction problem by warmup with large request
Signed-off-by: Jin Li <[email protected]>
1 parent e30d9ac commit 4f33d84

File tree

1 file changed

+66
-21
lines changed

1 file changed

+66
-21
lines changed

tensorrt_llm/_torch/pyexecutor/model_engine.py

Lines changed: 66 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -649,7 +649,9 @@ def get_cuda_graph_warmup_request(batch_size, draft_len):
649649
result = None
650650
return result
651651

652-
def get_warmup_request(num_tokens: int, num_gen_tokens: int):
652+
def get_warmup_request(num_tokens: int,
653+
num_gen_tokens: int,
654+
least_requrests: bool = True):
653655
available_tokens = kv_cache_manager.get_num_available_tokens(
654656
self.runtime_draft_len)
655657
available_blocks = kv_cache_manager.get_num_free_blocks()
@@ -673,13 +675,23 @@ def get_warmup_request(num_tokens: int, num_gen_tokens: int):
673675
num_left_over_tokens = 0
674676

675677
if num_ctx_tokens > 0:
676-
# We will try to assign as less context requests as possible to
677-
# fill the num_ctx_tokens.
678+
if least_requrests:
679+
# We will try to assign as less context requests as possible to
680+
# fill the num_ctx_tokens.
678681

679-
# Num full sequences:
680-
num_full_seqs = num_ctx_tokens // max_seq_len
681-
num_left_over_tokens = num_ctx_tokens - num_full_seqs * max_seq_len
682+
# Num full sequences:
683+
num_full_seqs = num_ctx_tokens // max_seq_len
684+
num_left_over_tokens = num_ctx_tokens - num_full_seqs * max_seq_len
682685

686+
else:
687+
max_bs = min(num_ctx_tokens,
688+
self.batch_size - num_gen_tokens)
689+
if num_ctx_tokens % max_bs == 0:
690+
num_full_seqs = max_bs
691+
else:
692+
num_full_seqs = max_bs - 1
693+
max_seq_len = num_ctx_tokens // num_full_seqs
694+
num_left_over_tokens = num_ctx_tokens - max_seq_len * num_full_seqs
683695
num_ctx_requests = num_full_seqs + (1 if num_left_over_tokens
684696
> 0 else 0)
685697

@@ -754,6 +766,32 @@ def release_batch(result: ScheduledRequests | None):
754766
if cp_type == CpType.STAR:
755767
return
756768

769+
def general_warmup(reverse: bool = False):
770+
warmup_requests = set([
771+
(1, 1), # Specialize for 1 token.
772+
(self.batch_size,
773+
self.batch_size), # max_batch_size, pure generation
774+
(2, 0), # Non-one, pure context
775+
(curr_max_num_tokens, 0), # max_num_tokens, pure context
776+
])
777+
if reverse:
778+
warmup_requests = sorted(list(warmup_requests), reverse=reverse)
779+
780+
for warmup_num_tokens, warmup_num_gen_tokens in warmup_requests:
781+
with release_batch(
782+
get_warmup_request(warmup_num_tokens,
783+
warmup_num_gen_tokens)) as batch:
784+
if batch is None:
785+
# No KV cache space!
786+
continue
787+
logger.info(
788+
f"Run warmup with {warmup_num_tokens} tokens, include {warmup_num_gen_tokens} generation tokens"
789+
)
790+
self.forward(batch,
791+
new_tensors_device=None,
792+
resource_manager=resource_manager)
793+
torch.cuda.synchronize()
794+
757795
if self._torch_compile_enabled:
758796

759797
warmup_requests = set([
@@ -766,21 +804,7 @@ def release_batch(result: ScheduledRequests | None):
766804

767805
# Disable cuda graph capture here so that we can properly capture it later
768806
with self.no_cuda_graph():
769-
for warmup_num_tokens, warmup_num_gen_tokens in warmup_requests:
770-
771-
with release_batch(
772-
get_warmup_request(warmup_num_tokens,
773-
warmup_num_gen_tokens)) as batch:
774-
if batch is None:
775-
# No KV cache space!
776-
continue
777-
logger.info(
778-
f"Run warmup with {warmup_num_tokens} tokens, include {warmup_num_gen_tokens} generation tokens"
779-
)
780-
self.forward(batch,
781-
new_tensors_device=None,
782-
resource_manager=resource_manager)
783-
torch.cuda.synchronize()
807+
general_warmup()
784808

785809
if self.pytorch_backend_config.enable_autotuner:
786810
with self.no_cuda_graph(), autotune():
@@ -867,6 +891,27 @@ def release_batch(result: ScheduledRequests | None):
867891
gc.collect()
868892
torch.cuda.empty_cache()
869893

894+
# When using piecewise cuda graph, the logits may suffer severe memory faction problem.
895+
# When the num of requests is growing, the block allocated by torch cannot be reused.
896+
# So after piecewise cuda graph capture, a request with most requests is triggered to makes
897+
# sure that a large enough block is allocated and can be correctly reused.
898+
for num_tokens in piecewise_cuda_graph_num_tokens:
899+
batch = get_warmup_request(num_tokens, 0, False)
900+
if batch is None:
901+
continue
902+
with release_batch(batch) as batch:
903+
logger.info(
904+
f"Run piecewise CUDA graph warmup for num tokens={num_tokens} with most requests"
905+
)
906+
self.forward(batch,
907+
new_tensors_device=None,
908+
resource_manager=resource_manager)
909+
910+
torch.cuda.synchronize()
911+
912+
# Also, we run a general warmup from large to small to make sure that blocks are allocated well.
913+
general_warmup(reverse=True)
914+
870915
# Set the value back to the original value
871916
self.enable_spec_decode = self.is_spec_decode
872917

0 commit comments

Comments
 (0)