Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 1 addition & 26 deletions modules/ModularDiffusers/controlnet.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
import importlib
import logging

from diffusers.modular_pipelines import SequentialPipelineBlocks

from mellon.NodeBase import NodeBase

from . import components
from .utils import combine_multi_inputs
from .modular_utils import pipeline_class_to_mellon_node_config


logger = logging.getLogger("mellon")
Expand Down Expand Up @@ -227,30 +226,6 @@ def execute(
return {"controlnet": controlnet}


def pipeline_class_to_mellon_node_config(pipeline_class, node_type=None):
print(f" inside pipeline_class_to_mellon_node_config: {pipeline_class}")

try:
from diffusers.modular_pipelines.mellon_node_utils import ModularMellonNodeRegistry

registry = ModularMellonNodeRegistry()
node_type_config = registry.get(pipeline_class)[node_type]
except Exception as e:
logger.debug(f" Failed to load the node from {pipeline_class}: {e}")
return None, None

node_type_blocks = None
pipeline = pipeline_class()

if pipeline is not None and node_type_config is not None and node_type_config.blocks_names:
blocks_dict = {
name: block for name, block in pipeline.blocks.sub_blocks.items() if name in node_type_config.blocks_names
}
node_type_blocks = SequentialPipelineBlocks.from_blocks_dict(blocks_dict)

return node_type_blocks, node_type_config


class DynamicControlnet(NodeBase):
label = "Dynamic ControlNet"
category = "adapters"
Expand Down
169 changes: 79 additions & 90 deletions modules/ModularDiffusers/denoise.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import logging
from typing import Any, List, Tuple
import importlib
from .modular_utils import pipeline_class_to_mellon_node_config, get_model_type_signal_data

import torch
from diffusers import ComponentsManager
Expand Down Expand Up @@ -49,117 +51,103 @@ def insert_preview_block_recursive(blocks, blocks_name, preview_block):

insert_preview_block_recursive(pipeline.blocks.sub_blocks["denoise"], "", preview_block)

# SIGNAL_DATA = get_model_type_signal_data()

class Denoise(NodeBase):
label = "Denoise"
category = "sampler"
resizable = True
skipParamsCheck = True
node_type = "denoise"
params = {
"model_type": {
"label": "Model Type",
"type": "string",
"default": "",
"hidden": True # Hidden field to receive signal data
},
"unet": {
"label": "Denoise Model",
"display": "input",
"type": "diffusers_auto_model",
"onSignal": [
{"action": "signal", "target": "guider"},
{"action": "signal", "target": "controlnet"},
{
"StableDiffusionXLModularPipeline": [
"width",
"height",
"ip_adapter",
"controlnet",
"latents_preview",
],
"QwenImageModularPipeline": ["width", "height", "controlnet"],
"FluxModularPipeline": [
"width",
"height",
"controlnet",
"ip_adapter",
],
"": [],
},
{
"action": "value",
"target": "skip_image_size",
"target": "model_type",
# "data": SIGNAL_DATA, # YiYi Notes: not working
"data": {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not able to use a variable for data, here, I created a function to dynamically create this map but will run into an issue if I put the function here, we need to not have to hard-code this

get_model_type_signal_data()

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is one of the petitions I made about the signals, we should be able to do it soon.

"QwenImageEditModularPipeline": True,
"QwenImageEditPlusModularPipeline": True,
"StableDiffusionXLModularPipeline": "StableDiffusionXLModularPipeline",
"QwenImageModularPipeline": "QwenImageModularPipeline",
"QwenImageEditModularPipeline": "QwenImageEditModularPipeline",
"QwenImageEditPlusModularPipeline": "QwenImageEditPlusModularPipeline",
"FluxModularPipeline": "FluxModularPipeline",
"FluxKontextModularPipeline": "FluxKontextModularPipeline",
},
},
],
},
"skip_image_size": {
"label": "Skip Image Size",
"type": "boolean",
"default": False,
"value": False,
"hidden": True,
},
"scheduler": {"label": "Scheduler", "display": "input", "type": "diffusers_auto_model"},
"embeddings": {"label": "Text Embeddings", "display": "input", "type": "embeddings"},
"latents": {"label": "Latents", "type": "latents", "display": "output"},
"width": {"label": "Width", "type": "int", "default": 1024, "min": 64, "step": 8},
"height": {"label": "Height", "type": "int", "default": 1024, "min": 64, "step": 8},
"seed": {"label": "Seed", "type": "int", "display": "random", "default": 0, "min": 0, "max": 4294967295},
"num_inference_steps": {
"label": "Steps",
"type": "int",
"display": "slider",
"default": 25,
"min": 1,
"max": 100,
},
"guidance_scale": {
"label": "Guidance Scale",
"type": "float",
"display": "slider",
"default": 5,
"min": 1.0,
"max": 30.0,
"step": 0.1,
},
"latents_preview": {"label": "Latents Preview", "display": "output", "type": "latent"},
"guider": {
"label": "Guider",
"display": "input",
"type": "custom_guider",
"onChange": {False: ["guidance_scale"], True: []},
},
"image_latents": {
"label": "Image Latents",
"type": "latents",
"display": "input",
"onChange": {False: ["height", "width"], True: ["strength"]},
},
"strength": {
"label": "Strength",
"type": "float",
"default": 0.5,
"min": 0.0,
"max": 1.0,
"step": 0.01,
},
"controlnet": {
"label": "Controlnet",
"type": "custom_controlnet",
"display": "input",
},
"ip_adapter": {
"label": "IP Adapter",
"type": "custom_ip_adapter",
"display": "input",
},
"doc": {
"label": "Doc",
"display": "output",
"type": "string",
{"action": "exec", "data": "update_node"},
{"action": "signal", "target": "guider"},
{"action": "signal", "target": "controlnet"},
]
},
}

