31
31
else :
32
32
from tqdm .auto import tqdm
33
33
34
+ from doctr import datasets
34
35
from doctr import transforms as T
35
36
from doctr .datasets import VOCABS , RecognitionDataset , WordGenerator
36
37
from doctr .models import login_to_hub , push_to_hf_hub , recognition
37
38
from doctr .utils .metrics import TextMatch
38
39
from utils import EarlyStopper , plot_recorder , plot_samples
39
40
41
+ # dataset_map = {
42
+ # "SVHN":SVHN,
43
+ # "CORD":CORD,
44
+ # "FUNSD": FUNSD
45
+ # }
46
+
40
47
41
48
def record_lr (
42
49
model : torch .nn .Module ,
@@ -220,8 +227,32 @@ def main(args):
220
227
labels_path = os .path .join (args .val_path , "labels.json" ),
221
228
img_transforms = T .Resize ((args .input_size , 4 * args .input_size ), preserve_aspect_ratio = True ),
222
229
)
230
+ elif args .val_datasets :
231
+ val_datasets = args .val_datasets
232
+
233
+ val_set = datasets .__dict__ [val_datasets [0 ]](
234
+ train = False ,
235
+ download = True ,
236
+ recognition_task = True ,
237
+ use_polygons = True ,
238
+ img_transforms = Compose ([
239
+ T .Resize ((args .input_size , 4 * args .input_size ), preserve_aspect_ratio = True ),
240
+ # Augmentations
241
+ T .RandomApply (T .ColorInversion (), 0.1 ),
242
+ ]),
243
+ )
244
+ if len (val_datasets ) > 1 :
245
+ for dataset_name in val_datasets [1 :]:
246
+ _ds = datasets .__dict__ [dataset_name ](
247
+ train = False ,
248
+ download = True ,
249
+ recognition_task = True ,
250
+ use_polygons = True ,
251
+ )
252
+
253
+ val_set .data .extend ((np_img , target ) for np_img , target in _ds .data )
254
+
223
255
else :
224
- val_hash = None
225
256
# Load synthetic data generator
226
257
val_set = WordGenerator (
227
258
vocab = vocab ,
@@ -235,6 +266,7 @@ def main(args):
235
266
T .RandomApply (T .ColorInversion (), 0.9 ),
236
267
]),
237
268
)
269
+ val_hash = None
238
270
239
271
val_loader = DataLoader (
240
272
val_set ,
@@ -315,8 +347,32 @@ def main(args):
315
347
train_set .merge_dataset (
316
348
RecognitionDataset (subfolder .joinpath ("images" ), subfolder .joinpath ("labels.json" ))
317
349
)
350
+
351
+ elif args .train_datasets :
352
+ train_datasets = args .train_datasets
353
+
354
+ train_set = datasets .__dict__ [train_datasets [0 ]](
355
+ train = True ,
356
+ download = True ,
357
+ recognition_task = True ,
358
+ use_polygons = True ,
359
+ img_transforms = Compose ([
360
+ T .Resize ((args .input_size , 4 * args .input_size ), preserve_aspect_ratio = True ),
361
+ # Augmentations
362
+ T .RandomApply (T .ColorInversion (), 0.1 ),
363
+ ]),
364
+ )
365
+ if len (train_datasets ) > 1 :
366
+ for dataset_name in train_datasets [1 :]:
367
+ _ds = datasets .__dict__ [dataset_name ](
368
+ train = True ,
369
+ download = True ,
370
+ recognition_task = True ,
371
+ use_polygons = True ,
372
+ )
373
+ train_set .data .extend ((np_img , target ) for np_img , target in _ds .data )
374
+
318
375
else :
319
- train_hash = None
320
376
# Load synthetic data generator
321
377
train_set = WordGenerator (
322
378
vocab = vocab ,
@@ -348,6 +404,8 @@ def main(args):
348
404
)
349
405
pbar .write (f"Train set loaded in { time .time () - st :.4} s ({ len (train_set )} samples in { len (train_loader )} batches)" )
350
406
407
+ train_hash = None
408
+
351
409
if args .show_samples :
352
410
x , target = next (iter (train_loader ))
353
411
plot_samples (x , target )
@@ -545,6 +603,50 @@ def parse_args():
545
603
default = 20 ,
546
604
help = "Multiplied by the vocab length gets you the number of synthetic validation samples that will be used." ,
547
605
)
606
+ (
607
+ parser .add_argument (
608
+ "--train_datasets" ,
609
+ type = str ,
610
+ nargs = "+" ,
611
+ choices = [
612
+ "COCOTEXT" ,
613
+ "CORD" ,
614
+ "FUNSD" ,
615
+ "IC03" ,
616
+ "IC13" ,
617
+ "IIIT5K" ,
618
+ "IMGUR5K" ,
619
+ "SROIE" ,
620
+ "SVHN" ,
621
+ "SVT" ,
622
+ "WILDRECEIPT" ,
623
+ ],
624
+ default = None ,
625
+ help = "Builtin dataset names (choose from: COCOTEXT, CORD, FUNSD, IC03, IC13, IIIT5K, IMGUR5K, SROIE, SVHN, SVT, WILDRECEIPT)" ,
626
+ ),
627
+ )
628
+ (
629
+ parser .add_argument (
630
+ "--val_datasets" ,
631
+ type = str ,
632
+ nargs = "+" ,
633
+ choices = [
634
+ "COCOTEXT" ,
635
+ "CORD" ,
636
+ "FUNSD" ,
637
+ "IC03" ,
638
+ "IC13" ,
639
+ "IIIT5K" ,
640
+ "IMGUR5K" ,
641
+ "SROIE" ,
642
+ "SVHN" ,
643
+ "SVT" ,
644
+ "WILDRECEIPT" ,
645
+ ],
646
+ default = None ,
647
+ help = "Builtin dataset names (choose from: COCOTEXT, CORD, FUNSD, IC03, IC13, IIIT5K, IMGUR5K, SROIE, SVHN, SVT, WILDRECEIPT)" ,
648
+ ),
649
+ )
548
650
parser .add_argument (
549
651
"--font" , type = str , default = "FreeMono.ttf,FreeSans.ttf,FreeSerif.ttf" , help = "Font family to be used"
550
652
)
0 commit comments