diff --git a/examples/llm-api/llm_multilora.py b/examples/llm-api/llm_multilora.py index 0c6fa4f5417..525f839d79d 100644 --- a/examples/llm-api/llm_multilora.py +++ b/examples/llm-api/llm_multilora.py @@ -1,6 +1,10 @@ ### :section Customization ### :title Generate text with multiple LoRA adapters ### :order 5 + +import argparse +from typing import Optional + from huggingface_hub import snapshot_download from tensorrt_llm import LLM @@ -8,17 +12,24 @@ from tensorrt_llm.lora_helper import LoraConfig -def main(): +def main(chatbot_lora_dir: Optional[str], mental_health_lora_dir: Optional[str], + tarot_lora_dir: Optional[str]): - # Download the LoRA adapters from huggingface hub. - lora_dir1 = snapshot_download(repo_id="snshrivas10/sft-tiny-chatbot") - lora_dir2 = snapshot_download( - repo_id="givyboy/TinyLlama-1.1B-Chat-v1.0-mental-health-conversational") - lora_dir3 = snapshot_download(repo_id="barissglc/tinyllama-tarot-v1") + # Download the LoRA adapters from huggingface hub, if not provided via command line args. + if chatbot_lora_dir is None: + chatbot_lora_dir = snapshot_download( + repo_id="snshrivas10/sft-tiny-chatbot") + if mental_health_lora_dir is None: + mental_health_lora_dir = snapshot_download( + repo_id= + "givyboy/TinyLlama-1.1B-Chat-v1.0-mental-health-conversational") + if tarot_lora_dir is None: + tarot_lora_dir = snapshot_download( + repo_id="barissglc/tinyllama-tarot-v1") # Currently, we need to pass at least one lora_dir to LLM constructor via build_config.lora_config. # This is necessary because it requires some configuration in the lora_dir to build the engine with LoRA support. - lora_config = LoraConfig(lora_dir=[lora_dir1], + lora_config = LoraConfig(lora_dir=[chatbot_lora_dir], max_lora_rank=64, max_loras=3, max_cpu_loras=3) @@ -39,10 +50,11 @@ def main(): for output in llm.generate(prompts, lora_request=[ None, - LoRARequest("chatbot", 1, lora_dir1), None, - LoRARequest("mental-health", 2, lora_dir2), + LoRARequest("chatbot", 1, chatbot_lora_dir), None, - LoRARequest("tarot", 3, lora_dir3) + LoRARequest("mental-health", 2, + mental_health_lora_dir), None, + LoRARequest("tarot", 3, tarot_lora_dir) ]): prompt = output.prompt generated_text = output.outputs[0].text @@ -58,4 +70,20 @@ def main(): if __name__ == '__main__': - main() + parser = argparse.ArgumentParser( + description="Generate text with multiple LoRA adapters") + parser.add_argument('--chatbot_lora_dir', + type=str, + default=None, + help='Path to the chatbot LoRA directory') + parser.add_argument('--mental_health_lora_dir', + type=str, + default=None, + help='Path to the mental health LoRA directory') + parser.add_argument('--tarot_lora_dir', + type=str, + default=None, + help='Path to the tarot LoRA directory') + args = parser.parse_args() + main(args.chatbot_lora_dir, args.mental_health_lora_dir, + args.tarot_lora_dir) diff --git a/tests/integration/defs/llmapi/test_llm_examples.py b/tests/integration/defs/llmapi/test_llm_examples.py index 1935acda092..f06c153b3b6 100644 --- a/tests/integration/defs/llmapi/test_llm_examples.py +++ b/tests/integration/defs/llmapi/test_llm_examples.py @@ -110,7 +110,16 @@ def test_llmapi_example_inference_async_streaming(llm_root, engine_dir, def test_llmapi_example_multilora(llm_root, engine_dir, llm_venv): - _run_llmapi_example(llm_root, engine_dir, llm_venv, "llm_multilora.py") + cmd_line_args = [ + "--chatbot_lora_dir", + f"{llm_models_root()}/llama-models-v2/sft-tiny-chatbot", + "--mental_health_lora_dir", + f"{llm_models_root()}/llama-models-v2/TinyLlama-1.1B-Chat-v1.0-mental-health-conversational", + "--tarot_lora_dir", + f"{llm_models_root()}/llama-models-v2/tinyllama-tarot-v1" + ] + _run_llmapi_example(llm_root, engine_dir, llm_venv, "llm_multilora.py", + *cmd_line_args) def test_llmapi_example_guided_decoding(llm_root, engine_dir, llm_venv):