def update_node(self, values, ref):

node_params = {
"model_type": {
"label": "Model Type",
"type": "string",
"default": "",
"hidden": True # Hidden field to receive signal data
},
"unet": {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Denoise node must have a unet input to receive signals, the rest should all be dynamic

"label": "Denoise Model",
"display": "input",
"type": "diffusers_auto_model",
"onSignal": [
{
"action": "value",
"target": "model_type",
# "data": SIGNAL_DATA, # YiYi Notes: not working
"data": {
"StableDiffusionXLModularPipeline": "StableDiffusionXLModularPipeline",
"QwenImageModularPipeline": "QwenImageModularPipeline",
"QwenImageEditModularPipeline": "QwenImageEditModularPipeline",
"QwenImageEditPlusModularPipeline": "QwenImageEditPlusModularPipeline",
"FluxModularPipeline": "FluxModularPipeline",
"FluxKontextModularPipeline": "FluxKontextModularPipeline",
},
},
{"action": "exec", "data": "update_node"},
{"action": "signal", "target": "guider"},
{"action": "signal", "target": "controlnet"},
]
},
}
model_type = values.get("model_type", "")

if model_type == "" or self._model_type == model_type:
return None

self._model_type = model_type

diffusers_module = importlib.import_module("diffusers")
self._pipeline_class = getattr(diffusers_module, model_type)

_, denoise_mellon_config = pipeline_class_to_mellon_node_config(self._pipeline_class, self.node_type)
# not support this node type
if denoise_mellon_config is None:
self.send_node_definition(node_params)
return

# required params for controlnet
node_params.update(**denoise_mellon_config.to_mellon_dict()["params"])
self.send_node_definition(node_params)

def __init__(self, node_id=None):
super().__init__(node_id)
self._denoise_node = None
self._model_type = ""
self._pipeline_class = None

