21
21
from PIL import Image
22
22
from torch .utils .data import Dataset
23
23
import jsonlines
24
+ from collections import deque
24
25
25
26
26
27
class EditDataset (Dataset ):
27
28
def __init__ (
28
29
self ,
29
- path_official : str ,
30
- path_ours : str ,
30
+ path_instructpix2pix : str ,
31
+ path_hive_0 : str ,
32
+ path_hive_1 : str ,
33
+ path_hive_2 : str ,
31
34
split : str = "train" ,
32
35
splits : tuple [float , float , float ] = (0.9 , 0.05 , 0.05 ),
33
36
min_resize_res : int = 256 ,
@@ -37,51 +40,91 @@ def __init__(
37
40
):
38
41
assert split in ("train" , "val" , "test" )
39
42
assert sum (splits ) == 1
40
- self .path_official = path_official
41
- self .path_ours = path_ours
43
+ self .path_instructpix2pix = path_instructpix2pix
44
+ self .path_hive_0 = path_hive_0
45
+ self .path_hive_1 = path_hive_1
46
+ self .path_hive_2 = path_hive_2
42
47
self .min_resize_res = min_resize_res
43
48
self .max_resize_res = max_resize_res
44
49
self .crop_res = crop_res
45
50
self .flip_prob = flip_prob
46
- # load official dataset
47
- with open (Path (self .path_official , "seeds.json" )) as f :
48
- self .seeds = json .load (f )
51
+ self .seeds = []
52
+ self .instructions = []
53
+ self .source_imgs = []
54
+ self .edited_imgs = []
55
+ # load instructpix2pix dataset
56
+ with open (Path (self .path_instructpix2pix , "seeds.json" )) as f :
57
+ seeds = json .load (f )
49
58
split_0 , split_1 = {
50
59
"train" : (0.0 , splits [0 ]),
51
60
"val" : (splits [0 ], splits [0 ] + splits [1 ]),
52
61
"test" : (splits [0 ] + splits [1 ], 1.0 ),
53
62
}[split ]
54
63
55
- idx_0 = math .floor (split_0 * len (self .seeds ))
56
- idx_1 = math .floor (split_1 * len (self .seeds ))
57
- self .seeds = self .seeds [idx_0 :idx_1 ]
64
+ idx_0 = math .floor (split_0 * len (seeds ))
65
+ idx_1 = math .floor (split_1 * len (seeds ))
66
+ seeds = seeds [idx_0 :idx_1 ]
67
+
68
+ for seed in seeds :
69
+ seed = deque (seed )
70
+ seed .appendleft ('' )
71
+ seed .appendleft ('instructpix2pix' )
72
+ self .seeds .append (list (seed ))
73
+
74
+
75
+ # load HIVE dataset first part
58
76
59
- # load in-house dataset
60
- self .instructions = []
61
- self .source_imgs = []
62
- self .edited_imgs = []
63
77
cnt = 0
64
- with jsonlines .open (Path (self .path_ours , "training_1M .jsonl" )) as reader :
78
+ with jsonlines .open (Path (self .path_hive_0 , "training_cycle .jsonl" )) as reader :
65
79
for ll in reader :
66
80
self .instructions .append (ll ['instruction' ])
67
81
self .source_imgs .append (ll ['source_img' ])
68
82
self .edited_imgs .append (ll ['edited_img' ])
69
- self .seeds .append (['in_house ' , [cnt ]])
83
+ self .seeds .append (['hive_0' , '' , ' ' , [cnt ]])
70
84
cnt += 1
71
85
86
+ # load HIVE dataset second part
87
+ with open (Path (self .path_hive_1 , "seeds.json" )) as f :
88
+ seeds = json .load (f )
89
+ for seed in seeds :
90
+ seed = deque (seed )
91
+ seed .appendleft ('hive_1' )
92
+ self .seeds .append (list (seed ))
93
+ # load HIVE dataset third part
94
+ with open (Path (self .path_hive_2 , "seeds.json" )) as f :
95
+ seeds = json .load (f )
96
+ for seed in seeds :
97
+ seed = deque (seed )
98
+ seed .appendleft ('hive_2' )
99
+ self .seeds .append (list (seed ))
100
+
72
101
def __len__ (self ) -> int :
73
102
return len (self .seeds )
74
103
75
104
def __getitem__ (self , i : int ) -> dict [str , Any ]:
76
105
77
- name , seeds = self .seeds [i ]
78
- if name != 'in_house ' :
79
- propt_dir = Path (self .path_official , name )
106
+ name_0 , name_1 , name_2 , seeds = self .seeds [i ]
107
+ if name_0 == 'instructpix2pix ' :
108
+ propt_dir = Path (self .path_instructpix2pix , name_2 )
80
109
seed = seeds [torch .randint (0 , len (seeds ), ()).item ()]
81
110
with open (propt_dir .joinpath ("prompt.json" )) as fp :
82
111
prompt = json .load (fp )["edit" ]
83
112
image_0 = Image .open (propt_dir .joinpath (f"{ seed } _0.jpg" ))
84
113
image_1 = Image .open (propt_dir .joinpath (f"{ seed } _1.jpg" ))
114
+ elif name_0 == 'hive_1' :
115
+ propt_dir = Path (self .path_hive_1 , name_1 , name_2 )
116
+ seed = seeds [torch .randint (0 , len (seeds ), ()).item ()]
117
+ with open (propt_dir .joinpath ("prompt.json" )) as fp :
118
+ prompt = json .load (fp )["instruction" ]
119
+ image_0 = Image .open (propt_dir .joinpath (f"{ seed } _0.jpg" ))
120
+ image_1 = Image .open (propt_dir .joinpath (f"{ seed } _1.jpg" ))
121
+ elif name_0 == 'hive_2' :
122
+ propt_dir = Path (self .path_hive_2 , name_1 , name_2 )
123
+ seed = seeds [torch .randint (0 , len (seeds ), ()).item ()]
124
+ with open (propt_dir .joinpath ("prompt.json" )) as fp :
125
+ prompt = json .load (fp )["instruction" ]
126
+ image_0 = Image .open (propt_dir .joinpath (f"{ seed } _0.jpg" ))
127
+ image_1 = Image .open (propt_dir .joinpath (f"{ seed } _1.jpg" ))
85
128
else :
86
129
j = seeds [0 ]
87
130
image_0 = Image .open (self .source_imgs [j ])
@@ -101,51 +144,3 @@ def __getitem__(self, i: int) -> dict[str, Any]:
101
144
102
145
return dict (edited = image_1 , edit = dict (c_concat = image_0 , c_crossattn = prompt ))
103
146
104
-
105
- class EditDatasetEval (Dataset ):
106
- def __init__ (
107
- self ,
108
- path : str ,
109
- split : str = "train" ,
110
- splits : tuple [float , float , float ] = (0.9 , 0.05 , 0.05 ),
111
- res : int = 256 ,
112
- ):
113
- assert split in ("train" , "val" , "test" )
114
- assert sum (splits ) == 1
115
- self .path = path
116
- self .res = res
117
-
118
- with open (Path (self .path , "seeds.json" )) as f :
119
- self .seeds = json .load (f )
120
-
121
- split_0 , split_1 = {
122
- "train" : (0.0 , splits [0 ]),
123
- "val" : (splits [0 ], splits [0 ] + splits [1 ]),
124
- "test" : (splits [0 ] + splits [1 ], 1.0 ),
125
- }[split ]
126
-
127
- idx_0 = math .floor (split_0 * len (self .seeds ))
128
- idx_1 = math .floor (split_1 * len (self .seeds ))
129
- self .seeds = self .seeds [idx_0 :idx_1 ]
130
-
131
- def __len__ (self ) -> int :
132
- return len (self .seeds )
133
-
134
- def __getitem__ (self , i : int ) -> dict [str , Any ]:
135
- name , seeds = self .seeds [i ]
136
- propt_dir = Path (self .path , name )
137
- seed = seeds [torch .randint (0 , len (seeds ), ()).item ()]
138
- with open (propt_dir .joinpath ("prompt.json" )) as fp :
139
- prompt = json .load (fp )
140
- edit = prompt ["edit" ]
141
- input_prompt = prompt ["input" ]
142
- output_prompt = prompt ["output" ]
143
-
144
- image_0 = Image .open (propt_dir .joinpath (f"{ seed } _0.jpg" ))
145
-
146
- reize_res = torch .randint (self .res , self .res + 1 , ()).item ()
147
- image_0 = image_0 .resize ((reize_res , reize_res ), Image .Resampling .LANCZOS )
148
-
149
- image_0 = rearrange (2 * torch .tensor (np .array (image_0 )).float () / 255 - 1 , "h w c -> c h w" )
150
-
151
- return dict (image_0 = image_0 , input_prompt = input_prompt , edit = edit , output_prompt = output_prompt )
0 commit comments