@@ -36,6 +36,8 @@ def create_logits_tensor(output_token_ids: list[list[int]],
3636def  create_sampling_metadata (
3737    all_greedy : bool ,
3838    temperature : Optional [torch .Tensor ] =  None ,
39+     top_k : Optional [torch .Tensor ] =  None ,
40+     top_p : Optional [torch .Tensor ] =  None ,
3941    generators : Optional [dict [int , Any ]] =  None ,
4042) ->  SamplingMetadata :
4143    """Create a v1 sampling metadata object with all_greedy set  
@@ -52,8 +54,8 @@ def create_sampling_metadata(
5254        temperature = temperature ,
5355        all_greedy = all_greedy ,
5456        all_random = not  all_greedy ,
55-         top_p = None ,
56-         top_k = None ,
57+         top_p = top_p ,
58+         top_k = top_k ,
5759        min_p = torch .empty (1 , ),
5860        generators = generators ,
5961        max_num_logprobs = 0 ,
@@ -462,3 +464,147 @@ def estimate_rejection_sampling_pdf(
462464                           density = True )
463465
464466    return  hist .hist 
467+ 
468+ 
469+ def  _test_masked_logits (
470+     rejection_sampler ,
471+     batch_size : int ,
472+     num_draft_tokens : int ,
473+     vocab_size : int ,
474+     target_logits : torch .Tensor ,
475+     unmasked_indices : torch .Tensor ,
476+     sampling_metadata : SamplingMetadata ,
477+ ):
478+     # Set up test parameters 
479+     num_tokens  =  batch_size  *  num_draft_tokens 
480+ 
481+     # Create random draft probabilities. 
482+     draft_probs  =  torch .rand ((num_tokens , vocab_size ),
483+                              dtype = torch .float32 ,
484+                              device = DEVICE )
485+     draft_probs  =  F .softmax (draft_probs , dim = - 1 )
486+ 
487+     # Randomly sample draft token ids from draft probs 
488+     draft_token_ids  =  torch .multinomial (draft_probs , num_samples = 1 )
489+     draft_token_ids  =  draft_token_ids .reshape (batch_size , num_draft_tokens )
490+     draft_token_ids  =  draft_token_ids .tolist ()
491+ 
492+     # Bonus tokens not used but required 
493+     bonus_token_ids  =  torch .zeros ((batch_size , 1 ),
494+                                   dtype = torch .int64 ,
495+                                   device = DEVICE )
496+ 
497+     # Create spec decode metadata 
498+     spec_decode_metadata  =  SpecDecodeMetadata .make_dummy (
499+         draft_token_ids ,
500+         device = DEVICE ,
501+     )
502+ 
503+     # Run rejection sampling 
504+     output_token_ids  =  rejection_sampler (
505+         spec_decode_metadata ,
506+         draft_probs = draft_probs ,
507+         target_logits = target_logits ,
508+         bonus_token_ids = bonus_token_ids ,
509+         sampling_metadata = sampling_metadata ,
510+     )
511+ 
512+     # Remove bonus tokens and reshape 
513+     output_token_ids  =  output_token_ids [:, :- 1 ].flatten ().tolist ()
514+ 
515+     # Check that all sampled tokens are within the unmasked indices. 
516+     for  i  in  range (num_tokens ):
517+         token_id  =  output_token_ids [i ]
518+         if  token_id  ==  PLACEHOLDER_TOKEN_ID :
519+             continue 
520+         assert  token_id  in  unmasked_indices [i ]
521+ 
522+ 
523+ @pytest .mark .parametrize ("top_k" , [1 , 5 , 99 ]) 
524+ def  test_top_k (rejection_sampler , top_k ):
525+     """Test rejection sampling with top-k sampling""" 
526+     vocab_size  =  100 
527+     batch_size  =  100 
528+     num_draft_tokens  =  3 
529+     num_tokens  =  batch_size  *  num_draft_tokens 
530+ 
531+     # Randomly create top-k indices. 
532+     top_k_indices  =  [
533+         torch .randperm (vocab_size , device = DEVICE )[:top_k ]
534+         for  _  in  range (num_tokens )
535+     ]
536+     top_k_indices  =  torch .stack (top_k_indices )
537+ 
538+     # Create logits with the uniform distribution. 
539+     target_logits  =  torch .zeros ((num_tokens , vocab_size ), device = DEVICE )
540+ 
541+     # Increment the logits for top-k indices, a little bit more than the other 
542+     # ones. If the masking is effective, the non-topk indices will never be 
543+     # sampled despite the small difference in logits. 
544+     for  i  in  range (num_tokens ):
545+         target_logits [i , top_k_indices [i ]] +=  0.1 
546+ 
547+     # Create sampling metadata 
548+     temperature  =  torch .ones (batch_size , dtype = torch .float32 , device = DEVICE )
549+     sampling_metadata  =  create_sampling_metadata (
550+         all_greedy = False ,
551+         temperature = temperature ,
552+         top_k = torch .tensor ([top_k ] *  batch_size ,
553+                            device = DEVICE ,
554+                            dtype = torch .int64 ),
555+     )
556+ 
557+     _test_masked_logits (
558+         rejection_sampler ,
559+         batch_size = batch_size ,
560+         num_draft_tokens = num_draft_tokens ,
561+         vocab_size = vocab_size ,
562+         target_logits = target_logits ,
563+         unmasked_indices = top_k_indices ,
564+         sampling_metadata = sampling_metadata ,
565+     )
566+ 
567+ 
568+ @pytest .mark .parametrize ("top_p" , [0.5 , 0.9 , 0.99 ]) 
569+ def  test_top_p (rejection_sampler , top_p ):
570+     """Test rejection sampling with top-p sampling""" 
571+     vocab_size  =  100 
572+     batch_size  =  100 
573+     num_draft_tokens  =  3 
574+     num_tokens  =  batch_size  *  num_draft_tokens 
575+ 
576+     # Create logits with the uniform distribution. 
577+     target_logits  =  torch .randn ((num_tokens , vocab_size ), device = DEVICE )
578+     temperature  =  torch .ones (batch_size , dtype = torch .float32 , device = DEVICE )
579+     rescaled_logits  =  target_logits  /  temperature 
580+ 
581+     logits_sort , logits_idx  =  rescaled_logits .sort (dim = - 1 , descending = False )
582+     probs_sort  =  logits_sort .softmax (dim = - 1 )
583+     probs_sum  =  probs_sort .cumsum (dim = - 1 )
584+     top_p_mask  =  probs_sum  <=  1  -  top_p 
585+     # at least one 
586+     top_p_mask [:, - 1 ] =  False 
587+ 
588+     # Get the top-p indices. 
589+     top_p_indices  =  []
590+     for  i  in  range (num_tokens ):
591+         top_p_indices .append (logits_idx [i ][~ top_p_mask [i ]].tolist ())
592+ 
593+     # Create sampling metadata 
594+     sampling_metadata  =  create_sampling_metadata (
595+         all_greedy = False ,
596+         temperature = temperature ,
597+         top_p = torch .tensor ([top_p ] *  batch_size ,
598+                            device = DEVICE ,
599+                            dtype = torch .float32 ),
600+     )
601+ 
602+     _test_masked_logits (
603+         rejection_sampler ,
604+         batch_size = batch_size ,
605+         num_draft_tokens = num_draft_tokens ,
606+         vocab_size = vocab_size ,
607+         target_logits = target_logits ,
608+         unmasked_indices = top_p_indices ,
609+         sampling_metadata = sampling_metadata ,
610+     )
0 commit comments