Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Longer videos and textual inversions and fp16 autocast #25

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,43 @@ Here we demonstrate several best results we found in our experiments.
</table>
<p style="margin-left: 2em; margin-top: -1em">Model:<a href="https://civitai.com/models/33208/filmgirl-film-grain-lora-and-loha">FilmVelvia</a></p>

### Longer generations
You can also generate longer animations by using overlapping sliding windows.
```
python -m scripts.animate --config configs/prompts/{your_config}.yaml --L 64 --context_length 16
```
##### Sliding window related parameters:

```L``` - the length of the generated animation.

```context_length``` - the length of the sliding window (limited by motion modules capacity), default to ```L```.

```context_overlap``` - how much neighbouring contexts overlap. By default ```context_length``` / 2

```context_stride``` - (2^```context_stride```) is a max stride between 2 neighbour frames. By default 0

##### Extended this way gallery examples

<table class="center">
<tr>
<td><img src="__assets__/animations/model_01_4x/01.gif"></td>
<td><img src="__assets__/animations/model_01_4x/02.gif"></td>
<td><img src="__assets__/animations/model_01_4x/03.gif"></td>
<td><img src="__assets__/animations/model_01_4x/04.gif"></td>
</tr>
</table>
<p style="margin-left: 2em; margin-top: -1em">Model:<a href="https://civitai.com/models/30240/toonyou">ToonYou</a></p>

<table>
<tr>
<td><img src="__assets__/animations/model_03_4x/01.gif"></td>
<td><img src="__assets__/animations/model_03_4x/02.gif"></td>
<td><img src="__assets__/animations/model_03_4x/03.gif"></td>
<td><img src="__assets__/animations/model_03_4x/04.gif"></td>
</tr>
</table>
<p style="margin-left: 2em; margin-top: -1em">Model:<a href="https://civitai.com/models/4201/realistic-vision-v20">Realistic Vision V2.0</a></p>

#### Community Cases
Here are some samples contributed by the community artists. Create a Pull Request if you would like to show your results here😚.

Expand Down
Binary file added __assets__/animations/model_01_4x/01.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added __assets__/animations/model_01_4x/02.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added __assets__/animations/model_01_4x/03.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added __assets__/animations/model_01_4x/04.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added __assets__/animations/model_03_4x/01.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added __assets__/animations/model_03_4x/02.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added __assets__/animations/model_03_4x/03.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added __assets__/animations/model_03_4x/04.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
116 changes: 94 additions & 22 deletions animatediff/pipelines/pipeline_animation.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
# Adapted from https://github.com/showlab/Tune-A-Video/blob/main/tuneavideo/pipelines/pipeline_tuneavideo.py

import inspect
import os
from typing import Callable, List, Optional, Union
from dataclasses import dataclass

import numpy as np
import torch
from torch import nn
from tqdm import tqdm

from diffusers.utils import is_accelerate_available
Expand All @@ -29,6 +31,8 @@

from ..models.unet import UNet3DConditionModel

from ..utils import overlap_policy
from ..utils.path import get_absolute_path

logger = logging.get_logger(__name__) # pylint: disable=invalid-name

Expand All @@ -55,6 +59,7 @@ def __init__(
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
],
scan_inversions: bool = True,
):
super().__init__()

