From 40c1fe0c28364b243b5944b3569000611ddf2b7d Mon Sep 17 00:00:00 2001 From: Arjun Suresh Date: Thu, 7 Nov 2024 21:20:52 +0530 Subject: [PATCH] Added an option to pass in sample_ids.txt for SDXL accuracy check --- text_to_image/tools/accuracy_coco.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/text_to_image/tools/accuracy_coco.py b/text_to_image/tools/accuracy_coco.py index 2d7c36506..8740ee172 100644 --- a/text_to_image/tools/accuracy_coco.py +++ b/text_to_image/tools/accuracy_coco.py @@ -51,6 +51,10 @@ def get_args(): required=False, help="path to dump 10 stable diffusion xl compliance images", ) + #Do not use for official MLPerf inference submissions as only the default one is valid + parser.add_argument( + "--ids-path", help="Path to 10 caption ids to dump as compliance images", default="os.path.join(os.path.dirname(__file__), 'sample_ids.txt')" + ) parser.add_argument("--device", default="cpu", choices=["gpu", "cpu"]) parser.add_argument( "--low_memory", @@ -97,8 +101,9 @@ def main(): os.makedirs(args.compliance_images_path) dump_compliance_images = True compliance_images_idx_list = [] + sample_ids_file_path = args.ids_path if args.ids_path else os.path.join(os.path.dirname(__file__), "sample_ids.txt") with open( - os.path.join(os.path.dirname(__file__), "sample_ids.txt"), "r" + os.path.join(sample_ids_file_path, "r" ) as compliance_id_file: for line in compliance_id_file: idx = int(line.strip())