Skip to content

Commit d6b0316

Browse files
hyuknmikeiovine
authored andcommitted
[https://nvbugs/5536131][fix] Fix illegal access issue when scale is not provided in Llama3/4. (#7960)
Signed-off-by: Yukun He <[email protected]> Signed-off-by: Mike Iovine <[email protected]>
1 parent 3c961f5 commit d6b0316

File tree

5 files changed

+44
-45
lines changed

5 files changed

+44
-45
lines changed

tensorrt_llm/_torch/models/modeling_llama.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -561,13 +561,13 @@ def forward(
561561
else:
562562
# The next layernorm exists but it could be the last decoder layer.
563563
# Adjust the scale and fusion pattern.
564-
if self.next_attn is not None and (self.is_nvfp4
565-
or self.is_fp8_quant):
566-
scale = self.next_attn.qkv_proj.input_scale if hasattr(
567-
self.next_attn.qkv_proj, 'input_scale') else None
568-
else:
569-
self.post_feed_forward_fusion_op = AllReduceFusionOp.RESIDUAL_RMS_NORM
564+
if not (self.next_attn is not None and (self.is_nvfp4
565+
or self.is_fp8_quant)) \
566+
or not hasattr(self.next_attn.qkv_proj, 'input_scale'):
570567
scale = None
568+
self.post_feed_forward_fusion_op = AllReduceFusionOp.RESIDUAL_RMS_NORM
569+
else:
570+
scale = self.next_attn.qkv_proj.input_scale
571571

572572
# TODO: MIN_LATENCY_MODE is hardcoded to False
573573
if cutlass_min_latency_mode:
@@ -771,13 +771,14 @@ def forward(
771771
else:
772772
# The next layernorm exists but it could be the last decoder layer.
773773
# Adjust the scale and fusion pattern.
774-
if self.next_attn is not None and (self.is_nvfp4
775-
or self.is_fp8_quant):
776-
scale = self.next_attn.qkv_proj.input_scale if hasattr(
777-
self.next_attn.qkv_proj, 'input_scale') else None
778-
else:
779-
self.post_mlp_fusion_op = AllReduceFusionOp.RESIDUAL_RMS_NORM
774+
775+
if not (self.next_attn is not None and (self.is_nvfp4
776+
or self.is_fp8_quant)) \
777+
or not hasattr(self.next_attn.qkv_proj, 'input_scale'):
780778
scale = None
779+
self.post_mlp_fusion_op = AllReduceFusionOp.RESIDUAL_RMS_NORM
780+
else:
781+
scale = self.next_attn.qkv_proj.input_scale
781782

782783
all_reduce_output = self.all_reduce(
783784
hidden_states,

tensorrt_llm/_torch/models/modeling_llama_min_latency.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -800,27 +800,27 @@ def forward(
800800
needs_post_allreduce = self.fusion_config.POST_MOE_FUSION \
801801
or self.fusion_config.POST_MLP_FUSION
802802
if needs_post_allreduce and self.next_layer_layernorm is not None:
803-
if use_fp8_allreduce and self.next_attn is not None:
803+
if use_fp8_allreduce and self.next_attn is not None \
804+
and hasattr(elf.next_attn.qkv_proj, 'input_scale'):
804805
hidden_states, residual = self.all_reduce(
805806
hidden_states,
806807
all_reduce_params=AllReduceParams(
807808
fusion_op=AllReduceFusionOp.RESIDUAL_RMS_NORM_QUANT_FP8,
808809
residual=residual,
809810
norm_weight=self.next_layer_layernorm.weight,
810-
scale=self.next_attn.qkv_proj.input_scale if hasattr(
811-
self.next_attn.qkv_proj, 'input_scale') else None,
811+
scale=self.next_attn.qkv_proj.input_scale,
812812
eps=self.next_layer_layernorm.variance_epsilon,
813813
))
814-
elif use_fp4_allreduce and self.next_attn is not None:
814+
elif use_fp4_allreduce and self.next_attn is not None \
815+
and hasattr(self.next_attn.qkv_proj, 'input_scale'):
815816
act_fp4, act_sf, residual = self.all_reduce(
816817
hidden_states,
817818
all_reduce_params=AllReduceParams(
818819
fusion_op=AllReduceFusionOp.
819820
RESIDUAL_RMS_NORM_QUANT_NVFP4,
820821
residual=residual,
821822
norm_weight=self.next_layer_layernorm.weight,
822-
scale=self.next_attn.qkv_proj.input_scale if hasattr(
823-
self.next_attn.qkv_proj, 'input_scale') else None,
823+
scale=self.next_attn.qkv_proj.input_scale,
824824
eps=self.next_layer_layernorm.variance_epsilon,
825825
))
826826
else:

tests/integration/defs/accuracy/test_llm_api_pytorch.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -643,6 +643,30 @@ def test_nvfp4_tp4(self):
643643
task.evaluate(llm,
644644
extra_evaluator_kwargs=dict(apply_chat_template=True))
645645

646+
@pytest.mark.skip_less_device(4)
647+
@skip_pre_blackwell
648+
def test_fp8_tp2pp2(self):
649+
model_path = f"{llm_models_root()}/llama-3.3-models/Llama-3.3-70B-Instruct-FP8"
650+
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.5)
651+
with LLM(model_path,
652+
tensor_parallel_size=2,
653+
pipeline_parallel_size=2,
654+
max_batch_size=32,
655+
kv_cache_config=kv_cache_config) as llm:
656+
assert llm.args.quant_config.quant_algo == QuantAlgo.FP8
657+
sampling_params = SamplingParams(
658+
max_tokens=256,
659+
temperature=0.0,
660+
add_special_tokens=False,
661+
)
662+
task = MMLU(self.MODEL_NAME)
663+
task.evaluate(llm, sampling_params=sampling_params)
664+
task = GSM8K(self.MODEL_NAME)
665+
task.evaluate(llm, sampling_params=sampling_params)
666+
task = GPQADiamond(self.MODEL_NAME)
667+
task.evaluate(llm,
668+
extra_evaluator_kwargs=dict(apply_chat_template=True))
669+
646670

647671
class TestLlama4MaverickInstruct(LlmapiAccuracyTestHarness):
648672
MODEL_NAME = "meta-llama/Llama-4-Maverick-17B-128E-Instruct"

tests/integration/test_lists/test-db/l0_dgx_b200.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ l0_dgx_b200:
5252
- disaggregated/test_disaggregated.py::test_disaggregated_benchmark_on_diff_backends[llama-3.1-8b-instruct-hf-fp8]
5353
- disaggregated/test_disaggregated.py::test_disaggregated_deepseek_v3_lite_fp8_nixl[DeepSeek-V3-Lite-fp8]
5454
- accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_fp8_tp4
55+
- accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_fp8_tp2pp2
5556
- accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_nvfp4_tp4
5657
- accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[dp4-cutlass-auto]
5758
- condition:
@@ -130,7 +131,6 @@ l0_dgx_b200:
130131
orchestrator: mpi
131132
tests:
132133
- unittest/_torch/modules/test_fused_moe.py::test_fused_moe_alltoall_fp4[DeepEP]
133-
- unittest/_torch/multi_gpu_modeling/test_llama3.py::test_llama_3_3
134134
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_bfloat16_4gpus[tp4-attn_backend=FLASHINFER-torch_compile=False]
135135
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_4gpus[tp4-fp8kv=False-attn_backend=FLASHINFER-torch_compile=False]
136136
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_4gpus[pp4-fp8kv=False-attn_backend=TRTLLM-torch_compile=False]

tests/unittest/_torch/multi_gpu_modeling/test_llama3.py

Lines changed: 0 additions & 26 deletions
This file was deleted.

0 commit comments

Comments
 (0)