diff --git a/ragfoundry/models/vllm.py b/ragfoundry/models/vllm.py index 269cdb9..19f2142 100644 --- a/ragfoundry/models/vllm.py +++ b/ragfoundry/models/vllm.py @@ -27,7 +27,7 @@ def __init__( self, model_name_or_path: str, instruction: Path, - instruct_in_prompt: False, + instruct_in_prompt: bool = False, template: Path = None, num_gpus: int = 1, llm_params: Dict = {}, @@ -39,19 +39,17 @@ def __init__( ) from vllm import LLM, SamplingParams - logger.info(f"Using the following instruction: {self.instruction}") - + self.model_name = model_name_or_path self.instruct_in_prompt = instruct_in_prompt self.template = open(template).read() if template else None self.instruction = open(instruction).read() + logger.info(f"Using the following instruction: {self.instruction}") self.sampling_params = SamplingParams(**generation) - self.llm = LLM( - model=model_name_or_path, tensor_parallel_size=num_gpus, **llm_params - ) - if self.instruct_in_prompt: - self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) - self.config = AutoConfig.from_pretrained(self.model_name) + self.llm = LLM(model=self.model_name, tensor_parallel_size=num_gpus, **llm_params) + + self.tokenizer = AutoTokenizer.from_pretrained(self.model_name) + self.config = AutoConfig.from_pretrained(self.model_name) def generate(self, prompt: str) -> str: """ @@ -59,8 +57,9 @@ def generate(self, prompt: str) -> str: """ if self.template: prompt = self.template.format(instruction=self.instruction, query=prompt) - elif self.instruct_in_prompt: - prompt = self.instruction + "\n" + prompt + else: + if self.instruct_in_prompt: + prompt = self.instruction + "\n" + prompt messages = [ {"role": "system", "content": self.instruction}, {"role": "user", "content": prompt}, @@ -72,7 +71,7 @@ def generate(self, prompt: str) -> str: add_generation_prompt=True, truncation=True, max_length=( - self.config.max_position_embeddings - self.sampling_param.max_tokens + self.config.max_position_embeddings - self.sampling_params.max_tokens ), )