def execute(
self,
Expand All @@ -177,6 +165,7 @@ def execute(
width=None,
height=None,
skip_image_size=False,
**kwargs,
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should refactor the execute method too, to take kwarggs, so that we will be parameters dynamically defined in param, similar to how we implemented DynamicCOntrolnet, but that will be next step and i will help

):
logger.debug(f" Denoise ({self.node_id}) received parameters:")
logger.debug(f" - unet: {unet}")
Expand Down
103 changes: 28 additions & 75 deletions modules/ModularDiffusers/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,68 +9,13 @@
from utils.torch_utils import DEFAULT_DEVICE, DEVICE_LIST, str_to_dtype

from . import components
from .modular_utils import get_all_model_types, get_model_type_metadata


logger = logging.getLogger("mellon")
logger.setLevel(logging.DEBUG)


@dataclass(frozen=True)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I moved ModularMellonNodeRegistry from diffusers to mellon and so this registry gets absorbed into that https://github.com/huggingface/diffusers/blob/main/src/diffusers/modular_pipelines/mellon_node_utils.py#L703

class PipelineRegistration:
label: str
filters: tuple[str, ...]
default_repo: str


PIPELINE_REGISTRY: dict[str, PipelineRegistration] = {
"": PipelineRegistration(label="", filters=(), default_repo=""),
"StableDiffusionXLModularPipeline": PipelineRegistration(
label="Stable Diffusion XL",
filters=("StableDiffusionXLModularPipeline", "StableDiffusionXLPipeline"),
default_repo="stabilityai/stable-diffusion-xl-base-1.0",
),
"QwenImageModularPipeline": PipelineRegistration(
label="Qwen Image",
filters=("QwenImageModularPipeline", "QwenImagePipeline"),
default_repo="Qwen/Qwen-Image",
),
"QwenImageEditModularPipeline": PipelineRegistration(
label="Qwen Image Edit",
filters=("QwenImageEditModularPipeline", "QwenImageEditPipeline"),
default_repo="Qwen/Qwen-Image-Edit",
),
"QwenImageEditPlusModularPipeline": PipelineRegistration(
label="Qwen Image Edit Plus",
filters=("QwenImageEditPlusModularPipeline", "QwenImageEditPlusPipeline"),
default_repo="Qwen/Qwen-Image-Edit-2509",
),
"FluxModularPipeline": PipelineRegistration(
label="Flux",
filters=("FluxModularPipeline", "FluxPipeline"),
default_repo="black-forest-labs/FLUX.1-dev",
),
"FluxKontextModularPipeline": PipelineRegistration(
label="Flux Kontext",
filters=("FluxKontextModularPipeline", "FluxKontextPipeline"),
default_repo="black-forest-labs/FLUX.1-Kontext-dev",
),
}


def get_pipeline_registration(model_type: str) -> PipelineRegistration:
return PIPELINE_REGISTRY.get(model_type, PIPELINE_REGISTRY[""])


# maybe in the future we can maintain the supported models in modular diffusers so we don't have to
# modify this file every time a new pipeline is added
def register_pipeline(
model_type: str, *, label: str = "", filters: tuple[str, ...] = (), default_repo: str = ""
) -> None:
PIPELINE_REGISTRY[model_type] = PipelineRegistration(
label=label or model_type, filters=tuple(filters), default_repo=default_repo
)


def node_get_component_info(node_id=None, manager=None, name=None):
comp_ids = manager._lookup_ids(name=name, collection=node_id)
if len(comp_ids) != 1:
Expand Down Expand Up @@ -226,13 +171,6 @@ def execute(self, model_type, model_id, dtype, variant=None, subfolder=None):
return {"model": components.get_model_info(comp_id)}


def model_type_options() -> dict[str, str]:
options = {"": PIPELINE_REGISTRY[""].label}
for k, reg in PIPELINE_REGISTRY.items():
if k == "":
continue
options[k] = reg.label or k
return options


class ModelsLoader(NodeBase):
Expand Down Expand Up @@ -298,24 +236,39 @@ def __del__(self):
def set_filters(self, values, ref):
# first time dynamically load the model_type options
if not self.model_types_loaded:
self.set_field_params("model_type", {"options": model_type_options()})
self.pipelines_loaded = True
self.set_field_params("model_type", {"options": get_all_model_types()})
self.model_types_loaded = True

model_type = values.get("model_type", "")
reg = get_pipeline_registration(model_type)
metadata = get_model_type_metadata(model_type)

if metadata:
default_repo = metadata["default_repo"]
default_dtype = metadata["default_dtype"]
else:
# Fallback for empty or unknown model types
default_repo = ""
default_dtype = "float16"
filters = [model_type] # YiYi Notes: 1:1 between model_type <-> modular pipeline class

self.set_field_params(
"repo_id",
{
"default": {"source": "hub", "value": reg.default_repo},
"value": {"source": "hub", "value": reg.default_repo},
"fieldOptions": {
"filter": {
"hub": {"className": list(reg.filters)},
"repo_id",
{
"default": {"source": "hub", "value": default_repo},
"value": {"source": "hub", "value": default_repo},
"fieldOptions": {
"filter": {
"hub": {"className": filters},
},
},
},
},
)
)
self.set_field_params(
"dtype",
{
"value": default_dtype,
},
)

def execute(self, model_type, repo_id, device, dtype, unet=None, vae=None, lora_list=None):
logger.debug(f"""
Expand Down
Loading