Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 2 additions & 15 deletions tensorrt_llm/_torch/models/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,22 +165,9 @@ def _forward_nope(
q, k, v = self.split_qkv(q, k, v)
q = self._attention_scaling(q, position_ids)

out_scale = None
out_scale_sf = None
if self.o_proj.has_fp8_qdq or self.o_proj.has_nvfp4 or self.o_proj.has_fp8_block_scales:
out_scale = self.o_proj.inv_input_scale
if self.o_proj.has_nvfp4 and self.support_nvfp4_output:
out_scale_sf = self.o_proj.input_scale

q, k, v = self.convert_qkv(q, k, v)
attn_output = self.attn.forward(q,
k,
v,
attn_metadata,
out_scale=out_scale,
out_scale_sf=out_scale_sf,
attention_mask=attention_mask,
mrope_config=mrope_config)
attn_output = self.forward_impl(q, k, v, attn_metadata, attention_mask,
None, None, mrope_config)

if isinstance(attn_output, tuple):
attn_output = Fp4QuantizedTensor(attn_output[0], attn_output[1])
Expand Down
6 changes: 6 additions & 0 deletions tensorrt_llm/_torch/models/modeling_speculative.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,6 +362,12 @@ def __init__(self, model: TModel, model_config: ModelConfig[TConfig]):
model_config,
model_config.mapping)

if draft_config is not None:
for key, value in draft_config.extra_attrs.items():
assert key in ('attn_layers', 'mla_layers')
assert key in model_config.extra_attrs
model_config.extra_attrs[key].update(value)

