Skip to content

Commit

Permalink
minor changes
Browse files Browse the repository at this point in the history
  • Loading branch information
DEQDON committed Jan 29, 2024
1 parent d9b03aa commit fe4afa9
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 34 deletions.
4 changes: 2 additions & 2 deletions run_batch_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import shutil
import csv

from vivid123 import generation_vivid123, prepare_pipelines
from vivid123 import generation_vivid123, prepare_vivid123_pipeline


ZERO123_MODEL_ID = "bennyguo/zero123-xl-diffusers"
Expand All @@ -23,7 +23,7 @@
parser.add_argument('--run_to_obj_index', type=int, default=999, help='The index of object to end with')
args = parser.parse_args()

vivid123_pipe, xl_pipe = prepare_pipelines(
vivid123_pipe, xl_pipe = prepare_vivid123_pipeline(
ZERO123_MODEL_ID=ZERO123_MODEL_ID,
VIDEO_MODEL_ID=VIDEO_MODEL_ID,
VIDEO_XL_MODEL_ID=VIDEO_XL_MODEL_ID
Expand Down
48 changes: 31 additions & 17 deletions scripts/job_config_yaml_generation.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,41 @@
import os
import shutil
import yaml
import csv
import argparse
from vivid123.configs import ViVid123BaseSchema


parser = argparse.ArgumentParser(description='ViVid123 Generation')
parser.add_argument('--task_yamls_output_dir', type=str, default="tasks_gso", help='The directory for all configs')
parser.add_argument('--run_on_slurm', action='store_true', help="whether to run on a slurm cluster")
args = parser.parse_args()
SLURM_TMPDIR = os.getenv("SLURM_TMPDIR") if os.getenv("SLURM_TMPDIR") else "/home/erqun/vivid123/tmp"

my_model = ViVid123BaseSchema()
job_specs = [
# {"num_frames": 24, "delta_azimuth_start": 15, "delta_azimuth_end": 360, "exp_name": "num_frames_24"},
{} # default job specified by default schema in vivid123/configs/base_schema.py
]

os.makedirs(args.task_yaml_output_dir, exist_ok=True)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='ViVid123 Generation')
parser.add_argument('--run_on_slurm', action='store_true', help="whether to run on a slurm cluster")
args = parser.parse_args()

with open("scripts/gso_metadata_object_prompt_100.csv", 'r') as f_metadata:
csv_lines = csv.reader(f_metadata, delimiter=',', quotechar='"')
for i, csv_line in enumerate(csv_lines):
obj_name = csv_line[0]
my_model.name = obj_name
if args.run_on_slurm:
my_model.input_image_path = r"${SLURM_TMPDIR}/" + f"{obj_name}/img/012.png"
else:
my_model.input_image_path = f"./tmp/{obj_name}/img/012.png"
with open(f"{args.task_yaml_output_dir}/{obj_name}.yaml", "w") as f_job:
yaml.dump(my_model.model_dump(), f_job)
for job_spec in job_specs:
with open("scripts/gso_metadata_object_prompt_100.csv", 'r') as f_metadata:
csv_lines = csv.reader(f_metadata, delimiter=',', quotechar='"')
my_model = ViVid123BaseSchema()
for fieldname, value in job_spec.items():
if hasattr(my_model, fieldname):
setattr(my_model, fieldname, value)
else:
raise ValueError(f"No field {fieldname}")

task_yamls_output_dir = f"exps/task_yamls/{my_model.exp_name}"
os.makedirs(task_yamls_output_dir, exist_ok=True)
for i, csv_line in enumerate(csv_lines):
my_model.obj_name = csv_line[0]
if args.run_on_slurm:
my_model.input_image_path = r"${SLURM_TMPDIR}/" + f"{my_model.obj_name}/img/012.png"
else:
my_model.input_image_path = f"./tmp/{my_model.obj_name}/img/012.png"
with open(os.path.join(task_yamls_output_dir, f"{my_model.obj_name}.yaml"), "w") as f_job:
print(f"dumping yaml to ", os.path.join(task_yamls_output_dir, f"{my_model.obj_name}.yaml"))
yaml.dump(my_model.model_dump(), f_job)
5 changes: 3 additions & 2 deletions vivid123/configs/base_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ class Config:
num_inference_steps: int = 50
guidance_scale_zero123: float = 3.0
guidance_scale_video: float = 1.0
eta: float = 1.0
eta: float = 0.0 # 0.0 for purely deterministic, 1.0 for purely stochastic
noise_identical_accross_frames: bool = False
prompt: str = ""

Expand All @@ -35,5 +35,6 @@ class Config:
refiner_strength: float = 0.3
refiner_guidance_scale: float = 12.0

name: str = "new_balance_used"
obj_name: str = "new_balance_used"
input_image_path: str = "tmp/new_balance_used/012.png"
exp_name: str = "test_exp"
25 changes: 12 additions & 13 deletions vivid123/generation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,16 +235,15 @@ def replace_fn(match):
def load_yaml(filename: str) -> dict:
yaml.add_implicit_resolver("!envvar", _tag_matcher, None, yaml.SafeLoader)
yaml.add_constructor("!envvar", _path_constructor, yaml.SafeLoader)
try:
with open(filename, "r") as f:
return yaml.safe_load(f.read())
except (FileNotFoundError, PermissionError, ParserError):
return dict()
with open(filename, "r") as f:
return yaml.safe_load(f.read())

yaml_loaded = load_yaml(config_path)
print(f"input_image_path is: ", yaml_loaded["input_image_path"])
cfg = ViVid123BaseSchema.model_validate(yaml_loaded)

# get reference image
print(f"input_image_path is: {cfg.input_image_path}")
input_image = Image.open(cfg.input_image_path)
input_image = conver_rgba_to_rgb_white_bg(input_image, H=cfg.height, W=cfg.width)

Expand Down Expand Up @@ -286,13 +285,13 @@ def load_yaml(filename: str) -> dict:
).frames

