diff --git a/library/train_util.py b/library/train_util.py index d076cf847..b69fb0950 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -1554,7 +1554,7 @@ def load_dreambooth_dir(subset: DreamBoothSubset): for img_path, caption in zip(img_paths, captions): info = ImageInfo(img_path, subset.num_repeats, caption, subset.is_reg, img_path) if subset.is_reg: - reg_infos.append(info) + reg_infos.append((info, subset)) else: self.register_image(info, subset) @@ -1575,7 +1575,7 @@ def load_dreambooth_dir(subset: DreamBoothSubset): n = 0 first_loop = True while n < num_train_images: - for info in reg_infos: + for info, subset in reg_infos: if first_loop: self.register_image(info, subset) n += info.num_repeats