Skip to content

Commit

Permalink
fix: crash in inference script, errors in documentation
Browse files Browse the repository at this point in the history
  • Loading branch information
Bycob authored and beniz committed Mar 7, 2024
1 parent c89ef9b commit f99dd34
Show file tree
Hide file tree
Showing 9 changed files with 17 additions and 27 deletions.
13 changes: 2 additions & 11 deletions data/image_folder.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,18 +94,9 @@ def make_labeled_path_dataset(dir, paths, max_dataset_size=float("inf")):
): # we allow B not having a label
images.append(line_split[0])

elif len(line_split) == 2:
elif len(line_split) >= 2:
images.append(line_split[0])
labels.append(line_split[1])

elif len(line_split) > 2:
images.append(line_split[0])

label_line = line_split[1]
for i in range(2, len(line_split)):
label_line += " " + line_split[i]

labels.append(label_line)
labels.append(" ".join(line_split[1:]))

return (
images[: min(max_dataset_size, len(images))],
Expand Down
2 changes: 1 addition & 1 deletion docs/source/inference.rst
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@ Download a pretrained model:
Run the inference script
========================

The ``--cond-in`` parameter specifies the conditioning image to use.
The ``--cond_in`` parameter specifies the conditioning image to use.

.. code:: bash
Expand Down
5 changes: 3 additions & 2 deletions options/base_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,10 +266,11 @@ def _after_parse(self, opt, set_device=True):

return self.opt

def parse(self):
def parse(self, save_config=True):
"""Parse our options, create checkpoints directory suffix, and set up gpu device."""
self.opt = self.gather_options()
self.save_options()
if save_config:
self.save_options()
opt = self._after_parse(self.opt)
return opt

Expand Down
2 changes: 1 addition & 1 deletion options/inference_diffusion_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def initialize(self, parser):
)

parser.add_argument(
"--cls_value",
"--cls",
type=int,
default=-1,
help="override input bbox classe for generation",
Expand Down
2 changes: 1 addition & 1 deletion scripts/gen_single_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,5 +142,5 @@ def inference(args):


if __name__ == "__main__":
opt = InferenceGANOptions().parse()
opt = InferenceGANOptions().parse(save_config=False)
inference(opt)
8 changes: 3 additions & 5 deletions scripts/gen_single_image_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ def generate(
cond_persp_vertical,
alg_diffusion_cond_image_creation,
alg_diffusion_sketch_canny_thresholds,
cls_value,
cls,
alg_diffusion_super_resolution_downsample,
alg_diffusion_guidance_scale,
data_refined_mask,
Expand Down Expand Up @@ -268,9 +268,7 @@ def generate(
elts = line.rstrip().split()
bboxes.append([int(elts[1]), int(elts[2]), int(elts[3]), int(elts[4])])
if conditioning:
if cls_value > 0:
cls = cls_value
else:
if cls <= 0:
cls = int(elts[0])
else:
cls = 1
Expand Down Expand Up @@ -760,5 +758,5 @@ def inference(args):


if __name__ == "__main__":
args = InferenceDiffusionOptions().parse()
args = InferenceDiffusionOptions().parse(save_config=False)
inference(args)
2 changes: 1 addition & 1 deletion tests/test_run_mask_online_ref.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
["cut", "unaligned_labeled_mask_online_ref"],
]
conditionings = [
"alg_palette_conditioning",
"alg_diffusion_cond_embed",
"alg_palette_cond_image_creation",
]

Expand Down
2 changes: 1 addition & 1 deletion tests/test_run_mask_ref.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
["cut", "unaligned_labeled_mask_ref"],
]
conditionings = [
"alg_palette_conditioning",
"alg_diffusion_cond_embed",
"alg_palette_cond_image_creation",
]

Expand Down
8 changes: 4 additions & 4 deletions util/visualizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,12 +458,12 @@ def plot_metrics_dict(
json.dump(self.metrics_dict, fp)

def plot_current_metrics(self, epoch, counter_ratio, metrics):
"""display the current fid values on visdom display: dictionary of fid labels and values
"""display the current metrics values on visdom display
Parameters:
epoch (int) -- current epoch
counter_ratio (float) -- progress (percentage) in the current epoch, between 0 to 1
fids (OrderedDict) -- training fid values stored in the format of (name, float) pairs
metrics (OrderedDict) -- training metrics values stored in the format of (name, float) pairs
"""
self.plot_metrics_dict(
"metric",
Expand All @@ -476,7 +476,7 @@ def plot_current_metrics(self, epoch, counter_ratio, metrics):
)

def plot_current_D_accuracies(self, epoch, counter_ratio, accuracies):
"""display the current fid values on visdom display: dictionary of fid labels and values
"""display the current accuracies values on visdom display
Parameters:
epoch (int) -- current epoch
Expand Down Expand Up @@ -505,7 +505,7 @@ def plot_current_APA_prob(self, epoch, counter_ratio, p):
)

def plot_current_miou(self, epoch, counter_ratio, miou):
"""display the current fid values on visdom display: dictionary of fid labels and values
"""display the current miou values on visdom display
Parameters:
epoch (int) -- current epoch
Expand Down

0 comments on commit f99dd34

Please sign in to comment.