@@ -32,6 +32,7 @@ class Llama2Config:
3232 q_norm = None
3333 k_norm = None
3434 rope_scale = None
35+ final_norm : bool = True
3536
3637@dataclass
3738class Qwen25_3BConfig :
@@ -53,6 +54,7 @@ class Qwen25_3BConfig:
5354 q_norm = None
5455 k_norm = None
5556 rope_scale = None
57+ final_norm : bool = True
5658
5759@dataclass
5860class Qwen25_7BVLI_Config :
@@ -74,6 +76,7 @@ class Qwen25_7BVLI_Config:
7476 q_norm = None
7577 k_norm = None
7678 rope_scale = None
79+ final_norm : bool = True
7780
7881@dataclass
7982class Gemma2_2B_Config :
@@ -96,6 +99,7 @@ class Gemma2_2B_Config:
9699 k_norm = None
97100 sliding_attention = None
98101 rope_scale = None
102+ final_norm : bool = True
99103
100104@dataclass
101105class Gemma3_4B_Config :
@@ -118,6 +122,7 @@ class Gemma3_4B_Config:
118122 k_norm = "gemma3"
119123 sliding_attention = [False , False , False , False , False , 1024 ]
120124 rope_scale = [1.0 , 8.0 ]
125+ final_norm : bool = True
121126
122127class RMSNorm (nn .Module ):
123128 def __init__ (self , dim : int , eps : float = 1e-5 , add = False , device = None , dtype = None ):
@@ -366,7 +371,12 @@ def __init__(self, config, device=None, dtype=None, ops=None):
366371 transformer (config , index = i , device = device , dtype = dtype , ops = ops )
367372 for i in range (config .num_hidden_layers )
368373 ])
369- self .norm = RMSNorm (config .hidden_size , eps = config .rms_norm_eps , add = config .rms_norm_add , device = device , dtype = dtype )
374+
375+ if config .final_norm :
376+ self .norm = RMSNorm (config .hidden_size , eps = config .rms_norm_eps , add = config .rms_norm_add , device = device , dtype = dtype )
377+ else :
378+ self .norm = None
379+
370380 # self.lm_head = ops.Linear(config.hidden_size, config.vocab_size, bias=False, device=device, dtype=dtype)
371381
372382 def forward (self , x , attention_mask = None , embeds = None , num_tokens = None , intermediate_output = None , final_layer_norm_intermediate = True , dtype = None , position_ids = None , embeds_info = []):
@@ -421,14 +431,16 @@ def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermed
421431 if i == intermediate_output :
422432 intermediate = x .clone ()
423433
424- x = self .norm (x )
434+ if self .norm is not None :
435+ x = self .norm (x )
436+
425437 if all_intermediate is not None :
426438 all_intermediate .append (x .unsqueeze (1 ).clone ())
427439
428440 if all_intermediate is not None :
429441 intermediate = torch .cat (all_intermediate , dim = 1 )
430442
431- if intermediate is not None and final_layer_norm_intermediate :
443+ if intermediate is not None and final_layer_norm_intermediate and self . norm is not None :
432444 intermediate = self .norm (intermediate )
433445
434446 return x , intermediate
0 commit comments