Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 83 additions & 0 deletions examples/llm-api/quickstart_multimodal.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,15 @@ def add_multimodal_args(parser):
type=str,
default="cpu",
help="The device to have the input on.")
# Add multiturn conversation related parameters
parser.add_argument("--multiturn",
action="store_true",
help="Enable multi-turn conversation mode.")
parser.add_argument(
"--conversation_turns",
type=int,
default=2,
help="Number of conversation turns for automated testing.")
return parser


Expand Down Expand Up @@ -162,6 +171,80 @@ def main():
open(os.path.join(llm._hf_model_dir, 'config.json')))['model_type']
assert model_type in ALL_SUPPORTED_MULTIMODAL_MODELS, f"Unsupported model_type: {model_type}"

# If multiturn mode is enabled
if args.multiturn:
# Run predefined multiturn conversation examples
assert args.prompt is not None, "Please provide a prompt for multiturn conversation."
assert args.media is not None, "Please provide media for multiturn conversation."
# Determine how many turns to run
max_turns = min(args.conversation_turns, len(args.prompt))
generated_outputs = [] # Store generated outputs for return

# Initialize conversation history with the first prompt
conversation_history = args.prompt[0] if args.prompt else ""

for i in range(max_turns):
print(f"\n--- Turn {i+1} ---")

try:
# Use multimodal input loader to process input with conversation context
# Use accumulated conversation history instead of just the current prompt
cur_prompt = conversation_history
inputs = default_multimodal_input_loader(
tokenizer=llm.tokenizer,
model_dir=llm._hf_model_dir,
model_type=model_type,
modality=args.modality,
prompts=[cur_prompt],
media=args.media,
image_data_format="pt",
num_frames=8,
device="cpu")

lora_request = None
if args.load_lora:
if model_class is None:
raise ValueError(
"model_class must be provided when load_lora is True"
)
lora_request = model_class.lora_request(
len(inputs), args.modality, llm._hf_model_dir)

# Generate response
outputs = llm.generate(inputs,
sampling_params,
lora_request=lora_request)
assert outputs and len(
outputs) > 0 and outputs[0].outputs and len(
outputs[0].outputs) > 0
response = outputs[0].outputs[0].text.strip()

# Store generated output
generated_outputs.append({
"turn": i + 1,
"user_input": cur_prompt,
"assistant_response": response,
"media": args.media
})

conversation_history = conversation_history + "\n" + response
if i + 1 < len(args.prompt):
conversation_history = conversation_history + "\n" + args.prompt[
i + 1]

except Exception as e:
print(f"Error in turn {i+1}: {e}")
import traceback
traceback.print_exc()
continue

for i, output in enumerate(generated_outputs):
print(
f"[{i}] Prompt: {output['user_input']!r}, Generated text: {output['assistant_response']!r}"
)
return

