Skip to content
This repository has been archived by the owner on Mar 12, 2024. It is now read-only.

Added modification based on customized dataset #397

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions datasets/ucsd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
"""
For UCSD dataset.

"""
from pathlib import Path

from .coco import CocoDetection, make_coco_transforms

def build(image_set, args):
"""
image_set = 'train' / 'val'
category = "all/tflt/"
"""
root = Path(args.coco_path)
assert root.exists(), f'provided path {root} to custom dataset does not exist'

if args.annotation_name is None:
raise ValueError("args doesn't have annotation_name")
else:
PATHS = {
"train": (root / "train2017", root / "annotations" / f'annotations_train_{args.annotation_name}.json'),
"val": (root / "val2017", root / "annotations" / f'annotations_val_{args.annotation_name}.json'),
}

img_folder, ann_file = PATHS[image_set]
dataset = CocoDetection(img_folder, ann_file, transforms=make_coco_transforms(image_set), return_masks=args.masks)
return dataset
29 changes: 23 additions & 6 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,29 @@

import datasets
import util.misc as utils
from datasets import build_dataset, get_coco_api_from_dataset
from datasets import get_coco_api_from_dataset
from datasets.ucsd import build as build_dataset
from engine import evaluate, train_one_epoch
from models import build_model


def get_args_parser():
parser = argparse.ArgumentParser('Set transformer detector', add_help=False)
# added for UCSD
parser.add_argument('--num_classes', default=None, type=int,
help='num of classes in your dataset, which can override the value hard-coded in file models/detr.py')
parser.add_argument('--annotation_name', default=None, type=str, action='store_true',
help='all / tfsg / veh / tflt')
parser.add_argument('--gpu_id', default=None, type=int,
help="specify GPU ID")

# original
parser.add_argument('--lr', default=1e-4, type=float)
parser.add_argument('--lr_backbone', default=1e-5, type=float)
parser.add_argument('--batch_size', default=2, type=int)
parser.add_argument('--weight_decay', default=1e-4, type=float)
parser.add_argument('--epochs', default=300, type=int)
parser.add_argument('--lr_drop', default=200, type=int)
parser.add_argument('--lr_drop', default=100, type=int)
parser.add_argument('--clip_max_norm', default=0.1, type=float,
help='gradient clipping max norm')

Expand Down Expand Up @@ -111,6 +121,7 @@ def main(args):
print(args)

device = torch.device(args.device)
torch.cuda.set_device(args.gpu_id)

# fix the seed for reproducibility
seed = args.seed + utils.get_rank()
Expand All @@ -121,10 +132,14 @@ def main(args):
model, criterion, postprocessors = build_model(args)
model.to(device)

model_without_ddp = model
if args.distributed:
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
model_without_ddp = model.module
else:
print("Data parallel mode activated!")
model = torch.nn.DataParallel(model, device_ids=[args.gpu_id])
model = model.to(device)
model_without_ddp = model
n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
print('number of params:', n_parameters)

Expand Down Expand Up @@ -200,11 +215,11 @@ def main(args):
if args.output_dir:
checkpoint_paths = [output_dir / 'checkpoint.pth']
# extra checkpoint before LR drop and every 100 epochs
if (epoch + 1) % args.lr_drop == 0 or (epoch + 1) % 100 == 0:
if (epoch + 1) % args.lr_drop == 0 or (epoch + 1) % 10 == 0:
checkpoint_paths.append(output_dir / f'checkpoint{epoch:04}.pth')
for checkpoint_path in checkpoint_paths:
utils.save_on_master({
'model': model_without_ddp.state_dict(),
'model': model.state_dict(),
'optimizer': optimizer.state_dict(),
'lr_scheduler': lr_scheduler.state_dict(),
'epoch': epoch,
Expand All @@ -229,7 +244,7 @@ def main(args):
(output_dir / 'eval').mkdir(exist_ok=True)
if "bbox" in coco_evaluator.coco_eval:
filenames = ['latest.pth']
if epoch % 50 == 0:
if epoch % 10 == 0:
filenames.append(f'{epoch:03}.pth')
for name in filenames:
torch.save(coco_evaluator.coco_eval["bbox"].eval,
Expand All @@ -246,3 +261,5 @@ def main(args):
if args.output_dir:
Path(args.output_dir).mkdir(parents=True, exist_ok=True)
main(args)

## python -m torch.distributed.launch --nproc_per_node=8 --use_env main.py --coco_path /path/to/coco
8 changes: 0 additions & 8 deletions models/detr.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,14 +302,6 @@ def forward(self, x):


def build(args):
# the `num_classes` naming here is somewhat misleading.
# it indeed corresponds to `max_obj_id + 1`, where max_obj_id
# is the maximum id for a class in your dataset. For example,
# COCO has a max_obj_id of 90, so we pass `num_classes` to be 91.
# As another example, for a dataset that has a single class with id 1,
# you should pass `num_classes` to be 2 (max_obj_id + 1).
# For more details on this, check the following discussion
# https://github.com/facebookresearch/detr/issues/108#issuecomment-650269223
num_classes = 20 if args.dataset_file != 'coco' else 91
if args.dataset_file == "coco_panoptic":
# for panoptic, we just add a num_classes that is large enough to hold
Expand Down