From 85ceae2b11834aa7e6f430e7f501ffb1a134f45c Mon Sep 17 00:00:00 2001 From: Ren Tianhe <48727989+rentainhe@users.noreply.github.com> Date: Fri, 2 Jun 2023 17:24:09 +0800 Subject: [PATCH] Bump version to 0.4.0 (#258) * release better dino result * refine * fix deta config * add dino better hyper * add DINO pretrained weights links * bump to v0.4.0 --- README.md | 16 +- changlog.md | 7 + projects/deformable_detr/train_net.py | 274 ++++++++++++++++++ projects/deta/configs/deta_r50_5scale_12ep.py | 1 + projects/dino/README.md | 81 +++++- .../dino_r50_4scale_12ep_better_hyper.py | 38 +++ .../dino_r50_4scale_12ep_no_frozen.py | 12 + 7 files changed, 414 insertions(+), 15 deletions(-) create mode 100644 projects/deformable_detr/train_net.py create mode 100644 projects/dino/configs/dino-resnet/dino_r50_4scale_12ep_better_hyper.py create mode 100644 projects/dino/configs/dino-resnet/dino_r50_4scale_12ep_no_frozen.py diff --git a/README.md b/README.md index 141f7db2..f6b197dd 100644 --- a/README.md +++ b/README.md @@ -60,15 +60,13 @@ The repo name detrex has several interpretations: - de-t.rex : de means 'the' in Dutch. T.rex, also called Tyrannosaurus Rex, means 'king of the tyrant lizards' and connects to our research work 'DINO', which is short for Dinosaur. ## What's New -v0.3.0 was released on 03/17/2023: -- Support new algorithms including `Anchor-DETR` and `DETA`. -- Release more than 10+ pretrained models (including the converted weights): `DETR-R50 & R101`, `DETR-R50 & R101-DC5`, `DAB-DETR-R50 & R101-DC5`, `DAB-DETR-R50-3patterns`, `Conditional-DETR-R50 & R101-DC5`, `DN-DETR-R50-DC5`, `Anchor-DETR` and the `DETA-Swin-o365-finetune` model which can achieve **`62.9AP`** on coco val. -- Support **MaskDINO** on ADE20k semantic segmentation task. -- Support `EMAHook` during training by setting `train.model_ema.enabled=True`, which can enhance the model performance. DINO with EMA can achieve **`49.4AP`** with only 12epoch training. -- Support mixed precision training by setting `train.amp.enabled=True`, which will **reduce 20% to 30% GPU memory usage**. -- Support `train.fast_dev_run=True` for **fast debugging**. -- Support **encoder-decoder checkpoint** in DINO, which may reduce **30% GPU** memory usage. -- Support a great `slurm training scripts` by @rayleizhu, please check this issue for more details [#213](https://github.com/IDEA-Research/detrex/issues/213) +v0.4.0 was released on 02/06/2023: +- Support [CO-MOT](./projects/co_mot/) aims for End-to-End Multi-Object Tracking by [Feng Yan](https://scholar.google.com/citations?user=gO4divAAAAAJ&hl=zh-CN&oi=sra). +- Release `DINO` with optimized hyper-parameters which achieves `50.0 AP` under 1x settings. +- Release pretrained DINO based on `InternImage`, `ConvNeXt-1K pretrained` backbones. +- Release `Deformable-DETR-R50` pretrained weights. +- Release `DETA` and better `H-DETR` pretrained weights: achieving `50.2 AP` and `49.1 AP` respectively. + Please see [changelog.md](./changlog.md) for details and release history. diff --git a/changlog.md b/changlog.md index a6e6f653..2a79e448 100644 --- a/changlog.md +++ b/changlog.md @@ -1,5 +1,12 @@ ## Change Log +### v0.4.0 (02/06/2023): +- Support [CO-MOT](./projects/co_mot/) aims for End-to-End Multi-Object Tracking by [Feng Yan](https://scholar.google.com/citations?user=gO4divAAAAAJ&hl=zh-CN&oi=sra). +- Release `DINO` with optimized hyper-parameters which achieves `50.0 AP` under 1x settings. +- Release pretrained DINO based on `InternImage`, `ConvNeXt-1K pretrained` backbones. +- Release `Deformable-DETR-R50` pretrained weights. +- Release `DETA` and better `H-DETR` pretrained weights: achieving `50.2 AP` and `49.1 AP` respectively. + ### v0.3.0 (17/03/2023) - Support new algorithms including `Anchor-DETR` and `DETA`. - Release more than 10+ pretrained models (including the converted weights): `DETR-R50 & R101`, `DETR-R50 & R101-DC5`, `DAB-DETR-R50 & R101-DC5`, `DAB-DETR-R50-3patterns`, `Conditional-DETR-R50 & R101-DC5`, `DN-DETR-R50-DC5`, `Anchor-DETR` and the `DETA-Swin-o365-finetune` model which can achieve **`62.9AP`** on coco val. diff --git a/projects/deformable_detr/train_net.py b/projects/deformable_detr/train_net.py new file mode 100644 index 00000000..4a69cea3 --- /dev/null +++ b/projects/deformable_detr/train_net.py @@ -0,0 +1,274 @@ +#!/usr/bin/env python +# Copyright (c) Facebook, Inc. and its affiliates. +""" +Training script using the new "LazyConfig" python config files. + +This scripts reads a given python config file and runs the training or evaluation. +It can be used to train any models or dataset as long as they can be +instantiated by the recursive construction defined in the given config file. + +Besides lazy construction of models, dataloader, etc., this scripts expects a +few common configuration parameters currently defined in "configs/common/train.py". +To add more complicated training logic, you can easily add other configs +in the config file and implement a new train_net.py to handle them. +""" +import logging +import os +import sys +import time +import torch +from torch.nn.parallel import DataParallel, DistributedDataParallel + +from detectron2.checkpoint import DetectionCheckpointer +from detectron2.config import LazyConfig, instantiate +from detectron2.engine import ( + SimpleTrainer, + default_argument_parser, + default_setup, + default_writers, + hooks, + launch, +) +from detectron2.engine.defaults import create_ddp_model +from detectron2.evaluation import inference_on_dataset, print_csv_format +from detectron2.utils import comm + +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir))) + +logger = logging.getLogger("detrex") + + +def match_name_keywords(n, name_keywords): + out = False + for b in name_keywords: + if b in n: + out = True + break + return out + + +class Trainer(SimpleTrainer): + """ + We've combine Simple and AMP Trainer together. + """ + + def __init__( + self, + model, + dataloader, + optimizer, + amp=False, + clip_grad_params=None, + grad_scaler=None, + ): + super().__init__(model=model, data_loader=dataloader, optimizer=optimizer) + + unsupported = "AMPTrainer does not support single-process multi-device training!" + if isinstance(model, DistributedDataParallel): + assert not (model.device_ids and len(model.device_ids) > 1), unsupported + assert not isinstance(model, DataParallel), unsupported + + if amp: + if grad_scaler is None: + from torch.cuda.amp import GradScaler + + grad_scaler = GradScaler() + self.grad_scaler = grad_scaler + + # set True to use amp training + self.amp = amp + + # gradient clip hyper-params + self.clip_grad_params = clip_grad_params + + def run_step(self): + """ + Implement the standard training logic described above. + """ + assert self.model.training, "[Trainer] model was changed to eval mode!" + assert torch.cuda.is_available(), "[Trainer] CUDA is required for AMP training!" + from torch.cuda.amp import autocast + + start = time.perf_counter() + """ + If you want to do something with the data, you can wrap the dataloader. + """ + data = next(self._data_loader_iter) + data_time = time.perf_counter() - start + + """ + If you want to do something with the losses, you can wrap the model. + """ + loss_dict = self.model(data) + with autocast(enabled=self.amp): + if isinstance(loss_dict, torch.Tensor): + losses = loss_dict + loss_dict = {"total_loss": loss_dict} + else: + losses = sum(loss_dict.values()) + + """ + If you need to accumulate gradients or do something similar, you can + wrap the optimizer with your custom `zero_grad()` method. + """ + self.optimizer.zero_grad() + + if self.amp: + self.grad_scaler.scale(losses).backward() + if self.clip_grad_params is not None: + self.grad_scaler.unscale_(self.optimizer) + self.clip_grads(self.model.parameters()) + self.grad_scaler.step(self.optimizer) + self.grad_scaler.update() + else: + losses.backward() + if self.clip_grad_params is not None: + self.clip_grads(self.model.parameters()) + self.optimizer.step() + + self._write_metrics(loss_dict, data_time) + + def clip_grads(self, params): + params = list(filter(lambda p: p.requires_grad and p.grad is not None, params)) + if len(params) > 0: + return torch.nn.utils.clip_grad_norm_( + parameters=params, + **self.clip_grad_params, + ) + + +def do_test(cfg, model): + if "evaluator" in cfg.dataloader: + ret = inference_on_dataset( + model, instantiate(cfg.dataloader.test), instantiate(cfg.dataloader.evaluator) + ) + print_csv_format(ret) + return ret + + +def do_train(args, cfg): + """ + Args: + cfg: an object with the following attributes: + model: instantiate to a module + dataloader.{train,test}: instantiate to dataloaders + dataloader.evaluator: instantiate to evaluator for test set + optimizer: instantaite to an optimizer + lr_multiplier: instantiate to a fvcore scheduler + train: other misc config defined in `configs/common/train.py`, including: + output_dir (str) + init_checkpoint (str) + amp.enabled (bool) + max_iter (int) + eval_period, log_period (int) + device (str) + checkpointer (dict) + ddp (dict) + """ + model = instantiate(cfg.model) + logger = logging.getLogger("detectron2") + logger.info("Model:\n{}".format(model)) + model.to(cfg.train.device) + + # this is an hack of train_net + param_dicts = [ + { + "params": [ + p + for n, p in model.named_parameters() + if not match_name_keywords(n, ["backbone"]) + and not match_name_keywords(n, ["reference_points", "sampling_offsets"]) + and p.requires_grad + ], + "lr": 2e-4, + }, + { + "params": [ + p + for n, p in model.named_parameters() + if match_name_keywords(n, ["backbone"]) and p.requires_grad + ], + "lr": 2e-5, + }, + { + "params": [ + p + for n, p in model.named_parameters() + if match_name_keywords(n, ["reference_points", "sampling_offsets"]) + and p.requires_grad + ], + "lr": 2e-5, + }, + ] + optim = torch.optim.AdamW(param_dicts, 2e-4, weight_decay=1e-4) + + train_loader = instantiate(cfg.dataloader.train) + + model = create_ddp_model(model, **cfg.train.ddp) + + trainer = Trainer( + model=model, + dataloader=train_loader, + optimizer=optim, + amp=cfg.train.amp.enabled, + clip_grad_params=cfg.train.clip_grad.params if cfg.train.clip_grad.enabled else None, + ) + + checkpointer = DetectionCheckpointer( + model, + cfg.train.output_dir, + trainer=trainer, + ) + + trainer.register_hooks( + [ + hooks.IterationTimer(), + hooks.LRScheduler(scheduler=instantiate(cfg.lr_multiplier)), + hooks.PeriodicCheckpointer(checkpointer, **cfg.train.checkpointer) + if comm.is_main_process() + else None, + hooks.EvalHook(cfg.train.eval_period, lambda: do_test(cfg, model)), + hooks.PeriodicWriter( + default_writers(cfg.train.output_dir, cfg.train.max_iter), + period=cfg.train.log_period, + ) + if comm.is_main_process() + else None, + ] + ) + + checkpointer.resume_or_load(cfg.train.init_checkpoint, resume=args.resume) + if args.resume and checkpointer.has_checkpoint(): + # The checkpoint stores the training iteration that just finished, thus we start + # at the next iteration + start_iter = trainer.iter + 1 + else: + start_iter = 0 + trainer.train(start_iter, cfg.train.max_iter) + + +def main(args): + cfg = LazyConfig.load(args.config_file) + cfg = LazyConfig.apply_overrides(cfg, args.opts) + default_setup(cfg, args) + + if args.eval_only: + model = instantiate(cfg.model) + model.to(cfg.train.device) + model = create_ddp_model(model) + DetectionCheckpointer(model).load(cfg.train.init_checkpoint) + print(do_test(cfg, model)) + else: + do_train(args, cfg) + + +if __name__ == "__main__": + args = default_argument_parser().parse_args() + launch( + main, + args.num_gpus, + num_machines=args.num_machines, + machine_rank=args.machine_rank, + dist_url=args.dist_url, + args=(args,), + ) diff --git a/projects/deta/configs/deta_r50_5scale_12ep.py b/projects/deta/configs/deta_r50_5scale_12ep.py index 327db20d..b7b5296e 100644 --- a/projects/deta/configs/deta_r50_5scale_12ep.py +++ b/projects/deta/configs/deta_r50_5scale_12ep.py @@ -4,6 +4,7 @@ # using the default optimizer and dataloader dataloader = get_config("common/data/coco_detr.py").dataloader +optimizer = get_config("common/optim.py").AdamW train = get_config("common/train.py").train # modify training config diff --git a/projects/dino/README.md b/projects/dino/README.md index 15f80f3b..006c5bff 100644 --- a/projects/dino/README.md +++ b/projects/dino/README.md @@ -8,6 +8,19 @@ Hao Zhang, Feng Li, Shilong Liu, Lei Zhang, Hang Su, Jun Zhu, Lionel M. Ni, Heun
+## Table of Contents +- [DINO with modified training engine](#dino-with-modified-training-engine) +- [Main Results with Pretrained Models](#main-results-with-pretrained-models) + - [DINO with ResNet Backbone](#pretrained-dino-with-resnet-backbone) + - [DINO with Swin-Transformer Backbone](#pretrained-dino-with-swin-transformer-backbone) + - [DINO with ViT Backbone](#pretrained-dino-with-vit-backbone) + - [DINO with ConvNeXt Backbone](#pretrained-dino-with-convnext-backbone) + - [DINO with FocalNet Backbone](#pretrained-dino-with-focalnet-backbone) + - [DINO with InternImage Backbone](#pretrained-dino-with-internimage-backbone) +- [Training DINO](#training) +- [Evaluate DINO](#evaluation) +- [Citation](#citing-dino) + ## DINO with modified training engine We've provide a hacked [train_net.py](./train_net.py) which aligns the optimizer params with Deformable-DETR that can achieve a better result on DINO models. @@ -31,6 +44,22 @@ We've provide a hacked [train_net.py](./train_net.py) which aligns the optimizer 49.4 model +DINO-R50-4scale (hacked trainer) +R-50 +IN1k +12 +100 +49.8 + model + + DINO-R50-4scale (hacked trainer) +R-50 +IN1k +12 +300 +50.0 + model + - Training model with hacked trainer @@ -40,7 +69,7 @@ python projects/dino/train_net.py --config-file /path/to/config.py --num-gpus 8 ## Main Results with Pretrained Models -**Pretrained DINO with ResNet Backbone** +##### Pretrained DINO with ResNet Backbone @@ -116,7 +145,7 @@ python projects/dino/train_net.py --config-file /path/to/config.py --num-gpus 8
-**Pretrained DINO with Swin-Transformer Backbone** +##### Pretrained DINO with Swin-Transformer Backbone @@ -151,6 +180,14 @@ python projects/dino/train_net.py --config-file /path/to/config.py --num-gpus 8 + + + + + + + + @@ -206,7 +243,7 @@ python projects/dino/train_net.py --config-file /path/to/config.py --num-gpus 8
Name Backbone100 53.0 model
DINO-Swin-S-224-4scaleSwin-Small-224IN22K to IN1K1210054.5 model
DINO-Swin-B-384-4scale
-**Pretrained DINO with FocalNet Backbone** +##### Pretrained DINO with FocalNet Backbone @@ -250,7 +287,7 @@ python projects/dino/train_net.py --config-file /path/to/config.py --num-gpus 8
Name Backbone
-**Pretrained DINO with ViT Backbone** +##### Pretrained DINO with ViT Backbone @@ -293,7 +330,7 @@ python projects/dino/train_net.py --config-file /path/to/config.py --num-gpus 8
Name Backbone
-**Pretrained DINO with ConvNeXt Backbone** +##### Pretrained DINO with ConvNeXt Backbone @@ -304,6 +341,14 @@ python projects/dino/train_net.py --config-file /path/to/config.py --num-gpus 8 + + + + + + + + @@ -312,6 +357,14 @@ python projects/dino/train_net.py --config-file /path/to/config.py --num-gpus 8 + + + + + + + + @@ -320,6 +373,14 @@ python projects/dino/train_net.py --config-file /path/to/config.py --num-gpus 8 + + + + + + + + @@ -328,6 +389,14 @@ python projects/dino/train_net.py --config-file /path/to/config.py --num-gpus 8 + + + + + + + + @@ -337,7 +406,7 @@ python projects/dino/train_net.py --config-file /path/to/config.py --num-gpus 8
Name Backbonedownload
DINO-ConvNeXt-Tiny-384-4scale ConvNeXt-Tiny-384IN1K1210051.4 model
DINO-ConvNeXt-Tiny-384-4scaleConvNeXt-Tiny-384 IN22k 12 100
DINO-ConvNeXt-Small-384-4scale ConvNeXt-Small-384IN1K1210052.0 model
DINO-ConvNeXt-Small-384-4scaleConvNeXt-Small-384 IN22k 12 100
DINO-ConvNeXt-Base-384-4scale ConvNeXt-Base-384IN1K1210052.6 model
DINO-ConvNeXt-Base-384-4scaleConvNeXt-Base-384 IN22k 12 100
DINO-ConvNeXt-Large-384-4scale ConvNeXt-Large-384IN1K1210053.4 model
DINO-ConvNeXt-Large-384-4scaleConvNeXt-Large-384 IN22k 12 100
-**Pretrained DINO with InternImage Backbone** +##### Pretrained DINO with InternImage Backbone diff --git a/projects/dino/configs/dino-resnet/dino_r50_4scale_12ep_better_hyper.py b/projects/dino/configs/dino-resnet/dino_r50_4scale_12ep_better_hyper.py new file mode 100644 index 00000000..8b8ad85a --- /dev/null +++ b/projects/dino/configs/dino-resnet/dino_r50_4scale_12ep_better_hyper.py @@ -0,0 +1,38 @@ +import copy +from .dino_r50_4scale_12ep import ( + train, + dataloader, + optimizer, + lr_multiplier, + model, +) + +# no frozen backbone get better results +model.backbone.freeze_at = -1 + +# more dn queries, set 300 here +model.dn_number = 300 + +# use 2.0 for class weight +model.criterion.weight_dict = { + "loss_class": 2.0, + "loss_bbox": 5.0, + "loss_giou": 2.0, + "loss_class_dn": 1, + "loss_bbox_dn": 5.0, + "loss_giou_dn": 2.0, +} + +# set aux loss weight dict +base_weight_dict = copy.deepcopy(model.criterion.weight_dict) +if model.aux_loss: + weight_dict = model.criterion.weight_dict + aux_weight_dict = {} + aux_weight_dict.update({k + "_enc": v for k, v in base_weight_dict.items()}) + for i in range(model.transformer.decoder.num_layers - 1): + aux_weight_dict.update({k + f"_{i}": v for k, v in base_weight_dict.items()}) + weight_dict.update(aux_weight_dict) + model.criterion.weight_dict = weight_dict + +# output dir +train.output_dir = "./output/dino_r50_4scale_12ep_better_hyper" \ No newline at end of file diff --git a/projects/dino/configs/dino-resnet/dino_r50_4scale_12ep_no_frozen.py b/projects/dino/configs/dino-resnet/dino_r50_4scale_12ep_no_frozen.py new file mode 100644 index 00000000..0931ee92 --- /dev/null +++ b/projects/dino/configs/dino-resnet/dino_r50_4scale_12ep_no_frozen.py @@ -0,0 +1,12 @@ +from .dino_r50_4scale_12ep import ( + train, + dataloader, + optimizer, + lr_multiplier, + model, +) + +# no frozen backbone get better results +model.backbone.freeze_at = -1 + +train.output_dir = "./output/dino_r50_4scale_12ep_no_frozen_backbone" \ No newline at end of file
Name Backbone