@@ -437,21 +437,10 @@ def cleanup(self):
437
437
438
438
def _create_auto_model (self , config : VLLMModelConfig ) -> Optional [AsyncLLM ]:
439
439
"""
440
- Creates an instance of the async vllm model loaded from HF.
441
-
442
- Args:
443
- pretrained (str): The name or path of the pretrained model.
444
- revision (str): The revision of the model.
445
- subfolder (Optional[str], optional): The subfolder within the model. Defaults to None.
446
- max_memory (Optional[dict], optional): The maximum memory to allocate for the model per GPU. Defaults to None.
447
- device_map (Optional[dict], optional): The device mapping for the model. Defaults to None.
448
- torch_dtype (Optional[Union[str, torch.dtype]], optional): The torch data type for the model. Defaults to None.
449
- quantization_config (Optional[Union[BitsAndBytesConfig, GPTQConfig]], optional): The quantization configuration for the model. Defaults to None.
450
- trust_remote_code (bool, optional): Whether to trust remote code. Defaults to False.
451
- cache_dir (str, optional): The cache directory for the model. Defaults to "/scratch".
440
+ Creates an instance of the async vllm model loaded from HF. Requires using the v1 of VLLM.
452
441
453
442
Returns:
454
- transformers.PreTrainedModel : The created auto model instance.
443
+ AsyncLLM : The created async VLLM instance
455
444
"""
456
445
self .model_args = {
457
446
"model" : config .model_name ,
@@ -519,7 +508,7 @@ async def _async_batch(self, requests: list[GreedyUntilRequest | LoglikelihoodRe
519
508
async def greedy_until (
520
509
self ,
521
510
requests : list [GreedyUntilRequest ],
522
- override_bs : Optional [ int ] = None ,
511
+ ** kwargs ,
523
512
) -> list [GenerativeResponse ]:
524
513
"""
525
514
Generates responses using a greedy decoding strategy until certain ending conditions are met.
@@ -564,9 +553,8 @@ async def greedy_until(
564
553
async def loglikelihood (
565
554
self ,
566
555
requests : list [LoglikelihoodRequest ],
567
- override_bs : Optional [int ] = None ,
568
556
return_bool_score : bool = True ,
569
- rolling : bool = False ,
557
+ ** kwargs ,
570
558
) -> list [LoglikelihoodResponse ]:
571
559
for request in requests :
572
560
if request .context == "" :
0 commit comments