From 464a3ea08378dcd42789e958c8085c2c72bd4415 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Wed, 29 Jan 2025 20:19:29 -0500 Subject: [PATCH] Support loading diffusers models without transformer and use weight_map index file --- library/flux_utils.py | 36 ++++++++++++++++++++++++++++++++---- 1 file changed, 32 insertions(+), 4 deletions(-) diff --git a/library/flux_utils.py b/library/flux_utils.py index 8be1d63ee..609a18c56 100644 --- a/library/flux_utils.py +++ b/library/flux_utils.py @@ -25,6 +25,21 @@ MODEL_NAME_SCHNELL = "schnell" +def get_shards(json_path): + try: + with open(json_path, 'r') as f: + data = json.load(f) + weight_map = data.get('weight_map', {}) + safetensors_files = list(weight_map.values()) + unique_files = list(set(safetensors_files)) + return unique_files + except FileNotFoundError: + return [] + except json.JSONDecodeError: + print(f"Error: Unable to parse JSON in {json_path}.") + return [] + + def analyze_checkpoint_state(ckpt_path: str) -> Tuple[bool, bool, Tuple[int, int], List[str]]: """ チェックポイントの状態を分析し、DiffusersかBFLか、devかschnellか、ブロック数を計算して返す。 @@ -42,13 +57,26 @@ def analyze_checkpoint_state(ckpt_path: str) -> Tuple[bool, bool, Tuple[int, int # check the state dict: Diffusers or BFL, dev or schnell, number of blocks logger.info(f"Checking the state dict: Diffusers or BFL, dev or schnell") + ckpt_paths = [] if os.path.isdir(ckpt_path): # if ckpt_path is a directory, it is Diffusers - ckpt_path = os.path.join(ckpt_path, "transformer", "diffusion_pytorch_model-00001-of-00003.safetensors") - if "00001-of-00003" in ckpt_path: - ckpt_paths = [ckpt_path.replace("00001-of-00003", f"0000{i}-of-00003") for i in range(1, 4)] - else: + # Check for a transformer directory + transformer_dir = os.path.join(ckpt_path, "transformer") + if os.path.isdir(transformer_dir): + ckpt_path = os.path.join(ckpt_path, "transformer") + + # Check for weight_map index file + weight_map_file = os.path.join(ckpt_path, "diffusion_pytorch_model.safetensors.index.json") + if os.path.isfile(weight_map_file): + ckpt_paths = [os.path.join(ckpt_path, file) for file in get_shards(weight_map_file)] + else: + ckpt_path = os.path.join(ckpt_path, "diffusion_pytorch_model-00001-of-00003.safetensors") + ckpt_paths = [ckpt_path.replace("00001-of-00003", f"0000{i}-of-00003") for i in range(1, 4)] + else: ckpt_paths = [ckpt_path] + if len(ckpt_paths) == 0: + raise RuntimeError("Could not find a checkpoint to analyze / 分析するチェックポイントが見つかりませんでした") + keys = [] for ckpt_path in ckpt_paths: with safe_open(ckpt_path, framework="pt") as f: