@@ -295,6 +295,12 @@ def create_weights(self):
295295 # which could be modified after __init__
296296 self .attn .update_quant_config (self .quant_config )
297297
298+ self .o_proj .create_weights ()
299+ self .has_quant_scale = (self .o_proj .has_fp8_qdq or self .o_proj .has_nvfp4
300+ or self .o_proj .has_fp8_block_scales
301+ or self .o_proj .has_fp8_rowwise
302+ or self .o_proj .has_w4a8_nvfp4_fp8 )
303+
298304 def split_qkv (self , q , k = None , v = None ):
299305 if k is None and v is None :
300306 q , k , v = q .split ([self .q_size , self .kv_size , self .kv_size ], dim = - 1 )
@@ -314,12 +320,8 @@ def create_output(self, q: torch.Tensor):
314320 out_dtype = q .dtype
315321
316322 if self .attn_backend == "TRTLLM" :
317- has_quant_scale = (self .o_proj .has_fp8_qdq or self .o_proj .has_nvfp4
318- or self .o_proj .has_fp8_block_scales
319- or self .o_proj .has_fp8_rowwise
320- or self .o_proj .has_w4a8_nvfp4_fp8 )
321- if has_quant_scale and (self .attn .has_fp8_kv_cache
322- or self .attn .has_fp4_kv_cache ):
323+ if self .has_quant_scale and (self .attn .has_fp8_kv_cache
324+ or self .attn .has_fp4_kv_cache ):
323325 out_dtype = torch .float8_e4m3fn
324326 output = q .new_empty ([num_tokens , hidden_size ], dtype = out_dtype )
325327 return output
@@ -350,11 +352,7 @@ def _attn_impl(
350352
351353 out_scale = None
352354 out_scale_sf = None
353- has_quant_scale = (self .o_proj .has_fp8_qdq or self .o_proj .has_nvfp4
354- or self .o_proj .has_fp8_block_scales
355- or self .o_proj .has_fp8_rowwise
356- or self .o_proj .has_w4a8_nvfp4_fp8 )
357- if has_quant_scale :
355+ if self .has_quant_scale :
358356 out_scale = self .o_proj .inv_input_scale
359357 if self .o_proj .has_nvfp4 and self .support_nvfp4_output and enable_attn_nvfp4_output :
360358 out_scale_sf = self .o_proj .input_scale
@@ -847,6 +845,9 @@ def create_weights(self):
847845 self .mha .update_quant_config (self .quant_config )
848846 self .mqa .update_quant_config (self .quant_config )
849847
848+ # Although we use FP8 MLA for context/generation phase, the output is still in BF16
849+ self .out_scale = None
850+
850851 # k_b_proj_trans's dtype must be consistent with self.kv_b_proj,
851852 # which can be modified after __init__
852853 has_fp8_block_scales = (
@@ -1050,17 +1051,14 @@ def forward_context_default(
10501051 self .qk_rope_head_dim )
10511052 k = k .view (- 1 , self .num_heads * self .qk_head_dim )
10521053
1053- # out_scale = getattr(self.o_proj, "inv_input_scale", None)
1054- out_scale = None # Currently we use BF16 MHA for context phase
1055-
10561054 attn_output = self .mha .forward (
10571055 q ,
10581056 k ,
10591057 v ,
10601058 attn_metadata ,
10611059 attention_input_type = AttentionInputType .context_only ,
10621060 latent_cache = latent_cache ,
1063- out_scale = out_scale ,
1061+ out_scale = self . out_scale ,
10641062 output = output ,
10651063 )
10661064
@@ -1115,9 +1113,6 @@ def forward_context_with_cached_kv(
11151113 full_kv = None
11161114 full_k_nope = None
11171115
1118- # out_scale = getattr(self.o_proj, "inv_input_scale", None)
1119- out_scale = None # Currently we use BF16 MHA for context phase
1120-
11211116 # latent_cache must be None to differentiate from normal context phase,
11221117 # so that we can skip applying RoPE and appending KV cache inside attention op
11231118 attn_output = self .mha .forward (
@@ -1127,7 +1122,7 @@ def forward_context_with_cached_kv(
11271122 attn_metadata ,
11281123 attention_input_type = AttentionInputType .context_only ,
11291124 latent_cache = None ,
1130- out_scale = out_scale ,
1125+ out_scale = self . out_scale ,
11311126 output = output ,
11321127 )
11331128
@@ -1217,7 +1212,6 @@ def forward_context_with_chunked_prefill(
12171212 loop_idx ]
12181213 attn_metadata .host_total_kv_lens [0 ] = total_ctx_chunked_tokens
12191214
1220- out_scale = None
12211215 # do not apply mask for attention within loop
12221216 # latent_cache must be None to differentiate from normal context phase,
12231217 # so that we can skip applying RoPE and appending KV cache inside attention op
@@ -1228,7 +1222,7 @@ def forward_context_with_chunked_prefill(
12281222 attn_metadata ,
12291223 attention_input_type = AttentionInputType .context_only ,
12301224 latent_cache = None ,
1231- out_scale = out_scale ,
1225+ out_scale = self . out_scale ,
12321226 attention_mask = PredefinedAttentionMask .FULL ,
12331227 softmax_stats_tensor = self .temp_softmax_stats_tensor ,
12341228 output = temp_attn_output ,
@@ -1267,9 +1261,6 @@ def forward_context_with_chunked_prefill(
12671261 num_contexts ].sum ().item (
12681262 )
12691263
1270- # out_scale = getattr(self.o_proj, "inv_input_scale", None)
1271- out_scale = None # Currently we use BF16 MHA for context phase
1272-
12731264 # latent_cache must be None to differentiate from normal context phase,
12741265 # so that we can skip applying RoPE and appending KV cache inside attention op
12751266 temp_attn_output = self .mha .forward (
@@ -1279,7 +1270,7 @@ def forward_context_with_chunked_prefill(
12791270 attn_metadata ,
12801271 attention_input_type = AttentionInputType .context_only ,
12811272 latent_cache = None ,
1282- out_scale = out_scale ,
1273+ out_scale = self . out_scale ,
12831274 softmax_stats_tensor = self .temp_softmax_stats_tensor ,
12841275 output = temp_attn_output ,
12851276 )
@@ -1375,16 +1366,13 @@ def forward_generation(
13751366 self .num_heads * (self .kv_lora_rank + self .qk_rope_head_dim )
13761367 ])
13771368
1378- # out_scale = getattr(self.o_proj, "inv_input_scale", None)
1379- out_scale = None # Although we use FP8 MLA for generation phase, the output is still in BF16
1380-
13811369 attn_out_latent = self .mqa .forward (
13821370 fused_q ,
13831371 None ,
13841372 None ,
13851373 attn_metadata ,
13861374 attention_input_type = AttentionInputType .generation_only ,
1387- out_scale = out_scale ,
1375+ out_scale = self . out_scale ,
13881376 latent_cache = latent_cache , # kvcache and k_pe
13891377 q_pe = q_pe , # used by `invokeMLARopeGeneration`
13901378 )
0 commit comments