Skip to content

Commit 17027f2

Browse files
Add a way to disable the final norm in the llama based TE models. (#10794)
1 parent b5c8be8 commit 17027f2

File tree

1 file changed

+15
-3
lines changed

1 file changed

+15
-3
lines changed

comfy/text_encoders/llama.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ class Llama2Config:
3232
q_norm = None
3333
k_norm = None
3434
rope_scale = None
35+
final_norm: bool = True
3536

3637
@dataclass
3738
class Qwen25_3BConfig:
@@ -53,6 +54,7 @@ class Qwen25_3BConfig:
5354
q_norm = None
5455
k_norm = None
5556
rope_scale = None
57+
final_norm: bool = True
5658

5759
@dataclass
5860
class Qwen25_7BVLI_Config:
@@ -74,6 +76,7 @@ class Qwen25_7BVLI_Config:
7476
q_norm = None
7577
k_norm = None
7678
rope_scale = None
79+
final_norm: bool = True
7780

7881
@dataclass
7982
class Gemma2_2B_Config:
@@ -96,6 +99,7 @@ class Gemma2_2B_Config:
9699
k_norm = None
97100
sliding_attention = None
98101
rope_scale = None
102+
final_norm: bool = True
99103

100104
@dataclass
101105
class Gemma3_4B_Config:
@@ -118,6 +122,7 @@ class Gemma3_4B_Config:
118122
k_norm = "gemma3"
119123
sliding_attention = [False, False, False, False, False, 1024]
120124
rope_scale = [1.0, 8.0]
125+
final_norm: bool = True
121126

122127
class RMSNorm(nn.Module):
123128
def __init__(self, dim: int, eps: float = 1e-5, add=False, device=None, dtype=None):
@@ -366,7 +371,12 @@ def __init__(self, config, device=None, dtype=None, ops=None):
366371
transformer(config, index=i, device=device, dtype=dtype, ops=ops)
367372
for i in range(config.num_hidden_layers)
368373
])
369-
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
374+
375+
if config.final_norm:
376+
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
377+
else:
378+
self.norm = None
379+
370380
# self.lm_head = ops.Linear(config.hidden_size, config.vocab_size, bias=False, device=device, dtype=dtype)
371381

372382
def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, position_ids=None, embeds_info=[]):
@@ -421,14 +431,16 @@ def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermed
421431
if i == intermediate_output:
422432
intermediate = x.clone()
423433

424-
x = self.norm(x)
434+
if self.norm is not None:
435+
x = self.norm(x)
436+
425437
if all_intermediate is not None:
426438
all_intermediate.append(x.unsqueeze(1).clone())
427439

428440
if all_intermediate is not None:
429441
intermediate = torch.cat(all_intermediate, dim=1)
430442

431-
if intermediate is not None and final_layer_norm_intermediate:
443+
if intermediate is not None and final_layer_norm_intermediate and self.norm is not None:
432444
intermediate = self.norm(intermediate)
433445

434446
return x, intermediate

0 commit comments

Comments
 (0)