@@ -55,9 +55,7 @@ def get_spec_metadata(spec_config,
5555    return  None 
5656
5757
58- def  get_spec_resource_manager (model_engine ,
59-                               draft_model_engine = None ,
60-                               drafter = None ):
58+ def  get_spec_resource_manager (model_engine , draft_model_engine = None ):
6159    spec_config  =  model_engine .spec_config 
6260    if  spec_config  is  None :
6361        return  None 
@@ -93,9 +91,10 @@ def get_spec_resource_manager(model_engine,
9391            max_seq_len ,
9492            max_num_tokens ,
9593        )
96-     if  spec_dec_mode .is_ngram () or  spec_dec_mode .is_user_provided ():
97-         assert  drafter  is  not None , "Drafter is required for ngram or user provided speculative decoding." 
98-         return  drafter .spec_resource_manager 
94+     if  spec_dec_mode .is_ngram ():
95+         return  NGramPoolManager (spec_config , max_num_requests )
96+     if  spec_dec_mode .is_user_provided ():
97+         return  spec_config .resource_manager 
9998    return  None 
10099
101100
@@ -113,14 +112,12 @@ def get_spec_decoder(sampler_args: TorchSampler.Args,
113112        f"Unsupported speculative decoding mode: { spec_config .spec_dec_mode }  )
114113
115114
116- def  get_spec_drafter (model_engine ):
115+ def  get_spec_drafter (model_engine ,  spec_resource_manager ):
117116    spec_config  =  model_engine .spec_config 
118-     max_num_requests  =  model_engine .batch_size 
119117    if  spec_config  is  None :
120118        return  None 
121119    if  spec_config .spec_dec_mode .is_ngram ():
122-         return  NGramDrafter (spec_config ,
123-                             NGramPoolManager (spec_config , max_num_requests ))
120+         return  NGramDrafter (spec_config , spec_resource_manager )
124121    if  spec_config .spec_dec_mode .is_user_provided ():
125122        return  spec_config .drafter 
126123    return  None 
0 commit comments