Skip to content

Commit

Permalink
Fixes in vLLM
Browse files Browse the repository at this point in the history
`instruct_in_prompt` is a boolean option; when set, system instruction
will be prefixed to the user role, for models without a system role.
Default is `False`, which means system instruction will be inserted as
the system role.
  • Loading branch information
danielfleischer committed Aug 25, 2024
1 parent d9b9c67 commit ac98c58
Showing 1 changed file with 11 additions and 12 deletions.
23 changes: 11 additions & 12 deletions ragfoundry/models/vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def __init__(
self,
model_name_or_path: str,
instruction: Path,
instruct_in_prompt: False,
instruct_in_prompt: bool = False,
template: Path = None,
num_gpus: int = 1,
llm_params: Dict = {},
Expand All @@ -39,28 +39,27 @@ def __init__(
)
from vllm import LLM, SamplingParams

logger.info(f"Using the following instruction: {self.instruction}")

self.model_name = model_name_or_path
self.instruct_in_prompt = instruct_in_prompt
self.template = open(template).read() if template else None
self.instruction = open(instruction).read()
logger.info(f"Using the following instruction: {self.instruction}")

self.sampling_params = SamplingParams(**generation)
self.llm = LLM(
model=model_name_or_path, tensor_parallel_size=num_gpus, **llm_params
)
if self.instruct_in_prompt:
self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
self.config = AutoConfig.from_pretrained(self.model_name)
self.llm = LLM(model=self.model_name, tensor_parallel_size=num_gpus, **llm_params)

self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
self.config = AutoConfig.from_pretrained(self.model_name)

def generate(self, prompt: str) -> str:
"""
Generates text based on the given prompt.
"""
if self.template:
prompt = self.template.format(instruction=self.instruction, query=prompt)
elif self.instruct_in_prompt:
prompt = self.instruction + "\n" + prompt
else:
if self.instruct_in_prompt:
prompt = self.instruction + "\n" + prompt
messages = [
{"role": "system", "content": self.instruction},
{"role": "user", "content": prompt},
Expand All @@ -72,7 +71,7 @@ def generate(self, prompt: str) -> str:
add_generation_prompt=True,
truncation=True,
max_length=(
self.config.max_position_embeddings - self.sampling_param.max_tokens
self.config.max_position_embeddings - self.sampling_params.max_tokens
),
)

Expand Down

0 comments on commit ac98c58

Please sign in to comment.