@@ -782,32 +782,46 @@ def inject_fake_data(self, tmpdir, config):
782
782
783
783
annotation_folder = tmpdir / self ._ANNOTATIONS_FOLDER
784
784
os .makedirs (annotation_folder )
785
+
786
+ segmentation_kind = config .pop ("segmentation_kind" , "list" )
785
787
info = self ._create_annotation_file (
786
- annotation_folder , self ._ANNOTATIONS_FILE , file_names , num_annotations_per_image
788
+ annotation_folder ,
789
+ self ._ANNOTATIONS_FILE ,
790
+ file_names ,
791
+ num_annotations_per_image ,
792
+ segmentation_kind = segmentation_kind ,
787
793
)
788
794
789
795
info ["num_examples" ] = num_images
790
796
return info
791
797
792
- def _create_annotation_file (self , root , name , file_names , num_annotations_per_image ):
798
+ def _create_annotation_file (self , root , name , file_names , num_annotations_per_image , segmentation_kind = "list" ):
793
799
image_ids = [int (file_name .stem ) for file_name in file_names ]
794
800
images = [dict (file_name = str (file_name ), id = id ) for file_name , id in zip (file_names , image_ids )]
795
801
796
- annotations , info = self ._create_annotations (image_ids , num_annotations_per_image )
802
+ annotations , info = self ._create_annotations (image_ids , num_annotations_per_image , segmentation_kind )
797
803
self ._create_json (root , name , dict (images = images , annotations = annotations ))
798
804
799
805
return info
800
806
801
- def _create_annotations (self , image_ids , num_annotations_per_image ):
807
+ def _create_annotations (self , image_ids , num_annotations_per_image , segmentation_kind = "list" ):
802
808
annotations = []
803
809
annotion_id = 0
810
+
804
811
for image_id in itertools .islice (itertools .cycle (image_ids ), len (image_ids ) * num_annotations_per_image ):
812
+ segmentation = {
813
+ "list" : [torch .rand (8 ).tolist ()],
814
+ "rle" : {"size" : [10 , 10 ], "counts" : [1 ]},
815
+ "rle_encoded" : {"size" : [2400 , 2400 ], "counts" : "PQRQ2[1\\ Y2f0gNVNRhMg2" },
816
+ "bad" : 123 ,
817
+ }[segmentation_kind ]
818
+
805
819
annotations .append (
806
820
dict (
807
821
image_id = image_id ,
808
822
id = annotion_id ,
809
823
bbox = torch .rand (4 ).tolist (),
810
- segmentation = [ torch . rand ( 8 ). tolist ()] ,
824
+ segmentation = segmentation ,
811
825
category_id = int (torch .randint (91 , ())),
812
826
area = float (torch .rand (1 )),
813
827
iscrowd = int (torch .randint (2 , size = (1 ,))),
@@ -832,11 +846,27 @@ def test_slice_error(self):
832
846
with pytest .raises (ValueError , match = "Index must be of type integer" ):
833
847
dataset [:2 ]
834
848
849
+ def test_segmentation_kind (self ):
850
+ if isinstance (self , CocoCaptionsTestCase ):
851
+ return
852
+
853
+ for segmentation_kind in ("list" , "rle" , "rle_encoded" ):
854
+ config = {"segmentation_kind" : segmentation_kind }
855
+ with self .create_dataset (config ) as (dataset , _ ):
856
+ dataset = datasets .wrap_dataset_for_transforms_v2 (dataset , target_keys = "all" )
857
+ list (dataset )
858
+
859
+ config = {"segmentation_kind" : "bad" }
860
+ with self .create_dataset (config ) as (dataset , _ ):
861
+ dataset = datasets .wrap_dataset_for_transforms_v2 (dataset , target_keys = "all" )
862
+ with pytest .raises (ValueError , match = "COCO segmentation expected to be a dict or a list" ):
863
+ list (dataset )
864
+
835
865
836
866
class CocoCaptionsTestCase (CocoDetectionTestCase ):
837
867
DATASET_CLASS = datasets .CocoCaptions
838
868
839
- def _create_annotations (self , image_ids , num_annotations_per_image ):
869
+ def _create_annotations (self , image_ids , num_annotations_per_image , segmentation_kind = "list" ):
840
870
captions = [str (idx ) for idx in range (num_annotations_per_image )]
841
871
annotations = combinations_grid (image_id = image_ids , caption = captions )
842
872
for id , annotation in enumerate (annotations ):
@@ -2442,28 +2472,68 @@ def inject_fake_data(self, tmpdir, config):
2442
2472
base_folder = os .path .join (tmpdir , "fer2013" )
2443
2473
os .makedirs (base_folder )
2444
2474
2475
+ use_icml = config .pop ("use_icml" , False )
2476
+ use_fer = config .pop ("use_fer" , False )
2477
+
2445
2478
num_samples = 5
2446
- with open (os .path .join (base_folder , f"{ config ['split' ]} .csv" ), "w" , newline = "" ) as file :
2447
- writer = csv .DictWriter (
2448
- file ,
2449
- fieldnames = ("emotion" , "pixels" ) if config ["split" ] == "train" else ("pixels" ,),
2450
- quoting = csv .QUOTE_NONNUMERIC ,
2451
- quotechar = '"' ,
2452
- )
2453
- writer .writeheader ()
2454
- for _ in range (num_samples ):
2455
- row = dict (
2456
- pixels = " " .join (
2457
- str (pixel ) for pixel in datasets_utils .create_image_or_video_tensor ((48 , 48 )).view (- 1 ).tolist ()
2458
- )
2479
+
2480
+ if use_icml or use_fer :
2481
+ pixels_key , usage_key = (" pixels" , " Usage" ) if use_icml else ("pixels" , "Usage" )
2482
+ fieldnames = ("emotion" , usage_key , pixels_key ) if use_icml else ("emotion" , pixels_key , usage_key )
2483
+ filename = "icml_face_data.csv" if use_icml else "fer2013.csv"
2484
+ with open (os .path .join (base_folder , filename ), "w" , newline = "" ) as file :
2485
+ writer = csv .DictWriter (
2486
+ file ,
2487
+ fieldnames = fieldnames ,
2488
+ quoting = csv .QUOTE_NONNUMERIC ,
2489
+ quotechar = '"' ,
2490
+ )
2491
+ writer .writeheader ()
2492
+ for i in range (num_samples ):
2493
+ row = {
2494
+ "emotion" : str (int (torch .randint (0 , 7 , ()))),
2495
+ usage_key : "Training" if i % 2 else "PublicTest" ,
2496
+ pixels_key : " " .join (
2497
+ str (pixel )
2498
+ for pixel in datasets_utils .create_image_or_video_tensor ((48 , 48 )).view (- 1 ).tolist ()
2499
+ ),
2500
+ }
2501
+
2502
+ writer .writerow (row )
2503
+ else :
2504
+ with open (os .path .join (base_folder , f"{ config ['split' ]} .csv" ), "w" , newline = "" ) as file :
2505
+ writer = csv .DictWriter (
2506
+ file ,
2507
+ fieldnames = ("emotion" , "pixels" ) if config ["split" ] == "train" else ("pixels" ,),
2508
+ quoting = csv .QUOTE_NONNUMERIC ,
2509
+ quotechar = '"' ,
2459
2510
)
2460
- if config ["split" ] == "train" :
2461
- row ["emotion" ] = str (int (torch .randint (0 , 7 , ())))
2511
+ writer .writeheader ()
2512
+ for _ in range (num_samples ):
2513
+ row = dict (
2514
+ pixels = " " .join (
2515
+ str (pixel )
2516
+ for pixel in datasets_utils .create_image_or_video_tensor ((48 , 48 )).view (- 1 ).tolist ()
2517
+ )
2518
+ )
2519
+ if config ["split" ] == "train" :
2520
+ row ["emotion" ] = str (int (torch .randint (0 , 7 , ())))
2462
2521
2463
- writer .writerow (row )
2522
+ writer .writerow (row )
2464
2523
2465
2524
return num_samples
2466
2525
2526
+ def test_icml_file (self ):
2527
+ config = {"split" : "test" }
2528
+ with self .create_dataset (config = config ) as (dataset , _ ):
2529
+ assert all (s [1 ] is None for s in dataset )
2530
+
2531
+ for split in ("train" , "test" ):
2532
+ for d in ({"use_icml" : True }, {"use_fer" : True }):
2533
+ config = {"split" : split , ** d }
2534
+ with self .create_dataset (config = config ) as (dataset , _ ):
2535
+ assert all (s [1 ] is not None for s in dataset )
2536
+
2467
2537
2468
2538
class GTSRBTestCase (datasets_utils .ImageDatasetTestCase ):
2469
2539
DATASET_CLASS = datasets .GTSRB
0 commit comments