# Original single-turn processing logic
# set prompts and media to example prompts and images if they are not provided
if args.prompt is None:
args.prompt = example_medias_and_prompts[args.modality]["prompt"]
Expand Down
7 changes: 7 additions & 0 deletions tests/integration/defs/accuracy/references/gsm8k.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,13 @@ meta-llama/Llama-3.3-70B-Instruct:
accuracy: 84.08
meta-llama/Llama-4-Maverick-17B-128E-Instruct:
- accuracy: 92.20
- quant_algo: FP8
kv_cache_quant_algo: FP8
accuracy: 92.20
- quant_algo: FP8
kv_cache_quant_algo: FP8
spec_dec_algo: Eagle
accuracy: 92.20
meta-llama/Llama-4-Scout-17B-16E-Instruct:
- accuracy: 89.70
- quant_algo: NVFP4
Expand Down
3 changes: 3 additions & 0 deletions tests/integration/defs/accuracy/references/mmlu.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,9 @@ meta-llama/Llama-4-Maverick-17B-128E-Instruct:
kv_cache_quant_algo: FP8
spec_dec_algo: Eagle
accuracy: 86.40
- quant_algo: FP8
kv_cache_quant_algo: FP8
accuracy: 86.40
meta-llama/Llama-4-Scout-17B-16E-Instruct:
- accuracy: 80.00
- quant_algo: NVFP4
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ def run_parallel_test(model_name: str, model_path: str, ctx_pp: int,
total_ctx_gpus = ctx_tp * ctx_pp * ctx_instances
total_gen_gpus = gen_tp * gen_pp * gen_instances
if total_ctx_gpus + total_gen_gpus > get_device_count():
pytest.fail(
pytest.skip(
f"Not enough devices for {ctx_instances} ctx instances (ctx_pp={ctx_pp}*ctx_tp={ctx_tp}) + {gen_instances} gen instances (gen_pp={gen_pp}*gen_tp={gen_tp}), total: {total_ctx_gpus + total_gen_gpus}"
)

Expand Down Expand Up @@ -378,6 +378,7 @@ def test_ngram(self):
task = GSM8K(self.MODEL_NAME)
task.evaluate(llm)

@skip_pre_hopper
@parametrize_with_ids("overlap_scheduler", [True, False])
@parametrize_with_ids("eagle3_one_model", [True, False])
def test_eagle3(self, overlap_scheduler, eagle3_one_model):
Expand Down Expand Up @@ -461,6 +462,7 @@ def test_multi_instance(self, testset):

@pytest.mark.skip_less_device_memory(140000)
@pytest.mark.timeout(3600)
@pytest.mark.skip_less_device(4)
class TestLlama4ScoutInstruct(LlmapiAccuracyTestHarness):
MODEL_NAME = "meta-llama/Llama-4-Scout-17B-16E-Instruct"
MODEL_PATH = f"{llm_models_root()}/llama4-models/Llama-4-Scout-17B-16E-Instruct"
Expand Down Expand Up @@ -540,6 +542,7 @@ def test_nixl_backend(self):
@parametrize_with_ids("overlap_scheduler", [True, False])
@parametrize_with_ids("mtp_nextn",
[0, pytest.param(2, marks=skip_pre_hopper)])
@pytest.mark.skip_less_device(4)
def test_auto_dtype(self, overlap_scheduler, mtp_nextn):
ctx_server_config = {"disable_overlap_scheduler": True}
gen_server_config = {"disable_overlap_scheduler": not overlap_scheduler}
Expand Down Expand Up @@ -671,6 +674,7 @@ def test_nixl_backend(self):
task.evaluate(llm)

@pytest.mark.parametrize("overlap_scheduler", [False, True])
@skip_pre_hopper
def test_auto_dtype(self, overlap_scheduler):
ctx_server_config = {
"disable_overlap_scheduler": True,
Expand Down
5 changes: 3 additions & 2 deletions tests/integration/defs/accuracy/test_llm_api_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -695,6 +695,7 @@ class TestMistralSmall24B(LlmapiAccuracyTestHarness):
MODEL_NAME = "mistralai/Mistral-Small-3.1-24B-Instruct-2503"
MODEL_PATH = f"{llm_models_root()}/Mistral-Small-3.1-24B-Instruct-2503"

@pytest.mark.skip_less_device_memory(80000)
def test_auto_dtype(self):
with LLM(self.MODEL_PATH) as llm:
task = CnnDailymail(self.MODEL_NAME)
Expand Down Expand Up @@ -1033,7 +1034,7 @@ def test_cute_dsl_fp8_block_scales(
max_num_streams=3) if torch_compile else None)
pytorch_config = dict(
disable_overlap_scheduler=not overlap_scheduler,
use_cuda_graph=cuda_graph,
cuda_graph_config=CudaGraphConfig() if cuda_graph else None,
torch_compile_config=torch_compile_config,
moe_config=MoeConfig(backend="CUTEDSL"),
)
Expand Down Expand Up @@ -1191,7 +1192,7 @@ def test_cute_dsl_fp8_block_scales_4gpus(
max_num_streams=3) if torch_compile else None)
pytorch_config = dict(
disable_overlap_scheduler=not overlap_scheduler,
use_cuda_graph=cuda_graph,
cuda_graph_config=CudaGraphConfig() if cuda_graph else None,
torch_compile_config=torch_compile_config,
moe_config=MoeConfig(backend="CUTEDSL"),
)
Expand Down
106 changes: 104 additions & 2 deletions tests/integration/defs/test_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -2051,7 +2051,7 @@ def test_ptp_quickstart_advanced_8gpus(llm_root, llm_venv, model_name,
def test_ptp_quickstart_advanced_8gpus_chunked_prefill_sq_22k(
llm_root, llm_venv, model_name, model_path, cuda_graph):
print(f"Testing {model_name} on 8 GPUs.")
example_root = Path(os.path.join(llm_root, "examples", "pytorch"))
example_root = Path(os.path.join(llm_root, "examples", "llm-api"))
cmd = [
str(example_root / "quickstart_advanced.py"),
"--enable_chunked_prefill",
Expand All @@ -2076,10 +2076,12 @@ def test_ptp_quickstart_advanced_8gpus_chunked_prefill_sq_22k(
@pytest.mark.skip_less_device_memory(80000)
@pytest.mark.skip_less_device(2)
@pytest.mark.parametrize("model_name,model_path", [
("Llama3.1-70B-BF16", "llama-3.1-model/Meta-Llama-3.1-70B"),
('Nemotron-Super-49B-v1-BF16',
'nemotron-nas/Llama-3_3-Nemotron-Super-49B-v1'),
("Mixtral-8x7B-BF16", "Mixtral-8x7B-Instruct-v0.1"),
pytest.param('Llama3.1-70B-BF16',
'llama-3.1-model/Meta-Llama-3.1-70B',
marks=pytest.mark.skip_less_device_memory(95000)),
])
def test_ptp_quickstart_advanced_2gpus_sm120(llm_root, llm_venv, model_name,
model_path):
Expand Down Expand Up @@ -2521,6 +2523,106 @@ def test_ptp_quickstart_multimodal_2gpu(llm_root, llm_venv, model_name,
print("All answers are correct!")


@pytest.mark.skip_less_device_memory(80000)
@pytest.mark.parametrize("model_name,model_path", [
("gemma-3-27b-it", "gemma/gemma-3-27b-it"),
("mistral-small-3.1-24b-instruct", "Mistral-Small-3.1-24B-Instruct-2503"),
("Phi-4-multimodal-instruct", "multimodals/Phi-4-multimodal-instruct"),
])
def test_ptp_quickstart_multimodal_multiturn(llm_root, llm_venv, model_name,
model_path):
example_root = Path(os.path.join(llm_root, "examples", "llm-api"))
test_data_root = Path(
os.path.join(llm_models_root(), "multimodals", "test_data"))

print(f"Accuracy test {model_name} image mode with example inputs.")

# Define accuracy inputs for image modality
accuracy_inputs = {
"image": {
"prompt": [
"Describe what you see in this image.",
"How would you describe the atmosphere of this scene?",
],
"media": [
str(test_data_root / "inpaint.png"),
],
}
}

# Define expected keywords for each model
expected_keywords = {
"gemma-3-27b-it": {
"image": [
["half", "dome", "yosemite", "landmark", "rounded"],
["atmosphere", "peaceful", "majestic", "calm", "quiet"],
],
},
"mistral-small-3.1-24b-instruct": {
"image": [
["depicts", "landscape", "rock", "sky", "high", "altitude"],
["atmosphere", "serene", "majestic", "sense", "tranquility"],
],
},
"Phi-4-multimodal-instruct": {
"image": [
["depicts", "landscape", "mountain", "half", "dome"],
["atmosphere", "serene", "sense", "tranquility", "peace."],
],
},
}
# Build command for image modality
cmd = [
str(example_root / "quickstart_multimodal.py"),
"--model_dir",
f"{llm_models_root()}/{model_path}",
"--modality",
"image",
"--multiturn",
"--prompt",
*accuracy_inputs["image"]["prompt"],
"--media",
*accuracy_inputs["image"]["media"],
]

# Add model-specific configurations
if model_name == "gemma-3-27b-it":
# Gemma3 VLM needs a custom mask which is only supported by flashinfer backend currently.
# Custom mask involves bidirectional masking of image tokens in context phase. To get this
# correct, chunked prefill and kv cache reuse need to be turned off.
cmd.append("--image_format=pil")
cmd.append("--attention_backend=FLASHINFER")
cmd.append("--disable_kv_cache_reuse")
elif model_name == "Phi-4-multimodal-instruct":
# Set max_seq_len to 4096 to use short rope factor.
cmd.append("--max_seq_len=4096")
cmd.append("--load_lora")
cmd.append("--auto_model_name")
cmd.append("Phi4MMForCausalLM")

output = llm_venv.run_cmd(cmd, caller=check_output)
print("output:", output)
# Set match ratio based on model
match_ratio = 4.0 / 5
if model_name == "Phi-4-multimodal-instruct":
match_ratio = 0.6

# Check output accuracy
for prompt_output, prompt_keywords in zip(
parse_output(output), expected_keywords[model_name]["image"]):
matches = [
keyword in prompt_output.lower() for keyword in prompt_keywords
]
obs_match_ratio = 1. * sum(matches) / len(matches)
print("prompt_output:", prompt_output)
print("prompt_keywords:", prompt_keywords)
print("matches:", matches)
print("obs_match_ratio:", obs_match_ratio)
assert obs_match_ratio >= match_ratio, f"Incorrect output!\nGenerated \"{prompt_output}\"\nExpected keywords \"{prompt_keywords}\"\n Matched keywords: {matches}\n Observed match ratio {obs_match_ratio} below threshold {match_ratio}"

print("All answers are correct!")


@pytest.mark.parametrize("model_name,model_path", [
("BertForSequenceClassification", "bert/bert-base-uncased-yelp-polarity"),
])
Expand Down
3 changes: 3 additions & 0 deletions tests/integration/test_lists/qa/llm_function_full.txt
Original file line number Diff line number Diff line change
Expand Up @@ -602,6 +602,9 @@ test_e2e.py::test_ptp_quickstart_multimodal_phi4mm[image_audio]
test_e2e.py::test_ptp_quickstart_multimodal_2gpu[gemma-3-27b-it-gemma/gemma-3-27b-it]
test_e2e.py::test_ptp_quickstart_multimodal_2gpu[mistral-small-3.1-24b-instruct-Mistral-Small-3.1-24B-Instruct-2503]
test_e2e.py::test_ptp_quickstart_multimodal_2gpu[Phi-4-multimodal-instruct-multimodals/Phi-4-multimodal-instruct]
test_e2e.py::test_ptp_quickstart_multimodal_multiturn[gemma-3-27b-it-gemma/gemma-3-27b-it]
test_e2e.py::test_ptp_quickstart_multimodal_multiturn[mistral-small-3.1-24b-instruct-Mistral-Small-3.1-24B-Instruct-2503]
test_e2e.py::test_ptp_quickstart_multimodal_multiturn[Phi-4-multimodal-instruct-multimodals/Phi-4-multimodal-instruct]
test_e2e.py::test_ptp_quickstart_bert[VANILLA-BertForSequenceClassification-bert/bert-base-uncased-yelp-polarity]
test_e2e.py::test_ptp_quickstart_bert[TRTLLM-BertForSequenceClassification-bert/bert-base-uncased-yelp-polarity]
test_e2e.py::test_ptp_star_attention_example[Llama3.1-8B-BF16-llama-3.1-model/Meta-Llama-3.1-8B]
Expand Down