@@ -437,21 +437,10 @@ def cleanup(self):
437
437
438
438
def _create_auto_model (self , config : VLLMModelConfig ) -> Optional [AsyncLLM ]:
439
439
"""
440
- Creates an instance of the async vllm model loaded from HF.
441
-
442
- Args:
443
- pretrained (str): The name or path of the pretrained model.
444
- revision (str): The revision of the model.
445
- subfolder (Optional[str], optional): The subfolder within the model. Defaults to None.
446
- max_memory (Optional[dict], optional): The maximum memory to allocate for the model per GPU. Defaults to None.
447
- device_map (Optional[dict], optional): The device mapping for the model. Defaults to None.
448
- torch_dtype (Optional[Union[str, torch.dtype]], optional): The torch data type for the model. Defaults to None.
449
- quantization_config (Optional[Union[BitsAndBytesConfig, GPTQConfig]], optional): The quantization configuration for the model. Defaults to None.
450
- trust_remote_code (bool, optional): Whether to trust remote code. Defaults to False.
451
- cache_dir (str, optional): The cache directory for the model. Defaults to "/scratch".
440
+ Creates an instance of the async vllm model loaded from HF. Requires using the v1 of VLLM.
452
441
453
442
Returns:
454
- transformers.PreTrainedModel : The created auto model instance.
443
+ AsyncLLM : The created async VLLM instance
455
444
"""
456
445
self .model_args = {
457
446
"model" : config .model_name ,
@@ -519,14 +508,13 @@ async def _async_batch(self, requests: list[GreedyUntilRequest | LoglikelihoodRe
519
508
async def greedy_until (
520
509
self ,
521
510
requests : list [GreedyUntilRequest ],
522
- override_bs : Optional [ int ] = None ,
511
+ ** kwargs ,
523
512
) -> list [GenerativeResponse ]:
524
513
"""
525
514
Generates responses using a greedy decoding strategy until certain ending conditions are met.
526
515
527
516
Args:
528
517
requests (list[Request]): list of requests containing the context and ending conditions.
529
- override_bs (int, optional): Override the batch size for generation. Defaults to None.
530
518
531
519
Returns:
532
520
list[GenerateReturn]: list of generated responses.
@@ -564,10 +552,20 @@ async def greedy_until(
564
552
async def loglikelihood (
565
553
self ,
566
554
requests : list [LoglikelihoodRequest ],
567
- override_bs : Optional [int ] = None ,
568
555
return_bool_score : bool = True ,
569
- rolling : bool = False ,
556
+ ** kwargs ,
570
557
) -> list [LoglikelihoodResponse ]:
558
+ """
559
+ Generates responses using a greedy decoding strategy until certain ending conditions are met and
560
+ stores the logprobs.
561
+
562
+ Args:
563
+ requests (list[Request]): list of requests containing the context and ending conditions.
564
+
565
+ Returns:
566
+ list[LoglikelihoodResponse]: list of generated responses.
567
+ """
568
+
571
569
for request in requests :
572
570
if request .context == "" :
573
571
request .tokenized_context = [self .tokenizer .eos_token_id ]
0 commit comments