diff --git a/python/sglang/multimodal_gen/runtime/pipelines_core/lora_format_adapter.py b/python/sglang/multimodal_gen/runtime/pipelines_core/lora_format_adapter.py new file mode 100644 index 000000000000..656d795abf01 --- /dev/null +++ b/python/sglang/multimodal_gen/runtime/pipelines_core/lora_format_adapter.py @@ -0,0 +1,418 @@ +from __future__ import annotations + +import logging +from enum import Enum +from typing import Dict, Iterable, Mapping, Optional + +import torch +from diffusers.loaders import lora_conversion_utils as lcu + +logger = logging.getLogger("LoRAFormatAdapter") + + +class LoRAFormat(str, Enum): + """Supported external LoRA formats before normalization.""" + + STANDARD = "standard" + NON_DIFFUSERS_SD = "non-diffusers-sd" + QWEN_IMAGE_STANDARD = "qwen-image-standard" + XLABS_FLUX = "xlabs-ai" + KOHYA_FLUX = "kohya-flux" + WAN = "wan" + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _sample_keys(keys: Iterable[str], k: int = 20) -> list[str]: + out = [] + for i, key in enumerate(keys): + if i >= k: + break + out.append(key) + return out + + +def _has_substring_key(keys: Iterable[str], substr: str) -> bool: + return any(substr in k for k in keys) + + +def _has_prefix_key(keys: Iterable[str], prefix: str) -> bool: + return any(k.startswith(prefix) for k in keys) + + +# --------------------------------------------------------------------------- +# Format-specific heuristics +# --------------------------------------------------------------------------- + + +def _looks_like_xlabs_flux_key(k: str) -> bool: + """XLabs FLUX-style keys under double_blocks/single_blocks with lora down/up.""" + if not (k.endswith(".down.weight") or k.endswith(".up.weight")): + return False + + if not k.startswith( + ( + "double_blocks.", + "single_blocks.", + "diffusion_model.double_blocks", + "diffusion_model.single_blocks", + ) + ): + return False + + return ".processor." in k or ".proj_lora" in k or ".qkv_lora" in k + + +def _looks_like_kohya_flux(state_dict: Mapping[str, torch.Tensor]) -> bool: + """Kohya FLUX LoRA (flux_lora.py) under lora_unet_double/single_blocks_ prefixes.""" + if not state_dict: + return False + keys = state_dict.keys() + return any( + k.startswith("lora_unet_double_blocks_") + or k.startswith("lora_unet_single_blocks_") + for k in keys + ) + + +def _looks_like_non_diffusers_sd(state_dict: Mapping[str, torch.Tensor]) -> bool: + """Classic non-diffusers SD LoRA (Kohya/A1111/sd-scripts).""" + if not state_dict: + return False + keys = state_dict.keys() + return all( + k.startswith(("lora_unet_", "lora_te_", "lora_te1_", "lora_te2_")) for k in keys + ) + + +def _looks_like_wan_lora(state_dict: Mapping[str, torch.Tensor]) -> bool: + """Wan2.2 distill LoRAs (Wan-AI / Wan2.2-Distill-Loras style).""" + if not state_dict: + return False + + for k in state_dict.keys(): + if not k.startswith("diffusion_model.blocks."): + continue + if ".lora_down" not in k and ".lora_up" not in k: + continue + if ".cross_attn." in k or ".self_attn." in k or ".ffn." in k or ".norm3." in k: + return True + + return False + + +def _looks_like_qwen_image(state_dict: Mapping[str, torch.Tensor]) -> bool: + keys = list(state_dict.keys()) + if not keys: + return False + return _has_prefix_key(keys, "transformer.transformer_blocks.") and ( + _has_substring_key(keys, ".lora.down.weight") + or _has_substring_key(keys, ".lora.up.weight") + ) + + +# --------------------------------------------------------------------------- +# Format detection +# --------------------------------------------------------------------------- + + +def detect_lora_format_from_state_dict( + state_dict: Mapping[str, torch.Tensor], +) -> LoRAFormat: + """Classify LoRA format by key patterns only.""" + keys = list(state_dict.keys()) + if not keys: + return LoRAFormat.STANDARD + + if _has_substring_key(keys, ".lora_A") or _has_substring_key(keys, ".lora_B"): + return LoRAFormat.STANDARD + + if any(_looks_like_xlabs_flux_key(k) for k in keys): + return LoRAFormat.XLABS_FLUX + if _looks_like_kohya_flux(state_dict): + return LoRAFormat.KOHYA_FLUX + + if _looks_like_wan_lora(state_dict): + return LoRAFormat.WAN + + if _looks_like_qwen_image(state_dict): + return LoRAFormat.STANDARD + + if _looks_like_non_diffusers_sd(state_dict): + return LoRAFormat.NON_DIFFUSERS_SD + + if _has_substring_key(keys, ".lora.down") or _has_substring_key(keys, ".lora_up"): + return LoRAFormat.NON_DIFFUSERS_SD + + return LoRAFormat.STANDARD + + +# --------------------------------------------------------------------------- +# Converters +# --------------------------------------------------------------------------- + + +def _convert_qwen_image_standard( + state_dict: Mapping[str, torch.Tensor], + log: logging.Logger, +) -> Dict[str, torch.Tensor]: + """Qwen-Image: transformer.*.lora.down/up -> transformer_blocks.*.lora_A/B.""" + out: Dict[str, torch.Tensor] = {} + + for name, tensor in state_dict.items(): + new_name = name + + if new_name.startswith("transformer."): + new_name = new_name[len("transformer.") :] + + if new_name.endswith(".lora.down.weight"): + new_name = new_name.replace(".lora.down.weight", ".lora_A.weight") + elif new_name.endswith(".lora.up.weight"): + new_name = new_name.replace(".lora.up.weight", ".lora_B.weight") + + out[new_name] = tensor + + sample = _sample_keys(out.keys(), 20) + return out + + +def _convert_non_diffusers_sd_simple( + state_dict: Mapping[str, torch.Tensor], + log: logging.Logger, +) -> Dict[str, torch.Tensor]: + """Generic down/up -> A/B conversion for non-diffusers SD-like formats.""" + out: Dict[str, torch.Tensor] = {} + + for name, tensor in state_dict.items(): + new_name = name + + if "lora_down.weight" in new_name: + new_name = new_name.replace("lora_down.weight", "lora_A.weight") + elif "lora_up.weight" in new_name: + new_name = new_name.replace("lora_up.weight", "lora_B.weight") + elif new_name.endswith(".lora_down"): + new_name = new_name.replace(".lora_down", ".lora_A") + elif new_name.endswith(".lora_up"): + new_name = new_name.replace(".lora_up", ".lora_B") + + out[new_name] = tensor + + sample = _sample_keys(out.keys(), 20) + log.info( + "[LoRAFormatAdapter] after NON_DIFFUSERS_SD simple conversion, " + "sample keys (<=20): %s", + ", ".join(sample), + ) + return out + + +def _convert_with_diffusers_utils_if_available( + state_dict: Mapping[str, torch.Tensor], + log: logging.Logger, +) -> Optional[Dict[str, torch.Tensor]]: + """Use diffusers.lora_conversion_utils if available.""" + try: + if hasattr(lcu, "maybe_convert_state_dict"): + converted = lcu.maybe_convert_state_dict( # type: ignore[attr-defined] + state_dict + ) + else: + converted = dict(state_dict) + + if not isinstance(converted, dict): + converted = dict(converted) + + sample = _sample_keys(converted.keys(), 20) + log.info( + "[LoRAFormatAdapter] diffusers.lora_conversion_utils converted keys, " + "sample keys (<=20): %s", + ", ".join(sample), + ) + return converted + except Exception as exc: # pragma: no cover + log.warning( + "[LoRAFormatAdapter] diffusers lora_conversion_utils failed, " + "falling back to internal converters. Error: %s", + exc, + ) + return None + + +def _convert_via_diffusers_candidates( + state_dict: Mapping[str, torch.Tensor], + candidate_names: tuple[str, ...], + log: logging.Logger, + unavailable_warning: str, + no_converter_warning: str, + success_info: str, + all_failed_warning: str, +) -> Dict[str, torch.Tensor]: + """Try multiple named converters in lora_conversion_utils, use the first that works.""" + converters = [ + (n, getattr(lcu, n)) for n in candidate_names if callable(getattr(lcu, n, None)) + ] + if not converters: + log.warning(no_converter_warning) + return dict(state_dict) + + last_err: Optional[Exception] = None + + for name, fn in converters: + try: + sd_copy = dict(state_dict) + out = fn(sd_copy) + if isinstance(out, tuple) and isinstance(out[0], dict): + out = out[0] + if not isinstance(out, dict): + raise TypeError(f"Converter {name} returned {type(out)}") + log.info(success_info.format(name=name)) + return out + except Exception as exc: + last_err = exc + + log.warning(all_failed_warning.format(last_err=last_err)) + return dict(state_dict) + + +def _convert_xlabs_ai_via_diffusers( + state_dict: Mapping[str, torch.Tensor], + log: logging.Logger, +) -> Dict[str, torch.Tensor]: + """Convert XLabs FLUX LoRA via diffusers helpers.""" + return _convert_via_diffusers_candidates( + state_dict, + ( + "_convert_xlabs_flux_lora_to_diffusers", + "convert_xlabs_lora_state_dict_to_diffusers", + "convert_xlabs_lora_to_diffusers", + "convert_xlabs_flux_lora_to_diffusers", + ), + log=log, + unavailable_warning=( + "[LoRAFormatAdapter] XLabs FLUX detected but diffusers is unavailable." + ), + no_converter_warning=( + "[LoRAFormatAdapter] No XLabs FLUX converter found in diffusers." + ), + success_info="[LoRAFormatAdapter] Converted XLabs FLUX LoRA using {name}", + all_failed_warning=( + "[LoRAFormatAdapter] All XLabs FLUX converters failed; " + "last error: {last_err}" + ), + ) + + +def _convert_kohya_flux_via_diffusers( + state_dict: Mapping[str, torch.Tensor], + log: logging.Logger, +) -> Dict[str, torch.Tensor]: + """Convert Kohya FLUX LoRA via diffusers helpers.""" + return _convert_via_diffusers_candidates( + state_dict, + ( + "_convert_kohya_flux_lora_to_diffusers", + "convert_kohya_flux_lora_to_diffusers", + ), + log=log, + unavailable_warning=( + "[LoRAFormatAdapter] Kohya FLUX detected but diffusers is unavailable." + ), + no_converter_warning="[LoRAFormatAdapter] No Kohya FLUX converter found.", + success_info="[LoRAFormatAdapter] Converted Kohya FLUX LoRA using {name}", + all_failed_warning=( + "[LoRAFormatAdapter] Kohya FLUX conversion failed; " + "last error: {last_err}" + ), + ) + + +# --------------------------------------------------------------------------- +# Conversion dispatcher +# --------------------------------------------------------------------------- + + +def convert_lora_state_dict_by_format( + state_dict: Mapping[str, torch.Tensor], + fmt: LoRAFormat, + log: logging.Logger, +) -> Dict[str, torch.Tensor]: + """Normalize a raw LoRA state_dict into A/B + .weight naming.""" + if fmt == LoRAFormat.QWEN_IMAGE_STANDARD: + return _convert_qwen_image_standard(state_dict, log) + + if fmt == LoRAFormat.XLABS_FLUX: + converted = _convert_xlabs_ai_via_diffusers(state_dict, log) + return _convert_non_diffusers_sd_simple(converted, log) + + if fmt == LoRAFormat.KOHYA_FLUX: + converted = _convert_kohya_flux_via_diffusers(state_dict, log) + return _convert_non_diffusers_sd_simple(converted, log) + + if fmt == LoRAFormat.WAN: + maybe = _convert_with_diffusers_utils_if_available(state_dict, log) + if maybe is None: + maybe = dict(state_dict) + return _convert_non_diffusers_sd_simple(maybe, log) + + if fmt == LoRAFormat.STANDARD: + maybe = _convert_with_diffusers_utils_if_available(state_dict, log) + if maybe is None: + maybe = dict(state_dict) + + if _looks_like_qwen_image(maybe): + return _convert_qwen_image_standard(maybe, log) + + return maybe + + if fmt == LoRAFormat.NON_DIFFUSERS_SD: + maybe = _convert_with_diffusers_utils_if_available(state_dict, log) + if maybe is None: + maybe = dict(state_dict) + return _convert_non_diffusers_sd_simple(maybe, log) + + log.info( + "[LoRAFormatAdapter] format %s not handled specially, returning as-is", + fmt, + ) + return dict(state_dict) + + +# --------------------------------------------------------------------------- +# Public entry point +# --------------------------------------------------------------------------- + + +def normalize_lora_state_dict( + state_dict: Mapping[str, torch.Tensor], + logger: Optional[logging.Logger] = None, +) -> Dict[str, torch.Tensor]: + """Normalize any supported LoRA format into a single canonical layout.""" + log = logger or globals()["logger"] + + keys = list(state_dict.keys()) + log.info( + "[LoRAFormatAdapter] normalize_lora_state_dict called, #keys=%d", + len(keys), + ) + if keys: + log.info( + "[LoRAFormatAdapter] before convert, sample keys (<=20): %s", + ", ".join(_sample_keys(keys, 20)), + ) + + fmt = detect_lora_format_from_state_dict(state_dict) + log.info("[LoRAFormatAdapter] detected format: %s", fmt) + + normalized = convert_lora_state_dict_by_format(state_dict, fmt, log) + + norm_keys = list(normalized.keys()) + if norm_keys: + log.info( + "[LoRAFormatAdapter] after convert, sample keys (<=20): %s", + ", ".join(_sample_keys(norm_keys, 20)), + ) + + return normalized diff --git a/python/sglang/multimodal_gen/runtime/pipelines_core/lora_pipeline.py b/python/sglang/multimodal_gen/runtime/pipelines_core/lora_pipeline.py index a9270ae8a916..c61cb793f82f 100644 --- a/python/sglang/multimodal_gen/runtime/pipelines_core/lora_pipeline.py +++ b/python/sglang/multimodal_gen/runtime/pipelines_core/lora_pipeline.py @@ -20,6 +20,9 @@ from sglang.multimodal_gen.runtime.pipelines_core.composed_pipeline_base import ( ComposedPipelineBase, ) +from sglang.multimodal_gen.runtime.pipelines_core.lora_format_adapter import ( + normalize_lora_state_dict, +) from sglang.multimodal_gen.runtime.server_args import ServerArgs from sglang.multimodal_gen.runtime.utils.hf_diffusers_utils import maybe_download_lora from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger @@ -298,7 +301,9 @@ def load_lora_adapter(self, lora_path: str, lora_nickname: str, rank: int): """ assert lora_path is not None lora_local_path = maybe_download_lora(lora_path) - lora_state_dict = load_file(lora_local_path) + + raw_state_dict = load_file(lora_local_path) + lora_state_dict = normalize_lora_state_dict(raw_state_dict, logger=logger) if lora_nickname in self.lora_adapters: self.lora_adapters[lora_nickname].clear() diff --git a/python/sglang/multimodal_gen/test/run_suite.py b/python/sglang/multimodal_gen/test/run_suite.py index cef91e34c56b..579a9493b400 100644 --- a/python/sglang/multimodal_gen/test/run_suite.py +++ b/python/sglang/multimodal_gen/test/run_suite.py @@ -22,6 +22,7 @@ "1-gpu": [ "test_server_a.py", "test_server_b.py", + "test_lora_format_adapter.py", # add new 1-gpu test files here ], "2-gpu": [ diff --git a/python/sglang/multimodal_gen/test/server/test_lora_format_adapter.py b/python/sglang/multimodal_gen/test/server/test_lora_format_adapter.py new file mode 100644 index 000000000000..48bf29af3b6f --- /dev/null +++ b/python/sglang/multimodal_gen/test/server/test_lora_format_adapter.py @@ -0,0 +1,324 @@ +""" +test_lora_format_adapter.py + +Small regression test for the LoRA format adapter. + +It downloads several public LoRA checkpoints from Hugging Face, runs +format detection and normalization, and prints a compact summary table. +""" + +import logging +import os +import tempfile +from typing import Dict, List + +import torch +from huggingface_hub import hf_hub_download +from safetensors.torch import load_file + +from sglang.multimodal_gen.runtime.pipelines_core.lora_format_adapter import ( + LoRAFormat, + detect_lora_format_from_state_dict, + normalize_lora_state_dict, +) + +logging.basicConfig(level=logging.INFO, force=True) +logger = logging.getLogger("lora_test") + +ROOT_DIR = os.path.join(tempfile.gettempdir(), "sglang_lora_tests") +os.makedirs(ROOT_DIR, exist_ok=True) + + +def download_lora( + repo_id: str, + filename: str, + local_name: str, +) -> str: + """ + Download a LoRA safetensors file into ROOT_DIR and return its local path. + """ + print(f"=== Downloading LoRA from {repo_id} ({filename}) ===") + path = hf_hub_download( + repo_id=repo_id, + filename=filename, + local_dir=ROOT_DIR, + local_dir_use_symlinks=False, + ) + dst = os.path.join(ROOT_DIR, local_name) + if os.path.abspath(path) != os.path.abspath(dst): + try: + import shutil + + shutil.copy2(path, dst) + except Exception: + dst = path + print(f"Saved to: {dst}") + return dst + + +def is_diffusers_style_keys( + sd: Dict[str, torch.Tensor], + debug_name: str = "", +) -> bool: + """ + Relaxed structural check that a state_dict looks like diffusers-style LoRA. + + The check verifies: + 1) No known non-diffusers prefixes. + 2) No non-diffusers suffixes such as alpha / dora_scale / magnitude vectors. + 3) Most top-level roots match common diffusers module namespaces. + """ + if not sd: + print(f"[{debug_name}] diffusers-style check: EMPTY state_dict") + return False + + keys: List[str] = list(sd.keys()) + total = len(keys) + + banned_prefixes = ( + "lora_unet_", + "lora_te_", + "lora_te1_", + "lora_te2_", + "lora_unet_double_blocks_", + "lora_unet_single_blocks_", + ) + bad_prefix_keys = [k for k in keys if k.startswith(banned_prefixes)] + cond1 = len(bad_prefix_keys) == 0 + + banned_suffixes = ( + ".alpha", + ".dora_scale", + ".lora_magnitude_vector", + ) + bad_suffix_keys = [k for k in keys if k.endswith(banned_suffixes)] + cond2 = len(bad_suffix_keys) == 0 + + allowed_roots = { + "unet", + "text_encoder", + "text_encoder_2", + "transformer", + "prior", + "image_encoder", + "vae", + "diffusion_model", + } + root_names = [k.split(".", 1)[0] for k in keys] + root_ok_count = sum(r in allowed_roots for r in root_names) + cond3 = root_ok_count >= 0.6 * total + + ok = cond1 and cond2 and cond3 + + if not ok: + print(f"[{debug_name}] diffusers-style check FAILED (relaxed):") + print(f" total keys = {total}") + print( + f" cond1(no banned prefixes) = {cond1}, bad_prefix_keys={len(bad_prefix_keys)}" + ) + if not cond1 and bad_prefix_keys: + print(" example bad prefix key:", bad_prefix_keys[0]) + print( + f" cond2(no banned suffixes) = {cond2}, bad_suffix_keys={len(bad_suffix_keys)}" + ) + if not cond2 and bad_suffix_keys: + print(" example bad suffix key:", bad_suffix_keys[0]) + print(f" cond3(allowed roots>=60%) = {cond3}, root_ok_count={root_ok_count}") + return ok + + +def run_single_test( + name: str, + repo_id: str, + filename: str, + local_name: str, + expected_before: LoRAFormat, + expected_after: LoRAFormat = LoRAFormat.STANDARD, +): + """ + Run a single end-to-end test for one LoRA checkpoint. + + Steps: + 1) Download. + 2) Detect format on raw keys. + 3) Normalize via lora_format_adapter. + 4) Detect again on the normalized dict. + 5) Optionally check for diffusers-style key structure. + """ + logger.info(f"=== Running test: {name} ===") + local_path = download_lora(repo_id, filename, local_name) + raw_state = load_file(local_path) + + detected_before = detect_lora_format_from_state_dict(raw_state) + norm_state = normalize_lora_state_dict(raw_state, logger=logger) + detected_after = detect_lora_format_from_state_dict(norm_state) + standard_like = is_diffusers_style_keys(norm_state, debug_name=name) + + passed = detected_before == expected_before and detected_after == expected_after + + return { + "name": name, + "expected_before": expected_before.value, + "detected_before": detected_before.value, + "expected_after": expected_after.value, + "detected_after": detected_after.value, + "standard_like_keys": standard_like, + "pass": passed, + "num_keys_raw": len(raw_state), + "num_keys_norm": len(norm_state), + } + + +def _run_all_tests() -> List[Dict]: + results: List[Dict] = [] + + # SDXL LoRA that is already in diffusers/PEFT format. + results.append( + run_single_test( + name="HF standard SDXL LoRA", + repo_id="jbilcke-hf/sdxl-cinematic-1", + filename="pytorch_lora_weights.safetensors", + local_name="sdxl_cinematic1_pytorch_lora_weights.safetensors", + expected_before=LoRAFormat.STANDARD, + expected_after=LoRAFormat.STANDARD, + ) + ) + + # XLabs FLUX LoRA (non-diffusers → diffusers). + results.append( + run_single_test( + name="XLabs FLUX Realism LoRA", + repo_id="XLabs-AI/flux-RealismLora", + filename="lora.safetensors", + local_name="flux_realism_lora.safetensors", + expected_before=LoRAFormat.XLABS_FLUX, + expected_after=LoRAFormat.STANDARD, + ) + ) + + # Kohya-style FLUX LoRA (sd-scripts flux_lora.py → diffusers). + results.append( + run_single_test( + name="Kohya-style Flux LoRA", + repo_id="kohya-ss/misc-models", + filename="flux-hasui-lora-d4-sigmoid-raw-gs1.0.safetensors", + local_name="flux_hasui_lora_d4_sigmoid_raw_gs1_0.safetensors", + expected_before=LoRAFormat.KOHYA_FLUX, + expected_after=LoRAFormat.STANDARD, + ) + ) + + # Classic Kohya/A1111 SD LoRA (non-diffusers SD → diffusers). + results.append( + run_single_test( + name="Kohya-style SD LoRA", + repo_id="kohya-ss/misc-models", + filename="fp-1f-chibi-1024.safetensors", + local_name="fp_1f_chibi_1024.safetensors", + expected_before=LoRAFormat.NON_DIFFUSERS_SD, + expected_after=LoRAFormat.STANDARD, + ) + ) + + # Wan2.1 Fun Reward LoRA (ComfyUI format → diffusers). + results.append( + run_single_test( + name="Wan2.1 Fun Reward LoRA (Comfy)", + repo_id="alibaba-pai/Wan2.1-Fun-Reward-LoRAs", + filename="Wan2.1-Fun-1.3B-InP-MPS.safetensors", + local_name="wan21_fun_1_3b_inp_mps.safetensors", + expected_before=LoRAFormat.NON_DIFFUSERS_SD, + expected_after=LoRAFormat.STANDARD, + ) + ) + + # Qwen-Image EVA LoRA (already diffusers/PEFT-style). + results.append( + run_single_test( + name="Qwen-Image EVA LoRA", + repo_id="starsfriday/Qwen-Image-EVA-LoRA", + filename="qwen_image_eva.safetensors", + local_name="qwen_image_eva.safetensors", + expected_before=LoRAFormat.STANDARD, + expected_after=LoRAFormat.STANDARD, + ) + ) + + # Qwen-Image Lightning LoRA (non-diffusers Qwen → diffusers). + results.append( + run_single_test( + name="Qwen-Image Lightning LoRA", + repo_id="lightx2v/Qwen-Image-Lightning", + filename="Qwen-Image-Lightning-4steps-V1.0-bf16.safetensors", + local_name="qwen_image_lightning_4steps_v1_bf16.safetensors", + expected_before=LoRAFormat.NON_DIFFUSERS_SD, + expected_after=LoRAFormat.STANDARD, + ) + ) + + # Classic Painting Z-Image Turbo LoRA (Z-Image family). + results.append( + run_single_test( + name="Classic Painting Z-Image LoRA", + repo_id="renderartist/Classic-Painting-Z-Image-Turbo-LoRA", + filename="Classic_Painting_Z_Image_Turbo_v1_renderartist_1750.safetensors", + local_name="classic_painting_z_image_turbo_v1_renderartist_1750.safetensors", + expected_before=LoRAFormat.STANDARD, + expected_after=LoRAFormat.STANDARD, + ) + ) + + return results + + +def _print_summary(results: List[Dict]) -> None: + print("\n================ LoRA format adapter test ================") + + header = ( + f"{'Test Name':30} " + f"{'Exp(b)':12} " + f"{'Act(b)':12} " + f"{'Exp(a)':12} " + f"{'Act(a)':12} " + f"{'StdLike':8} " + f"{'#Raw':7} " + f"{'#Norm':7} " + f"{'PASS':5}" + ) + print(header) + print("-" * len(header)) + + for r in results: + print( + f"{r['name'][:30]:30} " + f"{r['expected_before'][:12]:12} " + f"{r['detected_before'][:12]:12} " + f"{r['expected_after'][:12]:12} " + f"{r['detected_after'][:12]:12} " + f"{str(r['standard_like_keys']):8} " + f"{r['num_keys_raw']:7d} " + f"{r['num_keys_norm']:7d} " + f"{str(r['pass']):5}" + ) + + print("=========================================================\n") + + +def main() -> None: + results = _run_all_tests() + _print_summary(results) + + if not all(r["pass"] for r in results): + raise SystemExit(1) + + +class TestLoRAFormatAdapter: + def test_lora_format_adapter_all_formats(self): + results = _run_all_tests() + assert all( + r["pass"] for r in results + ), "At least one LoRA format adapter case failed" + + +if __name__ == "__main__": + main()