Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions comfy/model_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,8 +365,8 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["patch_size"] = 2
dit_config["in_channels"] = 16
dit_config["dim"] = 2304
dit_config["cap_feat_dim"] = 2304
dit_config["n_layers"] = 26
dit_config["cap_feat_dim"] = state_dict['{}cap_embedder.1.weight'.format(key_prefix)].shape[1]
dit_config["n_layers"] = count_blocks(state_dict_keys, '{}layers.'.format(key_prefix) + '{}.')
dit_config["n_heads"] = 24
dit_config["n_kv_heads"] = 8
dit_config["qk_norm"] = True
Expand Down
7 changes: 7 additions & 0 deletions comfy/sd.py
Original file line number Diff line number Diff line change
Expand Up @@ -890,6 +890,7 @@ class TEModel(Enum):
QWEN25_3B = 10
QWEN25_7B = 11
BYT5_SMALL_GLYPH = 12
GEMMA_3_4B = 13

def detect_te_model(sd):
if "text_model.encoder.layers.30.mlp.fc1.weight" in sd:
Expand All @@ -912,6 +913,8 @@ def detect_te_model(sd):
return TEModel.BYT5_SMALL_GLYPH
return TEModel.T5_BASE
if 'model.layers.0.post_feedforward_layernorm.weight' in sd:
if 'model.layers.0.self_attn.q_norm.weight' in sd:
return TEModel.GEMMA_3_4B
return TEModel.GEMMA_2_2B
if 'model.layers.0.self_attn.k_proj.bias' in sd:
weight = sd['model.layers.0.self_attn.k_proj.bias']
Expand Down Expand Up @@ -1016,6 +1019,10 @@ class EmptyClass:
clip_target.clip = comfy.text_encoders.lumina2.te(**llama_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.lumina2.LuminaTokenizer
tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None)
elif te_model == TEModel.GEMMA_3_4B:
clip_target.clip = comfy.text_encoders.lumina2.te(**llama_detect(clip_data), model_type="gemma3_4b")
clip_target.tokenizer = comfy.text_encoders.lumina2.NTokenizer
tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None)
elif te_model == TEModel.LLAMA3_8:
clip_target.clip = comfy.text_encoders.hidream.hidream_clip(**llama_detect(clip_data),
clip_l=False, clip_g=False, t5=False, llama=True, dtype_t5=None, t5xxl_scaled_fp8=None)
Expand Down
133 changes: 111 additions & 22 deletions comfy/text_encoders/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from dataclasses import dataclass
from typing import Optional, Any
import math
import logging

from comfy.ldm.modules.attention import optimized_attention_for_device
import comfy.model_management
Expand All @@ -28,6 +29,9 @@ class Llama2Config:
mlp_activation = "silu"
qkv_bias = False
rope_dims = None
q_norm = None
k_norm = None
rope_scale = None

@dataclass
class Qwen25_3BConfig:
Expand All @@ -46,6 +50,9 @@ class Qwen25_3BConfig:
mlp_activation = "silu"
qkv_bias = True
rope_dims = None
q_norm = None
k_norm = None
rope_scale = None

@dataclass
class Qwen25_7BVLI_Config:
Expand All @@ -64,6 +71,9 @@ class Qwen25_7BVLI_Config:
mlp_activation = "silu"
qkv_bias = True
rope_dims = [16, 24, 24]
q_norm = None
k_norm = None
rope_scale = None

@dataclass
class Gemma2_2B_Config:
Expand All @@ -82,6 +92,32 @@ class Gemma2_2B_Config:
mlp_activation = "gelu_pytorch_tanh"
qkv_bias = False
rope_dims = None
q_norm = None
k_norm = None
sliding_attention = None
rope_scale = None

@dataclass
class Gemma3_4B_Config:
vocab_size: int = 262208
hidden_size: int = 2560
intermediate_size: int = 10240
num_hidden_layers: int = 34
num_attention_heads: int = 8
num_key_value_heads: int = 4
max_position_embeddings: int = 131072
rms_norm_eps: float = 1e-6
rope_theta = [10000.0, 1000000.0]
transformer_type: str = "gemma3"
head_dim = 256
rms_norm_add = True
mlp_activation = "gelu_pytorch_tanh"
qkv_bias = False
rope_dims = None
q_norm = "gemma3"
k_norm = "gemma3"
sliding_attention = [False, False, False, False, False, 1024]
rope_scale = [1.0, 8.0]

class RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-5, add=False, device=None, dtype=None):
Expand All @@ -106,25 +142,40 @@ def rotate_half(x):
return torch.cat((-x2, x1), dim=-1)


def precompute_freqs_cis(head_dim, position_ids, theta, rope_dims=None, device=None):
theta_numerator = torch.arange(0, head_dim, 2, device=device).float()
inv_freq = 1.0 / (theta ** (theta_numerator / head_dim))
def precompute_freqs_cis(head_dim, position_ids, theta, rope_scale=None, rope_dims=None, device=None):
if not isinstance(theta, list):
theta = [theta]

out = []
for index, t in enumerate(theta):
theta_numerator = torch.arange(0, head_dim, 2, device=device).float()
inv_freq = 1.0 / (t ** (theta_numerator / head_dim))

if rope_scale is not None:
if isinstance(rope_scale, list):
inv_freq /= rope_scale[index]
else:
inv_freq /= rope_scale

inv_freq_expanded = inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
position_ids_expanded = position_ids[:, None, :].float()
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()
if rope_dims is not None and position_ids.shape[0] > 1:
mrope_section = rope_dims * 2
cos = torch.cat([m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], dim=-1).unsqueeze(0)
sin = torch.cat([m[i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))], dim=-1).unsqueeze(0)
else:
cos = cos.unsqueeze(1)
sin = sin.unsqueeze(1)
out.append((cos, sin))

inv_freq_expanded = inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
position_ids_expanded = position_ids[:, None, :].float()
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()
if rope_dims is not None and position_ids.shape[0] > 1:
mrope_section = rope_dims * 2
cos = torch.cat([m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], dim=-1).unsqueeze(0)
sin = torch.cat([m[i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))], dim=-1).unsqueeze(0)
else:
cos = cos.unsqueeze(1)
sin = sin.unsqueeze(1)
if len(out) == 1:
return out[0]

return (cos, sin)
return out


def apply_rope(xq, xk, freqs_cis):
Expand Down Expand Up @@ -152,6 +203,14 @@ def __init__(self, config: Llama2Config, device=None, dtype=None, ops: Any = Non
self.v_proj = ops.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=config.qkv_bias, device=device, dtype=dtype)
self.o_proj = ops.Linear(self.inner_size, config.hidden_size, bias=False, device=device, dtype=dtype)

self.q_norm = None
self.k_norm = None

if config.q_norm == "gemma3":
self.q_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
if config.k_norm == "gemma3":
self.k_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)

