@@ -531,14 +531,17 @@ def get_cuda_graph_warmup_request(batch_size, draft_len):
531531 return result
532532
533533 def get_warmup_request (num_tokens : int ,
534- num_gen_tokens : int ,
534+ num_gen_requests : int ,
535535 least_requests : bool = True ):
536536 available_tokens = kv_cache_manager .get_num_available_tokens (
537537 self .runtime_draft_len )
538538 available_blocks = kv_cache_manager .get_num_free_blocks ()
539539 if num_tokens > self .max_num_tokens or num_tokens > available_tokens :
540540 return None
541- if num_gen_tokens > self .batch_size :
541+ if num_gen_requests > self .batch_size :
542+ return None
543+ num_gen_tokens = num_gen_requests * (1 + self .runtime_draft_len )
544+ if num_gen_tokens > self .max_num_tokens :
542545 return None
543546
544547 num_extra_decoding_steps = get_num_extra_decoding_steps ()
@@ -548,7 +551,8 @@ def get_warmup_request(num_tokens: int,
548551 # during warmup.
549552 return None
550553
551- num_ctx_tokens = num_tokens - num_gen_tokens
554+ num_ctx_tokens = num_tokens - num_gen_requests * (
555+ 1 + self .runtime_draft_len )
552556 num_ctx_requests = 0
553557 ctx_requests = []
554558 gen_requests = []
@@ -557,7 +561,7 @@ def get_warmup_request(num_tokens: int,
557561 num_full_seqs = 0
558562 num_left_over_tokens = 0
559563
560- max_context_requests = self .batch_size - num_gen_tokens
564+ max_context_requests = self .batch_size - num_gen_requests
561565 if max_context_requests * max_seq_len < num_ctx_tokens :
562566 return None
563567
@@ -572,7 +576,7 @@ def get_warmup_request(num_tokens: int,
572576
573577 else :
574578 max_bs = min (num_ctx_tokens ,
575- self .batch_size - num_gen_tokens )
579+ self .batch_size - num_gen_requests )
576580 if num_ctx_tokens % max_bs == 0 :
577581 num_full_seqs = max_bs
578582 else :
@@ -583,13 +587,13 @@ def get_warmup_request(num_tokens: int,
583587 > 0 else 0 )
584588
585589 # We do not have enough batch to fill the request
586- if num_ctx_requests + num_gen_tokens > self .batch_size :
590+ if num_ctx_requests + num_gen_requests > self .batch_size :
587591 return None
588592
589593 blocks_to_use = num_full_seqs * math .ceil (
590594 max_seq_len / kv_cache_manager .tokens_per_block ) + math .ceil (
591595 num_left_over_tokens /
592- kv_cache_manager .tokens_per_block ) + num_gen_tokens
596+ kv_cache_manager .tokens_per_block ) + num_gen_requests
593597
594598 if blocks_to_use > available_blocks :
595599 return None
@@ -604,25 +608,29 @@ def get_warmup_request(num_tokens: int,
604608 token_nums = ctx_token_nums ,
605609 is_gen = False ,
606610 max_num_draft_tokens = self .runtime_draft_len ,
607- use_mrope = self .use_mrope )
611+ use_mrope = self .use_mrope ,
612+ max_beam_width = self .max_beam_width ,
613+ num_extra_decoding_steps = num_extra_decoding_steps )
608614
609615 if spec_resource_manager is not None :
610616 spec_resource_manager .add_dummy_requests (
611617 request_ids = list (range (num_ctx_requests )))
612618
613- if num_gen_tokens > 0 :
619+ if num_gen_requests > 0 :
614620 gen_requests = kv_cache_manager .add_dummy_requests (
615621 list (
616622 range (num_ctx_requests ,
617- num_ctx_requests + num_gen_tokens )),
618- token_nums = [1 ] * num_gen_tokens ,
623+ num_ctx_requests + num_gen_requests )),
624+ token_nums = [1 ] * num_gen_requests ,
619625 is_gen = True ,
620626 max_num_draft_tokens = self .max_draft_len ,
621- use_mrope = self .use_mrope )
627+ use_mrope = self .use_mrope ,
628+ max_beam_width = self .max_beam_width ,
629+ num_extra_decoding_steps = num_extra_decoding_steps )
622630 if spec_resource_manager is not None :
623631 spec_resource_manager .add_dummy_requests (request_ids = list (
624632 range (num_ctx_requests , num_ctx_requests +
625- num_gen_tokens )))
633+ num_gen_requests )))
626634
627635 result = ScheduledRequests ()
628636 result .context_requests = ctx_requests
@@ -655,15 +663,18 @@ def release_batch(result: ScheduledRequests | None):
655663 return
656664
657665 def general_warmup (reverse : bool = False ):
666+ max_batch_size = min (
667+ self .batch_size ,
668+ curr_max_num_tokens // (1 + self .runtime_draft_len ))
658669 warmup_requests = set ([
659670 (1 , 1 ), # Specialize for 1 token.
660- (self . batch_size ,
661- self . batch_size ), # max_batch_size, pure generation
671+ (max_batch_size ,
672+ max_batch_size ), # max_batch_size, pure generation
662673 (2 , 0 ), # Non-one, pure context
663674 (curr_max_num_tokens , 0 ), # max_num_tokens, pure context
664675 ])
665- if reverse :
666- warmup_requests = sorted (list (warmup_requests ), reverse = reverse )
676+
677+ warmup_requests = sorted (list (warmup_requests ), reverse = reverse )
667678
668679 for warmup_num_tokens , warmup_num_gen_tokens in warmup_requests :
669680 with release_batch (
@@ -817,6 +828,7 @@ def _update_draft_inference_state(is_first_draft: bool,
817828 # Also, we run a general warmup from large to small to make sure that blocks are allocated well.
818829 # The cudagraph and piecewise cuda graph capture calls torch.cuda.empty_cache() and block may already
819830 # be freed even we calls general_warmup for torch compile.
831+ # Also the additional warmup helps trigger the runtime jit to avoid runtime jit overhead.
820832 general_warmup (reverse = True )
821833
822834 # Set the value back to the original value
0 commit comments