Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 14 additions & 9 deletions models/kandinsky5/kandinsky/models/parallelize.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@
from torch.distributed._tensor import Replicate, Shard
from torch.distributed.tensor.parallel import (
ColwiseParallel,
PrepareModuleInput,
PrepareModuleOutput,
RowwiseParallel,
SequenceParallel,
parallelize_module,
)
try:
from torch.distributed._tensor import Replicate, Shard
from torch.distributed.tensor.parallel import (
ColwiseParallel,
PrepareModuleInput,
PrepareModuleOutput,
RowwiseParallel,
SequenceParallel,
parallelize_module,
)
except (ImportError, ModuleNotFoundError):
Replicate = Shard = None
ColwiseParallel = PrepareModuleInput = PrepareModuleOutput = None
RowwiseParallel = SequenceParallel = parallelize_module = None


def parallelize_dit(model, tp_mesh):
Expand Down
10 changes: 8 additions & 2 deletions models/wan/any2video.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,19 @@
import torch
import torch.nn as nn
import torch.cuda.amp as amp
import torch.distributed as dist
try:
import torch.distributed as dist
except ImportError:
dist = None
import numpy as np
from tqdm import tqdm
from PIL import Image
import torchvision.transforms.functional as TF
import torch.nn.functional as F
from .distributed.fsdp import shard_model
try:
from .distributed.fsdp import shard_model
except (ImportError, ModuleNotFoundError):
shard_model = None
from .modules.model import WanModel
from mmgp.offload import get_cache, clear_caches
from .modules.t5 import T5EncoderModel
Expand Down
5 changes: 4 additions & 1 deletion models/wan/multitalk/multitalk.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import random
import os
import torch
import torch.distributed as dist
try:
import torch.distributed as dist
except ImportError:
dist = None
from PIL import Image
import subprocess
import torchvision.transforms as transforms
Expand Down