Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ List of command-line flags

| Flag | Description |
|--------------------------------------------|-------------|
| `--loader LOADER` | Choose the model loader manually, otherwise, it will get autodetected. Valid options: Transformers, llama.cpp, llamacpp_HF, ExLlama_HF, ExLlamav2_HF, AutoGPTQ, AutoAWQ, GPTQ-for-LLaMa, ExLlama, ExLlamav2, ctransformers, QuIP#. |
| `--loader LOADER` | Choose the model loader manually, otherwise, it will get autodetected. Valid options: Transformers, llama.cpp, llamacpp_HF, ExLlama_HF, ExLlamav2_HF, AutoGPTQ, AutoAWQ, GPTQ-for-LLaMa, ExLlama, ExLlamav2, ctransformers, QuIP#, vllm. |

#### Accelerate/transformers

Expand Down Expand Up @@ -320,6 +320,16 @@ List of command-line flags
| `--rwkv-strategy RWKV_STRATEGY` | RWKV: The strategy to use while loading the model. Examples: "cpu fp32", "cuda fp16", "cuda fp16i8". |
| `--rwkv-cuda-on` | RWKV: Compile the CUDA kernel for better performance. |

#### VLLM

| Flag | Description |
|------------------|-------------|
| `--max-model-len MAX_MODEL_LEN` | Model context length. If unspecified, will be automatically derived from the model config.|
| `--dtype “float16”` | Data type for model weights and activations.“auto” will use FP16 precision for FP32 and FP16 models, and BF16 precision for BF16 models.|
| `--gpu-memory-utilization <percentage>` | The percentage of GPU memory to be used for the model executor.|

Refer https://docs.vllm.ai/en/latest/models/engine_args.html for more details. All arguments can simply be passed to textgen webui.

#### RoPE (for llama.cpp, ExLlama, ExLlamaV2, and transformers)

| Flag | Description |
Expand Down
2 changes: 1 addition & 1 deletion extensions/openai/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
class GenerationOptions(BaseModel):
preset: str | None = Field(default=None, description="The name of a file under text-generation-webui/presets (without the .yaml extension). The sampling parameters that get overwritten by this option are the keys in the default_preset() function in modules/presets.py.")
min_p: float = 0
top_k: int = 0
top_k: int = 1
repetition_penalty: float = 1
repetition_penalty_range: int = 1024
typical_p: float = 1
Expand Down
6 changes: 6 additions & 0 deletions modules/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,9 @@
'trust_remote_code',
'no_use_fast',
'no_flash_attn',
],
'Vllm': [
# Use Vllm's default settings
]
})

Expand Down Expand Up @@ -503,6 +506,9 @@
'skip_special_tokens',
'auto_max_new_tokens',
},
'Vllm': {
# Use Vllm's default settings
},
}

loaders_model_types = {
Expand Down
6 changes: 6 additions & 0 deletions modules/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def load_model(model_name, loader=None):
'ctransformers': ctransformers_loader,
'AutoAWQ': AutoAWQ_loader,
'QuIP#': QuipSharp_loader,
'Vllm': Vllm_loader,
}

metadata = get_model_metadata(model_name)
Expand Down Expand Up @@ -427,6 +428,11 @@ def RWKV_loader(model_name):
tokenizer = RWKVTokenizer.from_pretrained(Path(shared.args.model_dir))
return model, tokenizer

def Vllm_loader(model_name):
from modules.vllm import VllmModel

return VllmModel.from_pretrained(model_name)


def get_max_memory_dict():
max_memory = {}
Expand Down
6 changes: 5 additions & 1 deletion modules/shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,9 @@
parser.add_argument('--llama_cpp_seed', type=int, default=0, help='DEPRECATED')
parser.add_argument('--use_fast', action='store_true', help='DEPRECATED')

args = parser.parse_args()
args, unknown = parser.parse_known_args()
if unknown:
logger.warning(f'Textgen-webui has been provided with unknown arguments: {unknown}')
args_defaults = parser.parse_args([])
provided_arguments = []
for arg in sys.argv[1:]:
Expand Down Expand Up @@ -246,6 +248,8 @@ def fix_loader_name(name):
return 'AutoAWQ'
elif name in ['quip#', 'quip-sharp', 'quipsharp', 'quip_sharp']:
return 'QuIP#'
elif name in ['vllm', 'Vllm', 'VLLM']:
return 'Vllm'


