Skip to content

Commit

Permalink
Support ProMax union model (#2998)
Browse files Browse the repository at this point in the history
* Support ProMax union model

* nit
  • Loading branch information
huchenlei authored Jul 15, 2024
1 parent 19ec5ea commit 3ff69b9
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 4 deletions.
6 changes: 3 additions & 3 deletions scripts/cldm.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def __init__(
use_linear_in_transformer=False,
adm_in_channels=None,
transformer_depth_middle=None,
union_controlnet=False,
union_controlnet_num_control_type=None,
device=None,
global_average_pooling=False,
):
Expand Down Expand Up @@ -282,8 +282,8 @@ def __init__(
self.middle_block_out = self.make_zero_conv(ch)
self._feature_size += ch

if union_controlnet:
self.num_control_type = 6
if union_controlnet_num_control_type is not None:
self.num_control_type = union_controlnet_num_control_type
num_trans_channel = 320
num_trans_head = 8
num_trans_layer = 1
Expand Down
2 changes: 1 addition & 1 deletion scripts/controlnet_model_guess.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ def build_model_by_guess(state_dict, unet, model_path: str) -> ControlModel:
state_dict = final_state_dict

if "control_add_embedding.linear_1.bias" in state_dict: # Controlnet Union
config["union_controlnet"] = True
config["union_controlnet_num_control_type"] = state_dict["task_embedding"].shape[0]
final_state_dict = {}
for k in list(state_dict.keys()):
new_k = k.replace('.attn.in_proj_', '.attn.in_proj.')
Expand Down
8 changes: 8 additions & 0 deletions scripts/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,8 @@ class ControlNetUnionControlType(Enum):
HARD_EDGE = "Hard Edge"
NORMAL_MAP = "Normal Map"
SEGMENTATION = "Segmentation"
TILE = "Tile"
INPAINT = "Inpaint"

UNKNOWN = "Unknown"

Expand All @@ -308,6 +310,8 @@ def all_tags() -> List[str]:
"mlsd",
"normalmap",
"segmentation",
"inpaint",
"tile",
]

@staticmethod
Expand All @@ -326,6 +330,10 @@ def from_str(s: str) -> ControlNetUnionControlType:
return ControlNetUnionControlType.NORMAL_MAP
elif s == "segmentation":
return ControlNetUnionControlType.SEGMENTATION
elif s in ["tile", "blur"]:
return ControlNetUnionControlType.TILE
elif s == "inpaint":
return ControlNetUnionControlType.INPAINT

return ControlNetUnionControlType.UNKNOWN

Expand Down

0 comments on commit 3ff69b9

Please sign in to comment.