Skip to content

Commit 9527553

Browse files
author
Shu Zhang
committed
update training code and data
1 parent 54c8353 commit 9527553

File tree

4 files changed

+190
-70
lines changed

4 files changed

+190
-70
lines changed

Diff for: configs/generate.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ data:
8484
validation:
8585
target: edit_dataset.EditDataset
8686
params:
87-
path: data/clip-filtered-dataset
87+
path: ./data/training/instructpix2pix
8888
cache_dir: data/
8989
cache_name: data_10k
9090
split: val

Diff for: configs/train_v21_base.yaml

+125
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
# File modified by authors of InstructPix2Pix from original (https://github.com/CompVis/stable-diffusion).
2+
# See more details in LICENSE.
3+
4+
model:
5+
base_learning_rate: 1.0e-04
6+
target: ldm.models.diffusion.ddpm_edit_v21.LatentDiffusion
7+
params:
8+
ckpt_path: ./checkpoints/v2-1_512-ema-pruned.ckpt
9+
linear_start: 0.00085
10+
linear_end: 0.0120
11+
num_timesteps_cond: 1
12+
log_every_t: 200
13+
timesteps: 1000
14+
first_stage_key: edited
15+
cond_stage_key: edit
16+
image_size: 32
17+
channels: 4
18+
cond_stage_trainable: false # Note: different from the one we trained before
19+
conditioning_key: hybrid
20+
monitor: val/loss_simple_ema
21+
scale_factor: 0.18215
22+
use_ema: true
23+
load_ema: false
24+
25+
scheduler_config: # 10000 warmup steps
26+
target: ldm.lr_scheduler.LambdaLinearScheduler
27+
params:
28+
warm_up_steps: [ 0 ]
29+
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
30+
f_start: [ 1.e-6 ]
31+
f_max: [ 1. ]
32+
f_min: [ 1. ]
33+
34+
unet_config:
35+
target: ldm.modules.diffusionmodules.openaimodel_v21.UNetModel
36+
params:
37+
image_size: 32 # unused
38+
in_channels: 8
39+
out_channels: 4
40+
model_channels: 320
41+
attention_resolutions: [ 4, 2, 1 ]
42+
num_res_blocks: 2
43+
channel_mult: [ 1, 2, 4, 4 ]
44+
num_heads: 8
45+
num_head_channels: 64 # need to fix for flash-attn
46+
use_spatial_transformer: True
47+
use_linear_in_transformer: True
48+
transformer_depth: 1
49+
context_dim: 1024
50+
use_checkpoint: True
51+
legacy: False
52+
53+
first_stage_config:
54+
target: ldm.models.autoencoder.AutoencoderKL
55+
params:
56+
embed_dim: 4
57+
monitor: val/rec_loss
58+
ddconfig:
59+
double_z: true
60+
z_channels: 4
61+
resolution: 256
62+
in_channels: 3
63+
out_ch: 3
64+
ch: 128
65+
ch_mult:
66+
- 1
67+
- 2
68+
- 4
69+
- 4
70+
num_res_blocks: 2
71+
attn_resolutions: []
72+
dropout: 0.0
73+
lossconfig:
74+
target: torch.nn.Identity
75+
76+
cond_stage_config:
77+
target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder
78+
params:
79+
freeze: True
80+
layer: "penultimate"
81+
82+
83+
data:
84+
target: main.DataModuleFromConfig
85+
params:
86+
batch_size: 32
87+
num_workers: 2
88+
train:
89+
target: edit_dataset.EditDataset
90+
params:
91+
path_instructpix2pix: ./data/training/instructpix2pix
92+
path_hive_0: ./data/training
93+
path_hive_1: ./data/training/part_0_blip_prompt_new
94+
path_hive_2: ./data/training/part_1_blip_prompt_new
95+
split: train
96+
min_resize_res: 256
97+
max_resize_res: 256
98+
crop_res: 256
99+
flip_prob: 0.5
100+
validation:
101+
target: edit_dataset.EditDataset
102+
params:
103+
path_instructpix2pix: ./data/training/instructpix2pix
104+
path_hive_0: ./data/training
105+
path_hive_1: ./data/training/part_0_blip_prompt_new
106+
path_hive_2: ./data/training/part_1_blip_prompt_new
107+
split: val
108+
min_resize_res: 256
109+
max_resize_res: 256
110+
crop_res: 256
111+
112+
lightning:
113+
callbacks:
114+
image_logger:
115+
target: main.ImageLogger
116+
params:
117+
batch_frequency: 2000
118+
max_images: 2
119+
increase_log_steps: False
120+
121+
trainer:
122+
max_epochs: 3000
123+
benchmark: True
124+
accumulate_grad_batches: 4
125+
check_val_every_n_epoch: 4

