diff --git a/finetrainers/patches/__init__.py b/finetrainers/patches/__init__.py index 1deb0729..26ce0ff4 100644 --- a/finetrainers/patches/__init__.py +++ b/finetrainers/patches/__init__.py @@ -17,7 +17,7 @@ def perform_patches_for_training(args: "BaseArgs", parallel_backend: "ParallelBa if parallel_backend.tensor_parallel_enabled: patch.patch_apply_rotary_emb_for_tp_compatibility() - if args.model_name == ModelType.WAN: + if args.model_name == ModelType.WAN and "transformer" in args.layerwise_upcasting_modules: from .models.wan import patch patch.patch_time_text_image_embedding_forward()