From f99dd348fd7a3b2a4314012d0b2d4de08d774408 Mon Sep 17 00:00:00 2001 From: Louis Date: Mon, 4 Mar 2024 23:23:11 +0000 Subject: [PATCH] fix: crash in inference script, errors in documentation --- data/image_folder.py | 13 ++----------- docs/source/inference.rst | 2 +- options/base_options.py | 5 +++-- options/inference_diffusion_options.py | 2 +- scripts/gen_single_image.py | 2 +- scripts/gen_single_image_diffusion.py | 8 +++----- tests/test_run_mask_online_ref.py | 2 +- tests/test_run_mask_ref.py | 2 +- util/visualizer.py | 8 ++++---- 9 files changed, 17 insertions(+), 27 deletions(-) diff --git a/data/image_folder.py b/data/image_folder.py index 972afbbc9..85b3f6505 100644 --- a/data/image_folder.py +++ b/data/image_folder.py @@ -94,18 +94,9 @@ def make_labeled_path_dataset(dir, paths, max_dataset_size=float("inf")): ): # we allow B not having a label images.append(line_split[0]) - elif len(line_split) == 2: + elif len(line_split) >= 2: images.append(line_split[0]) - labels.append(line_split[1]) - - elif len(line_split) > 2: - images.append(line_split[0]) - - label_line = line_split[1] - for i in range(2, len(line_split)): - label_line += " " + line_split[i] - - labels.append(label_line) + labels.append(" ".join(line_split[1:])) return ( images[: min(max_dataset_size, len(images))], diff --git a/docs/source/inference.rst b/docs/source/inference.rst index 8de64308c..150eb1469 100644 --- a/docs/source/inference.rst +++ b/docs/source/inference.rst @@ -276,7 +276,7 @@ Download a pretrained model: Run the inference script ======================== -The ``--cond-in`` parameter specifies the conditioning image to use. +The ``--cond_in`` parameter specifies the conditioning image to use. .. code:: bash diff --git a/options/base_options.py b/options/base_options.py index 99d74840c..2eb371591 100644 --- a/options/base_options.py +++ b/options/base_options.py @@ -266,10 +266,11 @@ def _after_parse(self, opt, set_device=True): return self.opt - def parse(self): + def parse(self, save_config=True): """Parse our options, create checkpoints directory suffix, and set up gpu device.""" self.opt = self.gather_options() - self.save_options() + if save_config: + self.save_options() opt = self._after_parse(self.opt) return opt diff --git a/options/inference_diffusion_options.py b/options/inference_diffusion_options.py index 1e5637f30..afd997ef7 100644 --- a/options/inference_diffusion_options.py +++ b/options/inference_diffusion_options.py @@ -98,7 +98,7 @@ def initialize(self, parser): ) parser.add_argument( - "--cls_value", + "--cls", type=int, default=-1, help="override input bbox classe for generation", diff --git a/scripts/gen_single_image.py b/scripts/gen_single_image.py index 9b5832e2d..301f4ecc7 100644 --- a/scripts/gen_single_image.py +++ b/scripts/gen_single_image.py @@ -142,5 +142,5 @@ def inference(args): if __name__ == "__main__": - opt = InferenceGANOptions().parse() + opt = InferenceGANOptions().parse(save_config=False) inference(opt) diff --git a/scripts/gen_single_image_diffusion.py b/scripts/gen_single_image_diffusion.py index 3c171a5de..c50fa171c 100644 --- a/scripts/gen_single_image_diffusion.py +++ b/scripts/gen_single_image_diffusion.py @@ -186,7 +186,7 @@ def generate( cond_persp_vertical, alg_diffusion_cond_image_creation, alg_diffusion_sketch_canny_thresholds, - cls_value, + cls, alg_diffusion_super_resolution_downsample, alg_diffusion_guidance_scale, data_refined_mask, @@ -268,9 +268,7 @@ def generate( elts = line.rstrip().split() bboxes.append([int(elts[1]), int(elts[2]), int(elts[3]), int(elts[4])]) if conditioning: - if cls_value > 0: - cls = cls_value - else: + if cls <= 0: cls = int(elts[0]) else: cls = 1 @@ -760,5 +758,5 @@ def inference(args): if __name__ == "__main__": - args = InferenceDiffusionOptions().parse() + args = InferenceDiffusionOptions().parse(save_config=False) inference(args) diff --git a/tests/test_run_mask_online_ref.py b/tests/test_run_mask_online_ref.py index 4d9b916a1..099960197 100644 --- a/tests/test_run_mask_online_ref.py +++ b/tests/test_run_mask_online_ref.py @@ -36,7 +36,7 @@ ["cut", "unaligned_labeled_mask_online_ref"], ] conditionings = [ - "alg_palette_conditioning", + "alg_diffusion_cond_embed", "alg_palette_cond_image_creation", ] diff --git a/tests/test_run_mask_ref.py b/tests/test_run_mask_ref.py index ebb0568f9..5f2860d06 100644 --- a/tests/test_run_mask_ref.py +++ b/tests/test_run_mask_ref.py @@ -36,7 +36,7 @@ ["cut", "unaligned_labeled_mask_ref"], ] conditionings = [ - "alg_palette_conditioning", + "alg_diffusion_cond_embed", "alg_palette_cond_image_creation", ] diff --git a/util/visualizer.py b/util/visualizer.py index 7a75a2053..119081dd1 100644 --- a/util/visualizer.py +++ b/util/visualizer.py @@ -458,12 +458,12 @@ def plot_metrics_dict( json.dump(self.metrics_dict, fp) def plot_current_metrics(self, epoch, counter_ratio, metrics): - """display the current fid values on visdom display: dictionary of fid labels and values + """display the current metrics values on visdom display Parameters: epoch (int) -- current epoch counter_ratio (float) -- progress (percentage) in the current epoch, between 0 to 1 - fids (OrderedDict) -- training fid values stored in the format of (name, float) pairs + metrics (OrderedDict) -- training metrics values stored in the format of (name, float) pairs """ self.plot_metrics_dict( "metric", @@ -476,7 +476,7 @@ def plot_current_metrics(self, epoch, counter_ratio, metrics): ) def plot_current_D_accuracies(self, epoch, counter_ratio, accuracies): - """display the current fid values on visdom display: dictionary of fid labels and values + """display the current accuracies values on visdom display Parameters: epoch (int) -- current epoch @@ -505,7 +505,7 @@ def plot_current_APA_prob(self, epoch, counter_ratio, p): ) def plot_current_miou(self, epoch, counter_ratio, miou): - """display the current fid values on visdom display: dictionary of fid labels and values + """display the current miou values on visdom display Parameters: epoch (int) -- current epoch