# save imgs
os.makedirs(os.path.join(output_root_dir, cfg.name), exist_ok=True)
input_image.save(f"{output_root_dir}/{cfg.name}/input.png")
os.makedirs(os.path.join(output_root_dir, cfg.name, "base_frames"), exist_ok=True)
os.makedirs(os.path.join(output_root_dir, cfg.obj_name), exist_ok=True)
input_image.save(f"{output_root_dir}/{cfg.obj_name}/input.png")
os.makedirs(os.path.join(output_root_dir, cfg.obj_name, "base_frames"), exist_ok=True)
for i in range(len(vid_base_frames)):
Image.fromarray(vid_base_frames[i]).save(f"{output_root_dir}/{cfg.name}/base_frames/{i}.png")
Image.fromarray(vid_base_frames[i]).save(f"{output_root_dir}/{cfg.obj_name}/base_frames/{i}.png")

save_videos_grid_zeroscope_nplist(vid_base_frames, f"{output_root_dir}/{cfg.name}/base.mp4")
save_videos_grid_zeroscope_nplist(vid_base_frames, f"{output_root_dir}/{cfg.obj_name}/base.mp4")

if cfg.skip_refiner:
return
Expand All @@ -303,10 +302,10 @@ def load_yaml(filename: str) -> dict:
prompt=cfg.prompt, video=video_xl_input, strength=cfg.refiner_strength, guidance_scale=cfg.refiner_guidance_scale
).frames

os.makedirs(os.path.join(output_root_dir, cfg.name, "xl_frames"), exist_ok=True)
os.makedirs(os.path.join(output_root_dir, cfg.obj_name, "xl_frames"), exist_ok=True)
for i in range(len(vid_base_frames)):
Image.fromarray(vid_base_frames[i]).save(f"{output_root_dir}/{cfg.name}/xl_frames/{i}.png")
save_videos_grid_zeroscope_nplist(video_xl_frames, f"{output_root_dir}/{cfg.name}/xl.mp4")
Image.fromarray(vid_base_frames[i]).save(f"{output_root_dir}/{cfg.obj_name}/xl_frames/{i}.png")
save_videos_grid_zeroscope_nplist(video_xl_frames, f"{output_root_dir}/{cfg.obj_name}/xl.mp4")


def prepare_zero123_pipeline(
Expand Down

0 comments on commit fe4afa9

Please sign in to comment.