Skip to content

Commit f428cbb

Browse files
bhavya-workJuneja Sarjil
authored and
Juneja Sarjil
committed
mindee#1830 fixed: able to use the builtin dataset
1 parent 22b2260 commit f428cbb

File tree

2 files changed

+105
-4
lines changed

2 files changed

+105
-4
lines changed

doctr/__init__.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,2 @@
11
from . import io, models, datasets, contrib, transforms, utils
2-
from .file_utils import is_tf_available, is_torch_available
3-
print('i am trying something with git')
2+
from .file_utils import is_tf_available, is_torch_available

references/recognition/train_pytorch.py

+104-2
Original file line numberDiff line numberDiff line change
@@ -31,12 +31,19 @@
3131
else:
3232
from tqdm.auto import tqdm
3333

34+
from doctr import datasets
3435
from doctr import transforms as T
3536
from doctr.datasets import VOCABS, RecognitionDataset, WordGenerator
3637
from doctr.models import login_to_hub, push_to_hf_hub, recognition
3738
from doctr.utils.metrics import TextMatch
3839
from utils import EarlyStopper, plot_recorder, plot_samples
3940

41+
# dataset_map = {
42+
# "SVHN":SVHN,
43+
# "CORD":CORD,
44+
# "FUNSD": FUNSD
45+
# }
46+
4047

4148
def record_lr(
4249
model: torch.nn.Module,
@@ -220,8 +227,32 @@ def main(args):
220227
labels_path=os.path.join(args.val_path, "labels.json"),
221228
img_transforms=T.Resize((args.input_size, 4 * args.input_size), preserve_aspect_ratio=True),
222229
)
230+
elif args.val_datasets:
231+
val_datasets = args.val_datasets
232+
233+
val_set = datasets.__dict__[val_datasets[0]](
234+
train=False,
235+
download=True,
236+
recognition_task=True,
237+
use_polygons=True,
238+
img_transforms=Compose([
239+
T.Resize((args.input_size, 4 * args.input_size), preserve_aspect_ratio=True),
240+
# Augmentations
241+
T.RandomApply(T.ColorInversion(), 0.1),
242+
]),
243+
)
244+
if len(val_datasets) > 1:
245+
for dataset_name in val_datasets[1:]:
246+
_ds = datasets.__dict__[dataset_name](
247+
train=False,
248+
download=True,
249+
recognition_task=True,
250+
use_polygons=True,
251+
)
252+
253+
val_set.data.extend((np_img, target) for np_img, target in _ds.data)
254+
223255
else:
224-
val_hash = None
225256
# Load synthetic data generator
226257
val_set = WordGenerator(
227258
vocab=vocab,
@@ -235,6 +266,7 @@ def main(args):
235266
T.RandomApply(T.ColorInversion(), 0.9),
236267
]),
237268
)
269+
val_hash = None
238270

239271
val_loader = DataLoader(
240272
val_set,
@@ -315,8 +347,32 @@ def main(args):
315347
train_set.merge_dataset(
316348
RecognitionDataset(subfolder.joinpath("images"), subfolder.joinpath("labels.json"))
317349
)
350+
351+
elif args.train_datasets:
352+
train_datasets = args.train_datasets
353+
354+
train_set = datasets.__dict__[train_datasets[0]](
355+
train=True,
356+
download=True,
357+
recognition_task=True,
358+
use_polygons=True,
359+
img_transforms=Compose([
360+
T.Resize((args.input_size, 4 * args.input_size), preserve_aspect_ratio=True),
361+
# Augmentations
362+
T.RandomApply(T.ColorInversion(), 0.1),
363+
]),
364+
)
365+
if len(train_datasets) > 1:
366+
for dataset_name in train_datasets[1:]:
367+
_ds = datasets.__dict__[dataset_name](
368+
train=True,
369+
download=True,
370+
recognition_task=True,
371+
use_polygons=True,
372+
)
373+
train_set.data.extend((np_img, target) for np_img, target in _ds.data)
374+
318375
else:
319-
train_hash = None
320376
# Load synthetic data generator
321377
train_set = WordGenerator(
322378
vocab=vocab,
@@ -348,6 +404,8 @@ def main(args):
348404
)
349405
pbar.write(f"Train set loaded in {time.time() - st:.4}s ({len(train_set)} samples in {len(train_loader)} batches)")
350406

407+
train_hash = None
408+
351409
if args.show_samples:
352410
x, target = next(iter(train_loader))
353411
plot_samples(x, target)
@@ -545,6 +603,50 @@ def parse_args():
545603
default=20,
546604
help="Multiplied by the vocab length gets you the number of synthetic validation samples that will be used.",
547605
)
606+
(
607+
parser.add_argument(
608+
"--train_datasets",
609+
type=str,
610+
nargs="+",
611+
choices=[
612+
"COCOTEXT",
613+
"CORD",
614+
"FUNSD",
615+
"IC03",
616+
"IC13",
617+
"IIIT5K",
618+
"IMGUR5K",
619+
"SROIE",
620+
"SVHN",
621+
"SVT",
622+
"WILDRECEIPT",
623+
],
624+
default=None,
625+
help="Builtin dataset names (choose from: COCOTEXT, CORD, FUNSD, IC03, IC13, IIIT5K, IMGUR5K, SROIE, SVHN, SVT, WILDRECEIPT)",
626+
),
627+
)
628+
(
629+
parser.add_argument(
630+
"--val_datasets",
631+
type=str,
632+
nargs="+",
633+
choices=[
634+
"COCOTEXT",
635+
"CORD",
636+
"FUNSD",
637+
"IC03",
638+
"IC13",
639+
"IIIT5K",
640+
"IMGUR5K",
641+
"SROIE",
642+
"SVHN",
643+
"SVT",
644+
"WILDRECEIPT",
645+
],
646+
default=None,
647+
help="Builtin dataset names (choose from: COCOTEXT, CORD, FUNSD, IC03, IC13, IIIT5K, IMGUR5K, SROIE, SVHN, SVT, WILDRECEIPT)",
648+
),
649+
)
548650
parser.add_argument(
549651
"--font", type=str, default="FreeMono.ttf,FreeSans.ttf,FreeSerif.ttf", help="Font family to be used"
550652
)

0 commit comments

Comments
 (0)