Expand Down Expand Up @@ -114,6 +119,36 @@ def __init__(
scheduler=scheduler,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.embeddings_dir = get_absolute_path('models', 'embeddings')
self.embeddings_dict = {}
self.default_tokens = len(self.tokenizer)
self.scan_inversions = scan_inversions

def update_embeddings(self):
if not self.scan_inversions:
return
names = [p for p in os.listdir(self.embeddings_dir) if p.endswith('.pt')]
weight = self.text_encoder.text_model.embeddings.token_embedding.weight
added_embeddings = []
for name in names:
embedding_path = os.path.join(self.embeddings_dir, name)
embedding = torch.load(embedding_path)
key = os.path.splitext(name)[0]
if key in self.tokenizer.encoder:
idx = self.tokenizer.encoder[key]
else:
idx = len(self.tokenizer)
self.tokenizer.add_tokens([key])
embedding = embedding['string_to_param']['*']
if idx not in self.embeddings_dict:
added_embeddings.append(name)
self.embeddings_dict[idx] = torch.arange(weight.shape[0], weight.shape[0] + embedding.shape[0])
weight = torch.cat([weight, embedding.to(weight.device, weight.dtype)], dim=0)
self.tokenizer.add_tokens([key])
if added_embeddings:
self.text_encoder.text_model.embeddings.token_embedding = nn.Embedding(
weight.shape[0], weight.shape[1], _weight=weight)
logger.info(f'Added {len(added_embeddings)} embeddings: {added_embeddings}')

def enable_vae_slicing(self):
self.vae.enable_slicing()
Expand Down Expand Up @@ -147,9 +182,32 @@ def _execution_device(self):
return torch.device(module._hf_hook.execution_device)
return self.device

def insert_inversions(self, ids, attention_mask):
larger = ids >= self.default_tokens
for idx in reversed(torch.where(larger)[1]):
ids = torch.cat([
ids[:, :idx],
self.embeddings_dict[ids[:, idx].item()].unsqueeze(0),
ids[:, idx + 1:],
], 1)
if attention_mask is not None:
attention_mask = torch.cat([
attention_mask[:, :idx],
torch.ones(1, 1, dtype=attention_mask.dtype, device=attention_mask.device),
attention_mask[:, idx + 1:],
], 1)
if ids.shape[1] > self.tokenizer.model_max_length:
logger.warning(f"After inserting inversions, the sequence length is larger than the max length. Cutting off"
f" {ids.shape[1] - self.tokenizer.model_max_length} tokens.")
ids = torch.cat([ids[:, :self.tokenizer.model_max_length - 1], ids[:, -1:]], 1)
if attention_mask is not None:
attention_mask = attention_mask[:, :self.tokenizer.model_max_length]
return ids, attention_mask

def _encode_prompt(self, prompt, device, num_videos_per_prompt, do_classifier_free_guidance, negative_prompt):
batch_size = len(prompt) if isinstance(prompt, list) else 1

self.update_embeddings()
text_inputs = self.tokenizer(
prompt,
padding="max_length",
Expand All @@ -172,6 +230,7 @@ def _encode_prompt(self, prompt, device, num_videos_per_prompt, do_classifier_fr
else:
attention_mask = None

text_input_ids, attention_mask = self.insert_inversions(text_input_ids, attention_mask)
text_embeddings = self.text_encoder(
text_input_ids.to(device),
attention_mask=attention_mask,
Expand Down Expand Up @@ -218,8 +277,10 @@ def _encode_prompt(self, prompt, device, num_videos_per_prompt, do_classifier_fr
else:
attention_mask = None

uncond_input_ids = uncond_input.input_ids
uncond_input_ids, attention_mask = self.insert_inversions(uncond_input_ids, attention_mask)
uncond_embeddings = self.text_encoder(
uncond_input.input_ids.to(device),
uncond_input_ids.to(device),
attention_mask=attention_mask,
)
uncond_embeddings = uncond_embeddings[0]
Expand All @@ -242,8 +303,9 @@ def decode_latents(self, latents):
latents = rearrange(latents, "b c f h w -> (b f) c h w")
# video = self.vae.decode(latents).sample
video = []
device = self._execution_device
for frame_idx in tqdm(range(latents.shape[0])):
video.append(self.vae.decode(latents[frame_idx:frame_idx+1]).sample)
video.append(self.vae.decode(latents[frame_idx:frame_idx+1].to(device)).sample)
video = torch.cat(video)
video = rearrange(video, "(b f) c h w -> b c f h w", f=video_length)
video = (video / 2 + 0.5).clamp(0, 1)
Expand Down Expand Up @@ -317,6 +379,9 @@ def __call__(
self,
prompt: Union[str, List[str]],
video_length: Optional[int],
temporal_context: Optional[int] = None,
strides: int = 3,
overlap: int = 4,
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 50,
Expand All @@ -330,6 +395,8 @@ def __call__(
return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: Optional[int] = 1,
seq_policy=overlap_policy.uniform,
fp16=False,
**kwargs,
):
# Default height and width to unet
Expand All @@ -348,6 +415,7 @@ def __call__(
batch_size = len(prompt)

device = self._execution_device
cpu = torch.device('cpu')
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
Expand All @@ -356,7 +424,7 @@ def __call__(
# Encode input prompt
prompt = prompt if isinstance(prompt, list) else [prompt] * batch_size
if negative_prompt is not None:
negative_prompt = negative_prompt if isinstance(negative_prompt, list) else [negative_prompt] * batch_size
negative_prompt = negative_prompt if isinstance(negative_prompt, list) else [negative_prompt] * batch_size
text_embeddings = self._encode_prompt(
prompt, device, num_videos_per_prompt, do_classifier_free_guidance, negative_prompt
)
Expand All @@ -373,45 +441,49 @@ def __call__(
video_length,
height,
width,
text_embeddings.dtype,
device,
torch.float32,
cpu, # using cpu to store latents allows generated frame amount not to be limited by vram but by ram
generator,
latents,
)
latents_dtype = latents.dtype

# Prepare extra step kwargs.
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)

total = sum(
len(list(seq_policy(i, num_inference_steps, latents.shape[2], temporal_context, strides, overlap)))
for i in range(len(timesteps))
)
# Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
with self.progress_bar(total=num_inference_steps) as progress_bar:
with self.progress_bar(total=total) as progress_bar:
for i, t in enumerate(timesteps):
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)

# predict the noise residual
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample.to(dtype=latents_dtype)
# noise_pred = []
# import pdb
# pdb.set_trace()
# for batch_idx in range(latent_model_input.shape[0]):
# noise_pred_single = self.unet(latent_model_input[batch_idx:batch_idx+1], t, encoder_hidden_states=text_embeddings[batch_idx:batch_idx+1]).sample.to(dtype=latents_dtype)
# noise_pred.append(noise_pred_single)
# noise_pred = torch.cat(noise_pred)
noise_pred = torch.zeros((latents.shape[0] * (2 if do_classifier_free_guidance else 1),
*latents.shape[1:]), device=latents.device, dtype=latents_dtype)
counter = torch.zeros((1, 1, latents.shape[2], 1, 1), device=latents.device, dtype=latents_dtype)
for seq in seq_policy(i, num_inference_steps, latents.shape[2], temporal_context, strides, overlap):
# expand the latents if we are doing classifier free guidance
latent_model_input = latents[:, :, seq].to(device)\
.repeat(2 if do_classifier_free_guidance else 1, 1, 1, 1, 1)
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)

# predict the noise residual
with torch.autocast('cuda', enabled=fp16, dtype=torch.float16):
pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings)
noise_pred[:, :, seq] += pred.sample.to(dtype=latents_dtype, device=cpu)
counter[:, :, seq] += 1
progress_bar.update()

# perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred_uncond, noise_pred_text = (noise_pred / counter).chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample

# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)

Expand Down
20 changes: 20 additions & 0 deletions animatediff/utils/overlap_policy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import numpy as np


def ordered_halving(i):
return int('{:064b}'.format(i)[::-1], 2) / (1 << 64)


def uniform(step, steps, n, context_size, strides, overlap, closed_loop=True):
if n <= context_size:
yield list(range(n))
return
strides = min(strides, int(np.ceil(np.log2(n / context_size))) + 1)
for stride in 1 << np.arange(strides):
pad = int(round(n * ordered_halving(step)))
for j in range(
int(ordered_halving(step) * stride) + pad,
n + pad + (0 if closed_loop else -overlap),
(context_size * stride - overlap)
):
yield [e % n for e in range(j, j + context_size * stride, stride)]
16 changes: 16 additions & 0 deletions animatediff/utils/path.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import os

project_path = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))


def get_absolute_path(*relative):
if relative[0].startswith('/'):
return os.path.join(*relative) # absolute path
return os.path.join(project_path, *relative)


if __name__ == '__main__':
print(get_absolute_path('test'))
print(get_absolute_path('/test'))
print(get_absolute_path('test', 'test'))
print(get_absolute_path('/test', 'test'))
Empty file.
25 changes: 23 additions & 2 deletions scripts/animate.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,12 @@
def main(args):
*_, func_args = inspect.getargvalues(inspect.currentframe())
func_args = dict(func_args)


if args.context_length == 0:
args.context_length = args.L
if args.context_overlap == -1:
args.context_overlap = args.context_length // 2

time_str = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
savedir = f"samples/{Path(args.config).stem}-{time_str}"
os.makedirs(savedir)
Expand Down Expand Up @@ -58,6 +63,7 @@ def main(args):
pipeline = AnimationPipeline(
vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet,
scheduler=DDIMScheduler(**OmegaConf.to_container(inference_config.noise_scheduler_kwargs)),
scan_inversions=not args.disable_inversions,
).to("cuda")

# 1. unet ckpt
Expand Down Expand Up @@ -130,6 +136,10 @@ def main(args):
width = args.W,
height = args.H,
video_length = args.L,
temporal_context = args.context_length,
strides = args.context_stride + 1,
overlap = args.context_overlap,
fp16 = not args.fp32,
).videos
samples.append(sample)

Expand All @@ -150,7 +160,18 @@ def main(args):
parser.add_argument("--pretrained_model_path", type=str, default="models/StableDiffusion/stable-diffusion-v1-5",)
parser.add_argument("--inference_config", type=str, default="configs/inference/inference.yaml")
parser.add_argument("--config", type=str, required=True)


parser.add_argument("--fp32", action="store_true")
parser.add_argument("--disable_inversions", action="store_true",
help="do not scan for downloaded textual inversions")

parser.add_argument("--context_length", type=int, default=0,
help="temporal transformer context length (0 for same as -L)")
parser.add_argument("--context_stride", type=int, default=0,
help="max stride of motion is 2^context_stride")
parser.add_argument("--context_overlap", type=int, default=-1,
help="overlap between chunks of context (-1 for half of context length)")

parser.add_argument("--L", type=int, default=16 )
parser.add_argument("--W", type=int, default=512)
parser.add_argument("--H", type=int, default=512)
Expand Down