def add_extension(name, last=False):
Expand Down
6 changes: 3 additions & 3 deletions modules/text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def _generate_reply(question, state, stopping_strings=None, is_chat=False, escap
yield ''
return

if shared.model.__class__.__name__ in ['LlamaCppModel', 'RWKVModel', 'ExllamaModel', 'Exllamav2Model', 'CtransformersModel']:
if shared.model.__class__.__name__ in ['LlamaCppModel', 'RWKVModel', 'ExllamaModel', 'Exllamav2Model', 'CtransformersModel', 'VllmModel']:
generate_func = generate_reply_custom
else:
generate_func = generate_reply_HF
Expand Down Expand Up @@ -115,7 +115,7 @@ def encode(prompt, add_special_tokens=True, add_bos_token=True, truncation_lengt
if shared.tokenizer is None:
raise ValueError('No tokenizer is loaded')

if shared.model.__class__.__name__ in ['LlamaCppModel', 'RWKVModel', 'CtransformersModel', 'Exllamav2Model']:
if shared.model.__class__.__name__ in ['LlamaCppModel', 'RWKVModel', 'CtransformersModel', 'Exllamav2Model', 'VllmModel']:
input_ids = shared.tokenizer.encode(str(prompt))
if shared.model.__class__.__name__ not in ['Exllamav2Model']:
input_ids = np.array(input_ids).reshape(1, len(input_ids))
Expand All @@ -129,7 +129,7 @@ def encode(prompt, add_special_tokens=True, add_bos_token=True, truncation_lengt
if truncation_length is not None:
input_ids = input_ids[:, -truncation_length:]

if shared.model.__class__.__name__ in ['LlamaCppModel', 'RWKVModel', 'ExllamaModel', 'Exllamav2Model', 'CtransformersModel'] or shared.args.cpu:
if shared.model.__class__.__name__ in ['LlamaCppModel', 'RWKVModel', 'ExllamaModel', 'Exllamav2Model', 'CtransformersModel', 'VllmModel'] or shared.args.cpu:
return input_ids
elif shared.args.deepspeed:
return input_ids.to(device=local_rank)
Expand Down
158 changes: 158 additions & 0 deletions modules/vllm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@

from pathlib import Path
import argparse
import threading

from typing import List
from vllm import EngineArgs, LLMEngine, SamplingParams, RequestOutput
from vllm.utils import random_uuid

from modules import shared
from modules.logging_colors import logger

__VLLM_DEBUG__ = False

# Lock vllm to prevent multiple threads from using it at the same time
class LockContextManager:
def __init__(self, lock):
self.lock = lock

def __enter__(self):
self.lock.acquire()

def __exit__(self, exc_type, exc_value, exc_traceback):
self.lock.release()

class VllmModel:

def __init__(self):
self.inference_lock = threading.Lock()
pass

@classmethod
def from_pretrained(self, path_to_model):

# Parse the arguments, but ignore textgen arguments, only parse vllm arguments
vllm_parser = argparse.ArgumentParser(
description='VllmModel underlyingly uses the Vllm LLMEngine class directly, we will use Vllm argparser to parse the arguments instead')
vllm_parser = EngineArgs.add_cli_args(vllm_parser)
vllm_args, unknown = vllm_parser.parse_known_args()

# Check if the model exists
path_to_model = Path(f'{shared.args.model_dir}') / Path(path_to_model)
assert path_to_model.exists(), f'Model {path_to_model} does not exist'

# Set the model path
vllm_args.model = str(path_to_model.absolute())

# Log the parsed arguments
logger.info(f'Parsed vllm_args : {vllm_args}')
# Create an engine from the parsed arguments
engine_args = EngineArgs.from_cli_args(vllm_args)
engine = LLMEngine.from_engine_args(engine_args)

# Create a result object
result = self()
# Set the engine and tokenizer
result.engine = engine
result.tokenizer = engine.tokenizer

# Log the loaded model
logger.info(f'Loaded model into \n{result.engine}, \n{result.tokenizer}')

# Return the result object and the tokenizer
return result, result.tokenizer


def generate_with_streaming(self, prompt, state):

# Generate vllm's own sampling settings from textgen state
settings = SamplingParams()
for key, value in state.items():
if hasattr(settings, key) and value is not None:
setattr(settings, key, value)
if __VLLM_DEBUG__:
logger.debug(f'Setting {key} to {value}')

# use Vllm's own verification method to verify the settings
try:
settings._verify_args()
except ValueError as e:
settings = SamplingParams()
logger.warning(f'Vllm Error verifying settings, useing default sampler settings instead {settings}: {e}')

# Get prompt token prompt_token_ids
prompt_token_ids = self.tokenizer.encode(prompt)
# Get max new tokens
if state['auto_max_new_tokens']:
max_new_tokens = state['truncation_length'] - len(prompt_token_ids)
else:
max_new_tokens = state['max_new_tokens']
if max_new_tokens < 0:
logger.warning(f'Max new tokens {max_new_tokens} < 0, setting to 0')
max_new_tokens = 0
settings.max_tokens = max_new_tokens

if __VLLM_DEBUG__:
logger.debug(f'Generating with streaming, max_tokens={settings.max_tokens}')
logger.debug(f'Prompt token ids {prompt_token_ids}')
logger.debug(f'Prompt token ids length {len(prompt_token_ids)}')
logger.debug(f'settings {settings}')

# Can only handle 1 sample per generation
assert settings.n == 1, f'Only 1 sample per generation is supported, got {settings.n}'

# request_id as random uuid
request_id = f"{random_uuid()}"
with LockContextManager(self.inference_lock):
# Add request to vllm engine
self.engine.add_request(request_id=request_id,
prompt=prompt,
sampling_params=settings,
prompt_token_ids=prompt_token_ids)

while True:
# Abort generation if we are stopping everything
if shared.stop_everything:
with LockContextManager(self.inference_lock):
self.engine.abort(request_id)
if __VLLM_DEBUG__:
logger.debug(f'Aborted generation')
break

# Get the next output
target_request_output = None
with LockContextManager(self.inference_lock):
# Call vllm engine step to generate next token
request_outputs: List[RequestOutput] = self.engine.step()

# Find our request output (should be trivial because we only server 1 batch at a time)
for request_output in request_outputs:
if request_output.request_id != request_id:
if __VLLM_DEBUG__:
logger.warning(f'Request id mismatch, expected {request_id}, got {request_output.request_id}')
continue
# Can only handle 1 sample per generation
assert len(request_output.outputs) == 1, f'Only 1 sample per generation is supported, got {len(request_output.outputs)}'
target_request_output = request_output

# Get the output
output = target_request_output.outputs[0]
decoded_text = output.text
if shared.args.verbose and __VLLM_DEBUG__:
logger.debug(f'{decoded_text}')
yield decoded_text

# Quit if we are finished
if target_request_output.finished:
if __VLLM_DEBUG__:
logger.debug(f'Finished generation')
break


def generate(self, prompt, state):
output = ''
for output in self.generate_with_streaming(prompt, state):
pass

return output