diff --git a/models/kandinsky5/kandinsky/models/parallelize.py b/models/kandinsky5/kandinsky/models/parallelize.py index 9dc469b1a..ffe5e185d 100644 --- a/models/kandinsky5/kandinsky/models/parallelize.py +++ b/models/kandinsky5/kandinsky/models/parallelize.py @@ -1,12 +1,17 @@ -from torch.distributed._tensor import Replicate, Shard -from torch.distributed.tensor.parallel import ( - ColwiseParallel, - PrepareModuleInput, - PrepareModuleOutput, - RowwiseParallel, - SequenceParallel, - parallelize_module, -) +try: + from torch.distributed._tensor import Replicate, Shard + from torch.distributed.tensor.parallel import ( + ColwiseParallel, + PrepareModuleInput, + PrepareModuleOutput, + RowwiseParallel, + SequenceParallel, + parallelize_module, + ) +except (ImportError, ModuleNotFoundError): + Replicate = Shard = None + ColwiseParallel = PrepareModuleInput = PrepareModuleOutput = None + RowwiseParallel = SequenceParallel = parallelize_module = None def parallelize_dit(model, tp_mesh): diff --git a/models/wan/any2video.py b/models/wan/any2video.py index 65e11d05b..c1af38913 100644 --- a/models/wan/any2video.py +++ b/models/wan/any2video.py @@ -14,13 +14,19 @@ import torch import torch.nn as nn import torch.cuda.amp as amp -import torch.distributed as dist +try: + import torch.distributed as dist +except ImportError: + dist = None import numpy as np from tqdm import tqdm from PIL import Image import torchvision.transforms.functional as TF import torch.nn.functional as F -from .distributed.fsdp import shard_model +try: + from .distributed.fsdp import shard_model +except (ImportError, ModuleNotFoundError): + shard_model = None from .modules.model import WanModel from mmgp.offload import get_cache, clear_caches from .modules.t5 import T5EncoderModel diff --git a/models/wan/multitalk/multitalk.py b/models/wan/multitalk/multitalk.py index f0a65ce52..73a873b78 100644 --- a/models/wan/multitalk/multitalk.py +++ b/models/wan/multitalk/multitalk.py @@ -1,7 +1,10 @@ import random import os import torch -import torch.distributed as dist +try: + import torch.distributed as dist +except ImportError: + dist = None from PIL import Image import subprocess import torchvision.transforms as transforms