diff --git a/modules/ModularDiffusers/controlnet.py b/modules/ModularDiffusers/controlnet.py index bb8737b..fb7ec13 100644 --- a/modules/ModularDiffusers/controlnet.py +++ b/modules/ModularDiffusers/controlnet.py @@ -1,12 +1,11 @@ import importlib import logging -from diffusers.modular_pipelines import SequentialPipelineBlocks - from mellon.NodeBase import NodeBase from . import components from .utils import combine_multi_inputs +from .modular_utils import pipeline_class_to_mellon_node_config logger = logging.getLogger("mellon") @@ -227,30 +226,6 @@ def execute( return {"controlnet": controlnet} -def pipeline_class_to_mellon_node_config(pipeline_class, node_type=None): - print(f" inside pipeline_class_to_mellon_node_config: {pipeline_class}") - - try: - from diffusers.modular_pipelines.mellon_node_utils import ModularMellonNodeRegistry - - registry = ModularMellonNodeRegistry() - node_type_config = registry.get(pipeline_class)[node_type] - except Exception as e: - logger.debug(f" Failed to load the node from {pipeline_class}: {e}") - return None, None - - node_type_blocks = None - pipeline = pipeline_class() - - if pipeline is not None and node_type_config is not None and node_type_config.blocks_names: - blocks_dict = { - name: block for name, block in pipeline.blocks.sub_blocks.items() if name in node_type_config.blocks_names - } - node_type_blocks = SequentialPipelineBlocks.from_blocks_dict(blocks_dict) - - return node_type_blocks, node_type_config - - class DynamicControlnet(NodeBase): label = "Dynamic ControlNet" category = "adapters" diff --git a/modules/ModularDiffusers/denoise.py b/modules/ModularDiffusers/denoise.py index f9859b4..ddf3f57 100644 --- a/modules/ModularDiffusers/denoise.py +++ b/modules/ModularDiffusers/denoise.py @@ -1,5 +1,7 @@ import logging from typing import Any, List, Tuple +import importlib +from .modular_utils import pipeline_class_to_mellon_node_config, get_model_type_signal_data import torch from diffusers import ComponentsManager @@ -49,117 +51,103 @@ def insert_preview_block_recursive(blocks, blocks_name, preview_block): insert_preview_block_recursive(pipeline.blocks.sub_blocks["denoise"], "", preview_block) +# SIGNAL_DATA = get_model_type_signal_data() class Denoise(NodeBase): label = "Denoise" category = "sampler" resizable = True + skipParamsCheck = True + node_type = "denoise" params = { + "model_type": { + "label": "Model Type", + "type": "string", + "default": "", + "hidden": True # Hidden field to receive signal data + }, "unet": { "label": "Denoise Model", "display": "input", "type": "diffusers_auto_model", "onSignal": [ - {"action": "signal", "target": "guider"}, - {"action": "signal", "target": "controlnet"}, - { - "StableDiffusionXLModularPipeline": [ - "width", - "height", - "ip_adapter", - "controlnet", - "latents_preview", - ], - "QwenImageModularPipeline": ["width", "height", "controlnet"], - "FluxModularPipeline": [ - "width", - "height", - "controlnet", - "ip_adapter", - ], - "": [], - }, { "action": "value", - "target": "skip_image_size", + "target": "model_type", + # "data": SIGNAL_DATA, # YiYi Notes: not working "data": { - "QwenImageEditModularPipeline": True, - "QwenImageEditPlusModularPipeline": True, + "StableDiffusionXLModularPipeline": "StableDiffusionXLModularPipeline", + "QwenImageModularPipeline": "QwenImageModularPipeline", + "QwenImageEditModularPipeline": "QwenImageEditModularPipeline", + "QwenImageEditPlusModularPipeline": "QwenImageEditPlusModularPipeline", + "FluxModularPipeline": "FluxModularPipeline", + "FluxKontextModularPipeline": "FluxKontextModularPipeline", }, }, - ], - }, - "skip_image_size": { - "label": "Skip Image Size", - "type": "boolean", - "default": False, - "value": False, - "hidden": True, - }, - "scheduler": {"label": "Scheduler", "display": "input", "type": "diffusers_auto_model"}, - "embeddings": {"label": "Text Embeddings", "display": "input", "type": "embeddings"}, - "latents": {"label": "Latents", "type": "latents", "display": "output"}, - "width": {"label": "Width", "type": "int", "default": 1024, "min": 64, "step": 8}, - "height": {"label": "Height", "type": "int", "default": 1024, "min": 64, "step": 8}, - "seed": {"label": "Seed", "type": "int", "display": "random", "default": 0, "min": 0, "max": 4294967295}, - "num_inference_steps": { - "label": "Steps", - "type": "int", - "display": "slider", - "default": 25, - "min": 1, - "max": 100, - }, - "guidance_scale": { - "label": "Guidance Scale", - "type": "float", - "display": "slider", - "default": 5, - "min": 1.0, - "max": 30.0, - "step": 0.1, - }, - "latents_preview": {"label": "Latents Preview", "display": "output", "type": "latent"}, - "guider": { - "label": "Guider", - "display": "input", - "type": "custom_guider", - "onChange": {False: ["guidance_scale"], True: []}, - }, - "image_latents": { - "label": "Image Latents", - "type": "latents", - "display": "input", - "onChange": {False: ["height", "width"], True: ["strength"]}, - }, - "strength": { - "label": "Strength", - "type": "float", - "default": 0.5, - "min": 0.0, - "max": 1.0, - "step": 0.01, - }, - "controlnet": { - "label": "Controlnet", - "type": "custom_controlnet", - "display": "input", - }, - "ip_adapter": { - "label": "IP Adapter", - "type": "custom_ip_adapter", - "display": "input", - }, - "doc": { - "label": "Doc", - "display": "output", - "type": "string", + {"action": "exec", "data": "update_node"}, + {"action": "signal", "target": "guider"}, + {"action": "signal", "target": "controlnet"}, + ] }, } + def update_node(self, values, ref): + + node_params = { + "model_type": { + "label": "Model Type", + "type": "string", + "default": "", + "hidden": True # Hidden field to receive signal data + }, + "unet": { + "label": "Denoise Model", + "display": "input", + "type": "diffusers_auto_model", + "onSignal": [ + { + "action": "value", + "target": "model_type", + # "data": SIGNAL_DATA, # YiYi Notes: not working + "data": { + "StableDiffusionXLModularPipeline": "StableDiffusionXLModularPipeline", + "QwenImageModularPipeline": "QwenImageModularPipeline", + "QwenImageEditModularPipeline": "QwenImageEditModularPipeline", + "QwenImageEditPlusModularPipeline": "QwenImageEditPlusModularPipeline", + "FluxModularPipeline": "FluxModularPipeline", + "FluxKontextModularPipeline": "FluxKontextModularPipeline", + }, + }, + {"action": "exec", "data": "update_node"}, + {"action": "signal", "target": "guider"}, + {"action": "signal", "target": "controlnet"}, + ] + }, + } + model_type = values.get("model_type", "") + + if model_type == "" or self._model_type == model_type: + return None + + self._model_type = model_type + + diffusers_module = importlib.import_module("diffusers") + self._pipeline_class = getattr(diffusers_module, model_type) + + _, denoise_mellon_config = pipeline_class_to_mellon_node_config(self._pipeline_class, self.node_type) + # not support this node type + if denoise_mellon_config is None: + self.send_node_definition(node_params) + return + + # required params for controlnet + node_params.update(**denoise_mellon_config.to_mellon_dict()["params"]) + self.send_node_definition(node_params) + def __init__(self, node_id=None): super().__init__(node_id) - self._denoise_node = None + self._model_type = "" + self._pipeline_class = None def execute( self, @@ -177,6 +165,7 @@ def execute( width=None, height=None, skip_image_size=False, + **kwargs, ): logger.debug(f" Denoise ({self.node_id}) received parameters:") logger.debug(f" - unet: {unet}") diff --git a/modules/ModularDiffusers/loaders.py b/modules/ModularDiffusers/loaders.py index 84a0e4d..487c483 100644 --- a/modules/ModularDiffusers/loaders.py +++ b/modules/ModularDiffusers/loaders.py @@ -9,68 +9,13 @@ from utils.torch_utils import DEFAULT_DEVICE, DEVICE_LIST, str_to_dtype from . import components +from .modular_utils import get_all_model_types, get_model_type_metadata logger = logging.getLogger("mellon") logger.setLevel(logging.DEBUG) -@dataclass(frozen=True) -class PipelineRegistration: - label: str - filters: tuple[str, ...] - default_repo: str - - -PIPELINE_REGISTRY: dict[str, PipelineRegistration] = { - "": PipelineRegistration(label="", filters=(), default_repo=""), - "StableDiffusionXLModularPipeline": PipelineRegistration( - label="Stable Diffusion XL", - filters=("StableDiffusionXLModularPipeline", "StableDiffusionXLPipeline"), - default_repo="stabilityai/stable-diffusion-xl-base-1.0", - ), - "QwenImageModularPipeline": PipelineRegistration( - label="Qwen Image", - filters=("QwenImageModularPipeline", "QwenImagePipeline"), - default_repo="Qwen/Qwen-Image", - ), - "QwenImageEditModularPipeline": PipelineRegistration( - label="Qwen Image Edit", - filters=("QwenImageEditModularPipeline", "QwenImageEditPipeline"), - default_repo="Qwen/Qwen-Image-Edit", - ), - "QwenImageEditPlusModularPipeline": PipelineRegistration( - label="Qwen Image Edit Plus", - filters=("QwenImageEditPlusModularPipeline", "QwenImageEditPlusPipeline"), - default_repo="Qwen/Qwen-Image-Edit-2509", - ), - "FluxModularPipeline": PipelineRegistration( - label="Flux", - filters=("FluxModularPipeline", "FluxPipeline"), - default_repo="black-forest-labs/FLUX.1-dev", - ), - "FluxKontextModularPipeline": PipelineRegistration( - label="Flux Kontext", - filters=("FluxKontextModularPipeline", "FluxKontextPipeline"), - default_repo="black-forest-labs/FLUX.1-Kontext-dev", - ), -} - - -def get_pipeline_registration(model_type: str) -> PipelineRegistration: - return PIPELINE_REGISTRY.get(model_type, PIPELINE_REGISTRY[""]) - - -# maybe in the future we can maintain the supported models in modular diffusers so we don't have to -# modify this file every time a new pipeline is added -def register_pipeline( - model_type: str, *, label: str = "", filters: tuple[str, ...] = (), default_repo: str = "" -) -> None: - PIPELINE_REGISTRY[model_type] = PipelineRegistration( - label=label or model_type, filters=tuple(filters), default_repo=default_repo - ) - - def node_get_component_info(node_id=None, manager=None, name=None): comp_ids = manager._lookup_ids(name=name, collection=node_id) if len(comp_ids) != 1: @@ -226,13 +171,6 @@ def execute(self, model_type, model_id, dtype, variant=None, subfolder=None): return {"model": components.get_model_info(comp_id)} -def model_type_options() -> dict[str, str]: - options = {"": PIPELINE_REGISTRY[""].label} - for k, reg in PIPELINE_REGISTRY.items(): - if k == "": - continue - options[k] = reg.label or k - return options class ModelsLoader(NodeBase): @@ -298,24 +236,39 @@ def __del__(self): def set_filters(self, values, ref): # first time dynamically load the model_type options if not self.model_types_loaded: - self.set_field_params("model_type", {"options": model_type_options()}) - self.pipelines_loaded = True + self.set_field_params("model_type", {"options": get_all_model_types()}) + self.model_types_loaded = True model_type = values.get("model_type", "") - reg = get_pipeline_registration(model_type) + metadata = get_model_type_metadata(model_type) + + if metadata: + default_repo = metadata["default_repo"] + default_dtype = metadata["default_dtype"] + else: + # Fallback for empty or unknown model types + default_repo = "" + default_dtype = "float16" + filters = [model_type] # YiYi Notes: 1:1 between model_type <-> modular pipeline class self.set_field_params( - "repo_id", - { - "default": {"source": "hub", "value": reg.default_repo}, - "value": {"source": "hub", "value": reg.default_repo}, - "fieldOptions": { - "filter": { - "hub": {"className": list(reg.filters)}, + "repo_id", + { + "default": {"source": "hub", "value": default_repo}, + "value": {"source": "hub", "value": default_repo}, + "fieldOptions": { + "filter": { + "hub": {"className": filters}, + }, }, }, - }, - ) + ) + self.set_field_params( + "dtype", + { + "value": default_dtype, + }, + ) def execute(self, model_type, repo_id, device, dtype, unet=None, vae=None, lora_list=None): logger.debug(f""" diff --git a/modules/ModularDiffusers/modular_utils.py b/modules/ModularDiffusers/modular_utils.py new file mode 100644 index 0000000..bc622a2 --- /dev/null +++ b/modules/ModularDiffusers/modular_utils.py @@ -0,0 +1,680 @@ +from typing import Dict, Any +from diffusers.modular_pipelines.mellon_node_utils import MellonNodeConfig, MellonParam +import logging + +from diffusers.modular_pipelines import SequentialPipelineBlocks + + +logger = logging.getLogger("mellon") + +# mellon nodes +SDXL_NODE_TYPES_PARAMS_MAP = { + "controlnet": { + "inputs": [ + "control_image", + "controlnet_conditioning_scale", + "control_guidance_start", + "control_guidance_end", + "height", + "width", + ], + "model_inputs": [ + "controlnet", + ], + "outputs": [ + "controlnet_out", + ], + "block_names": [None], + }, + "denoise": { + "inputs": [ + "embeddings", + "width", + "height", + "seed", + "num_inference_steps", + "guidance_scale", + "image_latents", + "strength", + # custom adapters coming in as inputs + "controlnet", + # ip_adapter is optional and custom; include if available + "ip_adapter", + ], + "model_inputs": [ + "unet", + "guider", + "scheduler", + ], + "outputs": [ + "latents", + "latents_preview", + "doc" + ], + "block_names": ["denoise"], + }, + "vae_encoder": { + "inputs": [ + "image", + "width", + "height", + ], + "model_inputs": [ + "vae", + ], + "outputs": [ + "image_latents", + ], + "block_names": ["vae_encoder"], + }, + "text_encoder": { + "inputs": [ + "prompt", + "negative_prompt", + ], + "model_inputs": [ + "text_encoders", + ], + "outputs": [ + "embeddings", + ], + "block_names": ["text_encoder"], + }, + "decoder": { + "inputs": [ + "latents", + ], + "model_inputs": [ + "vae", + ], + "outputs": [ + "images", + ], + "block_names": ["decode"], + }, +} + +QwenImage_NODE_TYPES_PARAMS_MAP = { + "controlnet": { + "inputs": [ + "control_image", + "controlnet_conditioning_scale", + "control_guidance_start", + "control_guidance_end", + "height", + "width", + ], + "model_inputs": [ + "controlnet", + "vae", + ], + "outputs": [ + "controlnet_out", + ], + "block_names": ["controlnet_vae_encoder"], + }, + "denoise": { + "inputs": [ + "embeddings", + "width", + "height", + "seed", + "num_inference_steps", + "guidance_scale", + "image_latents", + "strength", + "controlnet", + ], + "model_inputs": [ + "unet", + "guider", + "scheduler", + ], + "outputs": [ + "latents", + "doc", + ], + "block_names": ["denoise"], + }, + "vae_encoder": { + "inputs": [ + "image", + "width", + "height", + ], + "model_inputs": [ + "vae", + ], + "outputs": [ + "image_latents", + ], + }, + "text_encoder": { + "inputs": [ + "prompt", + "negative_prompt", + ], + "model_inputs": [ + "text_encoders", + ], + "outputs": [ + "embeddings", + ], + }, + "decoder": { + "inputs": [ + "latents", + ], + "model_inputs": [ + "vae", + ], + "outputs": [ + "images", + ], + }, +} + + +QwenImageEdit_NODE_TYPES_PARAMS_MAP = { + "controlnet": None, + "denoise": { + "inputs": [ + "embeddings", + "seed", + "num_inference_steps", + "guidance_scale", + MellonParam(name="image_latents", label="Image Latents", type="latents", display="input"), + ], + "model_inputs": [ + "unet", + "guider", + "scheduler", + ], + "outputs": [ + "latents", + "doc", + ], + "block_names": ["denoise"], + }, + "vae_encoder": { + "inputs": [ + "image", + "width", + "height", + ], + "model_inputs": [ + "vae", + ], + "outputs": [ + "image_latents", + ], + }, + "text_encoder": { + "inputs": [ + "prompt", + "negative_prompt", + "image", + ], + "model_inputs": [ + "text_encoders", + ], + "outputs": [ + "embeddings", + ], + }, + "decoder": { + "inputs": [ + "latents", + ], + "model_inputs": [ + "vae", + ], + "outputs": [ + "images", + ], + }, +} + + +QwenImageEditPlus_NODE_TYPES_PARAMS_MAP = { + "controlnet": None, + "denoise": { + "inputs": [ + "embeddings", + "seed", + "num_inference_steps", + "guidance_scale", + MellonParam(name="image_latents", label="Image Latents", type="latents", display="input"), + ], + "model_inputs": [ + "unet", + "guider", + "scheduler", + ], + "outputs": [ + "latents", + "doc", + ], + "block_names": ["denoise"], + }, + "vae_encoder": { + "inputs": [ + "image", + "width", + "height", + ], + "model_inputs": [ + "vae", + ], + "outputs": [ + "image_latents", + ], + }, + "text_encoder": { + "inputs": [ + "prompt", + "negative_prompt", + "image", + ], + "model_inputs": [ + "text_encoders", + ], + "outputs": [ + "embeddings", + ], + }, + "decoder": { + "inputs": [ + "latents", + ], + "model_inputs": [ + "vae", + ], + "outputs": [ + "images", + ], + }, +} + + +Flux_NODE_TYPES_PARAMS_MAP = { + "controlnet": { + "inputs": [ + "control_image", + "controlnet_conditioning_scale", + "control_guidance_start", + "control_guidance_end", + "height", + "width", + ], + "model_inputs": [ + "controlnet", + ], + "outputs": [ + "controlnet_out", + ], + "block_names": [None], + }, + "denoise": { + "inputs": [ + "embeddings", + "width", + "height", + "seed", + "num_inference_steps", + "guidance_scale", + "image_latents", + "strength", + "controlnet", + ], + "model_inputs": [ + "unet", + "guider", + "scheduler", + ], + "outputs": [ + "latents", + "doc", + ], + "block_names": ["denoise"], + }, + "vae_encoder": { + "inputs": [ + "image", + "width", + "height", + ], + "model_inputs": [ + "vae", + ], + "outputs": [ + "image_latents", + ], + }, + "text_encoder": { + "inputs": [ + "prompt", + "negative_prompt", + ], + "model_inputs": [ + "text_encoders", + ], + "outputs": [ + "embeddings", + ], + }, + "decoder": { + "inputs": [ + "latents", + ], + "model_inputs": [ + "vae", + ], + "outputs": [ + "images", + ], + }, +} + + +FluxKontext_NODE_TYPES_PARAMS_MAP = { + "controlnet": None, + "denoise": { + "inputs": [ + "embeddings", + "width", + "height", + "seed", + "num_inference_steps", + "guidance_scale", + MellonParam(name="image_latents", label="Image Latents", type="latents", display="input"), + ], + "model_inputs": [ + "unet", + "guider", + "scheduler", + ], + "outputs": [ + "latents", + "doc", + ], + "block_names": ["denoise"], + }, + "vae_encoder": { + "inputs": [ + "image", + "width", + "height", + ], + "model_inputs": [ + "vae", + ], + "outputs": [ + "image_latents", + ], + }, + "text_encoder": { + "inputs": [ + "prompt", + "negative_prompt", + ], + "model_inputs": [ + "text_encoders", + ], + "outputs": [ + "embeddings", + ], + }, + "decoder": { + "inputs": [ + "latents", + ], + "model_inputs": [ + "vae", + ], + "outputs": [ + "images", + ], + }, +} + +# Minimal modular registry for Mellon node configs +class ModularMellonNodeRegistry: + """Registry mapping pipeline class to its node configs and metadata.""" + + def __init__(self): + self._registry = {} + self._initialized = False + + def register( + self, + pipeline_cls: type, + node_params: Dict[str, MellonNodeConfig], + label: str = "", + default_repo: str = "", + default_dtype: str = "" + ): + if not self._initialized: + _initialize_registry(self) + + model_type = pipeline_cls.__name__ + + _meta_data = { + "node_params": node_params, + "label": label, + "default_repo": default_repo, + "model_type": model_type, + "default_dtype": default_dtype, + } + + self._registry[pipeline_cls] = _meta_data + + def get(self, pipeline_cls: type) -> MellonNodeConfig: + if not self._initialized: + _initialize_registry(self) + return self._registry.get(pipeline_cls, None) + + def get_all(self) -> Dict[type, Dict[str, MellonNodeConfig]]: + if not self._initialized: + _initialize_registry(self) + return self._registry + + +def _register_pipeline( + pipeline_cls, + registry: ModularMellonNodeRegistry, + params_map: Dict[str, Dict[str, Any]], + label: str = "", + default_repo: str = "", + default_dtype: str = "", +): + """Register all node-type presets for a given pipeline class from a params map.""" + node_configs = {} + for node_type, spec in params_map.items(): + if spec is None: + node_config = None + else: + node_config = MellonNodeConfig( + inputs=spec.get("inputs", []), + model_inputs=spec.get("model_inputs", []), + outputs=spec.get("outputs", []), + blocks_names=spec.get("block_names", []), + node_type=node_type, + ) + node_configs[node_type] = node_config + registry.register( + pipeline_cls, + node_configs, + label=label, + default_repo=default_repo, + default_dtype=default_dtype, + ) + + +def _initialize_registry(registry: ModularMellonNodeRegistry): + """Initialize the registry and register all available pipeline configs.""" + print("Initializing registry") + + registry._initialized = True + + try: + from diffusers import QwenImageModularPipeline + _register_pipeline( + QwenImageModularPipeline, + registry, + QwenImage_NODE_TYPES_PARAMS_MAP, + label="Qwen Image", + default_repo="Qwen/Qwen-Image", + default_dtype="bfloat16", + ) + except Exception as e: + raise Exception(f"Failed to register QwenImageModularPipeline :{e}") + + try: + from diffusers import QwenImageEditModularPipeline + _register_pipeline( + QwenImageEditModularPipeline, + registry, + QwenImageEdit_NODE_TYPES_PARAMS_MAP, + label="Qwen Image Edit", + default_repo="Qwen/Qwen-Image-Edit", + default_dtype="bfloat16", + ) + except Exception as e: + raise Exception(f"Failed to register QwenImageEditModularPipeline :{e}") + + try: + from diffusers import QwenImageEditPlusModularPipeline + _register_pipeline( + QwenImageEditPlusModularPipeline, + registry, + QwenImageEditPlus_NODE_TYPES_PARAMS_MAP, + label="Qwen Image Edit Plus", + default_repo="Qwen/Qwen-Image-Edit-2509", + default_dtype="bfloat16", + ) + except Exception as e: + raise Exception(f"Failed to register QwenImageEditPlusModularPipeline :{e}") + + try: + from diffusers import FluxModularPipeline + _register_pipeline( + FluxModularPipeline, + registry, + Flux_NODE_TYPES_PARAMS_MAP, + label="Flux", + default_repo="black-forest-labs/FLUX.1-dev", + default_dtype="bfloat16", + ) + except Exception as e: + raise Exception(f"Failed to register FluxModularPipeline :{e}") + + try: + from diffusers import FluxKontextModularPipeline + _register_pipeline( + FluxKontextModularPipeline, + registry, + FluxKontext_NODE_TYPES_PARAMS_MAP, + label="Flux Kontext", + default_repo="black-forest-labs/FLUX.1-Kontext-dev", + default_dtype="bfloat16", + ) + except Exception as e: + raise Exception(f"Failed to register FluxKontextModularPipeline :{e}") + + try: + from diffusers import StableDiffusionXLModularPipeline + _register_pipeline( + StableDiffusionXLModularPipeline, + registry, + SDXL_NODE_TYPES_PARAMS_MAP, + label="Stable Diffusion XL", + default_repo="stabilityai/stable-diffusion-xl-base-1.0", + default_dtype="float16", + ) + except Exception as e: + raise Exception(f"Failed to register StableDiffusionXLModularPipeline :{e}") + + +# Global singleton registry instance +MODULAR_REGISTRY = ModularMellonNodeRegistry() + + +def get_all_model_types() -> Dict[str, str]: + + """Get all registered model types with their labels for UI dropdowns. + + Returns: + Dict mapping model type names (keys) to human-readable labels (values). + + Example output: + { + "": "", + "StableDiffusionXLModularPipeline": "Stable Diffusion XL", + "QwenImageModularPipeline": "Qwen Image", + "QwenImageEditModularPipeline": "Qwen Image Edit", + "FluxModularPipeline": "Flux", + "FluxKontextModularPipeline": "Flux Kontext", + } + """ + + registry = MODULAR_REGISTRY.get_all() + all_labels = {} + for _, meta_data in registry.items(): + all_labels[meta_data["model_type"]] = meta_data["label"] + all_labels[""] = "" + return all_labels + + +def get_model_type_signal_data() -> Dict[str, str]: + """Get model type mapping for onSignal value actions. + + Returns a dict mapping model type names to themselves, used in onSignal + to pass model type through from upstream nodes. + + Example: + { + "StableDiffusionXLModularPipeline": "StableDiffusionXLModularPipeline", + "QwenImageModularPipeline": "QwenImageModularPipeline", + "": "", + } + """ + registry = MODULAR_REGISTRY.get_all() + # Get all registered model types and map them to themselves + model_types = {} + for _, meta_data in registry.items(): + model_type = meta_data["model_type"] + model_types[model_type] = model_type + + # Add empty default + model_types[""] = "" + return model_types + + +def get_model_type_metadata(model_type: str) -> Dict[str, Any]: + registry = MODULAR_REGISTRY.get_all() + + for _, meta_data in registry.items(): + if meta_data["model_type"] == model_type: + return meta_data + return None + + +def pipeline_class_to_mellon_node_config(pipeline_class, node_type=None): + + try: + node_type_config = MODULAR_REGISTRY.get(pipeline_class)["node_params"][node_type] + except Exception as e: + logger.debug(f" Failed to load the node from {pipeline_class}: {e}") + return None, None + + node_type_blocks = None + pipeline = pipeline_class() + + if node_type_config is not None and node_type_config.blocks_names: + blocks_dict = { + name: block for name, block in pipeline.blocks.sub_blocks.items() if name in node_type_config.blocks_names + } + node_type_blocks = SequentialPipelineBlocks.from_blocks_dict(blocks_dict) + + return node_type_blocks, node_type_config \ No newline at end of file diff --git a/modules/__init__.py b/modules/__init__.py index 8ae039e..a69e553 100644 --- a/modules/__init__.py +++ b/modules/__init__.py @@ -216,6 +216,6 @@ def parse_module_map(base_path: str) -> None: logger.warning(f"Module '{module_name}' could not be parsed or has no NodeBase classes.") parse_module_map("modules") -parse_module_map("custom") +# parse_module_map("custom") logger.info(f"Loaded {total_nodes} nodes from {len(MODULE_MAP)} modules.")