-
Notifications
You must be signed in to change notification settings - Fork 17
/
inference.py
59 lines (48 loc) · 1.67 KB
/
inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import torch
from models.pipeline import VchitectXLPipeline
import random
import numpy as np
import os
def set_seed(seed):
random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
def infer(args):
pipe = VchitectXLPipeline(args.ckpt_path)
idx = 0
with open(args.test_file,'r') as f:
for lines in f.readlines():
for seed in range(5):
set_seed(seed)
prompt = lines.strip('\n')
with torch.cuda.amp.autocast(dtype=torch.bfloat16):
video = pipe(
prompt,
negative_prompt="",
num_inference_steps=100,
guidance_scale=7.5,
width=768,
height=432, #480x288 624x352 432x240 768x432
frames=40
)
images = video
from utils import save_as_mp4
import sys,os
duration = 1000 / 8
save_dir = args.save_dir
os.makedirs(save_dir,exist_ok=True)
idx += 1
save_as_mp4(images, os.path.join(save_dir, f"sample_{idx}_seed{seed}")+'.mp4', duration=duration)
import sys,os
import argparse
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--test_file", type=str)
parser.add_argument("--save_dir", type=str)
parser.add_argument("--ckpt_path", type=str)
args = parser.parse_known_args()[0]
infer(args)
if __name__ == "__main__":
main()