diff --git a/comfy/latent_formats.py b/comfy/latent_formats.py index 8945408793a7..77e642a94476 100644 --- a/comfy/latent_formats.py +++ b/comfy/latent_formats.py @@ -629,3 +629,20 @@ class Hunyuan3Dv2mini(LatentFormat): class ACEAudio(LatentFormat): latent_channels = 8 latent_dimensions = 2 + +class ChromaRadiance(LatentFormat): + latent_channels = 3 + + def __init__(self): + self.latent_rgb_factors = [ + # R G B + [ 1.0, 0.0, 0.0 ], + [ 0.0, 1.0, 0.0 ], + [ 0.0, 0.0, 1.0 ] + ] + + def process_in(self, latent): + return latent + + def process_out(self, latent): + return latent diff --git a/comfy/ldm/chroma/model.py b/comfy/ldm/chroma/model.py index 4f709f87d7c9..ad1c523fe1d3 100644 --- a/comfy/ldm/chroma/model.py +++ b/comfy/ldm/chroma/model.py @@ -151,8 +151,6 @@ def forward_orig( attn_mask: Tensor = None, ) -> Tensor: patches_replace = transformer_options.get("patches_replace", {}) - if img.ndim != 3 or txt.ndim != 3: - raise ValueError("Input img and txt tensors must have 3 dimensions.") # running on sequences img img = self.img_in(img) @@ -254,8 +252,9 @@ def block_wrap(args): img[:, txt.shape[1] :, ...] += add img = img[:, txt.shape[1] :, ...] - final_mod = self.get_modulations(mod_vectors, "final") - img = self.final_layer(img, vec=final_mod) # (N, T, patch_size ** 2 * out_channels) + if hasattr(self, "final_layer"): + final_mod = self.get_modulations(mod_vectors, "final") + img = self.final_layer(img, vec=final_mod) # (N, T, patch_size ** 2 * out_channels) return img def forward(self, x, timestep, context, guidance, control=None, transformer_options={}, **kwargs): @@ -271,6 +270,9 @@ def _forward(self, x, timestep, context, guidance, control=None, transformer_opt img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=self.patch_size, pw=self.patch_size) + if img.ndim != 3 or context.ndim != 3: + raise ValueError("Input img and txt tensors must have 3 dimensions.") + h_len = ((h + (self.patch_size // 2)) // self.patch_size) w_len = ((w + (self.patch_size // 2)) // self.patch_size) img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype) diff --git a/comfy/ldm/chroma_radiance/layers.py b/comfy/ldm/chroma_radiance/layers.py new file mode 100644 index 000000000000..3c7bc9b6b985 --- /dev/null +++ b/comfy/ldm/chroma_radiance/layers.py @@ -0,0 +1,206 @@ +# Adapted from https://github.com/lodestone-rock/flow +from functools import lru_cache + +import torch +from torch import nn + +from comfy.ldm.flux.layers import RMSNorm + + +class NerfEmbedder(nn.Module): + """ + An embedder module that combines input features with a 2D positional + encoding that mimics the Discrete Cosine Transform (DCT). + + This module takes an input tensor of shape (B, P^2, C), where P is the + patch size, and enriches it with positional information before projecting + it to a new hidden size. + """ + def __init__( + self, + in_channels: int, + hidden_size_input: int, + max_freqs: int, + dtype=None, + device=None, + operations=None, + ): + """ + Initializes the NerfEmbedder. + + Args: + in_channels (int): The number of channels in the input tensor. + hidden_size_input (int): The desired dimension of the output embedding. + max_freqs (int): The number of frequency components to use for both + the x and y dimensions of the positional encoding. + The total number of positional features will be max_freqs^2. + """ + super().__init__() + self.dtype = dtype + self.max_freqs = max_freqs + self.hidden_size_input = hidden_size_input + + # A linear layer to project the concatenated input features and + # positional encodings to the final output dimension. + self.embedder = nn.Sequential( + operations.Linear(in_channels + max_freqs**2, hidden_size_input, dtype=dtype, device=device) + ) + + @lru_cache(maxsize=4) + def fetch_pos(self, patch_size: int, device: torch.device, dtype: torch.dtype) -> torch.Tensor: + """ + Generates and caches 2D DCT-like positional embeddings for a given patch size. + + The LRU cache is a performance optimization that avoids recomputing the + same positional grid on every forward pass. + + Args: + patch_size (int): The side length of the square input patch. + device: The torch device to create the tensors on. + dtype: The torch dtype for the tensors. + + Returns: + A tensor of shape (1, patch_size^2, max_freqs^2) containing the + positional embeddings. + """ + # Create normalized 1D coordinate grids from 0 to 1. + pos_x = torch.linspace(0, 1, patch_size, device=device, dtype=dtype) + pos_y = torch.linspace(0, 1, patch_size, device=device, dtype=dtype) + + # Create a 2D meshgrid of coordinates. + pos_y, pos_x = torch.meshgrid(pos_y, pos_x, indexing="ij") + + # Reshape positions to be broadcastable with frequencies. + # Shape becomes (patch_size^2, 1, 1). + pos_x = pos_x.reshape(-1, 1, 1) + pos_y = pos_y.reshape(-1, 1, 1) + + # Create a 1D tensor of frequency values from 0 to max_freqs-1. + freqs = torch.linspace(0, self.max_freqs - 1, self.max_freqs, dtype=dtype, device=device) + + # Reshape frequencies to be broadcastable for creating 2D basis functions. + # freqs_x shape: (1, max_freqs, 1) + # freqs_y shape: (1, 1, max_freqs) + freqs_x = freqs[None, :, None] + freqs_y = freqs[None, None, :] + + # A custom weighting coefficient, not part of standard DCT. + # This seems to down-weight the contribution of higher-frequency interactions. + coeffs = (1 + freqs_x * freqs_y) ** -1 + + # Calculate the 1D cosine basis functions for x and y coordinates. + # This is the core of the DCT formulation. + dct_x = torch.cos(pos_x * freqs_x * torch.pi) + dct_y = torch.cos(pos_y * freqs_y * torch.pi) + + # Combine the 1D basis functions to create 2D basis functions by element-wise + # multiplication, and apply the custom coefficients. Broadcasting handles the + # combination of all (pos_x, freqs_x) with all (pos_y, freqs_y). + # The result is flattened into a feature vector for each position. + dct = (dct_x * dct_y * coeffs).view(1, -1, self.max_freqs ** 2) + + return dct + + def forward(self, inputs: torch.Tensor) -> torch.Tensor: + """ + Forward pass for the embedder. + + Args: + inputs (Tensor): The input tensor of shape (B, P^2, C). + + Returns: + Tensor: The output tensor of shape (B, P^2, hidden_size_input). + """ + # Get the batch size, number of pixels, and number of channels. + B, P2, C = inputs.shape + + # Infer the patch side length from the number of pixels (P^2). + patch_size = int(P2 ** 0.5) + + input_dtype = inputs.dtype + inputs = inputs.to(dtype=self.dtype) + + # Fetch the pre-computed or cached positional embeddings. + dct = self.fetch_pos(patch_size, inputs.device, self.dtype) + + # Repeat the positional embeddings for each item in the batch. + dct = dct.repeat(B, 1, 1) + + # Concatenate the original input features with the positional embeddings + # along the feature dimension. + inputs = torch.cat((inputs, dct), dim=-1) + + # Project the combined tensor to the target hidden size. + return self.embedder(inputs).to(dtype=input_dtype) + + +class NerfGLUBlock(nn.Module): + """ + A NerfBlock using a Gated Linear Unit (GLU) like MLP. + """ + def __init__(self, hidden_size_s: int, hidden_size_x: int, mlp_ratio, dtype=None, device=None, operations=None): + super().__init__() + # The total number of parameters for the MLP is increased to accommodate + # the gate, value, and output projection matrices. + # We now need to generate parameters for 3 matrices. + total_params = 3 * hidden_size_x**2 * mlp_ratio + self.param_generator = operations.Linear(hidden_size_s, total_params, dtype=dtype, device=device) + self.norm = RMSNorm(hidden_size_x, dtype=dtype, device=device, operations=operations) + self.mlp_ratio = mlp_ratio + + + def forward(self, x: torch.Tensor, s: torch.Tensor) -> torch.Tensor: + batch_size, num_x, hidden_size_x = x.shape + mlp_params = self.param_generator(s) + + # Split the generated parameters into three parts for the gate, value, and output projection. + fc1_gate_params, fc1_value_params, fc2_params = mlp_params.chunk(3, dim=-1) + + # Reshape the parameters into matrices for batch matrix multiplication. + fc1_gate = fc1_gate_params.view(batch_size, hidden_size_x, hidden_size_x * self.mlp_ratio) + fc1_value = fc1_value_params.view(batch_size, hidden_size_x, hidden_size_x * self.mlp_ratio) + fc2 = fc2_params.view(batch_size, hidden_size_x * self.mlp_ratio, hidden_size_x) + + # Normalize the generated weight matrices as in the original implementation. + fc1_gate = torch.nn.functional.normalize(fc1_gate, dim=-2) + fc1_value = torch.nn.functional.normalize(fc1_value, dim=-2) + fc2 = torch.nn.functional.normalize(fc2, dim=-2) + + res_x = x + x = self.norm(x) + + # Apply the final output projection. + x = torch.bmm(torch.nn.functional.silu(torch.bmm(x, fc1_gate)) * torch.bmm(x, fc1_value), fc2) + + return x + res_x + + +class NerfFinalLayer(nn.Module): + def __init__(self, hidden_size, out_channels, dtype=None, device=None, operations=None): + super().__init__() + self.norm = RMSNorm(hidden_size, dtype=dtype, device=device, operations=operations) + self.linear = operations.Linear(hidden_size, out_channels, dtype=dtype, device=device) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # RMSNorm normalizes over the last dimension, but our channel dim (C) is at dim=1. + # So we temporarily move the channel dimension to the end for the norm operation. + return self.linear(self.norm(x.movedim(1, -1))).movedim(-1, 1) + + +class NerfFinalLayerConv(nn.Module): + def __init__(self, hidden_size: int, out_channels: int, dtype=None, device=None, operations=None): + super().__init__() + self.norm = RMSNorm(hidden_size, dtype=dtype, device=device, operations=operations) + self.conv = operations.Conv2d( + in_channels=hidden_size, + out_channels=out_channels, + kernel_size=3, + padding=1, + dtype=dtype, + device=device, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # RMSNorm normalizes over the last dimension, but our channel dim (C) is at dim=1. + # So we temporarily move the channel dimension to the end for the norm operation. + return self.conv(self.norm(x.movedim(1, -1)).movedim(-1, 1)) diff --git a/comfy/ldm/chroma_radiance/model.py b/comfy/ldm/chroma_radiance/model.py new file mode 100644 index 000000000000..f7eb7a22eaa8 --- /dev/null +++ b/comfy/ldm/chroma_radiance/model.py @@ -0,0 +1,328 @@ +# Credits: +# Original Flux code can be found on: https://github.com/black-forest-labs/flux +# Chroma Radiance adaption referenced from https://github.com/lodestone-rock/flow + +from dataclasses import dataclass +from typing import Optional + +import torch +from torch import Tensor, nn +from einops import repeat +import comfy.ldm.common_dit + +from comfy.ldm.flux.layers import EmbedND + +from comfy.ldm.chroma.model import Chroma, ChromaParams +from comfy.ldm.chroma.layers import ( + DoubleStreamBlock, + SingleStreamBlock, + Approximator, +) +from .layers import ( + NerfEmbedder, + NerfGLUBlock, + NerfFinalLayer, + NerfFinalLayerConv, +) + + +@dataclass +class ChromaRadianceParams(ChromaParams): + patch_size: int + nerf_hidden_size: int + nerf_mlp_ratio: int + nerf_depth: int + nerf_max_freqs: int + # Setting nerf_tile_size to 0 disables tiling. + nerf_tile_size: int + # Currently one of linear (legacy) or conv. + nerf_final_head_type: str + # None means use the same dtype as the model. + nerf_embedder_dtype: Optional[torch.dtype] + + +class ChromaRadiance(Chroma): + """ + Transformer model for flow matching on sequences. + """ + + def __init__(self, image_model=None, final_layer=True, dtype=None, device=None, operations=None, **kwargs): + if operations is None: + raise RuntimeError("Attempt to create ChromaRadiance object without setting operations") + nn.Module.__init__(self) + self.dtype = dtype + params = ChromaRadianceParams(**kwargs) + self.params = params + self.patch_size = params.patch_size + self.in_channels = params.in_channels + self.out_channels = params.out_channels + if params.hidden_size % params.num_heads != 0: + raise ValueError( + f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}" + ) + pe_dim = params.hidden_size // params.num_heads + if sum(params.axes_dim) != pe_dim: + raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}") + self.hidden_size = params.hidden_size + self.num_heads = params.num_heads + self.in_dim = params.in_dim + self.out_dim = params.out_dim + self.hidden_dim = params.hidden_dim + self.n_layers = params.n_layers + self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim) + self.img_in_patch = operations.Conv2d( + params.in_channels, + params.hidden_size, + kernel_size=params.patch_size, + stride=params.patch_size, + bias=True, + dtype=dtype, + device=device, + ) + self.txt_in = operations.Linear(params.context_in_dim, self.hidden_size, dtype=dtype, device=device) + # set as nn identity for now, will overwrite it later. + self.distilled_guidance_layer = Approximator( + in_dim=self.in_dim, + hidden_dim=self.hidden_dim, + out_dim=self.out_dim, + n_layers=self.n_layers, + dtype=dtype, device=device, operations=operations + ) + + + self.double_blocks = nn.ModuleList( + [ + DoubleStreamBlock( + self.hidden_size, + self.num_heads, + mlp_ratio=params.mlp_ratio, + qkv_bias=params.qkv_bias, + dtype=dtype, device=device, operations=operations + ) + for _ in range(params.depth) + ] + ) + + self.single_blocks = nn.ModuleList( + [ + SingleStreamBlock( + self.hidden_size, + self.num_heads, + mlp_ratio=params.mlp_ratio, + dtype=dtype, device=device, operations=operations, + ) + for _ in range(params.depth_single_blocks) + ] + ) + + # pixel channel concat with DCT + self.nerf_image_embedder = NerfEmbedder( + in_channels=params.in_channels, + hidden_size_input=params.nerf_hidden_size, + max_freqs=params.nerf_max_freqs, + dtype=params.nerf_embedder_dtype or dtype, + device=device, + operations=operations, + ) + + self.nerf_blocks = nn.ModuleList([ + NerfGLUBlock( + hidden_size_s=params.hidden_size, + hidden_size_x=params.nerf_hidden_size, + mlp_ratio=params.nerf_mlp_ratio, + dtype=dtype, + device=device, + operations=operations, + ) for _ in range(params.nerf_depth) + ]) + + if params.nerf_final_head_type == "linear": + self.nerf_final_layer = NerfFinalLayer( + params.nerf_hidden_size, + out_channels=params.in_channels, + dtype=dtype, + device=device, + operations=operations, + ) + elif params.nerf_final_head_type == "conv": + self.nerf_final_layer_conv = NerfFinalLayerConv( + params.nerf_hidden_size, + out_channels=params.in_channels, + dtype=dtype, + device=device, + operations=operations, + ) + else: + errstr = f"Unsupported nerf_final_head_type {params.nerf_final_head_type}" + raise ValueError(errstr) + + self.skip_mmdit = [] + self.skip_dit = [] + self.lite = False + + @property + def _nerf_final_layer(self) -> nn.Module: + if self.params.nerf_final_head_type == "linear": + return self.nerf_final_layer + if self.params.nerf_final_head_type == "conv": + return self.nerf_final_layer_conv + # Impossible to get here as we raise an error on unexpected types on initialization. + raise NotImplementedError + + def img_in(self, img: Tensor) -> Tensor: + img = self.img_in_patch(img) # -> [B, Hidden, H/P, W/P] + # flatten into a sequence for the transformer. + return img.flatten(2).transpose(1, 2) # -> [B, NumPatches, Hidden] + + def forward_nerf( + self, + img_orig: Tensor, + img_out: Tensor, + params: ChromaRadianceParams, + ) -> Tensor: + B, C, H, W = img_orig.shape + num_patches = img_out.shape[1] + patch_size = params.patch_size + + # Store the raw pixel values of each patch for the NeRF head later. + # unfold creates patches: [B, C * P * P, NumPatches] + nerf_pixels = nn.functional.unfold(img_orig, kernel_size=patch_size, stride=patch_size) + nerf_pixels = nerf_pixels.transpose(1, 2) # -> [B, NumPatches, C * P * P] + + if params.nerf_tile_size > 0 and num_patches > params.nerf_tile_size: + # Enable tiling if nerf_tile_size isn't 0 and we actually have more patches than + # the tile size. + img_dct = self.forward_tiled_nerf(img_out, nerf_pixels, B, C, num_patches, patch_size, params) + else: + # Reshape for per-patch processing + nerf_hidden = img_out.reshape(B * num_patches, params.hidden_size) + nerf_pixels = nerf_pixels.reshape(B * num_patches, C, patch_size**2).transpose(1, 2) + + # Get DCT-encoded pixel embeddings [pixel-dct] + img_dct = self.nerf_image_embedder(nerf_pixels) + + # Pass through the dynamic MLP blocks (the NeRF) + for block in self.nerf_blocks: + img_dct = block(img_dct, nerf_hidden) + + # Reassemble the patches into the final image. + img_dct = img_dct.transpose(1, 2) # -> [B*NumPatches, C, P*P] + # Reshape to combine with batch dimension for fold + img_dct = img_dct.reshape(B, num_patches, -1) # -> [B, NumPatches, C*P*P] + img_dct = img_dct.transpose(1, 2) # -> [B, C*P*P, NumPatches] + img_dct = nn.functional.fold( + img_dct, + output_size=(H, W), + kernel_size=patch_size, + stride=patch_size, + ) + return self._nerf_final_layer(img_dct) + + def forward_tiled_nerf( + self, + nerf_hidden: Tensor, + nerf_pixels: Tensor, + batch: int, + channels: int, + num_patches: int, + patch_size: int, + params: ChromaRadianceParams, + ) -> Tensor: + """ + Processes the NeRF head in tiles to save memory. + nerf_hidden has shape [B, L, D] + nerf_pixels has shape [B, L, C * P * P] + """ + tile_size = params.nerf_tile_size + output_tiles = [] + # Iterate over the patches in tiles. The dimension L (num_patches) is at index 1. + for i in range(0, num_patches, tile_size): + end = min(i + tile_size, num_patches) + + # Slice the current tile from the input tensors + nerf_hidden_tile = nerf_hidden[:, i:end, :] + nerf_pixels_tile = nerf_pixels[:, i:end, :] + + # Get the actual number of patches in this tile (can be smaller for the last tile) + num_patches_tile = nerf_hidden_tile.shape[1] + + # Reshape the tile for per-patch processing + # [B, NumPatches_tile, D] -> [B * NumPatches_tile, D] + nerf_hidden_tile = nerf_hidden_tile.reshape(batch * num_patches_tile, params.hidden_size) + # [B, NumPatches_tile, C*P*P] -> [B*NumPatches_tile, C, P*P] -> [B*NumPatches_tile, P*P, C] + nerf_pixels_tile = nerf_pixels_tile.reshape(batch * num_patches_tile, channels, patch_size**2).transpose(1, 2) + + # get DCT-encoded pixel embeddings [pixel-dct] + img_dct_tile = self.nerf_image_embedder(nerf_pixels_tile) + + # pass through the dynamic MLP blocks (the NeRF) + for block in self.nerf_blocks: + img_dct_tile = block(img_dct_tile, nerf_hidden_tile) + + output_tiles.append(img_dct_tile) + + # Concatenate the processed tiles along the patch dimension + return torch.cat(output_tiles, dim=0) + + def radiance_get_override_params(self, overrides: dict) -> ChromaRadianceParams: + params = self.params + if not overrides: + return params + params_dict = {k: getattr(params, k) for k in params.__dataclass_fields__} + nullable_keys = frozenset(("nerf_embedder_dtype",)) + bad_keys = tuple(k for k in overrides if k not in params_dict) + if bad_keys: + e = f"Unknown key(s) in transformer_options chroma_radiance_options: {', '.join(bad_keys)}" + raise ValueError(e) + bad_keys = tuple( + k + for k, v in overrides.items() + if type(v) != type(getattr(params, k)) and (v is not None or k not in nullable_keys) + ) + if bad_keys: + e = f"Invalid value(s) in transformer_options chroma_radiance_options: {', '.join(bad_keys)}" + raise ValueError(e) + # At this point it's all valid keys and values so we can merge with the existing params. + params_dict |= overrides + return params.__class__(**params_dict) + + def _forward( + self, + x: Tensor, + timestep: Tensor, + context: Tensor, + guidance: Optional[Tensor], + control: Optional[dict]=None, + transformer_options: dict={}, + **kwargs: dict, + ) -> Tensor: + bs, c, h, w = x.shape + img = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_size, self.patch_size)) + + if img.ndim != 4: + raise ValueError("Input img tensor must be in [B, C, H, W] format.") + if context.ndim != 3: + raise ValueError("Input txt tensors must have 3 dimensions.") + + params = self.radiance_get_override_params(transformer_options.get("chroma_radiance_options", {})) + + h_len = ((h + (self.patch_size // 2)) // self.patch_size) + w_len = ((w + (self.patch_size // 2)) // self.patch_size) + img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype) + img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1) + img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0) + img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs) + txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype) + + img_out = self.forward_orig( + img, + img_ids, + context, + txt_ids, + timestep, + guidance, + control, + transformer_options, + attn_mask=kwargs.get("attention_mask", None), + ) + return self.forward_nerf(img, img_out, params) diff --git a/comfy/model_base.py b/comfy/model_base.py index 324d89cffe82..252dfcf69d6d 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -42,6 +42,7 @@ import comfy.ldm.hunyuan3d.model import comfy.ldm.hidream.model import comfy.ldm.chroma.model +import comfy.ldm.chroma_radiance.model import comfy.ldm.ace.model import comfy.ldm.omnigen.omnigen2 import comfy.ldm.qwen_image.model @@ -1320,8 +1321,8 @@ def extra_conds(self, **kwargs): return out class Chroma(Flux): - def __init__(self, model_config, model_type=ModelType.FLUX, device=None): - super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.chroma.model.Chroma) + def __init__(self, model_config, model_type=ModelType.FLUX, device=None, unet_model=comfy.ldm.chroma.model.Chroma): + super().__init__(model_config, model_type, device=device, unet_model=unet_model) def extra_conds(self, **kwargs): out = super().extra_conds(**kwargs) @@ -1331,6 +1332,10 @@ def extra_conds(self, **kwargs): out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([guidance])) return out +class ChromaRadiance(Chroma): + def __init__(self, model_config, model_type=ModelType.FLUX, device=None): + super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.chroma_radiance.model.ChromaRadiance) + class ACEStep(BaseModel): def __init__(self, model_config, model_type=ModelType.FLOW, device=None): super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.ace.model.ACEStepTransformer2DModel) diff --git a/comfy/model_detection.py b/comfy/model_detection.py index fe983cede3f9..03d44f65e2a9 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -174,7 +174,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["guidance_embed"] = len(guidance_keys) > 0 return dit_config - if '{}double_blocks.0.img_attn.norm.key_norm.scale'.format(key_prefix) in state_dict_keys and '{}img_in.weight'.format(key_prefix) in state_dict_keys: #Flux + if '{}double_blocks.0.img_attn.norm.key_norm.scale'.format(key_prefix) in state_dict_keys and ('{}img_in.weight'.format(key_prefix) in state_dict_keys or f"{key_prefix}distilled_guidance_layer.norms.0.scale" in state_dict_keys): #Flux, Chroma or Chroma Radiance (has no img_in.weight) dit_config = {} dit_config["image_model"] = "flux" dit_config["in_channels"] = 16 @@ -204,6 +204,18 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["out_dim"] = 3072 dit_config["hidden_dim"] = 5120 dit_config["n_layers"] = 5 + if f"{key_prefix}nerf_blocks.0.norm.scale" in state_dict_keys: #Chroma Radiance + dit_config["image_model"] = "chroma_radiance" + dit_config["in_channels"] = 3 + dit_config["out_channels"] = 3 + dit_config["patch_size"] = 16 + dit_config["nerf_hidden_size"] = 64 + dit_config["nerf_mlp_ratio"] = 4 + dit_config["nerf_depth"] = 4 + dit_config["nerf_max_freqs"] = 8 + dit_config["nerf_tile_size"] = 32 + dit_config["nerf_final_head_type"] = "conv" if f"{key_prefix}nerf_final_layer_conv.norm.scale" in state_dict_keys else "linear" + dit_config["nerf_embedder_dtype"] = torch.float32 else: dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys return dit_config diff --git a/comfy/sd.py b/comfy/sd.py index f8f1a89e8fff..cb92802e9168 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -785,6 +785,66 @@ def temporal_compression_decode(self): except: return None +# "Fake" VAE that converts from IMAGE B, H, W, C and values on the scale of 0..1 +# to LATENT B, C, H, W and values on the scale of -1..1. +class PixelspaceConversionVAE: + def __init__(self, size_increment: int=16): + self.intermediate_device = comfy.model_management.intermediate_device() + self.size_increment = size_increment + + def vae_encode_crop_pixels(self, pixels: torch.Tensor) -> torch.Tensor: + if self.size_increment == 1: + return pixels + dims = pixels.shape[1:-1] + for d in range(len(dims)): + d_adj = (dims[d] // self.size_increment) * self.size_increment + if d_adj == d: + continue + d_offset = (dims[d] % self.size_increment) // 2 + pixels = pixels.narrow(d + 1, d_offset, d_adj) + return pixels + + def encode(self, pixels: torch.Tensor, *_args, **_kwargs) -> torch.Tensor: + if pixels.ndim == 3: + pixels = pixels.unsqueeze(0) + elif pixels.ndim != 4: + raise ValueError("Unexpected input image shape") + # Ensure the image has spatial dimensions that are multiples of 16. + pixels = self.vae_encode_crop_pixels(pixels) + h, w, c = pixels.shape[1:] + if h < self.size_increment or w < self.size_increment: + raise ValueError(f"Image inputs must have height/width of at least {self.size_increment} pixel(s).") + pixels= pixels[..., :3] + if c == 1: + pixels = pixels.expand(-1, -1, -1, 3) + elif c != 3: + raise ValueError("Unexpected number of channels in input image") + # Rescale to -1..1 and move the channel dimension to position 1. + latent = pixels.to(device=self.intermediate_device, dtype=torch.float32, copy=True) + latent = latent.clamp_(0, 1).movedim(-1, 1).contiguous() + latent -= 0.5 + latent *= 2 + return latent.clamp_(-1, 1) + + def decode(self, samples: torch.Tensor, *_args, **_kwargs) -> torch.Tensor: + # Rescale to 0..1 and move the channel dimension to the end. + img = samples.to(device=self.intermediate_device, dtype=torch.float32, copy=True) + img = img.clamp_(-1, 1).movedim(1, -1).contiguous() + img += 1.0 + img *= 0.5 + return img.clamp_(0, 1) + + encode_tiled = encode + decode_tiled = decode + + @classmethod + def spacial_compression_decode(cls) -> int: + # This just exists so the tiled VAE nodes don't crash. + return 1 + + spacial_compression_encode = spacial_compression_decode + temporal_compression_decode = spacial_compression_decode + class StyleModel: def __init__(self, model, device="cpu"): self.model = model diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 472ea0ae95ef..be36b5dfe8ef 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -1205,6 +1205,19 @@ def clip_target(self, state_dict={}): t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref)) return supported_models_base.ClipTarget(comfy.text_encoders.pixart_t5.PixArtTokenizer, comfy.text_encoders.pixart_t5.pixart_te(**t5_detect)) +class ChromaRadiance(Chroma): + unet_config = { + "image_model": "chroma_radiance", + } + + latent_format = comfy.latent_formats.ChromaRadiance + + # Pixel-space model, no spatial compression for model input. + memory_usage_factor = 0.0325 + + def get_model(self, state_dict, prefix="", device=None): + return model_base.ChromaRadiance(self, device=device) + class ACEStep(supported_models_base.BASE): unet_config = { "audio_model": "ace", @@ -1338,6 +1351,6 @@ def get_model(self, state_dict, prefix="", device=None): out = model_base.HunyuanImage21Refiner(self, device=device) return out -models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ACEStep, Omnigen2, QwenImage] +models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, Omnigen2, QwenImage] models += [SVD_img2vid] diff --git a/comfy_extras/nodes_chroma_radiance.py b/comfy_extras/nodes_chroma_radiance.py new file mode 100644 index 000000000000..381989818b53 --- /dev/null +++ b/comfy_extras/nodes_chroma_radiance.py @@ -0,0 +1,114 @@ +from typing_extensions import override +from typing import Callable + +import torch + +import comfy.model_management +from comfy_api.latest import ComfyExtension, io + +import nodes + +class EmptyChromaRadianceLatentImage(io.ComfyNode): + @classmethod + def define_schema(cls) -> io.Schema: + return io.Schema( + node_id="EmptyChromaRadianceLatentImage", + category="latent/chroma_radiance", + inputs=[ + io.Int.Input(id="width", default=1024, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input(id="height", default=1024, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input(id="batch_size", default=1, min=1, max=4096), + ], + outputs=[io.Latent().Output()], + ) + + @classmethod + def execute(cls, *, width: int, height: int, batch_size: int=1) -> io.NodeOutput: + latent = torch.zeros((batch_size, 3, height, width), device=comfy.model_management.intermediate_device()) + return io.NodeOutput({"samples":latent}) + + +class ChromaRadianceOptions(io.ComfyNode): + @classmethod + def define_schema(cls) -> io.Schema: + return io.Schema( + node_id="ChromaRadianceOptions", + category="model_patches/chroma_radiance", + description="Allows setting advanced options for the Chroma Radiance model.", + inputs=[ + io.Model.Input(id="model"), + io.Boolean.Input( + id="preserve_wrapper", + default=True, + tooltip="When enabled, will delegate to an existing model function wrapper if it exists. Generally should be left enabled.", + ), + io.Float.Input( + id="start_sigma", + default=1.0, + min=0.0, + max=1.0, + tooltip="First sigma that these options will be in effect.", + ), + io.Float.Input( + id="end_sigma", + default=0.0, + min=0.0, + max=1.0, + tooltip="Last sigma that these options will be in effect.", + ), + io.Int.Input( + id="nerf_tile_size", + default=-1, + min=-1, + tooltip="Allows overriding the default NeRF tile size. -1 means use the default (32). 0 means use non-tiling mode (may require a lot of VRAM).", + ), + ], + outputs=[io.Model.Output()], + ) + + @classmethod + def execute( + cls, + *, + model: io.Model.Type, + preserve_wrapper: bool, + start_sigma: float, + end_sigma: float, + nerf_tile_size: int, + ) -> io.NodeOutput: + radiance_options = {} + if nerf_tile_size >= 0: + radiance_options["nerf_tile_size"] = nerf_tile_size + + if not radiance_options: + return io.NodeOutput(model) + + old_wrapper = model.model_options.get("model_function_wrapper") + + def model_function_wrapper(apply_model: Callable, args: dict) -> torch.Tensor: + c = args["c"].copy() + sigma = args["timestep"].max().detach().cpu().item() + if end_sigma <= sigma <= start_sigma: + transformer_options = c.get("transformer_options", {}).copy() + transformer_options["chroma_radiance_options"] = radiance_options.copy() + c["transformer_options"] = transformer_options + if not (preserve_wrapper and old_wrapper): + return apply_model(args["input"], args["timestep"], **c) + return old_wrapper(apply_model, args | {"c": c}) + + model = model.clone() + model.set_model_unet_function_wrapper(model_function_wrapper) + return io.NodeOutput(model) + + +class ChromaRadianceExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + EmptyChromaRadianceLatentImage, + ChromaRadianceOptions, + ] + + +async def comfy_entrypoint() -> ChromaRadianceExtension: + return ChromaRadianceExtension() diff --git a/nodes.py b/nodes.py index 2befb4b754a0..76b8cbac801f 100644 --- a/nodes.py +++ b/nodes.py @@ -730,6 +730,7 @@ def vae_list(): vaes.append("taesd3") if f1_taesd_dec and f1_taesd_enc: vaes.append("taef1") + vaes.append("chroma_radiance") return vaes @staticmethod @@ -772,7 +773,9 @@ def INPUT_TYPES(s): #TODO: scale factor? def load_vae(self, vae_name): - if vae_name in ["taesd", "taesdxl", "taesd3", "taef1"]: + if vae_name == "chroma_radiance": + return (comfy.sd.PixelspaceConversionVAE(),) + elif vae_name in ["taesd", "taesdxl", "taesd3", "taef1"]: sd = self.load_taesd(vae_name) else: vae_path = folder_paths.get_full_path_or_raise("vae", vae_name) @@ -2322,6 +2325,7 @@ async def init_builtin_extra_nodes(): "nodes_tcfg.py", "nodes_context_windows.py", "nodes_qwen.py", + "nodes_chroma_radiance.py", "nodes_model_patch.py", "nodes_easycache.py", "nodes_audio_encoder.py",