Diff for: edit_dataset.py

+62-67
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,16 @@
2121
from PIL import Image
2222
from torch.utils.data import Dataset
2323
import jsonlines
24+
from collections import deque
2425

2526

2627
class EditDataset(Dataset):
2728
def __init__(
2829
self,
29-
path_official: str,
30-
path_ours: str,
30+
path_instructpix2pix: str,
31+
path_hive_0: str,
32+
path_hive_1: str,
33+
path_hive_2: str,
3134
split: str = "train",
3235
splits: tuple[float, float, float] = (0.9, 0.05, 0.05),
3336
min_resize_res: int = 256,
@@ -37,51 +40,91 @@ def __init__(
3740
):
3841
assert split in ("train", "val", "test")
3942
assert sum(splits) == 1
40-
self.path_official = path_official
41-
self.path_ours = path_ours
43+
self.path_instructpix2pix = path_instructpix2pix
44+
self.path_hive_0 = path_hive_0
45+
self.path_hive_1 = path_hive_1
46+
self.path_hive_2 = path_hive_2
4247
self.min_resize_res = min_resize_res
4348
self.max_resize_res = max_resize_res
4449
self.crop_res = crop_res
4550
self.flip_prob = flip_prob
46-
# load official dataset
47-
with open(Path(self.path_official, "seeds.json")) as f:
48-
self.seeds = json.load(f)
51+
self.seeds = []
52+
self.instructions = []
53+
self.source_imgs = []
54+
self.edited_imgs = []
55+
# load instructpix2pix dataset
56+
with open(Path(self.path_instructpix2pix, "seeds.json")) as f:
57+
seeds = json.load(f)
4958
split_0, split_1 = {
5059
"train": (0.0, splits[0]),
5160
"val": (splits[0], splits[0] + splits[1]),
5261
"test": (splits[0] + splits[1], 1.0),
5362
}[split]
5463

55-
idx_0 = math.floor(split_0 * len(self.seeds))
56-
idx_1 = math.floor(split_1 * len(self.seeds))
57-
self.seeds = self.seeds[idx_0:idx_1]
64+
idx_0 = math.floor(split_0 * len(seeds))
65+
idx_1 = math.floor(split_1 * len(seeds))
66+
seeds = seeds[idx_0:idx_1]
67+
68+
for seed in seeds:
69+
seed = deque(seed)
70+
seed.appendleft('')
71+
seed.appendleft('instructpix2pix')
72+
self.seeds.append(list(seed))
73+
74+
75+
# load HIVE dataset first part
5876

59-
# load in-house dataset
60-
self.instructions = []
61-
self.source_imgs = []
62-
self.edited_imgs = []
6377
cnt = 0
64-
with jsonlines.open(Path(self.path_ours, "training_1M.jsonl")) as reader:
78+
with jsonlines.open(Path(self.path_hive_0, "training_cycle.jsonl")) as reader:
6579
for ll in reader:
6680
self.instructions.append(ll['instruction'])
6781
self.source_imgs.append(ll['source_img'])
6882
self.edited_imgs.append(ll['edited_img'])
69-
self.seeds.append(['in_house', [cnt]])
83+
self.seeds.append(['hive_0', '', '', [cnt]])
7084
cnt += 1
7185

86+
# load HIVE dataset second part
87+
with open(Path(self.path_hive_1, "seeds.json")) as f:
88+
seeds = json.load(f)
89+
for seed in seeds:
90+
seed = deque(seed)
91+
seed.appendleft('hive_1')
92+
self.seeds.append(list(seed))
93+
# load HIVE dataset third part
94+
with open(Path(self.path_hive_2, "seeds.json")) as f:
95+
seeds = json.load(f)
96+
for seed in seeds:
97+
seed = deque(seed)
98+
seed.appendleft('hive_2')
99+
self.seeds.append(list(seed))
100+
72101
def __len__(self) -> int:
73102
return len(self.seeds)
74103

75104
def __getitem__(self, i: int) -> dict[str, Any]:
76105

77-
name, seeds = self.seeds[i]
78-
if name != 'in_house':
79-
propt_dir = Path(self.path_official, name)
106+
name_0, name_1, name_2, seeds = self.seeds[i]
107+
if name_0 == 'instructpix2pix':
108+
propt_dir = Path(self.path_instructpix2pix, name_2)
80109
seed = seeds[torch.randint(0, len(seeds), ()).item()]
81110
with open(propt_dir.joinpath("prompt.json")) as fp:
82111
prompt = json.load(fp)["edit"]
83112
image_0 = Image.open(propt_dir.joinpath(f"{seed}_0.jpg"))
84113
image_1 = Image.open(propt_dir.joinpath(f"{seed}_1.jpg"))
114+
elif name_0 == 'hive_1':
115+
propt_dir = Path(self.path_hive_1, name_1, name_2)
116+
seed = seeds[torch.randint(0, len(seeds), ()).item()]
117+
with open(propt_dir.joinpath("prompt.json")) as fp:
118+
prompt = json.load(fp)["instruction"]
119+
image_0 = Image.open(propt_dir.joinpath(f"{seed}_0.jpg"))
120+
image_1 = Image.open(propt_dir.joinpath(f"{seed}_1.jpg"))
121+
elif name_0 == 'hive_2':
122+
propt_dir = Path(self.path_hive_2, name_1, name_2)
123+
seed = seeds[torch.randint(0, len(seeds), ()).item()]
124+
with open(propt_dir.joinpath("prompt.json")) as fp:
125+
prompt = json.load(fp)["instruction"]
126+
image_0 = Image.open(propt_dir.joinpath(f"{seed}_0.jpg"))
127+
image_1 = Image.open(propt_dir.joinpath(f"{seed}_1.jpg"))
85128
else:
86129
j = seeds[0]
87130
image_0 = Image.open(self.source_imgs[j])
@@ -101,51 +144,3 @@ def __getitem__(self, i: int) -> dict[str, Any]:
101144

102145
return dict(edited=image_1, edit=dict(c_concat=image_0, c_crossattn=prompt))
103146

104-
105-
class EditDatasetEval(Dataset):
106-
def __init__(
107-
self,
108-
path: str,
109-
split: str = "train",
110-
splits: tuple[float, float, float] = (0.9, 0.05, 0.05),
111-
res: int = 256,
112-
):
113-
assert split in ("train", "val", "test")
114-
assert sum(splits) == 1
115-
self.path = path
116-
self.res = res
117-
118-
with open(Path(self.path, "seeds.json")) as f:
119-
self.seeds = json.load(f)
120-
121-
split_0, split_1 = {
122-
"train": (0.0, splits[0]),
123-
"val": (splits[0], splits[0] + splits[1]),
124-
"test": (splits[0] + splits[1], 1.0),
125-
}[split]
126-
127-
idx_0 = math.floor(split_0 * len(self.seeds))
128-
idx_1 = math.floor(split_1 * len(self.seeds))
129-
self.seeds = self.seeds[idx_0:idx_1]
130-
131-
def __len__(self) -> int:
132-
return len(self.seeds)
133-
134-
def __getitem__(self, i: int) -> dict[str, Any]:
135-
name, seeds = self.seeds[i]
136-
propt_dir = Path(self.path, name)
137-
seed = seeds[torch.randint(0, len(seeds), ()).item()]
138-
with open(propt_dir.joinpath("prompt.json")) as fp:
139-
prompt = json.load(fp)
140-
edit = prompt["edit"]
141-
input_prompt = prompt["input"]
142-
output_prompt = prompt["output"]
143-
144-
image_0 = Image.open(propt_dir.joinpath(f"{seed}_0.jpg"))
145-
146-
reize_res = torch.randint(self.res, self.res + 1, ()).item()
147-
image_0 = image_0.resize((reize_res, reize_res), Image.Resampling.LANCZOS)
148-
149-
image_0 = rearrange(2 * torch.tensor(np.array(image_0)).float() / 255 - 1, "h w c -> c h w")
150-
151-
return dict(image_0=image_0, input_prompt=input_prompt, edit=edit, output_prompt=output_prompt)

Diff for: main.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ def str2bool(v):
111111
"-s",
112112
"--seed",
113113
type=int,
114-
default=23,
114+
default=100,
115115
help="seed for seed_everything",
116116
)
117117
parser.add_argument(
@@ -125,7 +125,7 @@ def str2bool(v):
125125
"-l",
126126
"--logdir",
127127
type=str,
128-
default="/export/laion-aesthetics-v2/instruct_pix2pix/logs",
128+
default="./logs",
129129
help="directory for logging dat shit",
130130
)
131131
parser.add_argument(

0 commit comments

Comments
 (0)