@@ -301,6 +301,12 @@ def create_weights(self):
301301 # which could be modified after __init__
302302 self .attn .update_quant_config (self .quant_config )
303303
304+ self .o_proj .create_weights ()
305+ self .has_quant_scale = (self .o_proj .has_fp8_qdq or self .o_proj .has_nvfp4
306+ or self .o_proj .has_fp8_block_scales
307+ or self .o_proj .has_fp8_rowwise
308+ or self .o_proj .has_w4a8_nvfp4_fp8 )
309+
304310 def split_qkv (self , q , k = None , v = None ):
305311 if k is None and v is None :
306312 q , k , v = q .split ([self .q_size , self .kv_size , self .kv_size ], dim = - 1 )
@@ -320,12 +326,8 @@ def create_output(self, q: torch.Tensor):
320326 out_dtype = q .dtype
321327
322328 if self .attn_backend == "TRTLLM" :
323- has_quant_scale = (self .o_proj .has_fp8_qdq or self .o_proj .has_nvfp4
324- or self .o_proj .has_fp8_block_scales
325- or self .o_proj .has_fp8_rowwise
326- or self .o_proj .has_w4a8_nvfp4_fp8 )
327- if has_quant_scale and (self .attn .has_fp8_kv_cache
328- or self .attn .has_fp4_kv_cache ):
329+ if self .has_quant_scale and (self .attn .has_fp8_kv_cache
330+ or self .attn .has_fp4_kv_cache ):
329331 out_dtype = torch .float8_e4m3fn
330332 output = q .new_empty ([num_tokens , hidden_size ], dtype = out_dtype )
331333 return output
@@ -356,11 +358,7 @@ def _attn_impl(
356358
357359 out_scale = None
358360 out_scale_sf = None
359- has_quant_scale = (self .o_proj .has_fp8_qdq or self .o_proj .has_nvfp4
360- or self .o_proj .has_fp8_block_scales
361- or self .o_proj .has_fp8_rowwise
362- or self .o_proj .has_w4a8_nvfp4_fp8 )
363- if has_quant_scale :
361+ if self .has_quant_scale :
364362 out_scale = self .o_proj .inv_input_scale
365363 if self .o_proj .has_nvfp4 and self .support_nvfp4_output and enable_attn_nvfp4_output :
366364 out_scale_sf = self .o_proj .input_scale
@@ -858,6 +856,9 @@ def create_weights(self):
858856 self .mha .update_quant_config (self .quant_config )
859857 self .mqa .update_quant_config (self .quant_config )
860858
859+ # Although we use FP8 MLA for context/generation phase, the output is still in BF16
860+ self .out_scale = None
861+
861862 # k_b_proj_trans's dtype must be consistent with self.kv_b_proj,
862863 # which can be modified after __init__
863864 has_fp8_block_scales = (
@@ -1061,17 +1062,14 @@ def forward_context_default(
10611062 self .qk_rope_head_dim )
10621063 k = k .view (- 1 , self .num_heads * self .qk_head_dim )
10631064
1064- # out_scale = getattr(self.o_proj, "inv_input_scale", None)
1065- out_scale = None # Currently we use BF16 MHA for context phase
1066-
10671065 attn_output = self .mha .forward (
10681066 q ,
10691067 k ,
10701068 v ,
10711069 attn_metadata ,
10721070 attention_input_type = AttentionInputType .context_only ,
10731071 latent_cache = latent_cache ,
1074- out_scale = out_scale ,
1072+ out_scale = self . out_scale ,
10751073 output = output ,
10761074 )
10771075
@@ -1126,9 +1124,6 @@ def forward_context_with_cached_kv(
11261124 full_kv = None
11271125 full_k_nope = None
11281126
1129- # out_scale = getattr(self.o_proj, "inv_input_scale", None)
1130- out_scale = None # Currently we use BF16 MHA for context phase
1131-
11321127 # latent_cache must be None to differentiate from normal context phase,
11331128 # so that we can skip applying RoPE and appending KV cache inside attention op
11341129 attn_output = self .mha .forward (
@@ -1138,7 +1133,7 @@ def forward_context_with_cached_kv(
11381133 attn_metadata ,
11391134 attention_input_type = AttentionInputType .context_only ,
11401135 latent_cache = None ,
1141- out_scale = out_scale ,
1136+ out_scale = self . out_scale ,
11421137 output = output ,
11431138 )
11441139
@@ -1232,7 +1227,6 @@ def forward_context_with_chunked_prefill(
12321227 loop_idx ]
12331228 attn_metadata .host_total_kv_lens [0 ] = total_ctx_chunked_tokens
12341229
1235- out_scale = None
12361230 # do not apply mask for attention within loop
12371231 # latent_cache must be None to differentiate from normal context phase,
12381232 # so that we can skip applying RoPE and appending KV cache inside attention op
@@ -1243,7 +1237,7 @@ def forward_context_with_chunked_prefill(
12431237 attn_metadata ,
12441238 attention_input_type = AttentionInputType .context_only ,
12451239 latent_cache = None ,
1246- out_scale = out_scale ,
1240+ out_scale = self . out_scale ,
12471241 attention_mask = PredefinedAttentionMask .FULL ,
12481242 softmax_stats_tensor = self .temp_softmax_stats_tensor ,
12491243 chunked_prefill_buffer_batch_size = attn_metadata .
@@ -1284,9 +1278,6 @@ def forward_context_with_chunked_prefill(
12841278 num_contexts ].sum ().item (
12851279 )
12861280
1287- # out_scale = getattr(self.o_proj, "inv_input_scale", None)
1288- out_scale = None # Currently we use BF16 MHA for context phase
1289-
12901281 # latent_cache must be None to differentiate from normal context phase,
12911282 # so that we can skip applying RoPE and appending KV cache inside attention op
12921283 temp_attn_output = self .mha .forward (
@@ -1296,7 +1287,7 @@ def forward_context_with_chunked_prefill(
12961287 attn_metadata ,
12971288 attention_input_type = AttentionInputType .context_only ,
12981289 latent_cache = None ,
1299- out_scale = out_scale ,
1290+ out_scale = self . out_scale ,
13001291 softmax_stats_tensor = self .temp_softmax_stats_tensor ,
13011292 chunked_prefill_buffer_batch_size = attn_metadata .runtime_features .
13021293 chunked_prefill_buffer_batch_size ,
@@ -1394,16 +1385,13 @@ def forward_generation(
13941385 self .num_heads * (self .kv_lora_rank + self .qk_rope_head_dim )
13951386 ])
13961387
1397- # out_scale = getattr(self.o_proj, "inv_input_scale", None)
1398- out_scale = None # Although we use FP8 MLA for generation phase, the output is still in BF16
1399-
14001388 attn_out_latent = self .mqa .forward (
14011389 fused_q ,
14021390 None ,
14031391 None ,
14041392 attn_metadata ,
14051393 attention_input_type = AttentionInputType .generation_only ,
1406- out_scale = out_scale ,
1394+ out_scale = self . out_scale ,
14071395 latent_cache = latent_cache , # kvcache and k_pe
14081396 q_pe = q_pe , # used by `invokeMLARopeGeneration`
14091397 )
0 commit comments