@@ -649,7 +649,9 @@ def get_cuda_graph_warmup_request(batch_size, draft_len):
649
649
result = None
650
650
return result
651
651
652
- def get_warmup_request (num_tokens : int , num_gen_tokens : int ):
652
+ def get_warmup_request (num_tokens : int ,
653
+ num_gen_tokens : int ,
654
+ least_requrests : bool = True ):
653
655
available_tokens = kv_cache_manager .get_num_available_tokens (
654
656
self .runtime_draft_len )
655
657
available_blocks = kv_cache_manager .get_num_free_blocks ()
@@ -673,13 +675,23 @@ def get_warmup_request(num_tokens: int, num_gen_tokens: int):
673
675
num_left_over_tokens = 0
674
676
675
677
if num_ctx_tokens > 0 :
676
- # We will try to assign as less context requests as possible to
677
- # fill the num_ctx_tokens.
678
+ if least_requrests :
679
+ # We will try to assign as less context requests as possible to
680
+ # fill the num_ctx_tokens.
678
681
679
- # Num full sequences:
680
- num_full_seqs = num_ctx_tokens // max_seq_len
681
- num_left_over_tokens = num_ctx_tokens - num_full_seqs * max_seq_len
682
+ # Num full sequences:
683
+ num_full_seqs = num_ctx_tokens // max_seq_len
684
+ num_left_over_tokens = num_ctx_tokens - num_full_seqs * max_seq_len
682
685
686
+ else :
687
+ max_bs = min (num_ctx_tokens ,
688
+ self .batch_size - num_gen_tokens )
689
+ if num_ctx_tokens % max_bs == 0 :
690
+ num_full_seqs = max_bs
691
+ else :
692
+ num_full_seqs = max_bs - 1
693
+ max_seq_len = num_ctx_tokens // num_full_seqs
694
+ num_left_over_tokens = num_ctx_tokens - max_seq_len * num_full_seqs
683
695
num_ctx_requests = num_full_seqs + (1 if num_left_over_tokens
684
696
> 0 else 0 )
685
697
@@ -754,6 +766,32 @@ def release_batch(result: ScheduledRequests | None):
754
766
if cp_type == CpType .STAR :
755
767
return
756
768
769
+ def general_warmup (reverse : bool = False ):
770
+ warmup_requests = set ([
771
+ (1 , 1 ), # Specialize for 1 token.
772
+ (self .batch_size ,
773
+ self .batch_size ), # max_batch_size, pure generation
774
+ (2 , 0 ), # Non-one, pure context
775
+ (curr_max_num_tokens , 0 ), # max_num_tokens, pure context
776
+ ])
777
+ if reverse :
778
+ warmup_requests = sorted (list (warmup_requests ), reverse = reverse )
779
+
780
+ for warmup_num_tokens , warmup_num_gen_tokens in warmup_requests :
781
+ with release_batch (
782
+ get_warmup_request (warmup_num_tokens ,
783
+ warmup_num_gen_tokens )) as batch :
784
+ if batch is None :
785
+ # No KV cache space!
786
+ continue
787
+ logger .info (
788
+ f"Run warmup with { warmup_num_tokens } tokens, include { warmup_num_gen_tokens } generation tokens"
789
+ )
790
+ self .forward (batch ,
791
+ new_tensors_device = None ,
792
+ resource_manager = resource_manager )
793
+ torch .cuda .synchronize ()
794
+
757
795
if self ._torch_compile_enabled :
758
796
759
797
warmup_requests = set ([
@@ -766,21 +804,7 @@ def release_batch(result: ScheduledRequests | None):
766
804
767
805
# Disable cuda graph capture here so that we can properly capture it later
768
806
with self .no_cuda_graph ():
769
- for warmup_num_tokens , warmup_num_gen_tokens in warmup_requests :
770
-
771
- with release_batch (
772
- get_warmup_request (warmup_num_tokens ,
773
- warmup_num_gen_tokens )) as batch :
774
- if batch is None :
775
- # No KV cache space!
776
- continue
777
- logger .info (
778
- f"Run warmup with { warmup_num_tokens } tokens, include { warmup_num_gen_tokens } generation tokens"
779
- )
780
- self .forward (batch ,
781
- new_tensors_device = None ,
782
- resource_manager = resource_manager )
783
- torch .cuda .synchronize ()
807
+ general_warmup ()
784
808
785
809
if self .pytorch_backend_config .enable_autotuner :
786
810
with self .no_cuda_graph (), autotune ():
@@ -867,6 +891,27 @@ def release_batch(result: ScheduledRequests | None):
867
891
gc .collect ()
868
892
torch .cuda .empty_cache ()
869
893
894
+ # When using piecewise cuda graph, the logits may suffer severe memory faction problem.
895
+ # When the num of requests is growing, the block allocated by torch cannot be reused.
896
+ # So after piecewise cuda graph capture, a request with most requests is triggered to makes
897
+ # sure that a large enough block is allocated and can be correctly reused.
898
+ for num_tokens in piecewise_cuda_graph_num_tokens :
899
+ batch = get_warmup_request (num_tokens , 0 , False )
900
+ if batch is None :
901
+ continue
902
+ with release_batch (batch ) as batch :
903
+ logger .info (
904
+ f"Run piecewise CUDA graph warmup for num tokens={ num_tokens } with most requests"
905
+ )
906
+ self .forward (batch ,
907
+ new_tensors_device = None ,
908
+ resource_manager = resource_manager )
909
+
910
+ torch .cuda .synchronize ()
911
+
912
+ # Also, we run a general warmup from large to small to make sure that blocks are allocated well.
913
+ general_warmup (reverse = True )
914
+
870
915
# Set the value back to the original value
871
916
self .enable_spec_decode = self .is_spec_decode
872
917
0 commit comments