@@ -285,6 +285,7 @@ def __init__(
285285
286286 self .support_fused_qkv = self .attn .support_fused_qkv ()
287287 self .support_nvfp4_output = self .attn .support_nvfp4_output ()
288+ self .enable_attn_nvfp4_output = True
288289
289290 if not config .skip_create_weights_in_init :
290291 self .create_weights ()
@@ -294,6 +295,17 @@ def create_weights(self):
294295 # which could be modified after __init__
295296 self .attn .update_quant_config (self .quant_config )
296297
298+ self .out_scale = None
299+ self .out_scale_sf = None
300+ self .o_proj .create_weights ()
301+ self .has_quant_scale = (self .o_proj .has_fp8_qdq or self .o_proj .has_nvfp4
302+ or self .o_proj .has_fp8_block_scales
303+ or self .o_proj .has_fp8_rowwise )
304+ if self .has_quant_scale :
305+ self .out_scale = self .o_proj .inv_input_scale .data
306+ if self .o_proj .has_nvfp4 and self .support_nvfp4_output and self .enable_attn_nvfp4_output :
307+ self .out_scale_sf = self .o_proj .input_scale .data
308+
297309 def split_qkv (self , q , k = None , v = None ):
298310 if k is None and v is None :
299311 q , k , v = q .split ([self .q_size , self .kv_size , self .kv_size ], dim = - 1 )
@@ -313,10 +325,7 @@ def create_output(self, q: torch.Tensor):
313325 out_dtype = q .dtype
314326
315327 if self .attn_backend == "TRTLLM" :
316- has_quant_scale = (self .o_proj .has_fp8_qdq or self .o_proj .has_nvfp4
317- or self .o_proj .has_fp8_block_scales
318- or self .o_proj .has_fp8_rowwise )
319- if has_quant_scale and self .attn .has_fp8_kv_cache :
328+ if self .has_quant_scale and self .attn .has_fp8_kv_cache :
320329 out_dtype = torch .float8_e4m3fn
321330 output = q .new_empty ([num_tokens , hidden_size ], dtype = out_dtype )
322331 return output
@@ -351,16 +360,6 @@ def _attn_impl(
351360 assert v .shape [0 ] == padded_num_tokens
352361 v = v [:num_tokens , :]
353362
354- out_scale = None
355- out_scale_sf = None
356- has_quant_scale = (self .o_proj .has_fp8_qdq or self .o_proj .has_nvfp4
357- or self .o_proj .has_fp8_block_scales
358- or self .o_proj .has_fp8_rowwise )
359- if has_quant_scale :
360- out_scale = self .o_proj .inv_input_scale
361- if self .o_proj .has_nvfp4 and self .support_nvfp4_output and enable_attn_nvfp4_output :
362- out_scale_sf = self .o_proj .input_scale
363-
364363 mrope_config = None
365364 if mrope_rotary_cos_sin is not None or mrope_position_deltas is not None :
366365 mrope_config = dict ()
@@ -374,8 +373,8 @@ def _attn_impl(
374373 k ,
375374 v ,
376375 attn_metadata ,
377- out_scale = out_scale ,
378- out_scale_sf = out_scale_sf ,
376+ out_scale = self . out_scale ,
377+ out_scale_sf = self . out_scale_sf ,
379378 attention_mask = attention_mask ,
380379 mrope_config = mrope_config ,
381380 attention_window_size = attention_window_size ,
@@ -840,6 +839,9 @@ def create_weights(self):
840839 self .mha .update_quant_config (self .quant_config )
841840 self .mqa .update_quant_config (self .quant_config )
842841
842+ # Although we use FP8 MLA for context/generation phase, the output is still in BF16
843+ self .out_scale = None
844+
843845 # k_b_proj_trans's dtype must be consistent with self.kv_b_proj,
844846 # which can be modified after __init__
845847 has_fp8_block_scales = (
@@ -1045,17 +1047,14 @@ def forward_context_default(
10451047 self .qk_rope_head_dim )
10461048 k = k .view (- 1 , self .num_heads * self .qk_head_dim )
10471049
1048- # out_scale = getattr(self.o_proj, "inv_input_scale", None)
1049- out_scale = None # Currently we use BF16 MHA for context phase
1050-
10511050 attn_output = self .mha .forward (
10521051 q ,
10531052 k ,
10541053 v ,
10551054 attn_metadata ,
10561055 attention_input_type = AttentionInputType .context_only ,
10571056 latent_cache = latent_cache ,
1058- out_scale = out_scale ,
1057+ out_scale = self . out_scale ,
10591058 output = output ,
10601059 )
10611060
@@ -1110,9 +1109,6 @@ def forward_context_with_cached_kv(
11101109 full_kv = None
11111110 full_k_nope = None
11121111
1113- # out_scale = getattr(self.o_proj, "inv_input_scale", None)
1114- out_scale = None # Currently we use BF16 MHA for context phase
1115-
11161112 # latent_cache must be None to differentiate from normal context phase,
11171113 # so that we can skip applying RoPE and appending KV cache inside attention op
11181114 attn_output = self .mha .forward (
@@ -1122,7 +1118,7 @@ def forward_context_with_cached_kv(
11221118 attn_metadata ,
11231119 attention_input_type = AttentionInputType .context_only ,
11241120 latent_cache = None ,
1125- out_scale = out_scale ,
1121+ out_scale = self . out_scale ,
11261122 output = output ,
11271123 )
11281124
@@ -1212,7 +1208,6 @@ def forward_context_with_chunked_prefill(
12121208 loop_idx ]
12131209 attn_metadata .host_total_kv_lens [0 ] = total_ctx_chunked_tokens
12141210
1215- out_scale = None
12161211 # do not apply mask for attention within loop
12171212 # latent_cache must be None to differentiate from normal context phase,
12181213 # so that we can skip applying RoPE and appending KV cache inside attention op
@@ -1223,7 +1218,7 @@ def forward_context_with_chunked_prefill(
12231218 attn_metadata ,
12241219 attention_input_type = AttentionInputType .context_only ,
12251220 latent_cache = None ,
1226- out_scale = out_scale ,
1221+ out_scale = self . out_scale ,
12271222 attention_mask = PredefinedAttentionMask .FULL ,
12281223 softmax_stats_tensor = self .temp_softmax_stats_tensor ,
12291224 output = temp_attn_output ,
@@ -1262,9 +1257,6 @@ def forward_context_with_chunked_prefill(
12621257 num_contexts ].sum ().item (
12631258 )
12641259
1265- # out_scale = getattr(self.o_proj, "inv_input_scale", None)
1266- out_scale = None # Currently we use BF16 MHA for context phase
1267-
12681260 # latent_cache must be None to differentiate from normal context phase,
12691261 # so that we can skip applying RoPE and appending KV cache inside attention op
12701262 temp_attn_output = self .mha .forward (
@@ -1274,7 +1266,7 @@ def forward_context_with_chunked_prefill(
12741266 attn_metadata ,
12751267 attention_input_type = AttentionInputType .context_only ,
12761268 latent_cache = None ,
1277- out_scale = out_scale ,
1269+ out_scale = self . out_scale ,
12781270 softmax_stats_tensor = self .temp_softmax_stats_tensor ,
12791271 output = temp_attn_output ,
12801272 )
@@ -1370,16 +1362,13 @@ def forward_generation(
13701362 self .num_heads * (self .kv_lora_rank + self .qk_rope_head_dim )
13711363 ])
13721364
1373- # out_scale = getattr(self.o_proj, "inv_input_scale", None)
1374- out_scale = None # Although we use FP8 MLA for generation phase, the output is still in BF16
1375-
13761365 attn_out_latent = self .mqa .forward (
13771366 fused_q ,
13781367 None ,
13791368 None ,
13801369 attn_metadata ,
13811370 attention_input_type = AttentionInputType .generation_only ,
1382- out_scale = out_scale ,
1371+ out_scale = self . out_scale ,
13831372 latent_cache = latent_cache , # kvcache and k_pe
13841373 q_pe = q_pe , # used by `invokeMLARopeGeneration`
13851374 )
0 commit comments