Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 39 additions & 11 deletions examples/llm-api/llm_multilora.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,35 @@
### :section Customization
### :title Generate text with multiple LoRA adapters
### :order 5

import argparse
from typing import Optional

from huggingface_hub import snapshot_download

from tensorrt_llm import LLM
from tensorrt_llm.executor import LoRARequest
from tensorrt_llm.lora_helper import LoraConfig


def main():
def main(chatbot_lora_dir: Optional[str], mental_health_lora_dir: Optional[str],
tarot_lora_dir: Optional[str]):

# Download the LoRA adapters from huggingface hub.
lora_dir1 = snapshot_download(repo_id="snshrivas10/sft-tiny-chatbot")
lora_dir2 = snapshot_download(
repo_id="givyboy/TinyLlama-1.1B-Chat-v1.0-mental-health-conversational")
lora_dir3 = snapshot_download(repo_id="barissglc/tinyllama-tarot-v1")
# Download the LoRA adapters from huggingface hub, if not provided via command line args.
if chatbot_lora_dir is None:
chatbot_lora_dir = snapshot_download(
repo_id="snshrivas10/sft-tiny-chatbot")
if mental_health_lora_dir is None:
mental_health_lora_dir = snapshot_download(
repo_id=
"givyboy/TinyLlama-1.1B-Chat-v1.0-mental-health-conversational")
if tarot_lora_dir is None:
tarot_lora_dir = snapshot_download(
repo_id="barissglc/tinyllama-tarot-v1")

# Currently, we need to pass at least one lora_dir to LLM constructor via build_config.lora_config.
# This is necessary because it requires some configuration in the lora_dir to build the engine with LoRA support.
lora_config = LoraConfig(lora_dir=[lora_dir1],
lora_config = LoraConfig(lora_dir=[chatbot_lora_dir],
max_lora_rank=64,
max_loras=3,
max_cpu_loras=3)
Expand All @@ -39,10 +50,11 @@ def main():
for output in llm.generate(prompts,
lora_request=[
None,
LoRARequest("chatbot", 1, lora_dir1), None,
LoRARequest("mental-health", 2, lora_dir2),
LoRARequest("chatbot", 1, chatbot_lora_dir),
None,
LoRARequest("tarot", 3, lora_dir3)
LoRARequest("mental-health", 2,
mental_health_lora_dir), None,
LoRARequest("tarot", 3, tarot_lora_dir)
]):
prompt = output.prompt
generated_text = output.outputs[0].text
Expand All @@ -58,4 +70,20 @@ def main():


if __name__ == '__main__':
main()
parser = argparse.ArgumentParser(
description="Generate text with multiple LoRA adapters")
parser.add_argument('--chatbot_lora_dir',
type=str,
default=None,
help='Path to the chatbot LoRA directory')
parser.add_argument('--mental_health_lora_dir',
type=str,
default=None,
help='Path to the mental health LoRA directory')
parser.add_argument('--tarot_lora_dir',
type=str,
default=None,
help='Path to the tarot LoRA directory')
args = parser.parse_args()
main(args.chatbot_lora_dir, args.mental_health_lora_dir,
args.tarot_lora_dir)
11 changes: 10 additions & 1 deletion tests/integration/defs/llmapi/test_llm_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,16 @@ def test_llmapi_example_inference_async_streaming(llm_root, engine_dir,


def test_llmapi_example_multilora(llm_root, engine_dir, llm_venv):
_run_llmapi_example(llm_root, engine_dir, llm_venv, "llm_multilora.py")
cmd_line_args = [
"--chatbot_lora_dir",
f"{llm_models_root()}/llama-models-v2/sft-tiny-chatbot",
"--mental_health_lora_dir",
f"{llm_models_root()}/llama-models-v2/TinyLlama-1.1B-Chat-v1.0-mental-health-conversational",
"--tarot_lora_dir",
f"{llm_models_root()}/llama-models-v2/tinyllama-tarot-v1"
]
_run_llmapi_example(llm_root, engine_dir, llm_venv, "llm_multilora.py",
*cmd_line_args)


def test_llmapi_example_guided_decoding(llm_root, engine_dir, llm_venv):
Expand Down