diff --git a/README.md b/README.md index ad8087ee60..6dfe9c89b7 100644 --- a/README.md +++ b/README.md @@ -216,7 +216,7 @@ List of command-line flags | Flag | Description | |--------------------------------------------|-------------| -| `--loader LOADER` | Choose the model loader manually, otherwise, it will get autodetected. Valid options: Transformers, llama.cpp, llamacpp_HF, ExLlama_HF, ExLlamav2_HF, AutoGPTQ, AutoAWQ, GPTQ-for-LLaMa, ExLlama, ExLlamav2, ctransformers, QuIP#. | +| `--loader LOADER` | Choose the model loader manually, otherwise, it will get autodetected. Valid options: Transformers, llama.cpp, llamacpp_HF, ExLlama_HF, ExLlamav2_HF, AutoGPTQ, AutoAWQ, GPTQ-for-LLaMa, ExLlama, ExLlamav2, ctransformers, QuIP#, vllm. | #### Accelerate/transformers @@ -320,6 +320,16 @@ List of command-line flags | `--rwkv-strategy RWKV_STRATEGY` | RWKV: The strategy to use while loading the model. Examples: "cpu fp32", "cuda fp16", "cuda fp16i8". | | `--rwkv-cuda-on` | RWKV: Compile the CUDA kernel for better performance. | +#### VLLM + +| Flag | Description | +|------------------|-------------| +| `--max-model-len MAX_MODEL_LEN` | Model context length. If unspecified, will be automatically derived from the model config.| +| `--dtype “float16”` | Data type for model weights and activations.“auto” will use FP16 precision for FP32 and FP16 models, and BF16 precision for BF16 models.| +| `--gpu-memory-utilization ` | The percentage of GPU memory to be used for the model executor.| + +Refer https://docs.vllm.ai/en/latest/models/engine_args.html for more details. All arguments can simply be passed to textgen webui. + #### RoPE (for llama.cpp, ExLlama, ExLlamaV2, and transformers) | Flag | Description | diff --git a/extensions/openai/typing.py b/extensions/openai/typing.py index 3a212dd9c2..941429609f 100644 --- a/extensions/openai/typing.py +++ b/extensions/openai/typing.py @@ -8,7 +8,7 @@ class GenerationOptions(BaseModel): preset: str | None = Field(default=None, description="The name of a file under text-generation-webui/presets (without the .yaml extension). The sampling parameters that get overwritten by this option are the keys in the default_preset() function in modules/presets.py.") min_p: float = 0 - top_k: int = 0 + top_k: int = 1 repetition_penalty: float = 1 repetition_penalty_range: int = 1024 typical_p: float = 1 diff --git a/modules/loaders.py b/modules/loaders.py index 9f1c70d121..967b7b2715 100644 --- a/modules/loaders.py +++ b/modules/loaders.py @@ -155,6 +155,9 @@ 'trust_remote_code', 'no_use_fast', 'no_flash_attn', + ], + 'Vllm': [ + # Use Vllm's default settings ] }) @@ -503,6 +506,9 @@ 'skip_special_tokens', 'auto_max_new_tokens', }, + 'Vllm': { + # Use Vllm's default settings + }, } loaders_model_types = { diff --git a/modules/models.py b/modules/models.py index 49e5f818fa..575cb9ead2 100644 --- a/modules/models.py +++ b/modules/models.py @@ -73,6 +73,7 @@ def load_model(model_name, loader=None): 'ctransformers': ctransformers_loader, 'AutoAWQ': AutoAWQ_loader, 'QuIP#': QuipSharp_loader, + 'Vllm': Vllm_loader, } metadata = get_model_metadata(model_name) @@ -427,6 +428,11 @@ def RWKV_loader(model_name): tokenizer = RWKVTokenizer.from_pretrained(Path(shared.args.model_dir)) return model, tokenizer +def Vllm_loader(model_name): + from modules.vllm import VllmModel + + return VllmModel.from_pretrained(model_name) + def get_max_memory_dict(): max_memory = {} diff --git a/modules/shared.py b/modules/shared.py index edd74af132..aa03dbbbf9 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -191,7 +191,9 @@ parser.add_argument('--llama_cpp_seed', type=int, default=0, help='DEPRECATED') parser.add_argument('--use_fast', action='store_true', help='DEPRECATED') -args = parser.parse_args() +args, unknown = parser.parse_known_args() +if unknown: + logger.warning(f'Textgen-webui has been provided with unknown arguments: {unknown}') args_defaults = parser.parse_args([]) provided_arguments = [] for arg in sys.argv[1:]: @@ -246,6 +248,8 @@ def fix_loader_name(name): return 'AutoAWQ' elif name in ['quip#', 'quip-sharp', 'quipsharp', 'quip_sharp']: return 'QuIP#' + elif name in ['vllm', 'Vllm', 'VLLM']: + return 'Vllm' def add_extension(name, last=False): diff --git a/modules/text_generation.py b/modules/text_generation.py index 72ccf99600..e306612e08 100644 --- a/modules/text_generation.py +++ b/modules/text_generation.py @@ -44,7 +44,7 @@ def _generate_reply(question, state, stopping_strings=None, is_chat=False, escap yield '' return - if shared.model.__class__.__name__ in ['LlamaCppModel', 'RWKVModel', 'ExllamaModel', 'Exllamav2Model', 'CtransformersModel']: + if shared.model.__class__.__name__ in ['LlamaCppModel', 'RWKVModel', 'ExllamaModel', 'Exllamav2Model', 'CtransformersModel', 'VllmModel']: generate_func = generate_reply_custom else: generate_func = generate_reply_HF @@ -115,7 +115,7 @@ def encode(prompt, add_special_tokens=True, add_bos_token=True, truncation_lengt if shared.tokenizer is None: raise ValueError('No tokenizer is loaded') - if shared.model.__class__.__name__ in ['LlamaCppModel', 'RWKVModel', 'CtransformersModel', 'Exllamav2Model']: + if shared.model.__class__.__name__ in ['LlamaCppModel', 'RWKVModel', 'CtransformersModel', 'Exllamav2Model', 'VllmModel']: input_ids = shared.tokenizer.encode(str(prompt)) if shared.model.__class__.__name__ not in ['Exllamav2Model']: input_ids = np.array(input_ids).reshape(1, len(input_ids)) @@ -129,7 +129,7 @@ def encode(prompt, add_special_tokens=True, add_bos_token=True, truncation_lengt if truncation_length is not None: input_ids = input_ids[:, -truncation_length:] - if shared.model.__class__.__name__ in ['LlamaCppModel', 'RWKVModel', 'ExllamaModel', 'Exllamav2Model', 'CtransformersModel'] or shared.args.cpu: + if shared.model.__class__.__name__ in ['LlamaCppModel', 'RWKVModel', 'ExllamaModel', 'Exllamav2Model', 'CtransformersModel', 'VllmModel'] or shared.args.cpu: return input_ids elif shared.args.deepspeed: return input_ids.to(device=local_rank) diff --git a/modules/vllm.py b/modules/vllm.py new file mode 100644 index 0000000000..c89108e47c --- /dev/null +++ b/modules/vllm.py @@ -0,0 +1,158 @@ + +from pathlib import Path +import argparse +import threading + +from typing import List +from vllm import EngineArgs, LLMEngine, SamplingParams, RequestOutput +from vllm.utils import random_uuid + +from modules import shared +from modules.logging_colors import logger + +__VLLM_DEBUG__ = False + +# Lock vllm to prevent multiple threads from using it at the same time +class LockContextManager: + def __init__(self, lock): + self.lock = lock + + def __enter__(self): + self.lock.acquire() + + def __exit__(self, exc_type, exc_value, exc_traceback): + self.lock.release() + +class VllmModel: + + def __init__(self): + self.inference_lock = threading.Lock() + pass + + @classmethod + def from_pretrained(self, path_to_model): + + # Parse the arguments, but ignore textgen arguments, only parse vllm arguments + vllm_parser = argparse.ArgumentParser( + description='VllmModel underlyingly uses the Vllm LLMEngine class directly, we will use Vllm argparser to parse the arguments instead') + vllm_parser = EngineArgs.add_cli_args(vllm_parser) + vllm_args, unknown = vllm_parser.parse_known_args() + + # Check if the model exists + path_to_model = Path(f'{shared.args.model_dir}') / Path(path_to_model) + assert path_to_model.exists(), f'Model {path_to_model} does not exist' + + # Set the model path + vllm_args.model = str(path_to_model.absolute()) + + # Log the parsed arguments + logger.info(f'Parsed vllm_args : {vllm_args}') + # Create an engine from the parsed arguments + engine_args = EngineArgs.from_cli_args(vllm_args) + engine = LLMEngine.from_engine_args(engine_args) + + # Create a result object + result = self() + # Set the engine and tokenizer + result.engine = engine + result.tokenizer = engine.tokenizer + + # Log the loaded model + logger.info(f'Loaded model into \n{result.engine}, \n{result.tokenizer}') + + # Return the result object and the tokenizer + return result, result.tokenizer + + + def generate_with_streaming(self, prompt, state): + + # Generate vllm's own sampling settings from textgen state + settings = SamplingParams() + for key, value in state.items(): + if hasattr(settings, key) and value is not None: + setattr(settings, key, value) + if __VLLM_DEBUG__: + logger.debug(f'Setting {key} to {value}') + + # use Vllm's own verification method to verify the settings + try: + settings._verify_args() + except ValueError as e: + settings = SamplingParams() + logger.warning(f'Vllm Error verifying settings, useing default sampler settings instead {settings}: {e}') + + # Get prompt token prompt_token_ids + prompt_token_ids = self.tokenizer.encode(prompt) + # Get max new tokens + if state['auto_max_new_tokens']: + max_new_tokens = state['truncation_length'] - len(prompt_token_ids) + else: + max_new_tokens = state['max_new_tokens'] + if max_new_tokens < 0: + logger.warning(f'Max new tokens {max_new_tokens} < 0, setting to 0') + max_new_tokens = 0 + settings.max_tokens = max_new_tokens + + if __VLLM_DEBUG__: + logger.debug(f'Generating with streaming, max_tokens={settings.max_tokens}') + logger.debug(f'Prompt token ids {prompt_token_ids}') + logger.debug(f'Prompt token ids length {len(prompt_token_ids)}') + logger.debug(f'settings {settings}') + + # Can only handle 1 sample per generation + assert settings.n == 1, f'Only 1 sample per generation is supported, got {settings.n}' + + # request_id as random uuid + request_id = f"{random_uuid()}" + with LockContextManager(self.inference_lock): + # Add request to vllm engine + self.engine.add_request(request_id=request_id, + prompt=prompt, + sampling_params=settings, + prompt_token_ids=prompt_token_ids) + + while True: + # Abort generation if we are stopping everything + if shared.stop_everything: + with LockContextManager(self.inference_lock): + self.engine.abort(request_id) + if __VLLM_DEBUG__: + logger.debug(f'Aborted generation') + break + + # Get the next output + target_request_output = None + with LockContextManager(self.inference_lock): + # Call vllm engine step to generate next token + request_outputs: List[RequestOutput] = self.engine.step() + + # Find our request output (should be trivial because we only server 1 batch at a time) + for request_output in request_outputs: + if request_output.request_id != request_id: + if __VLLM_DEBUG__: + logger.warning(f'Request id mismatch, expected {request_id}, got {request_output.request_id}') + continue + # Can only handle 1 sample per generation + assert len(request_output.outputs) == 1, f'Only 1 sample per generation is supported, got {len(request_output.outputs)}' + target_request_output = request_output + + # Get the output + output = target_request_output.outputs[0] + decoded_text = output.text + if shared.args.verbose and __VLLM_DEBUG__: + logger.debug(f'{decoded_text}') + yield decoded_text + + # Quit if we are finished + if target_request_output.finished: + if __VLLM_DEBUG__: + logger.debug(f'Finished generation') + break + + + def generate(self, prompt, state): + output = '' + for output in self.generate_with_streaming(prompt, state): + pass + + return output