Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions comfy/sd1_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,6 @@ def __init__(self, device="cpu", max_length=77,
special_tokens={"start": 49406, "end": 49407, "pad": 49407}, layer_norm_hidden_state=True, enable_attention_masks=False, zero_out_masked=False,
return_projected_pooled=True, return_attention_masks=False, model_options={}): # clip-vit-base-patch32
super().__init__()
assert layer in self.LAYERS

if textmodel_json_config is None:
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_clip_config.json")
Expand Down Expand Up @@ -164,7 +163,7 @@ def freeze(self):
def set_clip_options(self, options):
layer_idx = options.get("layer", self.layer_idx)
self.return_projected_pooled = options.get("projected_pooled", self.return_projected_pooled)
if self.layer == "all":
if isinstance(self.layer, list) or self.layer == "all":
pass
elif layer_idx is None or abs(layer_idx) > self.num_layers:
self.layer = "last"
Expand Down Expand Up @@ -266,7 +265,9 @@ def forward(self, tokens):
if self.enable_attention_masks:
attention_mask_model = attention_mask

if self.layer == "all":
if isinstance(self.layer, list):
intermediate_output = self.layer
elif self.layer == "all":
intermediate_output = "all"
else:
intermediate_output = self.layer_idx
Expand Down
4 changes: 2 additions & 2 deletions comfy/text_encoders/flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None
return tokens

class Mistral3_24BModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="all", layer_idx=None, dtype=None, attention_mask=True, model_options={}):
def __init__(self, device="cpu", layer=[10, 20, 30], layer_idx=None, dtype=None, attention_mask=True, model_options={}):
textmodel_json_config = {}
num_layers = model_options.get("num_layers", None)
if num_layers is not None:
Expand All @@ -154,7 +154,7 @@ def __init__(self, device="cpu", dtype=None, model_options={}, name="mistral3_24
def encode_token_weights(self, token_weight_pairs):
out, pooled, extra = super().encode_token_weights(token_weight_pairs)

out = torch.stack((out[:, 10], out[:, 20], out[:, 30]), dim=1)
out = torch.stack((out[:, 0], out[:, 1], out[:, 2]), dim=1)
out = out.movedim(1, 2)
out = out.reshape(out.shape[0], out.shape[1], -1)
return out, pooled, extra
Expand Down
12 changes: 9 additions & 3 deletions comfy/text_encoders/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,16 +434,21 @@ def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermed

intermediate = None
all_intermediate = None
only_layers = None
if intermediate_output is not None:
if intermediate_output == "all":
if isinstance(intermediate_output, list):
all_intermediate = []
only_layers = set(intermediate_output)
elif intermediate_output == "all":
all_intermediate = []
intermediate_output = None
elif intermediate_output < 0:
intermediate_output = len(self.layers) + intermediate_output

for i, layer in enumerate(self.layers):
if all_intermediate is not None:
all_intermediate.append(x.unsqueeze(1).clone())
if only_layers is None or (i in only_layers):
all_intermediate.append(x.unsqueeze(1).clone())
x = layer(
x=x,
attention_mask=mask,
Expand All @@ -457,7 +462,8 @@ def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermed
x = self.norm(x)

if all_intermediate is not None:
all_intermediate.append(x.unsqueeze(1).clone())
if only_layers is None or ((i + 1) in only_layers):
all_intermediate.append(x.unsqueeze(1).clone())

if all_intermediate is not None:
intermediate = torch.cat(all_intermediate, dim=1)
Expand Down