@@ -60,9 +60,7 @@ def get_spec_metadata(spec_config,
6060 return None
6161
6262
63- def get_spec_resource_manager (model_engine ,
64- draft_model_engine = None ,
65- drafter = None ):
63+ def get_spec_resource_manager (model_engine , draft_model_engine = None ):
6664 spec_config = model_engine .spec_config
6765 if spec_config is None :
6866 return None
@@ -98,9 +96,10 @@ def get_spec_resource_manager(model_engine,
9896 max_seq_len ,
9997 max_num_tokens ,
10098 )
101- if spec_dec_mode .is_ngram () or spec_dec_mode .is_user_provided ():
102- assert drafter is not None , "Drafter is required for ngram or user provided speculative decoding."
103- return drafter .spec_resource_manager
99+ if spec_dec_mode .is_ngram ():
100+ return NGramPoolManager (spec_config , max_num_requests )
101+ if spec_dec_mode .is_user_provided ():
102+ return spec_config .resource_manager
104103 return None
105104
106105
@@ -117,16 +116,13 @@ def get_spec_decoder(sampler_args: TorchSampler.Args, spec_config: SpecConfig):
117116 f"Unsupported speculative decoding mode: { spec_config .spec_dec_mode } " )
118117
119118
120- def get_spec_drafter (model_engine ):
119+ def get_spec_drafter (model_engine , spec_resource_manager ):
121120 spec_config = model_engine .spec_config
122- max_num_requests = model_engine .batch_size
121+ model_engine .batch_size
123122 if spec_config is None :
124123 return None
125124 if spec_config .spec_dec_mode .is_ngram ():
126- return NGramDrafter (spec_config ,
127- NGramPoolManager (spec_config , max_num_requests ))
128- if spec_config .spec_dec_mode .is_user_provided ():
129- return spec_config .drafter
125+ return NGramDrafter (spec_config , spec_resource_manager )
130126 return None
131127
132128
0 commit comments