def forward(
self,
attn_metadata: AttentionMetadata,
Expand Down
106 changes: 61 additions & 45 deletions tensorrt_llm/_torch/modules/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def attn_custom_op_inplace(
mrope_position_deltas,
attention_window_size,
attention_mask_data,
False,
enable_attn_nvfp4_output=False,
output=output)


Expand Down Expand Up @@ -372,6 +372,58 @@ def _attn_impl(
return attn_output[0], attn_output[1]
return attn_output, None

def forward_impl(
self,
q: torch.Tensor,
k: Optional[torch.Tensor],
v: Optional[torch.Tensor],
attn_metadata: AttentionMetadata,
attention_mask: AttentionMask,
attention_window_size: Optional[int],
attention_mask_data: Optional[torch.Tensor],
mrope_config: Optional[dict],
):
mrope_rotary_cos_sin = None
mrope_position_deltas = None
if mrope_config is not None:
if "mrope_rotary_cos_sin" in mrope_config:
mrope_rotary_cos_sin = mrope_config["mrope_rotary_cos_sin"]
if "mrope_position_deltas" in mrope_config:
mrope_position_deltas = mrope_config["mrope_position_deltas"]

# Currently only TRTLLM and FLASHINFER are torch compile compatible backends.
# Only enable custom inplace op when torch compiling.
use_custom_inplace_op = (self.register_to_config
and (self.attn_backend == "TRTLLM"
or self.attn_backend == "FLASHINFER")
and is_torch_compiling())

if use_custom_inplace_op:
output = self.create_output(q)
attn_custom_op_inplace(
q,
k,
v,
attention_mask,
mrope_rotary_cos_sin,
mrope_position_deltas,
attention_window_size,
attention_mask_data,
self.layer_idx_str,
output,
)
else:
output, output_sf = self._attn_impl(q, k, v, attn_metadata,
attention_mask,
mrope_rotary_cos_sin,
mrope_position_deltas,
attention_window_size,
attention_mask_data)
if output_sf is not None:
output = Fp4QuantizedTensor(output, output_sf)

return output

def forward(
self,
position_ids: Optional[torch.IntTensor],
Expand Down Expand Up @@ -414,54 +466,18 @@ def forward(
if qkv_lora is not None:
qkv = qkv + qkv_lora

mrope_rotary_cos_sin = None
mrope_position_deltas = None
if mrope_config is not None:
if "mrope_rotary_cos_sin" in mrope_config:
mrope_rotary_cos_sin = mrope_config["mrope_rotary_cos_sin"]
if "mrope_position_deltas" in mrope_config:
mrope_position_deltas = mrope_config["mrope_position_deltas"]

output = None

q, k, v = qkv, None, None
q, k, v = self.apply_rope(q, k, v, position_ids)
q, k, v = self.convert_qkv(q, k, v)

# Currently only TRTLLM and FLASHINFER are torch compile compatible backends.
# Only enable custom inplace op when torch compiling.
use_custom_inplace_op = (self.register_to_config
and (self.attn_backend == "TRTLLM"
or self.attn_backend == "FLASHINFER")
and is_torch_compiling())
if use_custom_inplace_op:
output = self.create_output(q)
attn_custom_op_inplace(
q,
k,
v,
attention_mask,
mrope_rotary_cos_sin,
mrope_position_deltas,
attention_window_size,
attention_mask_data,
self.layer_idx_str,
output=output,
)
else:
output, output_sf = self._attn_impl(
q,
k,
v,
attn_metadata,
attention_mask,
mrope_rotary_cos_sin,
mrope_position_deltas,
attention_window_size,
attention_mask_data,
)
if output_sf is not None:
output = Fp4QuantizedTensor(output, output_sf)
output = self.forward_impl(q,
k,
v,
attn_metadata,
attention_mask,
attention_window_size,
attention_mask_data,
mrope_config=mrope_config)

attn_output = self.o_proj(output,
all_reduce_params=all_reduce_params,
Expand Down
7 changes: 5 additions & 2 deletions tests/integration/defs/accuracy/test_llm_api_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -550,11 +550,14 @@ def test_fp8_eagle3(self, tp_size, pp_size, ep_size, torch_compile):
speculative_model_dir=eagle_model_dir)
kv_cache_config = KvCacheConfig(enable_block_reuse=False,
free_gpu_memory_fraction=0.75)
torch_compile_config = TorchCompileConfig(
enable_fullgraph=True,
enable_piecewise_cuda_graph=True,
max_num_streams=3) if torch_compile else None
pytorch_config = dict(
cuda_graph_config=CudaGraphConfig(max_batch_size=8),
enable_attention_dp=False,
torch_compile_config=TorchCompileConfig(
enable_fullgraph=torch_compile))
torch_compile_config=torch_compile_config)
with LLM(model_path,
kv_cache_config=kv_cache_config,
tensor_parallel_size=tp_size,
Expand Down
1 change: 0 additions & 1 deletion tests/integration/test_lists/waives.txt
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,6 @@ examples/test_bert.py::test_llm_bert_general[compare_hf-enable_remove_input_padd
test_e2e.py::test_ptp_quickstart_multimodal[mistral-small-3.1-24b-instruct-Mistral-Small-3.1-24B-Instruct-2503-mixture_text_image-True] SKIP (https://nvbugs/5430124)
examples/test_granite.py::test_granite_bf16_lora[granite-3.0-1b-a400m-instruct] SKIP (https://nvbugs/5431132)
accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_ctx_pp_gen_tp_asymmetric[MMLU-gen_tp=1-ctx_pp=4] SKIP (https://nvbugs/5431139)
accuracy/test_llm_api_pytorch.py::TestLlama4MaverickInstruct::test_fp8_eagle3[tp8-torch_compile=True] SKIP (https://nvbugs/5427801)
accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_ctx_pp_gen_tp_asymmetric[MMLU-gen_tp=2-ctx_pp=4] SKIP (https://nvbugs/5434320)
accuracy/test_llm_api.py::TestLlama3_2_1B::test_int4_awq_int8_kv_cache SKIP (https://nvbugs/5433541)
accuracy/test_llm_api.py::TestLlama3_2_1B::test_fp8_pp2 SKIP (https://nvbugs/5433541)
Expand Down