def forward(
self,
hidden_states: torch.Tensor,
Expand All @@ -168,6 +227,11 @@ def forward(
xk = xk.view(batch_size, seq_length, self.num_kv_heads, self.head_dim).transpose(1, 2)
xv = xv.view(batch_size, seq_length, self.num_kv_heads, self.head_dim).transpose(1, 2)

if self.q_norm is not None:
xq = self.q_norm(xq)
if self.k_norm is not None:
xk = self.k_norm(xk)

xq, xk = apply_rope(xq, xk, freqs_cis=freqs_cis)

xk = xk.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1)
Expand All @@ -192,7 +256,7 @@ def forward(self, x):
return self.down_proj(self.activation(self.gate_proj(x)) * self.up_proj(x))

class TransformerBlock(nn.Module):
def __init__(self, config: Llama2Config, device=None, dtype=None, ops: Any = None):
def __init__(self, config: Llama2Config, index, device=None, dtype=None, ops: Any = None):
super().__init__()
self.self_attn = Attention(config, device=device, dtype=dtype, ops=ops)
self.mlp = MLP(config, device=device, dtype=dtype, ops=ops)
Expand Down Expand Up @@ -226,7 +290,7 @@ def forward(
return x

class TransformerBlockGemma2(nn.Module):
def __init__(self, config: Llama2Config, device=None, dtype=None, ops: Any = None):
def __init__(self, config: Llama2Config, index, device=None, dtype=None, ops: Any = None):
super().__init__()
self.self_attn = Attention(config, device=device, dtype=dtype, ops=ops)
self.mlp = MLP(config, device=device, dtype=dtype, ops=ops)
Expand All @@ -235,13 +299,28 @@ def __init__(self, config: Llama2Config, device=None, dtype=None, ops: Any = Non
self.pre_feedforward_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
self.post_feedforward_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)

if config.sliding_attention is not None: # TODO: implement. (Not that necessary since models are trained on less than 1024 tokens)
self.sliding_attention = config.sliding_attention[index % len(config.sliding_attention)]
else:
self.sliding_attention = False

self.transformer_type = config.transformer_type

def forward(
self,
x: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
freqs_cis: Optional[torch.Tensor] = None,
optimized_attention=None,
):
if self.transformer_type == 'gemma3':
if self.sliding_attention:
if x.shape[1] > self.sliding_attention:
logging.warning("Warning: sliding attention not implemented, results may be incorrect")
freqs_cis = freqs_cis[1]
else:
freqs_cis = freqs_cis[0]

# Self Attention
residual = x
x = self.input_layernorm(x)
Expand Down Expand Up @@ -276,16 +355,16 @@ def __init__(self, config, device=None, dtype=None, ops=None):
device=device,
dtype=dtype
)
if self.config.transformer_type == "gemma2":
if self.config.transformer_type == "gemma2" or self.config.transformer_type == "gemma3":
transformer = TransformerBlockGemma2
self.normalize_in = True
else:
transformer = TransformerBlock
self.normalize_in = False

self.layers = nn.ModuleList([
transformer(config, device=device, dtype=dtype, ops=ops)
for _ in range(config.num_hidden_layers)
transformer(config, index=i, device=device, dtype=dtype, ops=ops)
for i in range(config.num_hidden_layers)
])
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
# self.lm_head = ops.Linear(config.hidden_size, config.vocab_size, bias=False, device=device, dtype=dtype)
Expand All @@ -305,6 +384,7 @@ def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermed
freqs_cis = precompute_freqs_cis(self.config.head_dim,
position_ids,
self.config.rope_theta,
self.config.rope_scale,
self.config.rope_dims,
device=x.device)

Expand Down Expand Up @@ -433,3 +513,12 @@ def __init__(self, config_dict, dtype, device, operations):

self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
self.dtype = dtype

class Gemma3_4B(BaseLlama, torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations):
super().__init__()
config = Gemma3_4B_Config(**config_dict)
self.num_layers = config.num_hidden_layers

self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
self.dtype = dtype
26 changes: 22 additions & 4 deletions comfy/text_encoders/lumina2.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,29 +11,47 @@ def __init__(self, embedding_directory=None, tokenizer_data={}):
def state_dict(self):
return {"spiece_model": self.tokenizer.serialize_model()}

class Gemma3_4BTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
tokenizer = tokenizer_data.get("spiece_model", None)
super().__init__(tokenizer, pad_with_end=False, embedding_size=2560, embedding_key='gemma3_4b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, tokenizer_args={"add_bos": True, "add_eos": False}, tokenizer_data=tokenizer_data)

def state_dict(self):
return {"spiece_model": self.tokenizer.serialize_model()}

class LuminaTokenizer(sd1_clip.SD1Tokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name="gemma2_2b", tokenizer=Gemma2BTokenizer)

class NTokenizer(sd1_clip.SD1Tokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name="gemma3_4b", tokenizer=Gemma3_4BTokenizer)

class Gemma2_2BModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="hidden", layer_idx=-2, dtype=None, attention_mask=True, model_options={}):
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"start": 2, "pad": 0}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Gemma2_2B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)

class Gemma3_4BModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="hidden", layer_idx=-2, dtype=None, attention_mask=True, model_options={}):
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"start": 2, "pad": 0}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Gemma3_4B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)

class LuminaModel(sd1_clip.SD1ClipModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
super().__init__(device=device, dtype=dtype, name="gemma2_2b", clip_model=Gemma2_2BModel, model_options=model_options)
def __init__(self, device="cpu", dtype=None, model_options={}, name="gemma2_2b", clip_model=Gemma2_2BModel):
super().__init__(device=device, dtype=dtype, name=name, clip_model=clip_model, model_options=model_options)


def te(dtype_llama=None, llama_scaled_fp8=None, model_type="gemma2_2b"):
if model_type == "gemma2_2b":
model = Gemma2_2BModel
elif model_type == "gemma3_4b":
model = Gemma3_4BModel

def te(dtype_llama=None, llama_scaled_fp8=None):
class LuminaTEModel_(LuminaModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
if llama_scaled_fp8 is not None and "scaled_fp8" not in model_options:
model_options = model_options.copy()
model_options["scaled_fp8"] = llama_scaled_fp8
if dtype_llama is not None:
dtype = dtype_llama
super().__init__(device=device, dtype=dtype, model_options=model_options)
super().__init__(device=device, dtype=dtype, name=model_type, model_options=model_options, clip_model=model)
return LuminaTEModel_