diff --git a/README.md b/README.md index dc5c6cdeaa1..78d56fc1293 100644 --- a/README.md +++ b/README.md @@ -86,13 +86,10 @@ https://github.com/open-mmlab/mmpretrain/assets/26739999/e4dcd3a2-f895-4d1b-a351 ## What's new -🌟 v1.0.2 was released in 15/08/2023 +🌟 v1.1.0 was released in 12/10/2023 -Support [MFF](./configs/mff/) self-supervised algorithm and enhance the codebase. More details can be found in the [changelog](https://mmpretrain.readthedocs.io/en/latest/notes/changelog.html). - -🌟 v1.0.1 was released in 28/07/2023 - -Fix some bugs and enhance the codebase. Please refer to [changelog](https://mmpretrain.readthedocs.io/en/latest/notes/changelog.html) for more details. +- Support Mini-GPT4 training and provide a Chinese model (based on Baichuan-7B) +- Support zero-shot classification based on CLIP. 🌟 v1.0.0 was released in 04/07/2023 diff --git a/README_zh-CN.md b/README_zh-CN.md index 801d3183982..06daeb1ce97 100644 --- a/README_zh-CN.md +++ b/README_zh-CN.md @@ -84,13 +84,10 @@ https://github.com/open-mmlab/mmpretrain/assets/26739999/e4dcd3a2-f895-4d1b-a351 ## 更新日志 -🌟 2023/8/15 发布了 v1.0.2 版本 +🌟 2023/10/12 发布了 v1.1.0 版本 -支持了 [MFF](./configs/mff/) 自监督算法,增强算法库功能。细节请参考 [更新日志](https://mmpretrain.readthedocs.io/zh_CN/latest/notes/changelog.html)。 - -🌟 2023/7/28 发布了 v1.0.1 版本 - -修复部分 bug 和增强算法库功能。细节请参考 [更新日志](https://mmpretrain.readthedocs.io/zh_CN/latest/notes/changelog.html)。 +- 支持 Mini-GPT4 训练并提供一个基于 Baichuan-7B 的中文模型 +- 支持基于 CLIP 的零样本分类。 🌟 2023/7/4 发布了 v1.0.0 版本 diff --git a/configs/clip/clip_vit-base-p16_zeroshot-cls_cifar100.py b/configs/clip/clip_vit-base-p16_zeroshot-cls_cifar100.py new file mode 100644 index 00000000000..dd684a50a31 --- /dev/null +++ b/configs/clip/clip_vit-base-p16_zeroshot-cls_cifar100.py @@ -0,0 +1,68 @@ +_base_ = '../_base_/default_runtime.py' + +# data settings +data_preprocessor = dict( + type='MultiModalDataPreprocessor', + mean=[0.48145466 * 255, 0.4578275 * 255, 0.40821073 * 255], + std=[0.26862954 * 255, 0.26130258 * 255, 0.27577711 * 255], + to_rgb=False, +) + +test_pipeline = [ + dict(type='Resize', scale=(224, 224), interpolation='bicubic'), + dict( + type='PackInputs', + algorithm_keys=['text'], + meta_keys=['image_id', 'scale_factor'], + ), +] + +train_dataloader = None +test_dataloader = dict( + batch_size=32, + num_workers=8, + dataset=dict( + type='CIFAR100', + data_root='data/cifar100', + split='test', + pipeline=test_pipeline), + sampler=dict(type='DefaultSampler', shuffle=False), +) +test_evaluator = dict(type='Accuracy', topk=(1, 5)) + +# schedule settings +train_cfg = None +val_cfg = None +test_cfg = dict() + +# model settings +model = dict( + type='CLIPZeroShot', + vision_backbone=dict( + type='VisionTransformer', + arch='base', + img_size=224, + patch_size=16, + drop_rate=0., + layer_cfgs=dict(act_cfg=dict(type='QuickGELU')), + pre_norm=True, + ), + projection=dict(type='CLIPProjection', in_channels=768, out_channels=512), + text_backbone=dict( + type='CLIPTransformer', + width=512, + layers=12, + heads=8, + attn_mask=True, + ), + tokenizer=dict( + type='AutoTokenizer', + name_or_path='openai/clip-vit-base-patch16', + use_fast=False), + vocab_size=49408, + transformer_width=512, + proj_dim=512, + text_prototype='cifar100', + text_prompt='openai_cifar100', + context_length=77, +) diff --git a/configs/clip/clip_vit-base-p16_zeroshot-cls_in1k.py b/configs/clip/clip_vit-base-p16_zeroshot-cls_in1k.py new file mode 100644 index 00000000000..80c4fde82f5 --- /dev/null +++ b/configs/clip/clip_vit-base-p16_zeroshot-cls_in1k.py @@ -0,0 +1,69 @@ +_base_ = '../_base_/default_runtime.py' + +# data settings +data_preprocessor = dict( + type='MultiModalDataPreprocessor', + mean=[0.48145466 * 255, 0.4578275 * 255, 0.40821073 * 255], + std=[0.26862954 * 255, 0.26130258 * 255, 0.27577711 * 255], + to_rgb=True, +) + +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='Resize', scale=(224, 224), interpolation='bicubic'), + dict( + type='PackInputs', + algorithm_keys=['text'], + meta_keys=['image_id', 'scale_factor'], + ), +] + +train_dataloader = None +test_dataloader = dict( + batch_size=32, + num_workers=8, + dataset=dict( + type='ImageNet', + data_root='data/imagenet', + split='val', + pipeline=test_pipeline), + sampler=dict(type='DefaultSampler', shuffle=False), +) +test_evaluator = dict(type='Accuracy', topk=(1, 5)) + +# schedule settings +train_cfg = None +val_cfg = None +test_cfg = dict() + +# model settings +model = dict( + type='CLIPZeroShot', + vision_backbone=dict( + type='VisionTransformer', + arch='base', + img_size=224, + patch_size=16, + drop_rate=0., + layer_cfgs=dict(act_cfg=dict(type='QuickGELU')), + pre_norm=True, + ), + projection=dict(type='CLIPProjection', in_channels=768, out_channels=512), + text_backbone=dict( + type='CLIPTransformer', + width=512, + layers=12, + heads=8, + attn_mask=True, + ), + tokenizer=dict( + type='AutoTokenizer', + name_or_path='openai/clip-vit-base-patch16', + use_fast=False), + vocab_size=49408, + transformer_width=512, + proj_dim=512, + text_prototype='imagenet', + text_prompt='openai_imagenet_sub', # openai_imagenet, openai_imagenet_sub + context_length=77, +) diff --git a/configs/clip/clip_vit-large-p14_zeroshot-cls_cifar100.py b/configs/clip/clip_vit-large-p14_zeroshot-cls_cifar100.py new file mode 100644 index 00000000000..a6dd7c11412 --- /dev/null +++ b/configs/clip/clip_vit-large-p14_zeroshot-cls_cifar100.py @@ -0,0 +1,68 @@ +_base_ = '../_base_/default_runtime.py' + +# data settings +data_preprocessor = dict( + type='MultiModalDataPreprocessor', + mean=[0.48145466 * 255, 0.4578275 * 255, 0.40821073 * 255], + std=[0.26862954 * 255, 0.26130258 * 255, 0.27577711 * 255], + to_rgb=False, +) + +test_pipeline = [ + dict(type='Resize', scale=(224, 224), interpolation='bicubic'), + dict( + type='PackInputs', + algorithm_keys=['text'], + meta_keys=['image_id', 'scale_factor'], + ), +] + +train_dataloader = None +test_dataloader = dict( + batch_size=32, + num_workers=8, + dataset=dict( + type='CIFAR100', + data_root='data/cifar100', + split='test', + pipeline=test_pipeline), + sampler=dict(type='DefaultSampler', shuffle=False), +) +test_evaluator = dict(type='Accuracy', topk=(1, 5)) + +# schedule settings +train_cfg = None +val_cfg = None +test_cfg = dict() + +# model settings +model = dict( + type='CLIPZeroShot', + vision_backbone=dict( + type='VisionTransformer', + arch='large', + img_size=224, + patch_size=14, + drop_rate=0., + layer_cfgs=dict(act_cfg=dict(type='QuickGELU')), + pre_norm=True, + ), + projection=dict(type='CLIPProjection', in_channels=1024, out_channels=768), + text_backbone=dict( + type='CLIPTransformer', + width=768, + layers=12, + heads=12, + attn_mask=True, + ), + tokenizer=dict( + type='AutoTokenizer', + name_or_path='openai/clip-vit-large-patch14', + use_fast=False), + vocab_size=49408, + transformer_width=768, + proj_dim=768, + text_prototype='cifar100', + text_prompt='openai_cifar100', + context_length=77, +) diff --git a/configs/clip/clip_vit-large-p14_zeroshot-cls_in1k.py b/configs/clip/clip_vit-large-p14_zeroshot-cls_in1k.py new file mode 100644 index 00000000000..10500017a93 --- /dev/null +++ b/configs/clip/clip_vit-large-p14_zeroshot-cls_in1k.py @@ -0,0 +1,69 @@ +_base_ = '../_base_/default_runtime.py' + +# data settings +data_preprocessor = dict( + type='MultiModalDataPreprocessor', + mean=[0.48145466 * 255, 0.4578275 * 255, 0.40821073 * 255], + std=[0.26862954 * 255, 0.26130258 * 255, 0.27577711 * 255], + to_rgb=True, +) + +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='Resize', scale=(224, 224), interpolation='bicubic'), + dict( + type='PackInputs', + algorithm_keys=['text'], + meta_keys=['image_id', 'scale_factor'], + ), +] + +train_dataloader = None +test_dataloader = dict( + batch_size=32, + num_workers=8, + dataset=dict( + type='ImageNet', + data_root='data/imagenet', + split='val', + pipeline=test_pipeline), + sampler=dict(type='DefaultSampler', shuffle=False), +) +test_evaluator = dict(type='Accuracy', topk=(1, 5)) + +# schedule settings +train_cfg = None +val_cfg = None +test_cfg = dict() + +# model settings +model = dict( + type='CLIPZeroShot', + vision_backbone=dict( + type='VisionTransformer', + arch='large', + img_size=224, + patch_size=14, + drop_rate=0., + layer_cfgs=dict(act_cfg=dict(type='QuickGELU')), + pre_norm=True, + ), + projection=dict(type='CLIPProjection', in_channels=1024, out_channels=768), + text_backbone=dict( + type='CLIPTransformer', + width=768, + layers=12, + heads=12, + attn_mask=True, + ), + tokenizer=dict( + type='AutoTokenizer', + name_or_path='openai/clip-vit-large-patch14', + use_fast=False), + vocab_size=49408, + transformer_width=768, + proj_dim=768, + text_prototype='imagenet', + text_prompt='openai_imagenet_sub', # openai_imagenet, openai_imagenet_sub + context_length=77, +) diff --git a/docker/serve/Dockerfile b/docker/serve/Dockerfile index bff871b722d..86df2926251 100644 --- a/docker/serve/Dockerfile +++ b/docker/serve/Dockerfile @@ -1,9 +1,9 @@ -ARG PYTORCH="1.12.1" -ARG CUDA="11.3" +ARG PYTORCH="2.0.1" +ARG CUDA="11.7" ARG CUDNN="8" FROM pytorch/torchserve:latest-gpu -ARG MMPRE="1.0.2" +ARG MMPRE="1.1.0" ENV PYTHONUNBUFFERED TRUE diff --git a/docs/en/notes/changelog.md b/docs/en/notes/changelog.md index f84d691aae7..7a8ab6808ad 100644 --- a/docs/en/notes/changelog.md +++ b/docs/en/notes/changelog.md @@ -1,5 +1,27 @@ # Changelog (MMPreTrain) +## v1.1.0(12/10/2023) + +### New Features + +- [Feature] Implement of Zero-Shot CLIP Classifier ([#1737](https://github.com/open-mmlab/mmpretrain/pull/1737)) +- [Feature] Add minigpt4 gradio demo and training script. ([#1758](https://github.com/open-mmlab/mmpretrain/pull/1758)) + +### Improvements + +- [Config] New Version of config Adapting MobileNet Algorithm ([#1774](https://github.com/open-mmlab/mmpretrain/pull/1774)) +- [Config] Support DINO self-supervised learning in project ([#1756](https://github.com/open-mmlab/mmpretrain/pull/1756)) +- [Config] New Version of config Adapting Swin Transformer Algorithm ([#1780](https://github.com/open-mmlab/mmpretrain/pull/1780)) +- [Enhance] Add iTPN Supports for Non-three channel image ([#1735](https://github.com/open-mmlab/mmpretrain/pull/1735)) +- [Docs] Update dataset download script from opendatalab to openXlab ([#1765](https://github.com/open-mmlab/mmpretrain/pull/1765)) +- [Docs] Update COCO-Retrieval dataset docs. ([#1806](https://github.com/open-mmlab/mmpretrain/pull/1806)) + +### Bug Fix + +- Update `train.py` to compat with new config. +- Update OFA module to compat with the latest huggingface. +- Fix pipeline bug in ImageRetrievalInferencer. + ## v1.0.2(15/08/2023) ### New Features diff --git a/docs/en/notes/faq.md b/docs/en/notes/faq.md index 9f78a04846c..dd0591142a3 100644 --- a/docs/en/notes/faq.md +++ b/docs/en/notes/faq.md @@ -16,7 +16,7 @@ and make sure you fill in all required information in the template. | MMPretrain version | MMEngine version | MMCV version | | :----------------: | :---------------: | :--------------: | - | 1.0.2 (main) | mmengine >= 0.8.3 | mmcv >= 2.0.0 | + | 1.1.0 (main) | mmengine >= 0.8.3 | mmcv >= 2.0.0 | | 1.0.0 | mmengine >= 0.8.0 | mmcv >= 2.0.0 | | 1.0.0rc8 | mmengine >= 0.7.1 | mmcv >= 2.0.0rc4 | | 1.0.0rc7 | mmengine >= 0.5.0 | mmcv >= 2.0.0rc4 | diff --git a/docs/zh_CN/notes/faq.md b/docs/zh_CN/notes/faq.md index efd2ff5e757..23ec5f50cb0 100644 --- a/docs/zh_CN/notes/faq.md +++ b/docs/zh_CN/notes/faq.md @@ -13,7 +13,7 @@ | MMPretrain 版本 | MMEngine 版本 | MMCV 版本 | | :-------------: | :---------------: | :--------------: | - | 1.0.2 (main) | mmengine >= 0.8.3 | mmcv >= 2.0.0 | + | 1.1.0 (main) | mmengine >= 0.8.3 | mmcv >= 2.0.0 | | 1.0.0 | mmengine >= 0.8.0 | mmcv >= 2.0.0 | | 1.0.0rc8 | mmengine >= 0.7.1 | mmcv >= 2.0.0rc4 | | 1.0.0rc7 | mmengine >= 0.5.0 | mmcv >= 2.0.0rc4 | diff --git a/mmpretrain/__init__.py b/mmpretrain/__init__.py index 0b0f573fe57..69c585bd26f 100644 --- a/mmpretrain/__init__.py +++ b/mmpretrain/__init__.py @@ -7,7 +7,7 @@ from .version import __version__ mmcv_minimum_version = '2.0.0' -mmcv_maximum_version = '2.1.0' +mmcv_maximum_version = '2.2.0' mmcv_version = digit_version(mmcv.__version__) mmengine_minimum_version = '0.8.3' diff --git a/mmpretrain/apis/image_retrieval.py b/mmpretrain/apis/image_retrieval.py index deae1de7975..27919b20f58 100644 --- a/mmpretrain/apis/image_retrieval.py +++ b/mmpretrain/apis/image_retrieval.py @@ -108,6 +108,7 @@ def build_dataloader(dataset): # A config of dataset from mmpretrain.registry import DATASETS test_pipeline = [dict(type='LoadImageFromFile'), self.pipeline] + prototype.setdefault('pipeline', test_pipeline) dataset = DATASETS.build(prototype) dataloader = build_dataloader(dataset) elif isinstance(prototype, DataLoader): diff --git a/mmpretrain/configs/_base_/datasets/cifar10_bs16.py b/mmpretrain/configs/_base_/datasets/cifar10_bs16.py new file mode 100644 index 00000000000..3737dbee9a6 --- /dev/null +++ b/mmpretrain/configs/_base_/datasets/cifar10_bs16.py @@ -0,0 +1,52 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmengine.dataset import DefaultSampler + +from mmpretrain.datasets import CIFAR10, PackInputs, RandomCrop, RandomFlip +from mmpretrain.evaluation import Accuracy + +# dataset settings +dataset_type = CIFAR10 +data_preprocessor = dict( + num_classes=10, + # RGB format normalization parameters + mean=[125.307, 122.961, 113.8575], + std=[51.5865, 50.847, 51.255], + # loaded images are already RGB format + to_rgb=False) + +train_pipeline = [ + dict(type=RandomCrop, crop_size=32, padding=4), + dict(type=RandomFlip, prob=0.5, direction='horizontal'), + dict(type=PackInputs), +] + +test_pipeline = [ + dict(type=PackInputs), +] + +train_dataloader = dict( + batch_size=16, + num_workers=2, + dataset=dict( + type=dataset_type, + data_root='data/cifar10', + split='train', + pipeline=train_pipeline), + sampler=dict(type=DefaultSampler, shuffle=True), +) + +val_dataloader = dict( + batch_size=16, + num_workers=2, + dataset=dict( + type=dataset_type, + data_root='data/cifar10/', + split='test', + pipeline=test_pipeline), + sampler=dict(type=DefaultSampler, shuffle=False), +) +val_evaluator = dict(type=Accuracy, topk=(1, )) + +test_dataloader = val_dataloader +test_evaluator = val_evaluator diff --git a/mmpretrain/configs/_base_/datasets/imagenet_bs128_mbv3.py b/mmpretrain/configs/_base_/datasets/imagenet_bs128_mbv3.py new file mode 100644 index 00000000000..cf0aa629d72 --- /dev/null +++ b/mmpretrain/configs/_base_/datasets/imagenet_bs128_mbv3.py @@ -0,0 +1,75 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmengine.dataset import DefaultSampler + +from mmpretrain.datasets import (AutoAugment, CenterCrop, ImageNet, + LoadImageFromFile, PackInputs, RandomErasing, + RandomFlip, RandomResizedCrop, ResizeEdge) +from mmpretrain.evaluation import Accuracy + +# dataset settings +dataset_type = ImageNet +data_preprocessor = dict( + num_classes=1000, + # RGB format normalization parameters + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + # convert image from BGR to RGB + to_rgb=True, +) + +bgr_mean = data_preprocessor['mean'][::-1] +bgr_std = data_preprocessor['std'][::-1] + +train_pipeline = [ + dict(type=LoadImageFromFile), + dict(type=RandomResizedCrop, scale=224, backend='pillow'), + dict(type=RandomFlip, prob=0.5, direction='horizontal'), + dict( + type=AutoAugment, + policies='imagenet', + hparams=dict(pad_val=[round(x) for x in bgr_mean])), + dict( + type=RandomErasing, + erase_prob=0.2, + mode='rand', + min_area_ratio=0.02, + max_area_ratio=1 / 3, + fill_color=bgr_mean, + fill_std=bgr_std), + dict(type=PackInputs), +] + +test_pipeline = [ + dict(type=LoadImageFromFile), + dict(type=ResizeEdge, scale=256, edge='short', backend='pillow'), + dict(type=CenterCrop, crop_size=224), + dict(type=PackInputs), +] + +train_dataloader = dict( + batch_size=128, + num_workers=5, + dataset=dict( + type=dataset_type, + data_root='data/imagenet', + split='train', + pipeline=train_pipeline), + sampler=dict(type=DefaultSampler, shuffle=True), +) + +val_dataloader = dict( + batch_size=128, + num_workers=5, + dataset=dict( + type=dataset_type, + data_root='data/imagenet', + split='val', + pipeline=test_pipeline), + sampler=dict(type=DefaultSampler, shuffle=False), +) +val_evaluator = dict(type=Accuracy, topk=(1, 5)) + +# If you want standard test, please manually configure the test dataset +test_dataloader = val_dataloader +test_evaluator = val_evaluator diff --git a/mmpretrain/configs/_base_/datasets/imagenet_bs32_pil_resize.py b/mmpretrain/configs/_base_/datasets/imagenet_bs32_pil_resize.py new file mode 100644 index 00000000000..f911bc20ff6 --- /dev/null +++ b/mmpretrain/configs/_base_/datasets/imagenet_bs32_pil_resize.py @@ -0,0 +1,60 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmengine.dataset import DefaultSampler + +from mmpretrain.datasets import (CenterCrop, ImageNet, LoadImageFromFile, + PackInputs, RandomFlip, RandomResizedCrop, + ResizeEdge) +from mmpretrain.evaluation import Accuracy + +# dataset settings +dataset_type = ImageNet +data_preprocessor = dict( + num_classes=1000, + # RGB format normalization parameters + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + # convert image from BGR to RGB + to_rgb=True, +) + +train_pipeline = [ + dict(type=LoadImageFromFile), + dict(type=RandomResizedCrop, scale=224, backend='pillow'), + dict(type=RandomFlip, prob=0.5, direction='horizontal'), + dict(type=PackInputs), +] + +test_pipeline = [ + dict(type=LoadImageFromFile), + dict(type=ResizeEdge, scale=256, edge='short', backend='pillow'), + dict(type=CenterCrop, crop_size=224), + dict(type=PackInputs), +] + +train_dataloader = dict( + batch_size=32, + num_workers=5, + dataset=dict( + type=dataset_type, + data_root='data/imagenet', + split='train', + pipeline=train_pipeline), + sampler=dict(type=DefaultSampler, shuffle=True), +) + +val_dataloader = dict( + batch_size=32, + num_workers=5, + dataset=dict( + type=dataset_type, + data_root='data/imagenet', + split='val', + pipeline=test_pipeline), + sampler=dict(type=DefaultSampler, shuffle=False), +) +val_evaluator = dict(type=Accuracy, topk=(1, 5)) + +# If you want standard test, please manually configure the test dataset +test_dataloader = val_dataloader +test_evaluator = val_evaluator diff --git a/mmpretrain/configs/_base_/models/mobilenet_v2_1x.py b/mmpretrain/configs/_base_/models/mobilenet_v2_1x.py new file mode 100644 index 00000000000..17dbb9fdd88 --- /dev/null +++ b/mmpretrain/configs/_base_/models/mobilenet_v2_1x.py @@ -0,0 +1,17 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmpretrain.models import (CrossEntropyLoss, GlobalAveragePooling, + ImageClassifier, LinearClsHead, MobileNetV2) + +# model settings +model = dict( + type=ImageClassifier, + backbone=dict(type=MobileNetV2, widen_factor=1.0), + neck=dict(type=GlobalAveragePooling), + head=dict( + type=LinearClsHead, + num_classes=1000, + in_channels=1280, + loss=dict(type=CrossEntropyLoss, loss_weight=1.0), + topk=(1, 5), + )) diff --git a/mmpretrain/configs/_base_/models/mobilenet_v3_small.py b/mmpretrain/configs/_base_/models/mobilenet_v3_small.py new file mode 100644 index 00000000000..83edab59206 --- /dev/null +++ b/mmpretrain/configs/_base_/models/mobilenet_v3_small.py @@ -0,0 +1,25 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmengine.model.weight_init import NormalInit +from torch.nn.modules.activation import Hardswish + +from mmpretrain.models import (CrossEntropyLoss, GlobalAveragePooling, + ImageClassifier, MobileNetV3, + StackedLinearClsHead) + +# model settings +model = dict( + type=ImageClassifier, + backbone=dict(type=MobileNetV3, arch='small'), + neck=dict(type=GlobalAveragePooling), + head=dict( + type=StackedLinearClsHead, + num_classes=1000, + in_channels=576, + mid_channels=[1024], + dropout_rate=0.2, + act_cfg=dict(type=Hardswish), + loss=dict(type=CrossEntropyLoss, loss_weight=1.0), + init_cfg=dict( + type=NormalInit, layer='Linear', mean=0., std=0.01, bias=0.), + topk=(1, 5))) diff --git a/mmpretrain/configs/_base_/schedules/cifar10_bs128.py b/mmpretrain/configs/_base_/schedules/cifar10_bs128.py new file mode 100644 index 00000000000..8ab749e8b64 --- /dev/null +++ b/mmpretrain/configs/_base_/schedules/cifar10_bs128.py @@ -0,0 +1,20 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmengine.optim import MultiStepLR +from torch.optim import SGD + +# optimizer +optim_wrapper = dict( + optimizer=dict(type=SGD, lr=0.1, momentum=0.9, weight_decay=0.0001)) +# learning policy +param_scheduler = dict( + type=MultiStepLR, by_epoch=True, milestones=[100, 150], gamma=0.1) + +# train, val, test setting +train_cfg = dict(by_epoch=True, max_epochs=200, val_interval=1) +val_cfg = dict() +test_cfg = dict() + +# NOTE: `auto_scale_lr` is for automatically scaling LR +# based on the actual training batch size. +auto_scale_lr = dict(base_batch_size=128) diff --git a/mmpretrain/configs/_base_/schedules/imagenet_bs256_epochstep.py b/mmpretrain/configs/_base_/schedules/imagenet_bs256_epochstep.py new file mode 100644 index 00000000000..9d245ebb9c3 --- /dev/null +++ b/mmpretrain/configs/_base_/schedules/imagenet_bs256_epochstep.py @@ -0,0 +1,20 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmengine.optim import StepLR +from torch.optim import SGD + +# optimizer +optim_wrapper = dict( + optimizer=dict(type=SGD, lr=0.045, momentum=0.9, weight_decay=0.00004)) + +# learning policy +param_scheduler = dict(type=StepLR, by_epoch=True, step_size=1, gamma=0.98) + +# train, val, test setting +train_cfg = dict(by_epoch=True, max_epochs=300, val_interval=1) +val_cfg = dict() +test_cfg = dict() + +# NOTE: `auto_scale_lr` is for automatically scaling LR, +# based on the actual training batch size. +auto_scale_lr = dict(base_batch_size=256) diff --git a/mmpretrain/configs/mobilenet_v2/mobilenet_v2_8xb32_in1k.py b/mmpretrain/configs/mobilenet_v2/mobilenet_v2_8xb32_in1k.py new file mode 100644 index 00000000000..79eec635501 --- /dev/null +++ b/mmpretrain/configs/mobilenet_v2/mobilenet_v2_8xb32_in1k.py @@ -0,0 +1,9 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmengine.config import read_base + +with read_base(): + from .._base_.datasets.imagenet_bs32_pil_resize import * + from .._base_.default_runtime import * + from .._base_.models.mobilenet_v2_1x import * + from .._base_.schedules.imagenet_bs256_epochstep import * diff --git a/mmpretrain/configs/mobilenet_v3/mobilenet_v3_large_8xb128_in1k.py b/mmpretrain/configs/mobilenet_v3/mobilenet_v3_large_8xb128_in1k.py new file mode 100644 index 00000000000..3f1bee1c132 --- /dev/null +++ b/mmpretrain/configs/mobilenet_v3/mobilenet_v3_large_8xb128_in1k.py @@ -0,0 +1,40 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. + +# Refers to https://pytorch.org/blog/ml-models-torchvision-v0.9/#classification +from mmengine.config import read_base + +with read_base(): + from .._base_.models.mobilenet_v3_small import * + from .._base_.datasets.imagenet_bs128_mbv3 import * + from .._base_.default_runtime import * + +from mmengine.optim import StepLR +from torch.optim import RMSprop + +# model settings +model.merge( + dict( + backbone=dict(arch='large'), + head=dict(in_channels=960, mid_channels=[1280]), + )) +# schedule settings +optim_wrapper = dict( + optimizer=dict( + type=RMSprop, + lr=0.064, + alpha=0.9, + momentum=0.9, + eps=0.0316, + weight_decay=1e-5)) + +param_scheduler = dict(type=StepLR, by_epoch=True, step_size=2, gamma=0.973) + +train_cfg = dict(by_epoch=True, max_epochs=600, val_interval=1) +val_cfg = dict() +test_cfg = dict() + +# NOTE: `auto_scale_lr` is for automatically scaling LR +# based on the actual training batch size. +# base_batch_size = (8 GPUs) x (128 samples per GPU) +auto_scale_lr = dict(base_batch_size=1024) diff --git a/mmpretrain/configs/mobilenet_v3/mobilenet_v3_small_050_8xb128_in1k.py b/mmpretrain/configs/mobilenet_v3/mobilenet_v3_small_050_8xb128_in1k.py new file mode 100644 index 00000000000..50e1ffc6709 --- /dev/null +++ b/mmpretrain/configs/mobilenet_v3/mobilenet_v3_small_050_8xb128_in1k.py @@ -0,0 +1,85 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +# Refers to https://pytorch.org/blog/ml-models-torchvision-v0.9/#classification + +from mmengine.config import read_base + +with read_base(): + from .._base_.models.mobilenet_v3_small import * + from .._base_.datasets.imagenet_bs128_mbv3 import * + from .._base_.default_runtime import * + +from mmengine.optim import StepLR +from torch.nn.modules.batchnorm import BatchNorm2d +from torch.optim import RMSprop + +# model settings +model.merge( + dict( + backbone=dict( + arch='small_050', + norm_cfg=dict(type=BatchNorm2d, eps=1e-5, momentum=0.1)), + head=dict(in_channels=288), + )) + +train_pipeline = [ + dict(type=LoadImageFromFile), + dict( + type=RandomResizedCrop, + scale=224, + backend='pillow', + interpolation='bicubic'), + dict(type=RandomFlip, prob=0.5, direction='horizontal'), + dict( + type=AutoAugment, + policies='imagenet', + hparams=dict(pad_val=[round(x) for x in [103.53, 116.28, 123.675]])), + dict( + type=RandomErasing, + erase_prob=0.2, + mode='rand', + min_area_ratio=0.02, + max_area_ratio=1 / 3, + fill_color=[103.53, 116.28, 123.675], + fill_std=[57.375, 57.12, 58.395]), + dict(type=PackInputs), +] + +test_pipeline = [ + dict(type=LoadImageFromFile), + dict( + type=ResizeEdge, + scale=256, + edge='short', + backend='pillow', + interpolation='bicubic'), + dict(type=CenterCrop, crop_size=224), + dict(type=PackInputs), +] + +train_dataloader.merge(dict(dataset=dict(pipeline=train_pipeline))) + +val_dataloader.merge(dict(dataset=dict(pipeline=test_pipeline))) +# If you want standard test, please manually configure the test dataset +test_dataloader = val_dataloader + +# schedule settings +optim_wrapper = dict( + optimizer=dict( + type=RMSprop, + lr=0.064, + alpha=0.9, + momentum=0.9, + eps=0.0316, + weight_decay=1e-5)) + +param_scheduler = dict(type=StepLR, by_epoch=True, step_size=2, gamma=0.973) + +train_cfg = dict(by_epoch=True, max_epochs=600, val_interval=10) +val_cfg = dict() +test_cfg = dict() + +# NOTE: `auto_scale_lr` is for automatically scaling LR +# based on the actual training batch size. +# base_batch_size = (8 GPUs) x (128 samples per GPU) +auto_scale_lr = dict(base_batch_size=1024) diff --git a/mmpretrain/configs/mobilenet_v3/mobilenet_v3_small_075_8xb128_in1k.py b/mmpretrain/configs/mobilenet_v3/mobilenet_v3_small_075_8xb128_in1k.py new file mode 100644 index 00000000000..c8c640cd8a0 --- /dev/null +++ b/mmpretrain/configs/mobilenet_v3/mobilenet_v3_small_075_8xb128_in1k.py @@ -0,0 +1,83 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +# Refers to https://pytorch.org/blog/ml-models-torchvision-v0.9/#classification + +from mmengine.config import read_base + +with read_base(): + from .._base_.models.mobilenet_v3_small import * + from .._base_.datasets.imagenet_bs128_mbv3 import * + from .._base_.default_runtime import * + +from mmengine.optim import StepLR +from torch.nn.modules.batchnorm import BatchNorm2d +from torch.optim import RMSprop + +# model settings +model.merge( + dict( + backbone=dict( + arch='small_075', + norm_cfg=dict(type=BatchNorm2d, eps=1e-5, momentum=0.1)), + head=dict(in_channels=432), + )) + +train_pipeline = [ + dict(type=LoadImageFromFile), + dict( + type=RandomResizedCrop, + scale=224, + backend='pillow', + interpolation='bicubic'), + dict(type=RandomFlip, prob=0.5, direction='horizontal'), + dict( + type=AutoAugment, + policies='imagenet', + hparams=dict(pad_val=[round(x) for x in [103.53, 116.28, 123.675]])), + dict( + type=RandomErasing, + erase_prob=0.2, + mode='rand', + min_area_ratio=0.02, + max_area_ratio=1 / 3, + fill_color=[103.53, 116.28, 123.675], + fill_std=[57.375, 57.12, 58.395]), + dict(type=PackInputs), +] + +test_pipeline = [ + dict(type=LoadImageFromFile), + dict( + type=ResizeEdge, + scale=256, + edge='short', + backend='pillow', + interpolation='bicubic'), + dict(type=CenterCrop, crop_size=224), + dict(type=PackInputs), +] + +train_dataloader.merge(dict(dataset=dict(pipeline=train_pipeline))) +val_dataloader.merge(dict(dataset=dict(pipeline=test_pipeline))) +test_dataloader = val_dataloader + +# schedule settings +optim_wrapper = dict( + optimizer=dict( + type=RMSprop, + lr=0.064, + alpha=0.9, + momentum=0.9, + eps=0.0316, + weight_decay=1e-5)) + +param_scheduler = dict(type=StepLR, by_epoch=True, step_size=2, gamma=0.973) + +train_cfg = dict(by_epoch=True, max_epochs=600, val_interval=10) +val_cfg = dict() +test_cfg = dict() + +# NOTE: `auto_scale_lr` is for automatically scaling LR +# based on the actual training batch size. +# base_batch_size = (8 GPUs) x (128 samples per GPU) +auto_scale_lr = dict(base_batch_size=1024) diff --git a/mmpretrain/configs/mobilenet_v3/mobilenet_v3_small_8xb128_in1k.py b/mmpretrain/configs/mobilenet_v3/mobilenet_v3_small_8xb128_in1k.py new file mode 100644 index 00000000000..0c220a01d09 --- /dev/null +++ b/mmpretrain/configs/mobilenet_v3/mobilenet_v3_small_8xb128_in1k.py @@ -0,0 +1,34 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +# Refers to https://pytorch.org/blog/ml-models-torchvision-v0.9/#classification + +from mmengine.config import read_base + +with read_base(): + from .._base_.models.mobilenet_v3_small import * + from .._base_.datasets.imagenet_bs128_mbv3 import * + from .._base_.default_runtime import * + +from mmengine.optim import StepLR +from torch.optim import RMSprop + +# schedule settings +optim_wrapper = dict( + optimizer=dict( + type=RMSprop, + lr=0.064, + alpha=0.9, + momentum=0.9, + eps=0.0316, + weight_decay=1e-5)) + +param_scheduler = dict(type=StepLR, by_epoch=True, step_size=2, gamma=0.973) + +train_cfg = dict(by_epoch=True, max_epochs=600, val_interval=1) +val_cfg = dict() +test_cfg = dict() + +# NOTE: `auto_scale_lr` is for automatically scaling LR +# based on the actual training batch size. +# base_batch_size = (8 GPUs) x (128 samples per GPU) +auto_scale_lr = dict(base_batch_size=1024) diff --git a/mmpretrain/configs/mobilenet_v3/mobilenet_v3_small_8xb16_cifar10.py b/mmpretrain/configs/mobilenet_v3/mobilenet_v3_small_8xb16_cifar10.py new file mode 100644 index 00000000000..0f91ee38243 --- /dev/null +++ b/mmpretrain/configs/mobilenet_v3/mobilenet_v3_small_8xb16_cifar10.py @@ -0,0 +1,34 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmengine.config import read_base + +with read_base(): + from .._base_.models.mobilenet_v3_small import * + from .._base_.datasets.cifar10_bs16 import * + from .._base_.schedules.cifar10_bs128 import * + from .._base_.default_runtime import * + +from mmengine.optim import MultiStepLR + +# model settings +model.merge( + dict( + head=dict( + _delete_=True, + type=StackedLinearClsHead, + num_classes=10, + in_channels=576, + mid_channels=[1280], + act_cfg=dict(type=Hardswish), + loss=dict(type=CrossEntropyLoss, loss_weight=1.0), + topk=(1, 5)))) +# schedule settings +param_scheduler.merge( + dict( + type=MultiStepLR, + by_epoch=True, + milestones=[120, 170], + gamma=0.1, + )) + +train_cfg.merge(dict(by_epoch=True, max_epochs=200)) diff --git a/mmpretrain/datasets/categories.py b/mmpretrain/datasets/categories.py index 011ee5c1609..9e75f7953b8 100644 --- a/mmpretrain/datasets/categories.py +++ b/mmpretrain/datasets/categories.py @@ -1438,3 +1438,224 @@ '海豹', '鲨鱼', '尖嘴小鼠', '臭鼬', '摩天大楼', '蜗牛', '蛇', '蜘蛛', '松鼠', '电车', '向日葵', '甜椒', '桌子', '坦克', '电话', '电视', '老虎', '拖拉机', '火车', '鳟鱼', '郁金香', '乌龟', '衣柜', '鲸鱼', '柳树', '狼', '女人', '蠕虫') + +IMAGENET_SIMPLE_CATEGORIES = ( + 'tench', 'goldfish', 'great white shark', 'tiger shark', + 'hammerhead shark', 'electric ray', 'stingray', 'rooster', 'hen', + 'ostrich', 'brambling', 'goldfinch', 'house finch', 'junco', + 'indigo bunting', 'American robin', 'bulbul', 'jay', 'magpie', 'chickadee', + 'American dipper', 'kite (bird of prey)', 'bald eagle', 'vulture', + 'great grey owl', 'fire salamander', 'smooth newt', 'newt', + 'spotted salamander', 'axolotl', 'American bullfrog', 'tree frog', + 'tailed frog', 'loggerhead sea turtle', 'leatherback sea turtle', + 'mud turtle', 'terrapin', 'box turtle', 'banded gecko', 'green iguana', + 'Carolina anole', 'desert grassland whiptail lizard', 'agama', + 'frilled-necked lizard', 'alligator lizard', 'Gila monster', + 'European green lizard', 'chameleon', 'Komodo dragon', 'Nile crocodile', + 'American alligator', 'triceratops', 'worm snake', 'ring-necked snake', + 'eastern hog-nosed snake', 'smooth green snake', 'kingsnake', + 'garter snake', 'water snake', 'vine snake', 'night snake', + 'boa constrictor', 'African rock python', 'Indian cobra', 'green mamba', + 'sea snake', 'Saharan horned viper', 'eastern diamondback rattlesnake', + 'sidewinder rattlesnake', 'trilobite', 'harvestman', 'scorpion', + 'yellow garden spider', 'barn spider', 'European garden spider', + 'southern black widow', 'tarantula', 'wolf spider', 'tick', 'centipede', + 'black grouse', 'ptarmigan', 'ruffed grouse', 'prairie grouse', 'peafowl', + 'quail', 'partridge', 'african grey parrot', 'macaw', + 'sulphur-crested cockatoo', 'lorikeet', 'coucal', 'bee eater', 'hornbill', + 'hummingbird', 'jacamar', 'toucan', 'duck', 'red-breasted merganser', + 'goose', 'black swan', 'tusker', 'echidna', 'platypus', 'wallaby', 'koala', + 'wombat', 'jellyfish', 'sea anemone', 'brain coral', 'flatworm', + 'nematode', 'conch', 'snail', 'slug', 'sea slug', 'chiton', + 'chambered nautilus', 'Dungeness crab', 'rock crab', 'fiddler crab', + 'red king crab', 'American lobster', 'spiny lobster', 'crayfish', + 'hermit crab', 'isopod', 'white stork', 'black stork', 'spoonbill', + 'flamingo', 'little blue heron', 'great egret', 'bittern bird', + 'crane bird', 'limpkin', 'common gallinule', 'American coot', 'bustard', + 'ruddy turnstone', 'dunlin', 'common redshank', 'dowitcher', + 'oystercatcher', 'pelican', 'king penguin', 'albatross', 'grey whale', + 'killer whale', 'dugong', 'sea lion', 'Chihuahua', 'Japanese Chin', + 'Maltese', 'Pekingese', 'Shih Tzu', 'King Charles Spaniel', 'Papillon', + 'toy terrier', 'Rhodesian Ridgeback', 'Afghan Hound', 'Basset Hound', + 'Beagle', 'Bloodhound', 'Bluetick Coonhound', 'Black and Tan Coonhound', + 'Treeing Walker Coonhound', 'English foxhound', 'Redbone Coonhound', + 'borzoi', 'Irish Wolfhound', 'Italian Greyhound', 'Whippet', + 'Ibizan Hound', 'Norwegian Elkhound', 'Otterhound', 'Saluki', + 'Scottish Deerhound', 'Weimaraner', 'Staffordshire Bull Terrier', + 'American Staffordshire Terrier', 'Bedlington Terrier', 'Border Terrier', + 'Kerry Blue Terrier', 'Irish Terrier', 'Norfolk Terrier', + 'Norwich Terrier', 'Yorkshire Terrier', 'Wire Fox Terrier', + 'Lakeland Terrier', 'Sealyham Terrier', 'Airedale Terrier', + 'Cairn Terrier', 'Australian Terrier', 'Dandie Dinmont Terrier', + 'Boston Terrier', 'Miniature Schnauzer', 'Giant Schnauzer', + 'Standard Schnauzer', 'Scottish Terrier', 'Tibetan Terrier', + 'Australian Silky Terrier', 'Soft-coated Wheaten Terrier', + 'West Highland White Terrier', 'Lhasa Apso', 'Flat-Coated Retriever', + 'Curly-coated Retriever', 'Golden Retriever', 'Labrador Retriever', + 'Chesapeake Bay Retriever', 'German Shorthaired Pointer', 'Vizsla', + 'English Setter', 'Irish Setter', 'Gordon Setter', 'Brittany dog', + 'Clumber Spaniel', 'English Springer Spaniel', 'Welsh Springer Spaniel', + 'Cocker Spaniel', 'Sussex Spaniel', 'Irish Water Spaniel', 'Kuvasz', + 'Schipperke', 'Groenendael dog', 'Malinois', 'Briard', 'Australian Kelpie', + 'Komondor', 'Old English Sheepdog', 'Shetland Sheepdog', 'collie', + 'Border Collie', 'Bouvier des Flandres dog', 'Rottweiler', + 'German Shepherd Dog', 'Dobermann', 'Miniature Pinscher', + 'Greater Swiss Mountain Dog', 'Bernese Mountain Dog', + 'Appenzeller Sennenhund', 'Entlebucher Sennenhund', 'Boxer', 'Bullmastiff', + 'Tibetan Mastiff', 'French Bulldog', 'Great Dane', 'St. Bernard', 'husky', + 'Alaskan Malamute', 'Siberian Husky', 'Dalmatian', 'Affenpinscher', + 'Basenji', 'pug', 'Leonberger', 'Newfoundland dog', 'Great Pyrenees dog', + 'Samoyed', 'Pomeranian', 'Chow Chow', 'Keeshond', 'brussels griffon', + 'Pembroke Welsh Corgi', 'Cardigan Welsh Corgi', 'Toy Poodle', + 'Miniature Poodle', 'Standard Poodle', + 'Mexican hairless dog (xoloitzcuintli)', 'grey wolf', + 'Alaskan tundra wolf', 'red wolf or maned wolf', 'coyote', 'dingo', + 'dhole', 'African wild dog', 'hyena', 'red fox', 'kit fox', 'Arctic fox', + 'grey fox', 'tabby cat', 'tiger cat', 'Persian cat', 'Siamese cat', + 'Egyptian Mau', 'cougar', 'lynx', 'leopard', 'snow leopard', 'jaguar', + 'lion', 'tiger', 'cheetah', 'brown bear', 'American black bear', + 'polar bear', 'sloth bear', 'mongoose', 'meerkat', 'tiger beetle', + 'ladybug', 'ground beetle', 'longhorn beetle', 'leaf beetle', + 'dung beetle', 'rhinoceros beetle', 'weevil', 'fly', 'bee', 'ant', + 'grasshopper', 'cricket insect', 'stick insect', 'cockroach', + 'praying mantis', 'cicada', 'leafhopper', 'lacewing', 'dragonfly', + 'damselfly', 'red admiral butterfly', 'ringlet butterfly', + 'monarch butterfly', 'small white butterfly', 'sulphur butterfly', + 'gossamer-winged butterfly', 'starfish', 'sea urchin', 'sea cucumber', + 'cottontail rabbit', 'hare', 'Angora rabbit', 'hamster', 'porcupine', + 'fox squirrel', 'marmot', 'beaver', 'guinea pig', 'common sorrel horse', + 'zebra', 'pig', 'wild boar', 'warthog', 'hippopotamus', 'ox', + 'water buffalo', 'bison', 'ram (adult male sheep)', 'bighorn sheep', + 'Alpine ibex', 'hartebeest', 'impala (antelope)', 'gazelle', + 'arabian camel', 'llama', 'weasel', 'mink', 'European polecat', + 'black-footed ferret', 'otter', 'skunk', 'badger', 'armadillo', + 'three-toed sloth', 'orangutan', 'gorilla', 'chimpanzee', 'gibbon', + 'siamang', 'guenon', 'patas monkey', 'baboon', 'macaque', 'langur', + 'black-and-white colobus', 'proboscis monkey', 'marmoset', + 'white-headed capuchin', 'howler monkey', 'titi monkey', + "Geoffroy's spider monkey", 'common squirrel monkey', 'ring-tailed lemur', + 'indri', 'Asian elephant', 'African bush elephant', 'red panda', + 'giant panda', 'snoek fish', 'eel', 'silver salmon', 'rock beauty fish', + 'clownfish', 'sturgeon', 'gar fish', 'lionfish', 'pufferfish', 'abacus', + 'abaya', 'academic gown', 'accordion', 'acoustic guitar', + 'aircraft carrier', 'airliner', 'airship', 'altar', 'ambulance', + 'amphibious vehicle', 'analog clock', 'apiary', 'apron', 'trash can', + 'assault rifle', 'backpack', 'bakery', 'balance beam', 'balloon', + 'ballpoint pen', 'Band-Aid', 'banjo', 'baluster / handrail', 'barbell', + 'barber chair', 'barbershop', 'barn', 'barometer', 'barrel', 'wheelbarrow', + 'baseball', 'basketball', 'bassinet', 'bassoon', 'swimming cap', + 'bath towel', 'bathtub', 'station wagon', 'lighthouse', 'beaker', + 'military hat (bearskin or shako)', 'beer bottle', 'beer glass', + 'bell tower', 'baby bib', 'tandem bicycle', 'bikini', 'ring binder', + 'binoculars', 'birdhouse', 'boathouse', 'bobsleigh', 'bolo tie', + 'poke bonnet', 'bookcase', 'bookstore', 'bottle cap', 'hunting bow', + 'bow tie', 'brass memorial plaque', 'bra', 'breakwater', 'breastplate', + 'broom', 'bucket', 'buckle', 'bulletproof vest', 'high-speed train', + 'butcher shop', 'taxicab', 'cauldron', 'candle', 'cannon', 'canoe', + 'can opener', 'cardigan', 'car mirror', 'carousel', 'tool kit', + 'cardboard box / carton', 'car wheel', 'automated teller machine', + 'cassette', 'cassette player', 'castle', 'catamaran', 'CD player', 'cello', + 'mobile phone', 'chain', 'chain-link fence', 'chain mail', 'chainsaw', + 'storage chest', 'chiffonier', 'bell or wind chime', 'china cabinet', + 'Christmas stocking', 'church', 'movie theater', 'cleaver', + 'cliff dwelling', 'cloak', 'clogs', 'cocktail shaker', 'coffee mug', + 'coffeemaker', 'spiral or coil', 'combination lock', 'computer keyboard', + 'candy store', 'container ship', 'convertible', 'corkscrew', 'cornet', + 'cowboy boot', 'cowboy hat', 'cradle', 'construction crane', + 'crash helmet', 'crate', 'infant bed', 'Crock Pot', 'croquet ball', + 'crutch', 'cuirass', 'dam', 'desk', 'desktop computer', + 'rotary dial telephone', 'diaper', 'digital clock', 'digital watch', + 'dining table', 'dishcloth', 'dishwasher', 'disc brake', 'dock', + 'dog sled', 'dome', 'doormat', 'drilling rig', 'drum', 'drumstick', + 'dumbbell', 'Dutch oven', 'electric fan', 'electric guitar', + 'electric locomotive', 'entertainment center', 'envelope', + 'espresso machine', 'face powder', 'feather boa', 'filing cabinet', + 'fireboat', 'fire truck', 'fire screen', 'flagpole', 'flute', + 'folding chair', 'football helmet', 'forklift', 'fountain', 'fountain pen', + 'four-poster bed', 'freight car', 'French horn', 'frying pan', 'fur coat', + 'garbage truck', 'gas mask or respirator', 'gas pump', 'goblet', 'go-kart', + 'golf ball', 'golf cart', 'gondola', 'gong', 'gown', 'grand piano', + 'greenhouse', 'radiator grille', 'grocery store', 'guillotine', + 'hair clip', 'hair spray', 'half-track', 'hammer', 'hamper', 'hair dryer', + 'hand-held computer', 'handkerchief', 'hard disk drive', 'harmonica', + 'harp', 'combine harvester', 'hatchet', 'holster', 'home theater', + 'honeycomb', 'hook', 'hoop skirt', 'gymnastic horizontal bar', + 'horse-drawn vehicle', 'hourglass', 'iPod', 'clothes iron', + 'carved pumpkin', 'jeans', 'jeep', 'T-shirt', 'jigsaw puzzle', 'rickshaw', + 'joystick', 'kimono', 'knee pad', 'knot', 'lab coat', 'ladle', 'lampshade', + 'laptop computer', 'lawn mower', 'lens cap', 'letter opener', 'library', + 'lifeboat', 'lighter', 'limousine', 'ocean liner', 'lipstick', + 'slip-on shoe', 'lotion', 'music speaker', 'loupe magnifying glass', + 'sawmill', 'magnetic compass', 'messenger bag', 'mailbox', 'tights', + 'one-piece bathing suit', 'manhole cover', 'maraca', 'marimba', 'mask', + 'matchstick', 'maypole', 'maze', 'measuring cup', 'medicine cabinet', + 'megalith', 'microphone', 'microwave oven', 'military uniform', 'milk can', + 'minibus', 'miniskirt', 'minivan', 'missile', 'mitten', 'mixing bowl', + 'mobile home', 'ford model t', 'modem', 'monastery', 'monitor', 'moped', + 'mortar and pestle', 'graduation cap', 'mosque', 'mosquito net', 'vespa', + 'mountain bike', 'tent', 'computer mouse', 'mousetrap', 'moving van', + 'muzzle', 'metal nail', 'neck brace', 'necklace', 'baby pacifier', + 'notebook computer', 'obelisk', 'oboe', 'ocarina', 'odometer', + 'oil filter', 'pipe organ', 'oscilloscope', 'overskirt', 'bullock cart', + 'oxygen mask', 'product packet / packaging', 'paddle', 'paddle wheel', + 'padlock', 'paintbrush', 'pajamas', 'palace', 'pan flute', 'paper towel', + 'parachute', 'parallel bars', 'park bench', 'parking meter', + 'railroad car', 'patio', 'payphone', 'pedestal', 'pencil case', + 'pencil sharpener', 'perfume', 'Petri dish', 'photocopier', 'plectrum', + 'Pickelhaube', 'picket fence', 'pickup truck', 'pier', 'piggy bank', + 'pill bottle', 'pillow', 'ping-pong ball', 'pinwheel', 'pirate ship', + 'drink pitcher', 'block plane', 'planetarium', 'plastic bag', 'plate rack', + 'farm plow', 'plunger', 'Polaroid camera', 'pole', 'police van', 'poncho', + 'pool table', 'soda bottle', 'plant pot', "potter's wheel", 'power drill', + 'prayer rug', 'printer', 'prison', 'missile', 'projector', 'hockey puck', + 'punching bag', 'purse', 'quill', 'quilt', 'race car', 'racket', + 'radiator', 'radio', 'radio telescope', 'rain barrel', + 'recreational vehicle', 'fishing casting reel', 'reflex camera', + 'refrigerator', 'remote control', 'restaurant', 'revolver', 'rifle', + 'rocking chair', 'rotisserie', 'eraser', 'rugby ball', + 'ruler measuring stick', 'sneaker', 'safe', 'safety pin', 'salt shaker', + 'sandal', 'sarong', 'saxophone', 'scabbard', 'weighing scale', + 'school bus', 'schooner', 'scoreboard', 'CRT monitor', 'screw', + 'screwdriver', 'seat belt', 'sewing machine', 'shield', 'shoe store', + 'shoji screen / room divider', 'shopping basket', 'shopping cart', + 'shovel', 'shower cap', 'shower curtain', 'ski', 'balaclava ski mask', + 'sleeping bag', 'slide rule', 'sliding door', 'slot machine', 'snorkel', + 'snowmobile', 'snowplow', 'soap dispenser', 'soccer ball', 'sock', + 'solar thermal collector', 'sombrero', 'soup bowl', 'keyboard space bar', + 'space heater', 'space shuttle', 'spatula', 'motorboat', 'spider web', + 'spindle', 'sports car', 'spotlight', 'stage', 'steam locomotive', + 'through arch bridge', 'steel drum', 'stethoscope', 'scarf', 'stone wall', + 'stopwatch', 'stove', 'strainer', 'tram', 'stretcher', 'couch', 'stupa', + 'submarine', 'suit', 'sundial', 'sunglasses', 'sunglasses', 'sunscreen', + 'suspension bridge', 'mop', 'sweatshirt', 'swim trunks / shorts', 'swing', + 'electrical switch', 'syringe', 'table lamp', 'tank', 'tape player', + 'teapot', 'teddy bear', 'television', 'tennis ball', 'thatched roof', + 'front curtain', 'thimble', 'threshing machine', 'throne', 'tile roof', + 'toaster', 'tobacco shop', 'toilet seat', 'torch', 'totem pole', + 'tow truck', 'toy store', 'tractor', 'semi-trailer truck', 'tray', + 'trench coat', 'tricycle', 'trimaran', 'tripod', 'triumphal arch', + 'trolleybus', 'trombone', 'hot tub', 'turnstile', 'typewriter keyboard', + 'umbrella', 'unicycle', 'upright piano', 'vacuum cleaner', 'vase', + 'vaulted or arched ceiling', 'velvet fabric', 'vending machine', + 'vestment', 'viaduct', 'violin', 'volleyball', 'waffle iron', 'wall clock', + 'wallet', 'wardrobe', 'military aircraft', 'sink', 'washing machine', + 'water bottle', 'water jug', 'water tower', 'whiskey jug', 'whistle', + 'hair wig', 'window screen', 'window shade', 'Windsor tie', 'wine bottle', + 'airplane wing', 'wok', 'wooden spoon', 'wool', 'split-rail fence', + 'shipwreck', 'sailboat', 'yurt', 'website', 'comic book', 'crossword', + 'traffic or street sign', 'traffic light', 'dust jacket', 'menu', 'plate', + 'guacamole', 'consomme', 'hot pot', 'trifle', 'ice cream', 'popsicle', + 'baguette', 'bagel', 'pretzel', 'cheeseburger', 'hot dog', + 'mashed potatoes', 'cabbage', 'broccoli', 'cauliflower', 'zucchini', + 'spaghetti squash', 'acorn squash', 'butternut squash', 'cucumber', + 'artichoke', 'bell pepper', 'cardoon', 'mushroom', 'Granny Smith apple', + 'strawberry', 'orange', 'lemon', 'fig', 'pineapple', 'banana', 'jackfruit', + 'cherimoya (custard apple)', 'pomegranate', 'hay', 'carbonara', + 'chocolate syrup', 'dough', 'meatloaf', 'pizza', 'pot pie', 'burrito', + 'red wine', 'espresso', 'tea cup', 'eggnog', 'mountain', 'bubble', 'cliff', + 'coral reef', 'geyser', 'lakeshore', 'promontory', 'sandbar', 'beach', + 'valley', 'volcano', 'baseball player', 'bridegroom', 'scuba diver', + 'rapeseed', 'daisy', "yellow lady's slipper", 'corn', 'acorn', 'rose hip', + 'horse chestnut seed', 'coral fungus', 'agaric', 'gyromitra', + 'stinkhorn mushroom', 'earth star fungus', 'hen of the woods mushroom', + 'bolete', 'corn cob', 'toilet paper') diff --git a/mmpretrain/datasets/coco_retrieval.py b/mmpretrain/datasets/coco_retrieval.py index 60d1586ad86..be8a0bcb864 100644 --- a/mmpretrain/datasets/coco_retrieval.py +++ b/mmpretrain/datasets/coco_retrieval.py @@ -1,18 +1,45 @@ # Copyright (c) OpenMMLab. All rights reserved. import json +import os.path as osp from collections import OrderedDict -from typing import List +from os import PathLike +from typing import List, Sequence, Union from mmengine import get_file_backend -from mmpretrain.registry import DATASETS +from mmpretrain.registry import DATASETS, TRANSFORMS from .base_dataset import BaseDataset +def expanduser(data_prefix): + if isinstance(data_prefix, (str, PathLike)): + return osp.expanduser(data_prefix) + else: + return data_prefix + + @DATASETS.register_module() class COCORetrieval(BaseDataset): """COCO Retrieval dataset. + COCO (Common Objects in Context): The COCO dataset contains more than + 330K images,each of which has approximately 5 descriptive annotations. + This dataset was releasedin collaboration between Microsoft and Carnegie + Mellon University + + COCO_2014 dataset directory: :: + + COCO_2014 + ├── val2014 + ├── train2014 + ├── annotations + ├── instances_train2014.json + ├── instances_val2014.json + ├── person_keypoints_train2014.json + ├── person_keypoints_val2014.json + ├── captions_train2014.json + ├── captions_val2014.json + Args: ann_file (str): Annotation file path. test_mode (bool): Whether dataset is used for evaluation. This will @@ -23,8 +50,52 @@ class COCORetrieval(BaseDataset): data_prefix (str | dict): Prefix for training data. Defaults to ''. pipeline (Sequence): Processing pipeline. Defaults to an empty tuple. **kwargs: Other keyword arguments in :class:`BaseDataset`. + + Examples: + >>> from mmpretrain.datasets import COCORetrieval + >>> train_dataset=COCORetrieval(data_root='coco2014/') + >>> train_dataset + Dataset COCORetrieval + Number of samples: 414113 + Annotation file: /coco2014/annotations/captions_train2014.json + Prefix of images: /coco2014/ + >>> from mmpretrain.datasets import COCORetrieval + >>> val_dataset = COCORetrieval(data_root='coco2014/') + >>> val_dataset + Dataset COCORetrieval + Number of samples: 202654 + Annotation file: /coco2014/annotations/captions_val2014.json + Prefix of images: /coco2014/ """ + def __init__(self, + ann_file: str, + test_mode: bool = False, + data_prefix: Union[str, dict] = '', + data_root: str = '', + pipeline: Sequence = (), + **kwargs): + + if isinstance(data_prefix, str): + data_prefix = dict(img_path=expanduser(data_prefix)) + + ann_file = expanduser(ann_file) + transforms = [] + for transform in pipeline: + if isinstance(transform, dict): + transforms.append(TRANSFORMS.build(transform)) + else: + transforms.append(transform) + + super().__init__( + data_root=data_root, + data_prefix=data_prefix, + test_mode=test_mode, + pipeline=transforms, + ann_file=ann_file, + **kwargs, + ) + def load_data_list(self) -> List[dict]: """Load data list.""" # get file backend diff --git a/mmpretrain/models/multimodal/__init__.py b/mmpretrain/models/multimodal/__init__.py index 072c0f84f72..e68504c6167 100644 --- a/mmpretrain/models/multimodal/__init__.py +++ b/mmpretrain/models/multimodal/__init__.py @@ -5,11 +5,13 @@ from .blip import * # noqa: F401,F403 from .blip2 import * # noqa: F401,F403 from .chinese_clip import * # noqa: F401, F403 + from .clip import * # noqa: F401, F403 from .flamingo import * # noqa: F401, F403 from .llava import * # noqa: F401, F403 from .minigpt4 import * # noqa: F401, F403 from .ofa import * # noqa: F401, F403 from .otter import * # noqa: F401, F403 + from .ram import * # noqa: F401, F403 else: from mmpretrain.registry import MODELS from mmpretrain.utils.dependency import register_multimodal_placeholder @@ -17,5 +19,6 @@ register_multimodal_placeholder([ 'Blip2Caption', 'Blip2Retrieval', 'Blip2VQA', 'BlipCaption', 'BlipNLVR', 'BlipRetrieval', 'BlipGrounding', 'BlipVQA', 'Flamingo', - 'OFA', 'ChineseCLIP', 'MiniGPT4', 'Llava', 'Otter' + 'OFA', 'ChineseCLIP', 'MiniGPT4', 'Llava', 'Otter', 'CLIP', + 'CLIPZeroShot', 'RAM', 'RAMNormal', 'RAMOpenset' ], MODELS) diff --git a/mmpretrain/models/multimodal/clip/__init__.py b/mmpretrain/models/multimodal/clip/__init__.py new file mode 100644 index 00000000000..f7a117ea7ca --- /dev/null +++ b/mmpretrain/models/multimodal/clip/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from ..clip.clip import CLIP, CLIPZeroShot +from ..clip.clip_transformer import CLIPProjection, CLIPTransformer + +__all__ = ['CLIP', 'CLIPZeroShot', 'CLIPTransformer', 'CLIPProjection'] diff --git a/mmpretrain/models/multimodal/clip/clip.py b/mmpretrain/models/multimodal/clip/clip.py new file mode 100644 index 00000000000..b509a63b3be --- /dev/null +++ b/mmpretrain/models/multimodal/clip/clip.py @@ -0,0 +1,364 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from abc import abstractmethod +from typing import List, Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn.functional as F +from mmengine.model import BaseModel +from torch import nn + +from mmpretrain.datasets.categories import (CIFAR100_CATEGORIES, + IMAGENET_SIMPLE_CATEGORIES) +from mmpretrain.registry import MODELS, TOKENIZER +from mmpretrain.structures import DataSample +from mmpretrain.utils import track_on_main_process +from .utils import (OPENAI_CIFAR100_PROMPT, OPENAI_IMAGENET_PROMPT, + OPENAI_IMAGENET_PROMPT_SUB) + +CIFAR100_CATEGORIES = [' '.join(c.split('_')) for c in CIFAR100_CATEGORIES] +PROTOTYPE_MAP = { + 'imagenet': IMAGENET_SIMPLE_CATEGORIES, + 'cifar100': CIFAR100_CATEGORIES, +} +PROMPT_MAP = { + 'openai_imagenet': OPENAI_IMAGENET_PROMPT, + 'openai_cifar100': OPENAI_CIFAR100_PROMPT, + 'vanilla': [lambda c: f'a photo of a {c}'], + 'openai_imagenet_sub': OPENAI_IMAGENET_PROMPT_SUB +} + + +class LayerNorm(nn.LayerNorm): + """Subclass torch's LayerNorm to handle fp16.""" + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward function.""" + orig_type = x.dtype + ret = super().forward(x.type(torch.float32)) + return ret.type(orig_type) + + +class CLIP(BaseModel): + """The implementation of `CLIP `_. + + Args: + vision_backbone (dict): Config dict for vision backbone. + text_backbone (dict): Config dict for text backbone. + tokenizer (dict): Config dict for text tokenizer. + proj_dim (int): Projection dimension for similarity computation. + text_prototype (str): Text prototype, which can be a key in + `PROTOTYPE_MAP` or list of text. + text_prompt (str): The prompt for text prototype. + Defaults to 'vanilla',which refers to "a photo of {cls}". + context_length (int): The context length to use. Defaults to 77. + data_preprocessor (Union[dict, nn.Module], optional): The config for + preprocessing input data. If None or no specified type, it will use + "MultiModalDataPreprocessor" as type. + See :class:`MultiModalDataPreprocessor` for more details. + Defaults to None. + init_cfg (dict, optional): The config to control the initialization. + Defaults to None. + """ + + def __init__(self, + vision_backbone: dict, + projection: dict, + text_backbone: dict, + tokenizer: dict, + vocab_size: int, + transformer_width: int, + proj_dim: int, + context_length: int = 77, + data_preprocessor: Optional[dict] = None, + init_cfg: Optional[dict] = None): + if data_preprocessor is None: + data_preprocessor = {} + data_preprocessor.setdefault('type', 'MultiModalDataPreprocessor') + data_preprocessor = MODELS.build(data_preprocessor) + + super().__init__( + data_preprocessor=data_preprocessor, init_cfg=init_cfg) + + self.context_length = context_length + + # build the vision transformer + self.visual = MODELS.build(vision_backbone) + + # build the visual projection + self.visual_proj = MODELS.build(projection) + + # build attn_mask for casual-attn + text_backbone['attn_mask'] = self.build_attention_mask() + + # build the text transformer + self.transformer = MODELS.build(text_backbone) + + self.vocab_size = vocab_size + self.token_embedding = nn.Embedding(vocab_size, transformer_width) + self.positional_embedding = nn.Parameter( + torch.empty(self.context_length, transformer_width)) + self.ln_final = LayerNorm(transformer_width) + + self.text_projection = nn.Parameter( + torch.empty(transformer_width, proj_dim)) + self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) + + self.initialize_parameters() + + self.tokenizer = TOKENIZER.build(tokenizer) + + self.tokenizer.vocab = self.tokenizer.get_vocab( + ) # CLIPTokenizer has no attribute named 'vocab', so manually + + def initialize_parameters(self) -> None: + """Initialize the parameters. + + The pretrained weight will override the initialized parameters by this + function. + """ + nn.init.normal_(self.token_embedding.weight, std=0.02) + nn.init.normal_(self.positional_embedding, std=0.01) + + proj_std = (self.transformer.width**-0.5) * ( + (2 * self.transformer.layers)**-0.5) + attn_std = self.transformer.width**-0.5 + fc_std = (2 * self.transformer.width)**-0.5 + for block in self.transformer.resblocks: + nn.init.normal_(block.attn.in_proj_weight, std=attn_std) + nn.init.normal_(block.attn.out_proj.weight, std=proj_std) + nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) + nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) + + if self.text_projection is not None: + nn.init.normal_( + self.text_projection, std=self.transformer.width**-0.5) + + def build_attention_mask(self): + # lazily create causal attention mask, + # with full attention between the vision tokens + # pytorch uses additive attention mask; fill with -inf + mask = torch.empty(self.context_length, self.context_length) + mask.fill_(float('-inf')) + mask.triu_(1) # zero out the lower diagonal + return mask + + def forward( + self, + images: torch.Tensor, + data_samples: Optional[list] = None, + mode: str = 'predict', + **kwargs, + ): + """The unified entry for a forward process in both training and test. + The method accepts the following modes: + + - "predict": Forward and return a list of data samples contain the + predict results. + + Args: + images (torch.Tensor): the preprocessed image tensor of shape + ``(N, C, H, W)``. + data_samples (List[DataSample], optional): The annotation data + of every samples. Defaults to None. + mode (str): Return what kind of value. Defaults to 'predict'. + """ + if mode == 'predict': + return self.predict(images, data_samples, **kwargs) + else: + raise RuntimeError(f'Invalid mode "{mode}".') + + def extract_image_feat(self, images: torch.Tensor) -> torch.Tensor: + """The function to extract image latent features.""" + return self.visual_proj(self.visual(images))[0] + + def extract_text_feat(self, texts: torch.Tensor) -> torch.Tensor: + """The function to extract text latent features.""" + x = self.token_embedding(texts) # [batch_size, n_ctx, d_model] + + x = x + self.positional_embedding + x = x.permute(1, 0, 2) # NLD -> LND + x = self.transformer(x)[0] + + x = x.permute(1, 0, 2) # LND -> NLD + x = self.ln_final(x) + + # x.shape = [batch_size, n_ctx, transformer.width] + # take features from the eot embedding + # (eot_token is the highest number in each sequence) + x = x[torch.arange(x.shape[0]), + texts.argmax(dim=-1)] @ self.text_projection + + return x + + def extract_feat( + self, images: torch.Tensor, + texts: torch.Tensor) -> Union[torch.Tensor, Tuple[torch.Tensor]]: + """The function to extract image and text latent features, the input + image or text can not both be None.""" + + assert images is not None or texts is not None, \ + 'text and image cannot both be None!' + if images is None: + return self.extract_text_feat(texts) + elif texts is None: + return self.extract_image_feat(images) + + image_features = self.extract_image_feat(images) + text_features = self.extract_text_feat(texts) + + image_features = image_features / image_features.norm( + dim=-1, keepdim=True) + text_features = text_features / text_features.norm( + dim=-1, keepdim=True) + + return image_features, text_features + + def compute_similarity(self, images, texts): + """Extract images and texts features and compute cosine similarity.""" + image_features, text_features = self.extract_feat( + images=images, texts=texts) + + # cosine similarity as logits + logit_scale = self.logit_scale.exp() + logits_per_image = logit_scale * image_features @ text_features.t() + logits_per_text = logits_per_image.t() + + # shape (N, N) + return logits_per_image, logits_per_text + + @abstractmethod + def predict(self, + images: torch.Tensor, + data_samples: DataSample = None) -> DataSample: + raise NotImplementedError + + def tokenize(self, texts: Union[str, List[str]]) -> torch.LongTensor: + """Returns the tokenized representation of given input string(s) + + Args: + texts (Union[str, List[str]]): An input string or a list of input + strings to tokenize + context_length (int): The context length to use. Defaults to 52. + + Returns: + torch.Tensor: Resulting tokens. + """ + if isinstance(texts, str): + texts = [texts] + + all_tokens = [] + for text in texts: + # adapt the text to Chinese BERT vocab + # text = text.lower().replace('“', "\"").replace('”', "\"") + + # add special tokens + all_tokens.append( + [self.tokenizer.vocab['<|startoftext|>'] + ] + # <|startoftext|>代表[CLS] token + self.tokenizer.convert_tokens_to_ids( + self.tokenizer.tokenize(text))[:self.context_length - 2] + + [self.tokenizer.vocab['<|endoftext|>']]) + + result = torch.zeros( + len(all_tokens), self.context_length, dtype=torch.long) + + for i, tokens in enumerate(all_tokens): + assert len(tokens) <= self.context_length + result[i, :len(tokens)] = torch.tensor(tokens) + + return result + + +@MODELS.register_module() +class CLIPZeroShot(CLIP): + + def __init__( + self, + vision_backbone: dict, + projection: dict, + text_backbone: dict, + tokenizer: dict, + vocab_size: int, + transformer_width: int, + proj_dim: int, + context_length: int = 77, + data_preprocessor: Optional[dict] = None, + init_cfg: Optional[dict] = None, + text_prototype: Union[str, List[str]] = 'imagenet', + text_prompt: str = 'vanilla', + ): + super(CLIPZeroShot, + self).__init__(vision_backbone, projection, text_backbone, + tokenizer, vocab_size, transformer_width, + proj_dim, context_length, data_preprocessor, + init_cfg) + + # for zero-shot classification + if isinstance(text_prototype, + str) and text_prototype in PROTOTYPE_MAP.keys(): + self.prototype = PROTOTYPE_MAP[text_prototype] + else: + self.prototype = text_prototype + self.text_prototype_embeds = None + + self.prompt = PROMPT_MAP[text_prompt] + + def predict(self, + images: torch.Tensor, + data_samples: DataSample = None) -> DataSample: + """Predict the classes of the input images. + + The prediction is for zero-shot classification and the text prototypes + will be prepared in thisfunction. + + Args: + images (torch.Tensor): The input images. + data_samples (DataSample): The data samples with information from + dataset. + + Returns: + DataSample: The results of prediction. + """ + + if self.text_prototype_embeds is None: + self.prepare_text_prototype(device=images.device) + + image_features = self.extract_image_feat(images=images) + image_features /= image_features.norm(dim=-1, keepdim=True) + + # cosine similarity as logits + logits_per_image = image_features @ self.text_prototype_embeds.to( + image_features.device) * self.logit_scale.exp() + + pred_scores = F.softmax(logits_per_image, dim=1) + pred_labels = pred_scores.argmax(dim=1, keepdim=True).detach() + + out_data_samples = [] + if data_samples is None: + data_samples = [None for _ in range(pred_scores.size(0))] + + for data_sample, score, label in zip(data_samples, pred_scores, + pred_labels): + if data_sample is None: + data_sample = DataSample() + + data_sample.set_pred_score(score).set_pred_label(label) + out_data_samples.append(data_sample) + return out_data_samples + + def prepare_text_prototype(self, device) -> None: + """The function to prepare text prototypes with prompt.""" + class_embeddings = [] + for classname in track_on_main_process(self.prototype, + 'Prepare text prototype...'): + # format with class + texts = [prompt(classname) for prompt in self.prompt] + tokenized_texts = self.tokenize(texts) + class_features = self.extract_text_feat(tokenized_texts.to(device)) + class_features /= class_features.norm(dim=-1, keepdim=True) + class_feature = class_features.mean(dim=0) + class_feature /= class_feature.norm() + class_embeddings.append(class_feature) + self.text_prototype_embeds = torch.stack( + class_embeddings, dim=1).to(device) diff --git a/mmpretrain/models/multimodal/clip/clip_transformer.py b/mmpretrain/models/multimodal/clip/clip_transformer.py new file mode 100644 index 00000000000..4b5f76661cb --- /dev/null +++ b/mmpretrain/models/multimodal/clip/clip_transformer.py @@ -0,0 +1,99 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# Modified from https://github.com/zejiangh/MILAN +from typing import Optional, Tuple + +import torch +from mmengine.model import BaseModule +from torch import nn + +from mmpretrain.models.utils.clip_generator_helper import \ + ResidualAttentionBlock +from mmpretrain.registry import MODELS + + +@MODELS.register_module() +class CLIPTransformer(nn.Module): + """Transformer. + + Both visual and text branches use this transformer. + + Args: + width (int): The feature dimension. + layers (int): The number of layers. + heads (int): The number of attention heads. + attn_mask (torch.Tensor, optional): The attention mask. + """ + + def __init__(self, + width: int, + layers: int, + heads: int, + attn_mask: Optional[torch.Tensor] = None) -> None: + super().__init__() + self.width = width + self.layers = layers + self.resblocks = nn.ModuleList() + for _ in range(layers - 1): + self.resblocks.append( + ResidualAttentionBlock(width, heads, attn_mask)) + self.resblocks.append( + ResidualAttentionBlock( + width, heads, attn_mask, return_attention=True)) + + def forward( + self, x: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Forward function.""" + z = [] + for idx, blk in enumerate(self.resblocks): + if idx < self.layers - 1: + x = blk(x) + z.append(x.permute(1, 0, 2)) + else: + x, attention = blk(x) + z.append(x.permute(1, 0, 2)) + return x, attention, z + + +@MODELS.register_module() +class CLIPProjection(BaseModule): + """Neck with CLIP Projection. + + Args: + in_channels (int): Number of channels in the input. + out_channels (int): Number of channels in the output. + init_cfg (dict | list[dict], optional): Initialization config dict. + Defaults to None. + """ + + def __init__(self, + in_channels: int, + out_channels: int, + init_cfg: Optional[dict] = None): + super(CLIPProjection, self).__init__(init_cfg=init_cfg) + + self.in_channels = in_channels + self.out_channels = out_channels + scale = in_channels**-0.5 + self.proj = nn.Parameter(scale * + torch.randn(in_channels, out_channels)) + + def forward(self, inputs: Tuple) -> Tuple[torch.Tensor]: + """forward function. + + Args: + inputs (Tuple): The features extracted from + the backbone. Multiple stage inputs are acceptable but only + the last stage will be used. + Returns: + Tuple(torch.Tensor)): A tuple of reducted features. + """ + if isinstance(inputs, tuple): + inputs = inputs[-1] + out = inputs @ self.proj + elif isinstance(inputs, torch.Tensor): + out = inputs @ self.proj + else: + raise TypeError( + '`CLIPProjection` neck inputs should be tuple or torch.tensor') + return (out, ) diff --git a/mmpretrain/models/multimodal/clip/utils.py b/mmpretrain/models/multimodal/clip/utils.py new file mode 100644 index 00000000000..65239bc37d6 --- /dev/null +++ b/mmpretrain/models/multimodal/clip/utils.py @@ -0,0 +1,115 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +OPENAI_CIFAR100_PROMPT = [ + lambda c: f'a photo of a {c}.', + lambda c: f'a blurry photo of a {c}.', + lambda c: f'a black and white photo of a {c}.', + lambda c: f'a low contrast photo of a {c}.', + lambda c: f'a high contrast photo of a {c}.', + lambda c: f'a bad photo of a {c}.', + lambda c: f'a good photo of a {c}.', + lambda c: f'a photo of a small {c}.', + lambda c: f'a photo of a big {c}.', + lambda c: f'a photo of the {c}.', + lambda c: f'a blurry photo of the {c}.', + lambda c: f'a black and white photo of the {c}.', + lambda c: f'a low contrast photo of the {c}.', + lambda c: f'a high contrast photo of the {c}.', + lambda c: f'a bad photo of the {c}.', + lambda c: f'a good photo of the {c}.', + lambda c: f'a photo of the small {c}.', + lambda c: f'a photo of the big {c}.', +] + +OPENAI_IMAGENET_PROMPT_SUB = [ + lambda c: f'itap of a {c}.', + lambda c: f'a bad photo of the {c}.', + lambda c: f'a origami {c}.', + lambda c: f'a photo of the large {c}.', + lambda c: f'a {c} in a video game.', + lambda c: f'art of the {c}.', + lambda c: f'a photo of the small {c}.', +] + +OPENAI_IMAGENET_PROMPT = [ + lambda c: f'a bad photo of a {c}.', + lambda c: f'a photo of many {c}.', + lambda c: f'a sculpture of a {c}.', + lambda c: f'a photo of the hard to see {c}.', + lambda c: f'a low resolution photo of the {c}.', + lambda c: f'a rendering of a {c}.', + lambda c: f'graffiti of a {c}.', + lambda c: f'a bad photo of the {c}.', + lambda c: f'a cropped photo of the {c}.', + lambda c: f'a tattoo of a {c}.', + lambda c: f'the embroidered {c}.', + lambda c: f'a photo of a hard to see {c}.', + lambda c: f'a bright photo of a {c}.', + lambda c: f'a photo of a clean {c}.', + lambda c: f'a photo of a dirty {c}.', + lambda c: f'a dark photo of the {c}.', + lambda c: f'a drawing of a {c}.', + lambda c: f'a photo of my {c}.', + lambda c: f'the plastic {c}.', + lambda c: f'a photo of the cool {c}.', + lambda c: f'a close-up photo of a {c}.', + lambda c: f'a black and white photo of the {c}.', + lambda c: f'a painting of the {c}.', + lambda c: f'a painting of a {c}.', + lambda c: f'a pixelated photo of the {c}.', + lambda c: f'a sculpture of the {c}.', + lambda c: f'a bright photo of the {c}.', + lambda c: f'a cropped photo of a {c}.', + lambda c: f'a plastic {c}.', + lambda c: f'a photo of the dirty {c}.', + lambda c: f'a jpeg corrupted photo of a {c}.', + lambda c: f'a blurry photo of the {c}.', + lambda c: f'a photo of the {c}.', + lambda c: f'a good photo of the {c}.', + lambda c: f'a rendering of the {c}.', + lambda c: f'a {c} in a video game.', + lambda c: f'a photo of one {c}.', + lambda c: f'a doodle of a {c}.', + lambda c: f'a close-up photo of the {c}.', + lambda c: f'a photo of a {c}.', + lambda c: f'the origami {c}.', + lambda c: f'the {c} in a video game.', + lambda c: f'a sketch of a {c}.', + lambda c: f'a doodle of the {c}.', + lambda c: f'a origami {c}.', + lambda c: f'a low resolution photo of a {c}.', + lambda c: f'the toy {c}.', + lambda c: f'a rendition of the {c}.', + lambda c: f'a photo of the clean {c}.', + lambda c: f'a photo of a large {c}.', + lambda c: f'a rendition of a {c}.', + lambda c: f'a photo of a nice {c}.', + lambda c: f'a photo of a weird {c}.', + lambda c: f'a blurry photo of a {c}.', + lambda c: f'a cartoon {c}.', + lambda c: f'art of a {c}.', + lambda c: f'a sketch of the {c}.', + lambda c: f'a embroidered {c}.', + lambda c: f'a pixelated photo of a {c}.', + lambda c: f'itap of the {c}.', + lambda c: f'a jpeg corrupted photo of the {c}.', + lambda c: f'a good photo of a {c}.', + lambda c: f'a plushie {c}.', + lambda c: f'a photo of the nice {c}.', + lambda c: f'a photo of the small {c}.', + lambda c: f'a photo of the weird {c}.', + lambda c: f'the cartoon {c}.', + lambda c: f'art of the {c}.', + lambda c: f'a drawing of the {c}.', + lambda c: f'a photo of the large {c}.', + lambda c: f'a black and white photo of a {c}.', + lambda c: f'the plushie {c}.', + lambda c: f'a dark photo of a {c}.', + lambda c: f'itap of a {c}.', + lambda c: f'graffiti of the {c}.', + lambda c: f'a toy {c}.', + lambda c: f'itap of my {c}.', + lambda c: f'a photo of a cool {c}.', + lambda c: f'a photo of a small {c}.', + lambda c: f'a tattoo of the {c}.', +] diff --git a/mmpretrain/models/multimodal/ofa/ofa_modules.py b/mmpretrain/models/multimodal/ofa/ofa_modules.py index 1c79049b617..ef5c8533755 100644 --- a/mmpretrain/models/multimodal/ofa/ofa_modules.py +++ b/mmpretrain/models/multimodal/ofa/ofa_modules.py @@ -1301,6 +1301,7 @@ class OFAEncoderDecoder(BaseModule, GenerationMixin): Defaults to an empty dict. init_cfg (dict, optional): The initialization config. Defaults to None. """ + base_model_prefix = '' def __init__( self, diff --git a/mmpretrain/models/multimodal/ram/__init__.py b/mmpretrain/models/multimodal/ram/__init__.py new file mode 100644 index 00000000000..35619d88516 --- /dev/null +++ b/mmpretrain/models/multimodal/ram/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .ram import RAM, RAMNormal, RAMOpenset + +__all__ = ['RAM', 'RAMNormal', 'RAMOpenset'] diff --git a/mmpretrain/models/multimodal/ram/bert.py b/mmpretrain/models/multimodal/ram/bert.py new file mode 100644 index 00000000000..f54b2ce8e47 --- /dev/null +++ b/mmpretrain/models/multimodal/ram/bert.py @@ -0,0 +1,1197 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# Modify from: +# https://github.com/xinyu1205/recognize-anything/blob/main/ram/models/bert.py + +import math +from typing import Tuple + +import torch +import torch.utils.checkpoint +from torch import Tensor, device, nn +from torch.nn import CrossEntropyLoss +from transformers.activations import ACT2FN +from transformers.modeling_outputs import ( + BaseModelOutputWithPastAndCrossAttentions, + BaseModelOutputWithPoolingAndCrossAttentions, + CausalLMOutputWithCrossAttentions) +from transformers.modeling_utils import (PreTrainedModel, + apply_chunking_to_forward, + find_pruneable_heads_and_indices, + prune_linear_layer) +from transformers.models.bert.configuration_bert import BertConfig +from transformers.utils import logging + +logger = logging.get_logger(__name__) + + +class BertEmbeddings_nopos(nn.Module): + """Construct the embeddings from word and position embeddings.""" + + def __init__(self, config): + super().__init__() + self.word_embeddings = nn.Embedding( + config.vocab_size, + config.hidden_size, + padding_idx=config.pad_token_id) + # self.position_embeddings = nn.Embedding( + # config.max_position_embeddings, config.hidden_size) + '''self.LayerNorm is not snake-cased to stick with + TensorFlow model variable name and be able to load''' + # any TensorFlow checkpoint file + self.LayerNorm = nn.LayerNorm( + config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + # position_ids (1, len position emb) is contiguous + # in memory and exported when serialized + # self.register_buffer("position_ids", + # torch.arange(config.max_position_embeddings).expand((1, -1))) + # self.position_embedding_type = \ + # getattr(config, "position_embedding_type", "absolute") + + self.config = config + + def forward(self, + input_ids=None, + position_ids=None, + inputs_embeds=None, + past_key_values_length=0): + if input_ids is not None: + input_shape = input_ids.size() + else: + input_shape = inputs_embeds.size()[:-1] + + seq_length = input_shape[1] # noqa: F841 + + # if position_ids is None: + # position_ids = self.position_ids[:, \ + # past_key_values_length : seq_length + \ + # past_key_values_length] + + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + + embeddings = inputs_embeds + + # if self.position_embedding_type == "absolute": + # position_embeddings = self.position_embeddings(position_ids) + # # print('add position_embeddings!!!!') + # embeddings += position_embeddings + embeddings = self.LayerNorm(embeddings) + embeddings = self.dropout(embeddings) + return embeddings + + +class BertEmbeddings(nn.Module): + """Construct the embeddings from word and position embeddings.""" + + def __init__(self, config): + super().__init__() + self.word_embeddings = nn.Embedding( + config.vocab_size, + config.hidden_size, + padding_idx=config.pad_token_id) + self.position_embeddings = nn.Embedding(config.max_position_embeddings, + config.hidden_size) + + # self.LayerNorm is not snake-cased to stick with + # TensorFlow model variable name and be able to load + # any TensorFlow checkpoint file + self.LayerNorm = nn.LayerNorm( + config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + # position_ids (1, len position emb) is contiguous + # in memory and exported when serialized + self.register_buffer( + 'position_ids', + torch.arange(config.max_position_embeddings).expand((1, -1))) + self.position_embedding_type = getattr(config, + 'position_embedding_type', + 'absolute') + + self.config = config + + def forward(self, + input_ids=None, + position_ids=None, + inputs_embeds=None, + past_key_values_length=0): + if input_ids is not None: + input_shape = input_ids.size() + else: + input_shape = inputs_embeds.size()[:-1] + + seq_length = input_shape[1] + + if position_ids is None: + position_ids = self.position_ids[:, past_key_values_length: + seq_length + + past_key_values_length] + + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + + embeddings = inputs_embeds + + if self.position_embedding_type == 'absolute': + position_embeddings = self.position_embeddings(position_ids) + # print('add position_embeddings!!!!') + embeddings += position_embeddings + embeddings = self.LayerNorm(embeddings) + embeddings = self.dropout(embeddings) + return embeddings + + +class BertSelfAttention(nn.Module): + + def __init__(self, config, is_cross_attention): + super().__init__() + self.config = config + if config.hidden_size % config.num_attention_heads != 0 and \ + not hasattr(config, 'embedding_size'): + raise ValueError('''The hidden size (%d) is not a multiple of + the number of attention heads (%d)''' % + (config.hidden_size, config.num_attention_heads)) + + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / + config.num_attention_heads) + self.all_head_size = self.num_attention_heads * \ + self.attention_head_size + + self.query = nn.Linear(config.hidden_size, self.all_head_size) + if is_cross_attention: + self.key = nn.Linear(config.encoder_width, self.all_head_size) + self.value = nn.Linear(config.encoder_width, self.all_head_size) + else: + self.key = nn.Linear(config.hidden_size, self.all_head_size) + self.value = nn.Linear(config.hidden_size, self.all_head_size) + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + self.position_embedding_type = getattr(config, + 'position_embedding_type', + 'absolute') + if (self.position_embedding_type == 'relative_key' + or self.position_embedding_type == 'relative_key_query'): + self.max_position_embeddings = config.max_position_embeddings + self.distance_embedding = nn.Embedding( + 2 * config.max_position_embeddings - 1, + self.attention_head_size) + self.save_attention = False + + def save_attn_gradients(self, attn_gradients): + self.attn_gradients = attn_gradients + + def get_attn_gradients(self): + return self.attn_gradients + + def save_attention_map(self, attention_map): + self.attention_map = attention_map + + def get_attention_map(self): + return self.attention_map + + def transpose_for_scores(self, x): + new_x_shape = x.size()[:-1] + (self.num_attention_heads, + self.attention_head_size) + x = x.view(*new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=None, + output_attentions=False, + ): + mixed_query_layer = self.query(hidden_states) + + # If this is instantiated as a cross-attention module, the keys + # and values come from an encoder; the attention mask needs to be + # such that the encoder's padding tokens are not attended to. + is_cross_attention = encoder_hidden_states is not None + + if is_cross_attention: + # print(self.key.weight.shape) + key_layer = self.transpose_for_scores( + self.key(encoder_hidden_states)) + value_layer = self.transpose_for_scores( + self.value(encoder_hidden_states)) + attention_mask = encoder_attention_mask + elif past_key_value is not None: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + key_layer = torch.cat([past_key_value[0], key_layer], dim=2) + value_layer = torch.cat([past_key_value[1], value_layer], dim=2) + else: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + + query_layer = self.transpose_for_scores(mixed_query_layer) + + past_key_value = (key_layer, value_layer) + + # compatible with higher versions of transformers + if key_layer.shape[0] > query_layer.shape[0]: + key_layer = key_layer[:query_layer.shape[0], :, :, :] + attention_mask = attention_mask[:query_layer.shape[0], :, :] + value_layer = value_layer[:query_layer.shape[0], :, :, :] + + # Take the dot product between "query" and "key" + # to get the raw attention scores. + attention_scores = torch.matmul(query_layer, + key_layer.transpose(-1, -2)) + + if (self.position_embedding_type == 'relative_key' + or self.position_embedding_type == 'relative_key_query'): + seq_length = hidden_states.size()[1] + position_ids_l = torch.arange( + seq_length, dtype=torch.long, + device=hidden_states.device).view(-1, 1) + position_ids_r = torch.arange( + seq_length, dtype=torch.long, + device=hidden_states.device).view(1, -1) + distance = position_ids_l - position_ids_r + positional_embedding = self.distance_embedding( + distance + self.max_position_embeddings - 1) + positional_embedding = positional_embedding.to( + dtype=query_layer.dtype) # fp16 compatibility + + if self.position_embedding_type == 'relative_key': + relative_position_scores = torch.einsum( + 'bhld,lrd->bhlr', query_layer, positional_embedding) + attention_scores = attention_scores + relative_position_scores + elif self.position_embedding_type == 'relative_key_query': + relative_position_scores_query = torch.einsum( + 'bhld,lrd->bhlr', query_layer, positional_embedding) + relative_position_scores_key = torch.einsum( + 'bhrd,lrd->bhlr', key_layer, positional_embedding) + attention_scores = attention_scores + \ + relative_position_scores_query + \ + relative_position_scores_key + + attention_scores = attention_scores / math.sqrt( + self.attention_head_size) + if attention_mask is not None: + # Apply the attention mask is (precomputed for + # all layers in BertModel forward() function) + attention_scores = attention_scores + attention_mask + + # Normalize the attention scores to probabilities. + attention_probs = nn.Softmax(dim=-1)(attention_scores) + + if is_cross_attention and self.save_attention: + self.save_attention_map(attention_probs) + attention_probs.register_hook(self.save_attn_gradients) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs_dropped = self.dropout(attention_probs) + + # Mask heads if we want to + if head_mask is not None: + attention_probs_dropped = attention_probs_dropped * head_mask + + context_layer = torch.matmul(attention_probs_dropped, value_layer) + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + ( + self.all_head_size, ) + context_layer = context_layer.view(*new_context_layer_shape) + + outputs = (context_layer, + attention_probs) if output_attentions else (context_layer, ) + + outputs = outputs + (past_key_value, ) + return outputs + + +class BertSelfOutput(nn.Module): + + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm( + config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class BertAttention(nn.Module): + + def __init__(self, config, is_cross_attention=False): + super().__init__() + self.self = BertSelfAttention(config, is_cross_attention) + self.output = BertSelfOutput(config) + self.pruned_heads = set() + + def prune_heads(self, heads): + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, self.self.num_attention_heads, + self.self.attention_head_size, self.pruned_heads) + + # Prune linear layers + self.self.query = prune_linear_layer(self.self.query, index) + self.self.key = prune_linear_layer(self.self.key, index) + self.self.value = prune_linear_layer(self.self.value, index) + self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) + + # Update hyper params and store pruned heads + self.self.num_attention_heads = self.self.num_attention_heads - len( + heads) + self.self.all_head_size = self.self.attention_head_size * \ + self.self.num_attention_heads + self.pruned_heads = self.pruned_heads.union(heads) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=None, + output_attentions=False, + ): + self_outputs = self.self( + hidden_states, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + ) + attention_output = self.output(self_outputs[0], hidden_states) + outputs = (attention_output, + ) + self_outputs[1:] # add attentions if we output them + return outputs + + +class BertIntermediate(nn.Module): + + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.intermediate_size) + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +class BertOutput(nn.Module): + + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.intermediate_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm( + config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class BertLayer(nn.Module): + + def __init__(self, config, layer_num): + super().__init__() + self.config = config + self.chunk_size_feed_forward = config.chunk_size_feed_forward + self.seq_len_dim = 1 + self.attention = BertAttention(config) + self.layer_num = layer_num + if self.config.add_cross_attention: + self.crossattention = BertAttention( + config, is_cross_attention=self.config.add_cross_attention) + self.intermediate = BertIntermediate(config) + self.output = BertOutput(config) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=None, + output_attentions=False, + mode=None, + ): + + if mode == 'tagging': + + assert encoder_hidden_states is not None, \ + '''encoder_hidden_states must be given + for cross-attention layers''' + + cross_attention_outputs = self.crossattention( + hidden_states, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + output_attentions=output_attentions, + ) + attention_output = cross_attention_outputs[0] + outputs = cross_attention_outputs[ + 1:-1] # add cross attentions if we output attention weights + + present_key_value = cross_attention_outputs[-1] + + else: + # decoder uni-directional self-attention + # cached key/values tuple is at positions 1,2 + self_attn_past_key_value = \ + (past_key_value[:2] + if past_key_value is not None else None) + self_attention_outputs = self.attention( + hidden_states, + attention_mask, + head_mask, + output_attentions=output_attentions, + past_key_value=self_attn_past_key_value, + ) + attention_output = self_attention_outputs[0] + + outputs = self_attention_outputs[1:-1] + present_key_value = self_attention_outputs[-1] + + if mode == 'multimodal': + assert encoder_hidden_states is not None, \ + '''encoder_hidden_states must be + given for cross-attention layers''' + + cross_attention_outputs = self.crossattention( + attention_output, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + output_attentions=output_attentions, + ) + attention_output = cross_attention_outputs[0] + outputs = outputs + cross_attention_outputs[ + 1: + -1] # add cross attentions if we output attention weights + layer_output = apply_chunking_to_forward(self.feed_forward_chunk, + self.chunk_size_feed_forward, + self.seq_len_dim, + attention_output) + outputs = (layer_output, ) + outputs + + outputs = outputs + (present_key_value, ) + + return outputs + + def feed_forward_chunk(self, attention_output): + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output, attention_output) + return layer_output + + +class BertEncoder(nn.Module): + + def __init__(self, config): + super().__init__() + self.config = config + self.layer = nn.ModuleList( + [BertLayer(config, i) for i in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_values=None, + use_cache=None, + output_attentions=False, + output_hidden_states=False, + return_dict=True, + mode='multimodal', + ): + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + all_cross_attentions = ( + ) if output_attentions and self.config.add_cross_attention else None + + next_decoder_cache = () if use_cache else None + + for i in range(self.config.num_hidden_layers): + layer_module = self.layer[i] + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states, ) + + layer_head_mask = head_mask[i] if head_mask is not None else None + past_key_value = past_key_values[ + i] if past_key_values is not None else None + + if self.gradient_checkpointing and self.training: + + if use_cache: + logger.warn('''`use_cache=True` is incompatible with + gradient checkpointing. Setting `use_cache=False`...''' + ) + use_cache = False + + def create_custom_forward(module): + + def custom_forward(*inputs): + return module(*inputs, past_key_value, + output_attentions) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(layer_module), + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + mode=mode, + ) + else: + layer_outputs = layer_module( + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + mode=mode, + ) + + hidden_states = layer_outputs[0] + if use_cache: + next_decoder_cache += (layer_outputs[-1], ) + if output_attentions: + all_self_attentions = all_self_attentions + ( + layer_outputs[1], ) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states, ) + + if not return_dict: + return tuple(v for v in [ + hidden_states, + next_decoder_cache, + all_hidden_states, + all_self_attentions, + all_cross_attentions, + ] if v is not None) + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=next_decoder_cache, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + cross_attentions=all_cross_attentions, + ) + + +class BertPooler(nn.Module): + + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.activation = nn.Tanh() + + def forward(self, hidden_states): + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + first_token_tensor = hidden_states[:, 0] + pooled_output = self.dense(first_token_tensor) + pooled_output = self.activation(pooled_output) + return pooled_output + + +class BertPredictionHeadTransform(nn.Module): + + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + if isinstance(config.hidden_act, str): + self.transform_act_fn = ACT2FN[config.hidden_act] + else: + self.transform_act_fn = config.hidden_act + self.LayerNorm = nn.LayerNorm( + config.hidden_size, eps=config.layer_norm_eps) + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.transform_act_fn(hidden_states) + hidden_states = self.LayerNorm(hidden_states) + return hidden_states + + +class BertLMPredictionHead(nn.Module): + + def __init__(self, config): + super().__init__() + self.transform = BertPredictionHeadTransform(config) + + # The output weights are the same as the input embeddings, but there is + # an output-only bias for each token. + self.decoder = nn.Linear( + config.hidden_size, config.vocab_size, bias=False) + + self.bias = nn.Parameter(torch.zeros(config.vocab_size)) + + # Need a link between the two variables so that + # the bias is correctly resized with `resize_token_embeddings` + self.decoder.bias = self.bias + + def forward(self, hidden_states): + hidden_states = self.transform(hidden_states) + hidden_states = self.decoder(hidden_states) + return hidden_states + + +class BertOnlyMLMHead(nn.Module): + + def __init__(self, config): + super().__init__() + self.predictions = BertLMPredictionHead(config) + + def forward(self, sequence_output): + prediction_scores = self.predictions(sequence_output) + return prediction_scores + + +class BertPreTrainedModel(PreTrainedModel): + """An abstract class to handle weights initialization and a simple + interface for downloading and loading pretrained models.""" + + config_class = BertConfig + base_model_prefix = 'bert' + _keys_to_ignore_on_load_missing = [r'position_ids'] + + def _init_weights(self, module): + """Initialize the weights.""" + if isinstance(module, (nn.Linear, nn.Embedding)): + # Slightly different from the TF version + # which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_( + mean=0.0, std=self.config.initializer_range) + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + if isinstance(module, nn.Linear) and module.bias is not None: + module.bias.data.zero_() + + +class BertModel(BertPreTrainedModel): + """The model can behave as an encoder (with only self-attention) as well as + a decoder, in which case a layer of cross-attention is added between the + self-attention layers, following the architecture described in `Attention + is all you need `__ by Ashish Vaswani, + Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. + + Gomez, Lukasz Kaiser and Illia Polosukhin. argument and + :obj:`add_cross_attention` set to :obj:`True`; an + :obj:`encoder_hidden_states` is then expected as an input to the forward + pass. + """ + + def __init__(self, config, add_pooling_layer=True): + super().__init__(config) + self.config = config + + self.embeddings = BertEmbeddings(config) + + self.encoder = BertEncoder(config) + + self.pooler = BertPooler(config) if add_pooling_layer else None + + self.init_weights() + + def get_input_embeddings(self): + return self.embeddings.word_embeddings + + def set_input_embeddings(self, value): + self.embeddings.word_embeddings = value + + def _prune_heads(self, heads_to_prune): + """Prunes heads of the model. + + heads_to_prune: + dict of {layer_num: list of heads to prune in this layer} + See base class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + def get_extended_attention_mask(self, attention_mask: Tensor, + input_shape: Tuple[int], device: device, + is_decoder: bool) -> Tensor: + """Makes broadcastable attention and causal masks so that future and + masked tokens are ignored. + + Arguments: + attention_mask (:obj:`torch.Tensor`): + Mask with ones indicating tokens to attend to, + zeros for tokens to ignore. + input_shape (:obj:`Tuple[int]`): + The shape of the input to the model. + device: (:obj:`torch.device`): + The device of the input to the model. + + Returns: + :obj:`torch.Tensor` The extended attention mask, + with a the same dtype as :obj:`attention_mask.dtype`. + """ + # We can provide a self-attention mask of dimensions + # [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it + # broadcastable to all heads. + if attention_mask.dim() == 3: + extended_attention_mask = attention_mask[:, None, :, :] + elif attention_mask.dim() == 2: + # Provided a padding mask of dimensions [batch_size, seq_length] + # - if the model is a decoder, apply a causal mask + # in addition to the padding mask + # - if the model is an encoder, make the mask + # broadcastable to [batch_size, num_heads, seq_length, seq_length] + if is_decoder: + batch_size, seq_length = input_shape + + seq_ids = torch.arange(seq_length, device=device) + causal_mask = seq_ids[None, None, :].repeat( + batch_size, seq_length, 1) <= seq_ids[None, :, None] + # in case past_key_values are used we need to + # add a prefix ones mask to the causal mask + # causal and attention masks must have same type + # with pytorch version < 1.3 + causal_mask = causal_mask.to(attention_mask.dtype) + + if causal_mask.shape[1] < attention_mask.shape[1]: + prefix_seq_len = attention_mask.shape[ + 1] - causal_mask.shape[1] + causal_mask = torch.cat( + [ + torch.ones( + (batch_size, seq_length, prefix_seq_len), + device=device, + dtype=causal_mask.dtype), + causal_mask, + ], + axis=-1, + ) + + extended_attention_mask = ( + causal_mask[:None, :, :] * + attention_mask[:, None, None, :]) + else: + extended_attention_mask = attention_mask[:, None, None, :] + else: + raise ValueError( + '''Wrong shape for input_ids (shape {}) or attention_mask + (shape {})'''.format(input_shape, attention_mask.shape)) + + # Since attention_mask is 1.0 + # for positions we want to attend and 0.0 + # for masked positions, this operation will + # create a tensor which is 0.0 for positions + # we want to attend and -10000.0 for masked positions. + # Since we are adding it to the raw scores + # before the softmax, this is effectively + # the same as removing these entirely. + extended_attention_mask = extended_attention_mask.to( + dtype=self.dtype) # fp16 compatibility + extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 + return extended_attention_mask + + def forward( + self, + input_ids=None, + attention_mask=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + encoder_embeds=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_values=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + is_decoder=False, + mode='multimodal', + ): + r""" + encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj: + `(batch_size, sequence_length, hidden_size)`, `optional`): + Sequence of hidden-states at the output of the last layer + of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj: + `(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on the padding token + indices of the encoder input. This mask is used in + the cross-attention if the model is configured as + a decoder. Mask values selected in ``[0, 1]``: + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length : + obj:`config.n_layers` with each tuple having 4 tensors of shape : + obj:`(batch_size, num_heads, sequence_length - 1, + embed_size_per_head)`): + Contains precomputed key and value hidden states of the + attention blocks. Can be used to speed up decoding. + If :obj:`past_key_values` are used, the user can optionally + input only the last :obj:`decoder_input_ids` + (those that don't have their past key value states given to + this model) of shape :obj:`(batch_size, 1)` + instead of all :obj:`decoder_input_ids` of shape :obj: + `(batch_size, sequence_length)`. + use_cache (:obj:`bool`, `optional`): + If set to :obj:`True`, :obj:`past_key_values` key value + states are returned and can be used to speed up + decoding (see :obj:`past_key_values`). + """ + output_attentions = ( + output_attentions if output_attentions is not None else + self.config.output_attentions) + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else + self.config.output_hidden_states) + return_dict = ( + return_dict + if return_dict is not None else self.config.use_return_dict) + + if is_decoder: + use_cache = ( + use_cache if use_cache is not None else self.config.use_cache) + else: + use_cache = False + + if input_ids is not None and inputs_embeds is not None: + raise ValueError('''You cannot specify both + input_ids and inputs_embeds at the same time''') + elif input_ids is not None: + input_shape = input_ids.size() + batch_size, seq_length = input_shape + device = input_ids.device + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + batch_size, seq_length = input_shape + device = inputs_embeds.device + elif encoder_embeds is not None: + input_shape = encoder_embeds.size()[:-1] + batch_size, seq_length = input_shape + device = encoder_embeds.device + else: + raise ValueError('''You have to specify either + input_ids or inputs_embeds or encoder_embeds''') + + # past_key_values_length + past_key_values_length = past_key_values[0][0].shape[ + 2] if past_key_values is not None else 0 + + if attention_mask is None: + attention_mask = torch.ones( + ((batch_size, seq_length + past_key_values_length)), + device=device) + + # We can provide a self-attention mask of dimensions + # [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to + # make it broadcastable to all heads. + extended_attention_mask: torch.Tensor = \ + (self.get_extended_attention_mask( + attention_mask, input_shape, device, is_decoder)) + + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to + # [batch_size, num_heads, seq_length, seq_length] + if encoder_hidden_states is not None: + if type(encoder_hidden_states) == list: + encoder_batch_size, encoder_sequence_length, _ = \ + (encoder_hidden_states[0].size()) + else: + encoder_batch_size, encoder_sequence_length, _ = \ + (encoder_hidden_states.size()) + encoder_hidden_shape = (encoder_batch_size, + encoder_sequence_length) + + if type(encoder_attention_mask) == list: + encoder_extended_attention_mask = [ + self.invert_attention_mask(mask) + for mask in encoder_attention_mask + ] + elif encoder_attention_mask is None: + encoder_attention_mask = torch.ones( + encoder_hidden_shape, device=device) + encoder_extended_attention_mask = self.invert_attention_mask( + encoder_attention_mask) + else: + encoder_extended_attention_mask = self.invert_attention_mask( + encoder_attention_mask) + else: + encoder_extended_attention_mask = None + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape + # [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape + # [num_hidden_layers x batch x num_heads x seq_length x seq_length] + head_mask = self.get_head_mask(head_mask, + self.config.num_hidden_layers) + + if encoder_embeds is None: + embedding_output = self.embeddings( + input_ids=input_ids, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + past_key_values_length=past_key_values_length, + ) + else: + embedding_output = encoder_embeds + + encoder_outputs = self.encoder( + embedding_output, + attention_mask=extended_attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_extended_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + mode=mode, + ) + sequence_output = encoder_outputs[0] + pooled_output = self.pooler( + sequence_output) if self.pooler is not None else None + + if not return_dict: + return (sequence_output, pooled_output) + encoder_outputs[1:] + + return BaseModelOutputWithPoolingAndCrossAttentions( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + past_key_values=encoder_outputs.past_key_values, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + cross_attentions=encoder_outputs.cross_attentions, + ) + + +class BertLMHeadModel(BertPreTrainedModel): + + _keys_to_ignore_on_load_unexpected = [r'pooler'] + _keys_to_ignore_on_load_missing = [ + r'position_ids', r'predictions.decoder.bias' + ] + + def __init__(self, config): + super().__init__(config) + + self.bert = BertModel(config, add_pooling_layer=False) + self.cls = BertOnlyMLMHead(config) + + self.init_weights() + + def get_output_embeddings(self): + return self.cls.predictions.decoder + + def set_output_embeddings(self, new_embeddings): + self.cls.predictions.decoder = new_embeddings + + def forward( + self, + input_ids=None, + attention_mask=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + labels=None, + past_key_values=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + return_logits=False, + is_decoder=True, + reduction='mean', + mode='multimodal', + ): + r""" + encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj: + `(batch_size, sequence_length, hidden_size)`, `optional`): + Sequence of hidden-states at the output of the last layer + of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj: + `(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on the padding token + indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. + Mask values selected in ``[0, 1]``: + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + labels (:obj:`torch.LongTensor` of shape :obj: + `(batch_size, sequence_length)`, `optional`): + Labels for computing the left-to-right + language modeling loss (next word prediction). + Indices should be in + ``[-100, 0, ..., config.vocab_size]`` + (see ``input_ids`` docstring) Tokens with indices set to + ``-100`` are ignored (masked), the loss is only computed + for the tokens with labels n ``[0, ..., config.vocab_size]`` + past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length + :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj: + `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention + blocks. Can be used to speed up decoding. + If :obj:`past_key_values` are used, the user can optionally + input only the last :obj:`decoder_input_ids` + (those that don't have their past key value states given to + this model) of shape :obj:`(batch_size, 1)` + instead of all :obj:`decoder_input_ids` of shape :obj: + `(batch_size, sequence_length)`. + use_cache (:obj:`bool`, `optional`): + If set to :obj:`True`, :obj:`past_key_values` key value states + are returned and can be used to speed up + decoding (see :obj:`past_key_values`). + Returns: + Example:: + >>> from transformers import (BertTokenizer, + BertLMHeadModel, BertConfig) + >>> import torch + >>> tokenizer = BertTokenizer.from_pretrained('bert-base-cased') + >>> config = BertConfig.from_pretrained("bert-base-cased") + >>> model = BertLMHeadModel.from_pretrained( + 'bert-base-cased', config=config) + >>> inputs = tokenizer("Hello, my dog is cute", + return_tensors="pt") + >>> outputs = model(**inputs) + >>> prediction_logits = outputs.logits + """ + return_dict = ( + return_dict + if return_dict is not None else self.config.use_return_dict) + if labels is not None: + use_cache = False + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + is_decoder=is_decoder, + mode=mode, + ) + + sequence_output = outputs[0] + prediction_scores = self.cls(sequence_output) + # sequence_output.shape torch.Size([85, 30, 768]) + # prediction_scores.shape torch.Size([85, 30, 30524]) + # labels.shape torch.Size([85, 30]) + + if return_logits: + return prediction_scores[:, :-1, :].contiguous() + + lm_loss = None + if labels is not None: + # we are doing next-token prediction; shift + # prediction scores and input ids by one + shifted_prediction_scores = prediction_scores[:, : + -1, :].contiguous() + labels = labels[:, 1:].contiguous() + loss_fct = CrossEntropyLoss( + reduction=reduction, label_smoothing=0.1) + lm_loss = loss_fct( + shifted_prediction_scores.view(-1, self.config.vocab_size), + labels.view(-1)) + if reduction == 'none': + lm_loss = lm_loss.view(prediction_scores.size(0), -1).sum(1) + + if not return_dict: + output = (prediction_scores, ) + outputs[2:] + return ((lm_loss, ) + output) if lm_loss is not None else output + + return CausalLMOutputWithCrossAttentions( + loss=lm_loss, + logits=prediction_scores, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, + ) + + def prepare_inputs_for_generation(self, + input_ids, + past=None, + attention_mask=None, + **model_kwargs): + input_shape = input_ids.shape + # if model is used as a decoder in encoder-decoder model, + # the decoder attention mask is created on the fly + if attention_mask is None: + attention_mask = input_ids.new_ones(input_shape) + + # cut decoder_input_ids if past is used + if past is not None: + input_ids = input_ids[:, -1:] + + return { + 'input_ids': + input_ids, + 'attention_mask': + attention_mask, + 'past_key_values': + past, + 'encoder_hidden_states': + model_kwargs.get('encoder_hidden_states', None), + 'encoder_attention_mask': + model_kwargs.get('encoder_attention_mask', None), + 'is_decoder': + True, + } + + def _reorder_cache(self, past, beam_idx): + reordered_past = () + for layer_past in past: + reordered_past += (tuple( + past_state.index_select(0, beam_idx) + for past_state in layer_past), ) + return reordered_past diff --git a/mmpretrain/models/multimodal/ram/config/__init__.py b/mmpretrain/models/multimodal/ram/config/__init__.py new file mode 100644 index 00000000000..ef101fec61e --- /dev/null +++ b/mmpretrain/models/multimodal/ram/config/__init__.py @@ -0,0 +1 @@ +# Copyright (c) OpenMMLab. All rights reserved. diff --git a/mmpretrain/models/multimodal/ram/config/ram_swin_large_14m.py b/mmpretrain/models/multimodal/ram/config/ram_swin_large_14m.py new file mode 100644 index 00000000000..e4b88653b3b --- /dev/null +++ b/mmpretrain/models/multimodal/ram/config/ram_swin_large_14m.py @@ -0,0 +1,93 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# data settings +test_transforms_cfg = [ + dict(type='Resize', scale=(384, 384), interpolation='bicubic'), + dict( + type='mmpretrain.PackInputs', + algorithm_keys=['text'], + meta_keys=['image_id', 'scale_factor'], + ), +] + + +def get_ram_cfg(mode='normal'): + assert mode in ['normal', 'openset'], 'mode must "normal" or "openset"' + model_type = 'RAMNormal' if mode == 'normal' else 'RAMOpenset' + model_cfg = dict( + type=model_type, + tokenizer=dict( + type='BertTokenizer', + name_or_path='/public/DATA/qbw/ckpt/bert-base-uncased', + use_fast=False), + vision_backbone=dict( + type='SwinTransformer', + arch='large', + img_size=384, + window_size=12, + ), + tag_encoder={ + 'architectures': ['BertModel'], + 'attention_probs_dropout_prob': 0.1, + 'hidden_act': 'gelu', + 'hidden_dropout_prob': 0.1, + 'hidden_size': 768, + 'initializer_range': 0.02, + 'intermediate_size': 3072, + 'layer_norm_eps': 1e-12, + 'max_position_embeddings': 512, + 'model_type': 'bert', + 'num_attention_heads': 12, + 'num_hidden_layers': 12, + 'pad_token_id': 0, + 'type_vocab_size': 2, + 'vocab_size': 30524, + 'encoder_width': 512, + 'add_cross_attention': True + }, + text_decoder={ + 'architectures': ['BertModel'], + 'attention_probs_dropout_prob': 0.1, + 'hidden_act': 'gelu', + 'hidden_dropout_prob': 0.1, + 'hidden_size': 768, + 'initializer_range': 0.02, + 'intermediate_size': 3072, + 'layer_norm_eps': 1e-12, + 'max_position_embeddings': 512, + 'model_type': 'bert', + 'num_attention_heads': 12, + 'num_hidden_layers': 12, + 'pad_token_id': 0, + 'type_vocab_size': 2, + 'vocab_size': 30524, + 'encoder_width': 768, + 'add_cross_attention': True + }, + tagging_head={ + 'architectures': ['BertModel'], + 'attention_probs_dropout_prob': 0.1, + 'hidden_act': 'gelu', + 'hidden_dropout_prob': 0.1, + 'hidden_size': 768, + 'initializer_range': 0.02, + 'intermediate_size': 3072, + 'layer_norm_eps': 1e-12, + 'max_position_embeddings': 512, + 'model_type': 'bert', + 'num_attention_heads': 4, + 'num_hidden_layers': 2, + 'pad_token_id': 0, + 'type_vocab_size': 2, + 'vocab_size': 30522, + 'encoder_width': 512, + 'add_cross_attention': True, + 'add_tag_cross_attention': False + }, + data_preprocessor=dict( + type='MultiModalDataPreprocessor', + mean=[0.48145466 * 255, 0.4578275 * 255, 0.40821073 * 255], + std=[0.26862954 * 255, 0.26130258 * 255, 0.27577711 * 255], + to_rgb=False, + ), + ) + return model_cfg diff --git a/mmpretrain/models/multimodal/ram/data/ram_tag_list.pickle b/mmpretrain/models/multimodal/ram/data/ram_tag_list.pickle new file mode 100644 index 00000000000..0519d1ee759 Binary files /dev/null and b/mmpretrain/models/multimodal/ram/data/ram_tag_list.pickle differ diff --git a/mmpretrain/models/multimodal/ram/data/ram_tag_list_chinese.pickle b/mmpretrain/models/multimodal/ram/data/ram_tag_list_chinese.pickle new file mode 100644 index 00000000000..4abe105e3b3 Binary files /dev/null and b/mmpretrain/models/multimodal/ram/data/ram_tag_list_chinese.pickle differ diff --git a/mmpretrain/models/multimodal/ram/data/ram_tag_list_threshold.pickle b/mmpretrain/models/multimodal/ram/data/ram_tag_list_threshold.pickle new file mode 100644 index 00000000000..2be681d6f0a Binary files /dev/null and b/mmpretrain/models/multimodal/ram/data/ram_tag_list_threshold.pickle differ diff --git a/mmpretrain/models/multimodal/ram/gradio_demo.py b/mmpretrain/models/multimodal/ram/gradio_demo.py new file mode 100644 index 00000000000..206e6b40fd8 --- /dev/null +++ b/mmpretrain/models/multimodal/ram/gradio_demo.py @@ -0,0 +1,109 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse + +import gradio as gr +import torch + +from mmpretrain.registry import MODELS, TRANSFORMS +from .config.ram_swin_large_14m import get_ram_cfg, test_transforms_cfg +from .run.inference import inference + +parser = argparse.ArgumentParser( + description='RAM(Recognize Anything Model) demo') +parser.add_argument( + 'ram_ckpt', type=str, help='pretrained file for ram (absolute path)') +parser.add_argument( + 'clip_ckpt', + type=str, + help='clip vit-base-p16 pretrained file (absolute path)') +args = parser.parse_args() + +if torch.cuda.is_available(): + devices = [ + torch.device(f'cuda:{i}') for i in range(torch.cuda.device_count()) + ] +elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): + devices = [torch.device('mps')] +else: + devices = [torch.device('cpu')] + + +def get_free_device(): + if hasattr(torch.cuda, 'mem_get_info'): + free = [torch.cuda.mem_get_info(gpu)[0] for gpu in devices] + select = max(zip(free, range(len(free))))[1] + else: + import random + select = random.randint(0, len(devices) - 1) + return devices[select] + + +device = get_free_device() + + +def ram_inference(image, tag_list, mode, threshold): + test_transforms = TRANSFORMS.get('Compose')(transforms=test_transforms_cfg) + model = MODELS.build(get_ram_cfg(mode=mode)) + model.load_state_dict(torch.load(args.ram_ckpt)) + model.device = device + + if mode == 'openset': + categories = tag_list + if categories != '': + categories = categories.strip().split() + else: + categories = None + model.set_openset( + categories=categories, + clip_ckpt=args.clip_ckpt, + threshold=threshold) + + sample = dict(img=image) + result = inference(sample, model, test_transforms, mode=mode) + tag, tag_chinese, logits = \ + result.get('tag_output')[0][0], result.get('tag_output')[1][0],\ + result.get('logits_output')[0] + + def wrap(tags, logits): + if tags is None: + return 'Openset mode has no tag_en' + tag_lst = tags.split('|') + rt_lst = [] + for i, tag in enumerate(tag_lst): + tag = tag.strip() + rt_lst.append(tag + f': {logits[i]:.2f}') + return ' | '.join(rt_lst) + + return [wrap(tag, logits), wrap(tag_chinese, logits)] + + +def build_gradio(): + inputs = [ + gr.components.Image(label='image'), + gr.components.Textbox( + lines=2, + label='tag_list', + placeholder= + 'please input the categories split by keyboard "blank": ', + value=''), + gr.components.Radio(['normal', 'openset'], + label='mode', + value='normal'), + gr.components.Slider( + minimum=0, maximum=1, value=0.68, step=0.01, label='threshold') + ] + return gr.Interface( + fn=ram_inference, + inputs=inputs, + outputs=[ + gr.components.Textbox(), + gr.components.Textbox(info="it's translated from the english tags") + ]) + + +def main(): + build_gradio().launch() + + +if __name__ == '__main__': + main() diff --git a/mmpretrain/models/multimodal/ram/openset_utils.py b/mmpretrain/models/multimodal/ram/openset_utils.py new file mode 100644 index 00000000000..5fa0f52e26e --- /dev/null +++ b/mmpretrain/models/multimodal/ram/openset_utils.py @@ -0,0 +1,212 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch + +from mmpretrain.registry import MODELS + + +def article(name): + return 'an' if name[0] in 'aeiou' else 'a' + + +def processed_name(name, rm_dot=False): + # _ for lvis + # / for obj365 + res = name.replace('_', ' ').replace('/', ' or ').lower() + if rm_dot: + res = res.rstrip('.') + return res + + +single_template = ['a photo of a {}.'] + +multiple_templates = [ + 'There is {article} {} in the scene.', + 'There is the {} in the scene.', + 'a photo of {article} {} in the scene.', + 'a photo of the {} in the scene.', + 'a photo of one {} in the scene.', + 'itap of {article} {}.', + 'itap of my {}.', # itap: I took a picture of + 'itap of the {}.', + 'a photo of {article} {}.', + 'a photo of my {}.', + 'a photo of the {}.', + 'a photo of one {}.', + 'a photo of many {}.', + 'a good photo of {article} {}.', + 'a good photo of the {}.', + 'a bad photo of {article} {}.', + 'a bad photo of the {}.', + 'a photo of a nice {}.', + 'a photo of the nice {}.', + 'a photo of a cool {}.', + 'a photo of the cool {}.', + 'a photo of a weird {}.', + 'a photo of the weird {}.', + 'a photo of a small {}.', + 'a photo of the small {}.', + 'a photo of a large {}.', + 'a photo of the large {}.', + 'a photo of a clean {}.', + 'a photo of the clean {}.', + 'a photo of a dirty {}.', + 'a photo of the dirty {}.', + 'a bright photo of {article} {}.', + 'a bright photo of the {}.', + 'a dark photo of {article} {}.', + 'a dark photo of the {}.', + 'a photo of a hard to see {}.', + 'a photo of the hard to see {}.', + 'a low resolution photo of {article} {}.', + 'a low resolution photo of the {}.', + 'a cropped photo of {article} {}.', + 'a cropped photo of the {}.', + 'a close-up photo of {article} {}.', + 'a close-up photo of the {}.', + 'a jpeg corrupted photo of {article} {}.', + 'a jpeg corrupted photo of the {}.', + 'a blurry photo of {article} {}.', + 'a blurry photo of the {}.', + 'a pixelated photo of {article} {}.', + 'a pixelated photo of the {}.', + 'a black and white photo of the {}.', + 'a black and white photo of {article} {}.', + 'a plastic {}.', + 'the plastic {}.', + 'a toy {}.', + 'the toy {}.', + 'a plushie {}.', + 'the plushie {}.', + 'a cartoon {}.', + 'the cartoon {}.', + 'an embroidered {}.', + 'the embroidered {}.', + 'a painting of the {}.', + 'a painting of a {}.', +] + +openimages_rare_unseen = [ + 'Aerial photography', 'Aircraft engine', 'Ale', 'Aloe', 'Amphibian', + 'Angling', 'Anole', 'Antique car', 'Arcade game', 'Arthropod', + 'Assault rifle', 'Athletic shoe', 'Auto racing', 'Backlighting', + 'Bagpipes', 'Ball game', 'Barbecue chicken', 'Barechested', 'Barquentine', + 'Beef tenderloin', 'Billiard room', 'Billiards', 'Bird of prey', + 'Black swan', 'Black-and-white', 'Blond', 'Boating', 'Bonbon', + 'Bottled water', 'Bouldering', 'Bovine', 'Bratwurst', 'Breadboard', + 'Briefs', 'Brisket', 'Brochette', 'Calabaza', 'Camera operator', 'Canola', + 'Childbirth', 'Chordophone', 'Church bell', 'Classical sculpture', + 'Close-up', 'Cobblestone', 'Coca-cola', 'Combat sport', 'Comics', + 'Compact car', 'Computer speaker', 'Cookies and crackers', + 'Coral reef fish', 'Corn on the cob', 'Cosmetics', 'Crocodilia', + 'Digital camera', 'Dishware', 'Divemaster', 'Dobermann', 'Dog walking', + 'Domestic rabbit', 'Domestic short-haired cat', 'Double-decker bus', + 'Drums', 'Electric guitar', 'Electric piano', 'Electronic instrument', + 'Equestrianism', 'Equitation', 'Erinaceidae', 'Extreme sport', 'Falafel', + 'Figure skating', 'Filling station', 'Fire apparatus', 'Firearm', + 'Flatbread', 'Floristry', 'Forklift truck', 'Freight transport', + 'Fried food', 'Fried noodles', 'Frigate', 'Frozen yogurt', 'Frying', + 'Full moon', 'Galleon', 'Glacial landform', 'Gliding', 'Go-kart', 'Goats', + 'Grappling', 'Great white shark', 'Gumbo', 'Gun turret', 'Hair coloring', + 'Halter', 'Headphones', 'Heavy cruiser', 'Herding', 'High-speed rail', + 'Holding hands', 'Horse and buggy', 'Horse racing', 'Hound', + 'Hunting knife', 'Hurdling', 'Inflatable', 'Jackfruit', 'Jeans', 'Jiaozi', + 'Junk food', 'Khinkali', 'Kitesurfing', 'Lawn game', 'Leaf vegetable', + 'Lechon', 'Lifebuoy', 'Locust', 'Lumpia', 'Luxury vehicle', 'Machine tool', + 'Medical imaging', 'Melee weapon', 'Microcontroller', 'Middle ages', + 'Military person', 'Military vehicle', 'Milky way', 'Miniature Poodle', + 'Modern dance', 'Molluscs', 'Monoplane', 'Motorcycling', 'Musical theatre', + 'Narcissus', 'Nest box', 'Newsagent\'s shop', 'Nile crocodile', + 'Nordic skiing', 'Nuclear power plant', 'Orator', 'Outdoor shoe', + 'Parachuting', 'Pasta salad', 'Peafowl', 'Pelmeni', 'Perching bird', + 'Performance car', 'Personal water craft', 'Pit bull', 'Plant stem', + 'Pork chop', 'Portrait photography', 'Primate', 'Procyonidae', + 'Prosciutto', 'Public speaking', 'Racewalking', 'Ramen', + 'Rear-view mirror', 'Residential area', 'Ribs', 'Rice ball', + 'Road cycling', 'Roller skating', 'Roman temple', 'Rowing', 'Rural area', + 'Sailboat racing', 'Scaled reptile', 'Scuba diving', 'Senior citizen', + 'Shallot', 'Shinto shrine', 'Shooting range', 'Siberian husky', 'Sledding', + 'Soba', 'Solar energy', 'Sport climbing', 'Sport utility vehicle', + 'Steamed rice', 'Stemware', 'Sumo', 'Surfing Equipment', 'Team sport', + 'Touring car', 'Toy block', 'Trampolining', 'Underwater diving', + 'Vegetarian food', 'Wallaby', 'Water polo', 'Watercolor paint', 'Whiskers', + 'Wind wave', 'Woodwind instrument', 'Yakitori', 'Zeppelin' +] + + +def get_clip_model(): + model = dict( + type='CLIPZeroShot', + vision_backbone=dict( + type='VisionTransformer', + arch='base', + img_size=224, + patch_size=16, + drop_rate=0., + layer_cfgs=dict(act_cfg=dict(type='mmpretrain.QuickGELU')), + pre_norm=True, + ), + projection=dict( + type='CLIPProjection', in_channels=768, out_channels=512), + text_backbone=dict( + type='CLIPTransformer', + width=512, + layers=12, + heads=8, + attn_mask=True, + ), + tokenizer=dict( + type='AutoTokenizer', + name_or_path='openai/clip-vit-base-patch16', + use_fast=False), + vocab_size=49408, + transformer_width=512, + proj_dim=512, + context_length=77, + data_preprocessor=dict( + type='MultiModalDataPreprocessor', + mean=[0.48145466 * 255, 0.4578275 * 255, 0.40821073 * 255], + std=[0.26862954 * 255, 0.26130258 * 255, 0.27577711 * 255], + to_rgb=False, + ), + ) + return MODELS.build(model) + + +def build_openset_label_embedding(categories=None, clip_ckpt_path=''): + if categories is None: + print('Categories is None, so using rare_unseen categories') + categories = openimages_rare_unseen + model = get_clip_model() + model.load_state_dict(torch.load(clip_ckpt_path)) + templates = multiple_templates + + run_on_gpu = torch.cuda.is_available() + + with torch.no_grad(): + openset_label_embedding = [] + for category in categories: + texts = [ + template.format( + processed_name(category, rm_dot=True), + article=article(category)) for template in templates + ] + texts = [ + 'This is ' + text + if text.startswith('a') or text.startswith('the') else text + for text in texts + ] + texts = model.tokenize(texts) # tokenize + if run_on_gpu: + texts = texts.cuda() + model = model.cuda() + text_embeddings = model.extract_text_feat(texts) + text_embeddings /= text_embeddings.norm(dim=-1, keepdim=True) + text_embedding = text_embeddings.mean(dim=0) + text_embedding /= text_embedding.norm() + openset_label_embedding.append(text_embedding) + openset_label_embedding = torch.stack(openset_label_embedding, dim=1) + if run_on_gpu: + openset_label_embedding = openset_label_embedding.cuda() + + openset_label_embedding = openset_label_embedding.t() + return openset_label_embedding, categories diff --git a/mmpretrain/models/multimodal/ram/ram.py b/mmpretrain/models/multimodal/ram/ram.py new file mode 100644 index 00000000000..c5d22f07817 --- /dev/null +++ b/mmpretrain/models/multimodal/ram/ram.py @@ -0,0 +1,332 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os +import pickle +from abc import abstractmethod +from typing import List, Optional + +import numpy as np +import torch +import torch.nn as nn +from mmengine.model import BaseModel + +from mmpretrain.registry import MODELS, TOKENIZER +from mmpretrain.structures import DataSample +from .bert import BertConfig, BertLMHeadModel, BertModel +from .openset_utils import build_openset_label_embedding +from .utils import tie_encoder_decoder_weights + + +def get_path(path): + file_path = os.path.abspath(os.path.dirname(__file__)) + if not os.path.isabs(path): + return os.path.join(file_path, path) + + +class RAM(BaseModel): + """The implementation of `RAM `_.""" + + def __init__(self, + tokenizer: dict, + vision_backbone: dict, + tag_encoder: dict, + tagging_head: dict, + text_decoder: dict, + device: str = 'cpu', + vision_width: int = 1536, + prompt='a picture of ', + threshold=0.68, + delete_tag_index=[], + tag_list='./data/ram_tag_list.pickle', + tag_list_chinese='./data/ram_tag_list_chinese.pickle', + data_preprocessor: Optional[dict] = None, + init_cfg: Optional[dict] = None): + if data_preprocessor is None: + data_preprocessor = {} + data_preprocessor.setdefault('type', 'MultiModalDataPreprocessor') + data_preprocessor = MODELS.build(data_preprocessor) + + super().__init__( + data_preprocessor=data_preprocessor, init_cfg=init_cfg) + + self.device = device + # build the visual encoder + self.visual_encoder = MODELS.build(vision_backbone) + + # build the tokenizer + self.tokenizer = TOKENIZER.build(tokenizer) + self.tokenizer.add_special_tokens({'bos_token': '[DEC]'}) + self.tokenizer.add_special_tokens( + {'additional_special_tokens': ['[ENC]']}) + self.tokenizer.enc_token_id = \ + self.tokenizer.additional_special_tokens_ids[0] + + # build the tag encoder + # encoder_config = BertConfig.from_json_file(med_config) + # encoder_config.encoder_width = 512 + encoder_config = BertConfig.from_dict(tag_encoder) + self.tag_encoder = BertModel( + config=encoder_config, add_pooling_layer=False) + + # build image-tag-text decoder + # decoder_config = BertConfig.from_json_file(med_config) + decoder_config = BertConfig.from_dict(text_decoder) + self.text_decoder = BertLMHeadModel(config=decoder_config) + + self.delete_tag_index = delete_tag_index + self.prompt = prompt + self.prompt_length = len(self.tokenizer(self.prompt).input_ids) - 1 + + # load tag list + self.tag_list = self.load_tag_list(get_path(tag_list)) + self.tag_list_chinese = self.load_tag_list(get_path(tag_list_chinese)) + + # create image-tag recognition decoder + self.threshold = threshold + self.num_class = len(self.tag_list) + # q2l_config = \ + # BertConfig.from_json_file(f'{CONFIG_PATH}/configs/q2l_config.json') + # q2l_config.encoder_width = 512 + q2l_config = BertConfig.from_dict(tagging_head) + self.tagging_head = BertModel( + config=q2l_config, add_pooling_layer=False) + self.tagging_head.resize_token_embeddings(len(self.tokenizer)) + self.label_embed = nn.Parameter( + torch.zeros(self.num_class, q2l_config.encoder_width)) + + if q2l_config.hidden_size != 512: + self.wordvec_proj = nn.Linear(512, q2l_config.hidden_size) + else: + self.wordvec_proj = nn.Identity() + + self.fc = nn.Linear(q2l_config.hidden_size, 1) + + self.del_selfattention() + + # share weights of the lowest 2-layer of + # "image-tag interaction encoder" with + # the "image-tag recogntion decoder" + tie_encoder_decoder_weights(self.tag_encoder, self.tagging_head, '', + ' ') + self.image_proj = nn.Linear(vision_width, 512) + # self.label_embed = nn.Parameter(torch.load( + # f'{CONFIG_PATH}/data/textual_label_embedding.pth', + # map_location='cpu').float()) + + # adjust thresholds for some tags + self.class_threshold = torch.ones(self.num_class) * self.threshold + ram_class_threshold_path = get_path( + './data/ram_tag_list_threshold.pickle') + with open(ram_class_threshold_path, 'rb') as f: + ram_class_threshold = pickle.load(f) + for key, value in enumerate(ram_class_threshold): + self.class_threshold[key] = value + + def load_tag_list(self, tag_list_file): + with open(tag_list_file, 'rb') as f: + tag_list = pickle.load(f) + tag_list = np.array(tag_list) + return tag_list + + # delete self-attention layer of image-tag recognition decoder + # to reduce computation, follower Query2Label + def del_selfattention(self): + del self.tagging_head.embeddings + for layer in self.tagging_head.encoder.layer: + del layer.attention + + def get_label_embed(self): + return torch.nn.functional.relu(self.wordvec_proj(self.label_embed)) + + def extract_visual_feature(self, images): + image_embeds = self.visual_encoder(images)[0] + image_embeds = image_embeds.flatten(2, 3) + attn_pool = nn.AdaptiveAvgPool1d(1) + cls_token = attn_pool(image_embeds).permute(0, 2, 1).contiguous() + image_embeds = image_embeds.permute(0, 2, 1).contiguous() + image_embeds = torch.cat([cls_token, image_embeds], dim=1) + image_embeds = self.image_proj(image_embeds) + image_atts = torch.ones( + image_embeds.size()[:-1], dtype=torch.long).to(images.device) + return image_embeds, image_atts + + def image2tag(self, label_embed, image_embeds, image_atts): + # recognized image tags using image-tag recogntiion decoder + # image_cls_embeds = image_embeds[:, 0, :] + image_spatial_embeds = image_embeds[:, 1:, :] + + bs = image_spatial_embeds.shape[0] + label_embed = label_embed.unsqueeze(0).repeat(bs, 1, 1) + tagging_embed = self.tagging_head( + encoder_embeds=label_embed, + encoder_hidden_states=image_embeds, + encoder_attention_mask=image_atts, + return_dict=False, + mode='tagging', + ) + + logits = self.fc(tagging_embed[0]).squeeze(-1) + return logits + + def forward( + self, + images: torch.Tensor, + data_samples: Optional[list] = None, + mode: str = 'predict', + **kwargs, + ): + if mode == 'predict': + return self.predict(images, data_samples, **kwargs) + else: + raise RuntimeError(f'Invalid mode "{mode}".') + + @abstractmethod + def predict(self, + images: torch.Tensor, + data_samples: DataSample = None) -> DataSample: + raise NotImplementedError + + +@MODELS.register_module() +class RAMNormal(RAM): + + def __init__(self, + tokenizer: dict, + vision_backbone: dict, + tag_encoder: dict, + tagging_head: dict, + text_decoder: dict, + device: str = 'cpu', + vision_width: int = 1536, + prompt='a picture of ', + threshold=0.68, + delete_tag_index=[], + tag_list='./data/ram_tag_list.pickle', + tag_list_chinese='./data/ram_tag_list_chinese.pickle', + data_preprocessor: Optional[dict] = None, + init_cfg: Optional[dict] = None): + super().__init__( + tokenizer, + vision_backbone, + tag_encoder, + tagging_head, + text_decoder, + device, + vision_width, + prompt, + threshold, + delete_tag_index, + tag_list, + tag_list_chinese, + data_preprocessor, + init_cfg, + ) + + def tag_process(self, logits): + targets = torch.where( + torch.sigmoid(logits) > self.class_threshold.to(logits.device), + torch.tensor(1.0).to(logits.device), + torch.zeros(self.num_class).to(logits.device)) + + tag = targets.cpu().numpy() + tag[:, self.delete_tag_index] = 0 + tag_output = [] + tag_output_chinese = [] + logits_output = [] + + bs = logits.shape[0] + for b in range(bs): + index = np.argwhere(tag[b] == 1) + token = self.tag_list[index].squeeze(axis=1) + logits_output.append( + torch.sigmoid(logits)[b][index[:, 0]].cpu().numpy()) + tag_output.append(' | '.join(token)) + token_chinese = self.tag_list_chinese[index].squeeze(axis=1) + tag_output_chinese.append(' | '.join(token_chinese)) + + return [(tag_output, tag_output_chinese), logits_output] + + def predict(self, + images: torch.Tensor, + data_samples: DataSample = None) -> DataSample: + self.eval() + self.to(self.device) + images = images.to(self.device) + label_embed = self.get_label_embed() + image_embeds, image_atts = self.extract_visual_feature(images) + logits = self.image2tag(label_embed, image_embeds, image_atts) + tag_output, logits_output = self.tag_process(logits) + data_samples.set_field(logits_output, 'logits_output') + data_samples.set_field(tag_output, 'tag_output') + return data_samples + + +@MODELS.register_module() +class RAMOpenset(RAMNormal): + + def __init__(self, + tokenizer: dict, + vision_backbone: dict, + tag_encoder: dict, + tagging_head: dict, + text_decoder: dict, + device: str = 'cpu', + vision_width: int = 1536, + prompt='a picture of ', + threshold=0.68, + delete_tag_index=[], + tag_list='./data/ram_tag_list.pickle', + tag_list_chinese='./data/ram_tag_list_chinese.pickle', + data_preprocessor: Optional[dict] = None, + init_cfg: Optional[dict] = None): + super().__init__( + tokenizer, + vision_backbone, + tag_encoder, + tagging_head, + text_decoder, + device, + vision_width, + prompt, + threshold, + delete_tag_index, + tag_list, + tag_list_chinese, + data_preprocessor, + init_cfg, + ) + + def set_openset(self, + categories: List[str] = None, + clip_ckpt: str = '', + threshold: float = 0.68): + openset_label_embedding, openset_categories = \ + build_openset_label_embedding( + categories, clip_ckpt + ) + self.tag_list = np.array(openset_categories) + self.label_embed = nn.Parameter(openset_label_embedding.float()) + self.num_class = len(openset_categories) + + # the threshold for unseen categories is often lower + self.class_threshold = torch.ones(self.num_class) * threshold + + def tag_process(self, logits): + targets = torch.where( + torch.sigmoid(logits) > self.class_threshold.to(logits.device), + torch.tensor(1.0).to(logits.device), + torch.zeros(self.num_class).to(logits.device)) + + tag = targets.cpu().numpy() + tag[:, self.delete_tag_index] = 0 + + bs = logits.shape[0] + tag_output = [] + logits_output = [] + for b in range(bs): + index = np.argwhere(tag[b] == 1) + token = self.tag_list[index].squeeze(axis=1) + logits_output.append( + torch.sigmoid(logits)[b][index[:, 0]].cpu().numpy()) + tag_output.append(' | '.join(token)) + + return [(tag_output, [None]), logits_output] diff --git a/mmpretrain/models/multimodal/ram/run/__init__.py b/mmpretrain/models/multimodal/ram/run/__init__.py new file mode 100644 index 00000000000..ef101fec61e --- /dev/null +++ b/mmpretrain/models/multimodal/ram/run/__init__.py @@ -0,0 +1 @@ +# Copyright (c) OpenMMLab. All rights reserved. diff --git a/mmpretrain/models/multimodal/ram/run/inference.py b/mmpretrain/models/multimodal/ram/run/inference.py new file mode 100644 index 00000000000..da5afcf5e9d --- /dev/null +++ b/mmpretrain/models/multimodal/ram/run/inference.py @@ -0,0 +1,29 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch + + +def inference_ram(sample, model): + + with torch.no_grad(): + result = model.test_step(sample) + + return result + + +def inference_ram_openset(sample, model): + with torch.no_grad(): + result = model.test_step(sample) + + return result + + +def inference(sample, model, transforms, mode='normal'): + sample = transforms(sample) + if sample['inputs'].ndim == 3: + sample['inputs'] = sample['inputs'].unsqueeze(dim=0) + assert mode in ['normal', 'openset' + ], 'mode of inference must be "normal" or "openset"' + if mode == 'normal': + return inference_ram(sample, model) + else: + return inference_ram_openset(sample, model) diff --git a/mmpretrain/models/multimodal/ram/utils.py b/mmpretrain/models/multimodal/ram/utils.py new file mode 100644 index 00000000000..32cb115be64 --- /dev/null +++ b/mmpretrain/models/multimodal/ram/utils.py @@ -0,0 +1,87 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List + +from torch import nn + + +def tie_encoder_decoder_weights(encoder: nn.Module, decoder: nn.Module, + base_model_prefix: str, skip_key: str): + uninitialized_encoder_weights: List[str] = [] + if decoder.__class__ != encoder.__class__: + print(f'''{decoder.__class__} and {encoder.__class__} are not equal. + In this case make sure that + all encoder weights are correctly initialized.''') + + def tie_encoder_to_decoder_recursively( + decoder_pointer: nn.Module, + encoder_pointer: nn.Module, + module_name: str, + uninitialized_encoder_weights: List[str], + skip_key: str, + depth=0, + ): + assert isinstance(decoder_pointer, nn.Module) and isinstance( + encoder_pointer, nn.Module + ), f'{decoder_pointer} and {encoder_pointer}' + \ + 'have to be of type torch.nn.Module' + if hasattr(decoder_pointer, 'weight') and skip_key not in module_name: + assert hasattr(encoder_pointer, 'weight') + encoder_pointer.weight = decoder_pointer.weight + if hasattr(decoder_pointer, 'bias'): + assert hasattr(encoder_pointer, 'bias') + encoder_pointer.bias = decoder_pointer.bias + print(module_name + ' is tied') + return + + encoder_modules = encoder_pointer._modules + decoder_modules = decoder_pointer._modules + if len(decoder_modules) > 0: + assert (len(encoder_modules) > + 0), f'''Encoder module {encoder_pointer} + does not match decoder module {decoder_pointer}''' + + all_encoder_weights = set([ + module_name + '/' + sub_name + for sub_name in encoder_modules.keys() + ]) + encoder_layer_pos = 0 + for name, module in decoder_modules.items(): + if name.isdigit(): + encoder_name = str(int(name) + encoder_layer_pos) + decoder_name = name + if not isinstance( + decoder_modules[decoder_name], + type(encoder_modules[encoder_name])) and len( + encoder_modules) != len(decoder_modules): + # this can happen if the name corresponds to + # the position in a list module list of layers + # in this case the decoder has added a + # cross-attention that the encoder doesn't have + # thus skip this step and + # subtract one layer pos from encoder + encoder_layer_pos -= 1 + continue + elif name not in encoder_modules: + continue + elif depth > 500: + raise ValueError( + '''Max depth of recursive function `tie_encoder_to_decoder` reached. + It seems that there is a circular dependency + between two or more `nn.Modules` of your model.''') + else: + decoder_name = encoder_name = name + tie_encoder_to_decoder_recursively( + decoder_modules[decoder_name], + encoder_modules[encoder_name], + module_name + '/' + name, + uninitialized_encoder_weights, + skip_key, + depth=depth + 1, + ) + all_encoder_weights.remove(module_name + '/' + encoder_name) + + uninitialized_encoder_weights += list(all_encoder_weights) + + # tie weights recursively + tie_encoder_to_decoder_recursively(decoder, encoder, base_model_prefix, + uninitialized_encoder_weights, skip_key) diff --git a/mmpretrain/models/utils/tokenizer.py b/mmpretrain/models/utils/tokenizer.py index 5b8a324bad0..fddda432ff7 100644 --- a/mmpretrain/models/utils/tokenizer.py +++ b/mmpretrain/models/utils/tokenizer.py @@ -12,6 +12,7 @@ register_hf_tokenizer(AutoTokenizer) register_hf_tokenizer(LlamaTokenizer) +register_hf_tokenizer(BertTokenizer) @register_hf_tokenizer() diff --git a/mmpretrain/version.py b/mmpretrain/version.py index 24b33124c0d..32f800cd969 100644 --- a/mmpretrain/version.py +++ b/mmpretrain/version.py @@ -1,6 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved -__version__ = '1.0.2' +__version__ = '1.1.0' def parse_version_info(version_str): diff --git a/projects/dino/README.md b/projects/dino/README.md new file mode 100644 index 00000000000..3458fa4cdb3 --- /dev/null +++ b/projects/dino/README.md @@ -0,0 +1,26 @@ +# Implementation for DINO + +**NOTE**: We only guarantee correctness of the forward pass, not responsible for full reimplementation. + +First, ensure you are in the root directory of MMPretrain, then you have two choices +to play with DINO in MMPretrain: + +## Slurm + +If you are using a cluster managed by Slurm, you can use the following command to +start your job: + +```shell +GPUS_PER_NODE=8 GPUS=8 CPUS_PER_TASK=16 bash projects/dino/tools/slurm_train.sh mm_model dino projects/dino/config/dino_vit-base-p16_8xb64-amp-coslr-100e_in1k.py --amp +``` + +The above command will pre-train the model on a single node with 8 GPUs. + +## PyTorch + +If you are using a single machine, without any cluster management software, you can use the following command + +```shell +NNODES=1 bash projects/dino/tools/dist_train.sh projects/dino/config/dino_vit-base-p16_8xb64-amp-coslr-100e_in1k.py 8 +--amp +``` diff --git a/projects/dino/config/dino_vit-base-p16_8xb64-amp-coslr-100e_in1k.py b/projects/dino/config/dino_vit-base-p16_8xb64-amp-coslr-100e_in1k.py new file mode 100644 index 00000000000..d4a1c240218 --- /dev/null +++ b/projects/dino/config/dino_vit-base-p16_8xb64-amp-coslr-100e_in1k.py @@ -0,0 +1,104 @@ +model = dict( + type='DINO', + data_preprocessor=dict( + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + bgr_to_rgb=True), + backbone=dict( + type='mmpretrain.VisionTransformer', arch='b', patch_size=16), + neck=dict( + type='DINONeck', + in_channels=768, + out_channels=65536, + hidden_channels=2048, + bottleneck_channels=256), + head=dict( + type='DINOHead', + out_channels=65536, + num_crops=10, + student_temp=0.1, + center_momentum=0.9)) +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='DINOMultiCrop', + global_crops_scale=(0.4, 1.0), + local_crops_scale=(0.05, 0.4), + local_crops_number=8), + dict(type='PackInputs') +] +train_dataloader = dict( + batch_size=32, + num_workers=16, + persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=True), + collate_fn=dict(type='default_collate'), + dataset=dict( + type='mmpretrain.ImageNet', + data_root='/data/imagenet/', + ann_file='meta/train.txt', + data_prefix=dict(img_path='train/'), + pipeline=train_pipeline, + )) +optimizer = dict(type='AdamW', lr=0.0024, betas=(0.9, 0.95), weight_decay=0.05) +optim_wrapper = dict( + type='AmpOptimWrapper', + optimizer=dict( + type='AdamW', lr=0.0024, betas=(0.9, 0.95), weight_decay=0.05), + paramwise_cfg=dict( + custom_keys=dict( + ln=dict(decay_mult=0.0), + bias=dict(decay_mult=0.0), + pos_embed=dict(decay_mult=0.0), + mask_token=dict(decay_mult=0.0), + cls_token=dict(decay_mult=0.0))), + loss_scale='dynamic') +param_scheduler = [ + dict( + type='LinearLR', + start_factor=1e-09, + by_epoch=True, + begin=0, + end=10, + convert_to_iter_based=True), + dict( + type='CosineAnnealingLR', + T_max=90, + by_epoch=True, + begin=10, + end=100, + convert_to_iter_based=True) +] +train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=100) +default_scope = 'mmpretrain' +default_hooks = dict( + runtime_info=dict(type='RuntimeInfoHook'), + timer=dict(type='IterTimerHook'), + logger=dict(type='LoggerHook', interval=100), + param_scheduler=dict(type='ParamSchedulerHook'), + checkpoint=dict(type='CheckpointHook', interval=1, max_keep_ckpts=1), + sampler_seed=dict(type='DistSamplerSeedHook')) +env_cfg = dict( + cudnn_benchmark=False, + mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0), + dist_cfg=dict(backend='nccl')) +log_processor = dict( + window_size=10, + custom_cfg=[dict(data_src='', method='mean', window_size='global')]) +vis_backends = [dict(type='LocalVisBackend')] +visualizer = dict( + type='UniversalVisualizer', + vis_backends=[dict(type='LocalVisBackend')], + name='visualizer') +log_level = 'INFO' +load_from = None +resume = True +randomness = dict(seed=2, diff_rank_seed=True) +custom_hooks = [ + dict( + type='DINOTeacherTempWarmupHook', + warmup_teacher_temp=0.04, + teacher_temp=0.04, + teacher_temp_warmup_epochs=0, + max_epochs=100) +] diff --git a/projects/dino/dataset/__init__.py b/projects/dino/dataset/__init__.py new file mode 100644 index 00000000000..da65f2853ad --- /dev/null +++ b/projects/dino/dataset/__init__.py @@ -0,0 +1 @@ +from .transform import * # noqa: F401,F403 diff --git a/projects/dino/dataset/transform/__init__.py b/projects/dino/dataset/transform/__init__.py new file mode 100644 index 00000000000..00dacb3f3c9 --- /dev/null +++ b/projects/dino/dataset/transform/__init__.py @@ -0,0 +1,3 @@ +from .processing import DINOMultiCrop + +__all__ = ['DINOMultiCrop'] diff --git a/projects/dino/dataset/transform/processing.py b/projects/dino/dataset/transform/processing.py new file mode 100644 index 00000000000..df4bf0be9dd --- /dev/null +++ b/projects/dino/dataset/transform/processing.py @@ -0,0 +1,91 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import random + +from mmcv.transforms import RandomApply # noqa: E501 +from mmcv.transforms import BaseTransform, Compose, RandomFlip, RandomGrayscale + +from mmpretrain.datasets.transforms import (ColorJitter, GaussianBlur, + RandomResizedCrop, Solarize) +from mmpretrain.registry import TRANSFORMS + + +@TRANSFORMS.register_module() +class DINOMultiCrop(BaseTransform): + """Multi-crop transform for DINO. + + This module applies the multi-crop transform for DINO. + + Args: + global_crops_scale (int): Scale of global crops. + local_crops_scale (int): Scale of local crops. + local_crops_number (int): Number of local crops. + """ + + def __init__(self, global_crops_scale: int, local_crops_scale: int, + local_crops_number: int) -> None: + super().__init__() + self.global_crops_scale = global_crops_scale + self.local_crops_scale = local_crops_scale + + flip_and_color_jitter = Compose([ + RandomFlip(prob=0.5, direction='horizontal'), + RandomApply([ + ColorJitter( + brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1) + ], + prob=0.8), + RandomGrayscale( + prob=0.2, + keep_channels=True, + channel_weights=(0.114, 0.587, 0.2989), + ) + ]) + + self.global_transform_1 = Compose([ + RandomResizedCrop( + 224, + crop_ratio_range=global_crops_scale, + interpolation='bicubic'), + flip_and_color_jitter, + GaussianBlur(prob=1.0, radius=random.uniform(0.1, 2.0)), + ]) + + self.global_transform_2 = Compose([ + RandomResizedCrop( + 224, + crop_ratio_range=global_crops_scale, + interpolation='bicubic'), + flip_and_color_jitter, + GaussianBlur(prob=1.0, radius=random.uniform(0.1, 2.0)), + Solarize(thr=128, prob=0.2), + ]) + + self.local_crops_number = local_crops_number + self.local_transform = Compose([ + RandomResizedCrop( + 96, + crop_ratio_range=local_crops_scale, + interpolation='bicubic'), + flip_and_color_jitter, + GaussianBlur(prob=1.0, radius=random.uniform(0.1, 2.0)), + ]) + + def transform(self, results: dict) -> dict: + ori_img = results['img'] + crops = [] + results['img'] = ori_img + crops.append(self.global_transform_1(results)['img']) + results['img'] = ori_img + crops.append(self.global_transform_2(results)['img']) + for _ in range(self.local_crops_number): + results['img'] = ori_img + crops.append(self.local_transform(results)['img']) + results['img'] = crops + return results + + def __repr__(self) -> str: + repr_str = self.__class__.__name__ + repr_str += f'(global_crops_scale = {self.global_crops_scale}, ' + repr_str += f'local_crops_scale = {self.local_crops_scale}, ' + repr_str += f'local_crop_number = {self.local_crops_number})' + return repr_str diff --git a/projects/dino/engine/__init__.py b/projects/dino/engine/__init__.py new file mode 100644 index 00000000000..41422545e61 --- /dev/null +++ b/projects/dino/engine/__init__.py @@ -0,0 +1 @@ +from .hooks import * # noqa diff --git a/projects/dino/engine/hooks/__init__.py b/projects/dino/engine/hooks/__init__.py new file mode 100644 index 00000000000..df43c492e52 --- /dev/null +++ b/projects/dino/engine/hooks/__init__.py @@ -0,0 +1,3 @@ +from .dino_teacher_temp_warmup_hook import DINOTeacherTempWarmupHook + +__all__ = ['DINOTeacherTempWarmupHook'] diff --git a/projects/dino/engine/hooks/dino_teacher_temp_warmup_hook.py b/projects/dino/engine/hooks/dino_teacher_temp_warmup_hook.py new file mode 100644 index 00000000000..d66b0250e72 --- /dev/null +++ b/projects/dino/engine/hooks/dino_teacher_temp_warmup_hook.py @@ -0,0 +1,33 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import numpy as np +from mmengine.hooks import Hook + +from mmpretrain.registry import HOOKS + + +@HOOKS.register_module() +class DINOTeacherTempWarmupHook(Hook): + """Warmup teacher temperature for DINO. + + This hook warmups the temperature for teacher to stabilize the training + process. + + Args: + warmup_teacher_temp (float): Warmup temperature for teacher. + teacher_temp (float): Temperature for teacher. + teacher_temp_warmup_epochs (int): Warmup epochs for teacher + temperature. + max_epochs (int): Maximum epochs for training. + """ + + def __init__(self, warmup_teacher_temp: float, teacher_temp: float, + teacher_temp_warmup_epochs: int, max_epochs: int) -> None: + super().__init__() + self.teacher_temps = np.concatenate( + (np.linspace(warmup_teacher_temp, teacher_temp, + teacher_temp_warmup_epochs), + np.ones(max_epochs - teacher_temp_warmup_epochs) * teacher_temp)) + + def before_train_epoch(self, runner) -> None: + runner.model.module.head.teacher_temp = self.teacher_temps[ + runner.epoch] diff --git a/projects/dino/models/__init__.py b/projects/dino/models/__init__.py new file mode 100644 index 00000000000..49d014874ad --- /dev/null +++ b/projects/dino/models/__init__.py @@ -0,0 +1,3 @@ +from .algorithm import * # noqa +from .head import * # noqa +from .neck import * # noqa diff --git a/projects/dino/models/algorithm/__init__.py b/projects/dino/models/algorithm/__init__.py new file mode 100644 index 00000000000..1125b63f851 --- /dev/null +++ b/projects/dino/models/algorithm/__init__.py @@ -0,0 +1,3 @@ +from .dino import DINO + +__all__ = ['DINO'] diff --git a/projects/dino/models/algorithm/dino.py b/projects/dino/models/algorithm/dino.py new file mode 100644 index 00000000000..2d78922f1f6 --- /dev/null +++ b/projects/dino/models/algorithm/dino.py @@ -0,0 +1,82 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Optional, Union + +import torch +from torch import nn + +from mmpretrain.models import BaseSelfSupervisor, CosineEMA +from mmpretrain.registry import MODELS +from mmpretrain.structures import DataSample + + +@MODELS.register_module() +class DINO(BaseSelfSupervisor): + """Implementation for DINO. + + This module is proposed in `DINO: Emerging Properties in Self-Supervised + Vision Transformers `_. + + Args: + backbone (dict): Config for backbone. + neck (dict): Config for neck. + head (dict): Config for head. + pretrained (str, optional): Path for pretrained model. + Defaults to None. + base_momentum (float, optional): Base momentum for momentum update. + Defaults to 0.99. + data_preprocessor (dict, optional): Config for data preprocessor. + Defaults to None. + init_cfg (list[dict] | dict, optional): Config for initialization. + Defaults to None. + """ + + def __init__(self, + backbone: dict, + neck: dict, + head: dict, + pretrained: Optional[str] = None, + base_momentum: float = 0.99, + data_preprocessor: Optional[dict] = None, + init_cfg: Optional[Union[List[dict], dict]] = None) -> None: + super().__init__( + backbone=backbone, + neck=neck, + head=head, + pretrained=pretrained, + data_preprocessor=data_preprocessor, + init_cfg=init_cfg) + + # create momentum model + self.teacher = CosineEMA( + nn.Sequential(self.backbone, self.neck), momentum=base_momentum) + # weight normalization layer + self.neck.last_layer = nn.utils.weight_norm(self.neck.last_layer) + self.neck.last_layer.weight_g.data.fill_(1) + self.neck.last_layer.weight_g.requires_grad = False + self.teacher.module[1].last_layer = nn.utils.weight_norm( + self.teacher.module[1].last_layer) + self.teacher.module[1].last_layer.weight_g.data.fill_(1) + self.teacher.module[1].last_layer.weight_g.requires_grad = False + + def loss(self, inputs: torch.Tensor, + data_samples: List[DataSample]) -> dict: + global_crops = torch.cat(inputs[:2]) + local_crops = torch.cat(inputs[2:]) + # teacher forward + teacher_output = self.teacher(global_crops) + + # student forward global + student_output_global = self.backbone(global_crops) + student_output_global = self.neck(student_output_global) + + # student forward local + student_output_local = self.backbone(local_crops) + student_output_local = self.neck(student_output_local) + + student_output = torch.cat( + (student_output_global, student_output_local)) + + # compute loss + loss = self.head(student_output, teacher_output) + + return dict(loss=loss) diff --git a/projects/dino/models/head/__init__.py b/projects/dino/models/head/__init__.py new file mode 100644 index 00000000000..fe31e084cd3 --- /dev/null +++ b/projects/dino/models/head/__init__.py @@ -0,0 +1,3 @@ +from .dino_head import DINOHead + +__all__ = ['DINOHead'] diff --git a/projects/dino/models/head/dino_head.py b/projects/dino/models/head/dino_head.py new file mode 100644 index 00000000000..e817bfade38 --- /dev/null +++ b/projects/dino/models/head/dino_head.py @@ -0,0 +1,69 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn.functional as F +from mmengine.dist import all_reduce, get_world_size +from mmengine.model import BaseModule + +from mmpretrain.registry import MODELS + + +@MODELS.register_module() +class DINOHead(BaseModule): + """Implementation for DINO head. + + This module is proposed in `DINO: Emerging Properties in Self-Supervised + Vision Transformers `_. + + Args: + out_channels (int): Output channels of the head. + num_crops (int): Number of crops. + student_temp (float): Temperature for student output. + center_momentum (float): Momentum for center update. + """ + + def __init__(self, out_channels: int, num_crops: int, student_temp: float, + center_momentum: float) -> None: + super().__init__() + self.student_temp = student_temp + self.teacher_temp = 0 + self.center_momentum = center_momentum + self.num_crops = num_crops + self.register_buffer('center', torch.zeros(1, out_channels)) + + def forward(self, student_output: torch.Tensor, + teacher_output: torch.Tensor) -> torch.Tensor: + + current_teacher_output = teacher_output + student_output = student_output / self.student_temp + student_output = student_output.chunk(self.num_crops, dim=0) + + # teacher centering and sharpening + teacher_output = F.softmax( + (teacher_output - self.center) / self.teacher_temp, dim=-1) + teacher_output = teacher_output.detach().chunk(2, dim=0) + + total_loss = 0 + n_loss_terms = 0 + + for i in range(len(teacher_output)): + for j in range(len(student_output)): + if i == j: + continue + total_loss += (-teacher_output[i] * + student_output[j].log_softmax(dim=-1)).sum( + dim=-1).mean() + n_loss_terms += 1 + total_loss /= n_loss_terms + self.update_center(current_teacher_output) + return total_loss + + @torch.no_grad() + def update_center(self, teacher_output: torch.Tensor) -> None: + + batch_center = torch.sum(teacher_output, dim=0, keepdim=True) + all_reduce(batch_center) + batch_center = batch_center / (len(teacher_output) * get_world_size()) + + # ema update batch center + self.center = self.center * self.center_momentum + batch_center * ( + 1 - self.center_momentum) diff --git a/projects/dino/models/neck/__init__.py b/projects/dino/models/neck/__init__.py new file mode 100644 index 00000000000..e5f4aadb09d --- /dev/null +++ b/projects/dino/models/neck/__init__.py @@ -0,0 +1,3 @@ +from .dino_neck import DINONeck + +__all__ = ['DINONeck'] diff --git a/projects/dino/models/neck/dino_neck.py b/projects/dino/models/neck/dino_neck.py new file mode 100644 index 00000000000..8d8881ea24a --- /dev/null +++ b/projects/dino/models/neck/dino_neck.py @@ -0,0 +1,41 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from mmengine.model import BaseModule +from torch import nn + +from mmpretrain.registry import MODELS + + +@MODELS.register_module() +class DINONeck(BaseModule): + """Implementation for DINO neck. + + This module is proposed in `DINO: Emerging Properties in Self-Supervised + Vision Transformers `_. + + Args: + in_channels (int): Input channels. + hidden_channels (int): Hidden channels. + out_channels (int): Output channels. + bottleneck_channels (int): Bottleneck channels. + """ + + def __init__(self, in_channels: int, hidden_channels: int, + out_channels: int, bottleneck_channels: int) -> None: + super().__init__() + self.mlp = nn.Sequential(*[ + nn.Linear(in_channels, hidden_channels), + nn.GELU(), + nn.Linear(hidden_channels, hidden_channels), + nn.GELU(), + nn.Linear(hidden_channels, bottleneck_channels), + ]) + + self.last_layer = nn.Linear( + bottleneck_channels, out_channels, bias=False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.mlp(x[0]) + x = nn.functional.normalize(x, dim=-1, p=2) + x = self.last_layer(x) + return x diff --git a/projects/dino/tools/dist_train.sh b/projects/dino/tools/dist_train.sh new file mode 100644 index 00000000000..3fca7641dec --- /dev/null +++ b/projects/dino/tools/dist_train.sh @@ -0,0 +1,19 @@ +#!/usr/bin/env bash + +CONFIG=$1 +GPUS=$2 +NNODES=${NNODES:-1} +NODE_RANK=${NODE_RANK:-0} +PORT=${PORT:-29500} +MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"} + +PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \ +python -m torch.distributed.launch \ + --nnodes=$NNODES \ + --node_rank=$NODE_RANK \ + --master_addr=$MASTER_ADDR \ + --nproc_per_node=$GPUS \ + --master_port=$PORT \ + $(dirname "$0")/train.py \ + $CONFIG \ + --launcher pytorch ${@:3} diff --git a/projects/dino/tools/slurm_train.sh b/projects/dino/tools/slurm_train.sh new file mode 100644 index 00000000000..7e2ad297d84 --- /dev/null +++ b/projects/dino/tools/slurm_train.sh @@ -0,0 +1,23 @@ +#!/usr/bin/env bash + +set -x + +PARTITION=$1 +JOB_NAME=$2 +CONFIG=$3 +GPUS=${GPUS:-8} +GPUS_PER_NODE=${GPUS_PER_NODE:-8} +CPUS_PER_TASK=${CPUS_PER_TASK:-5} +SRUN_ARGS=${SRUN_ARGS:-""} +PY_ARGS=${@:4} + +PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \ +srun -p ${PARTITION} \ + --job-name=${JOB_NAME} \ + --gres=gpu:${GPUS_PER_NODE} \ + --ntasks=${GPUS} \ + --ntasks-per-node=${GPUS_PER_NODE} \ + --cpus-per-task=${CPUS_PER_TASK} \ + --kill-on-bad-exit=1 \ + ${SRUN_ARGS} \ + python -u projects/dino/tools/train.py ${CONFIG} --launcher="slurm" ${PY_ARGS} diff --git a/projects/dino/tools/train.py b/projects/dino/tools/train.py new file mode 100644 index 00000000000..b9482c3b75a --- /dev/null +++ b/projects/dino/tools/train.py @@ -0,0 +1,104 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse +import os +import os.path as osp + +from dataset import * # noqa: F401,F403 +from engine import * # noqa: F401,F403 +from mmengine.config import Config, DictAction +from mmengine.runner import Runner +from models.algorithm import * # noqa: F401,F403 +from models.head import * # noqa: F401,F403 +from models.neck import * # noqa: F401,F403 + +from mmpretrain.utils import register_all_modules + + +def parse_args(): + parser = argparse.ArgumentParser(description='Train a model') + parser.add_argument('config', help='train config file path') + parser.add_argument('--work-dir', help='the dir to save logs and models') + parser.add_argument( + '--resume', + nargs='?', + type=str, + const='auto', + help='If specify checkpint path, resume from it, while if not ' + 'specify, try to auto resume from the latest checkpoint ' + 'in the work directory.') + parser.add_argument( + '--amp', + action='store_true', + help='enable automatic-mixed-precision training') + parser.add_argument( + '--cfg-options', + nargs='+', + action=DictAction, + help='override some settings in the used config, the key-value pair ' + 'in xxx=yyy format will be merged into config file. If the value to ' + 'be overwritten is a list, it should be like key="[a,b]" or key=a,b ' + 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" ' + 'Note that the quotation marks are necessary and that no white space ' + 'is allowed.') + parser.add_argument( + '--launcher', + choices=['none', 'pytorch', 'slurm', 'mpi'], + default='none', + help='job launcher') + parser.add_argument('--local_rank', type=int, default=0) + args = parser.parse_args() + if 'LOCAL_RANK' not in os.environ: + os.environ['LOCAL_RANK'] = str(args.local_rank) + + return args + + +def main(): + args = parse_args() + + # register all modules in mmpretrain into the registries + # do not init the default scope here because it will be init in the runner + register_all_modules(init_default_scope=False) + + # load config + cfg = Config.fromfile(args.config) + cfg.launcher = args.launcher + if args.cfg_options is not None: + cfg.merge_from_dict(args.cfg_options) + + # work_dir is determined in this priority: CLI > segment in file > filename + if args.work_dir is not None: + # update configs according to CLI args if args.work_dir is not None + cfg.work_dir = args.work_dir + elif cfg.get('work_dir', None) is None: + # use config filename as default work_dir if cfg.work_dir is None + work_type = args.config.split('/')[1] + cfg.work_dir = osp.join('./work_dirs', work_type, + osp.splitext(osp.basename(args.config))[0]) + + # enable automatic-mixed-precision training + if args.amp is True: + optim_wrapper = cfg.optim_wrapper.get('type', 'OptimWrapper') + assert optim_wrapper in ['OptimWrapper', 'AmpOptimWrapper'], \ + '`--amp` is not supported custom optimizer wrapper type ' \ + f'`{optim_wrapper}.' + cfg.optim_wrapper.type = 'AmpOptimWrapper' + cfg.optim_wrapper.setdefault('loss_scale', 'dynamic') + + # resume training + if args.resume == 'auto': + cfg.resume = True + cfg.load_from = None + elif args.resume is not None: + cfg.resume = True + cfg.load_from = args.resume + + # build the runner from config + runner = Runner.from_cfg(cfg) + + # start training + runner.train() + + +if __name__ == '__main__': + main() diff --git a/tools/model_converters/openai-clip_to_mmpretrain-clip.py b/tools/model_converters/openai-clip_to_mmpretrain-clip.py new file mode 100644 index 00000000000..72725502551 --- /dev/null +++ b/tools/model_converters/openai-clip_to_mmpretrain-clip.py @@ -0,0 +1,77 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse +import os.path as osp +from collections import OrderedDict + +import mmengine +import torch +from mmengine.runner import CheckpointLoader + + +def convert_clip(ckpt): + new_ckpt = OrderedDict() + + for k, v in list(ckpt.items()): + new_v = v + if k.startswith('visual.conv1'): + new_k = k.replace('conv1', 'patch_embed.projection') + elif k.startswith('visual.positional_embedding'): + new_k = k.replace('positional_embedding', 'pos_embed') + new_v = v.unsqueeze(dim=0) + elif k.startswith('visual.class_embedding'): + new_k = k.replace('class_embedding', 'cls_token') + new_v = v.unsqueeze(dim=0).unsqueeze(dim=0) + elif k.startswith('visual.ln_pre'): + new_k = k.replace('ln_pre', 'pre_norm') + elif k.startswith('visual.transformer.resblocks'): + new_k = k.replace('transformer.resblocks', 'layers') + if 'ln_1' in k: + new_k = new_k.replace('ln_1', 'ln1') + elif 'ln_2' in k: + new_k = new_k.replace('ln_2', 'ln2') + elif 'mlp.c_fc' in k: + new_k = new_k.replace('mlp.c_fc', 'ffn.layers.0.0') + elif 'mlp.c_proj' in k: + new_k = new_k.replace('mlp.c_proj', 'ffn.layers.1') + elif 'attn.in_proj_weight' in k: + new_k = new_k.replace('in_proj_weight', 'qkv.weight') + elif 'attn.in_proj_bias' in k: + new_k = new_k.replace('in_proj_bias', 'qkv.bias') + elif 'attn.out_proj' in k: + new_k = new_k.replace('out_proj', 'proj') + elif k.startswith('visual.ln_post'): + new_k = k.replace('ln_post', 'ln1') + elif k.startswith('visual.proj'): + new_k = k.replace('visual.proj', 'visual_proj.proj') + else: + new_k = k + + new_ckpt[new_k] = new_v + return new_ckpt + + +def main(): + parser = argparse.ArgumentParser( + description='Convert keys in pretrained clip ' + 'models to mmpretrain style.') + parser.add_argument('src', help='src model path or url') + # The dst path must be a full path of the new checkpoint. + parser.add_argument('dst', help='save path') + args = parser.parse_args() + + checkpoint = CheckpointLoader.load_checkpoint(args.src, map_location='cpu') + + if 'state_dict' in checkpoint: + state_dict = checkpoint['state_dict'] + else: + state_dict = checkpoint + + weight = convert_clip(state_dict) + mmengine.mkdir_or_exist(osp.dirname(args.dst)) + torch.save(weight, args.dst) + + print('Done!!') + + +if __name__ == '__main__': + main() diff --git a/tools/model_converters/ram2mmpretrain.py b/tools/model_converters/ram2mmpretrain.py new file mode 100644 index 00000000000..5ee3b47677f --- /dev/null +++ b/tools/model_converters/ram2mmpretrain.py @@ -0,0 +1,117 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse +import os.path as osp +from collections import OrderedDict +from copy import deepcopy + +import mmengine +import torch +from mmengine.runner import CheckpointLoader + + +def convert_swin(ckpt): + new_ckpt = OrderedDict() + convert_mapping = dict() + + def correct_unfold_reduction_order(x): + out_channel, in_channel = x.shape + x = x.reshape(out_channel, 4, in_channel // 4) + x = x[:, [0, 2, 1, 3], :].transpose(1, + 2).reshape(out_channel, in_channel) + return x + + def correct_unfold_norm_order(x): + in_channel = x.shape[0] + x = x.reshape(4, in_channel // 4) + x = x[[0, 2, 1, 3], :].transpose(0, 1).reshape(in_channel) + return x + + for k, v in ckpt.items(): + if 'attn_mask' in k: + continue + if k.startswith('head'): + continue + elif k.startswith('layers'): + new_v = v + if 'attn.' in k: + new_k = k.replace('attn.', 'attn.w_msa.') + elif 'mlp.' in k: + if 'mlp.fc1.' in k: + new_k = k.replace('mlp.fc1.', 'ffn.layers.0.0.') + elif 'mlp.fc2.' in k: + new_k = k.replace('mlp.fc2.', 'ffn.layers.1.') + else: + new_k = k.replace('mlp.', 'ffn.') + elif 'downsample' in k: + new_k = k + if 'reduction.' in k: + new_v = correct_unfold_reduction_order(v) + elif 'norm.' in k: + new_v = correct_unfold_norm_order(v) + else: + new_k = k + new_k = new_k.replace('layers', 'stages', 1) + elif k.startswith('patch_embed'): + new_v = v + if 'proj' in k: + new_k = k.replace('proj', 'projection') + else: + new_k = k + elif k.startswith('norm'): + new_v = v + new_k = k.replace('norm', 'norm3') + else: + new_v = v + new_k = k + + new_ckpt[new_k] = new_v + convert_mapping[k] = new_k + + return new_ckpt, convert_mapping + + +def main(): + parser = argparse.ArgumentParser( + description='Convert keys in official pretrained RAM models to' + 'MMPretrain style.') + parser.add_argument('src', help='src model path or url') + # The dst path must be a full path of the new checkpoint. + parser.add_argument('dst', help='save path') + args = parser.parse_args() + + checkpoint = CheckpointLoader.load_checkpoint(args.src, map_location='cpu') + if 'state_dict' in checkpoint: + state_dict = checkpoint['state_dict'] + elif 'model' in checkpoint: + state_dict = checkpoint['model'] + else: + state_dict = checkpoint + + visual_ckpt = OrderedDict() + for key in state_dict: + if key.startswith('visual_encoder.'): + new_key = key.replace('visual_encoder.', '') + visual_ckpt[new_key] = state_dict[key] + + new_visual_ckpt, convert_mapping = convert_swin(visual_ckpt) + new_ckpt = deepcopy(state_dict) + for key in state_dict: + if key.startswith('visual_encoder.'): + if 'attn_mask' in key: + del new_ckpt[key] + continue + del new_ckpt[key] + old_key = key.replace('visual_encoder.', '') + new_ckpt[key.replace(old_key, + convert_mapping[old_key])] = deepcopy( + new_visual_ckpt[key.replace( + old_key, + convert_mapping[old_key]).replace( + 'visual_encoder.', '')]) + + mmengine.mkdir_or_exist(osp.dirname(args.dst)) + torch.save(new_ckpt, args.dst) + + +if __name__ == '__main__': + main() diff --git a/tools/train.py b/tools/train.py index 84c1eec93aa..89c8548fc32 100644 --- a/tools/train.py +++ b/tools/train.py @@ -91,10 +91,6 @@ def merge_args(cfg, args): # enable automatic-mixed-precision training if args.amp is True: - optim_wrapper = cfg.optim_wrapper.get('type', 'OptimWrapper') - assert optim_wrapper in ['OptimWrapper', 'AmpOptimWrapper'], \ - '`--amp` is not supported custom optimizer wrapper type ' \ - f'`{optim_wrapper}.' cfg.optim_wrapper.type = 'AmpOptimWrapper' cfg.optim_wrapper.setdefault('loss_scale', 'dynamic')