-
-
Notifications
You must be signed in to change notification settings - Fork 6
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
7914f98
commit 4cb577f
Showing
17 changed files
with
2,639 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,138 @@ | ||
output_dir/ | ||
outputs/ | ||
selected/ | ||
|
||
# Byte-compiled / optimized / DLL files | ||
__pycache__/ | ||
*.py[cod] | ||
*$py.class | ||
**/*.pyc | ||
|
||
# C extensions | ||
*.so | ||
|
||
# Distribution / packaging | ||
.Python | ||
build/ | ||
develop-eggs/ | ||
dist/ | ||
downloads/ | ||
eggs/ | ||
.eggs/ | ||
lib/ | ||
lib64/ | ||
parts/ | ||
sdist/ | ||
var/ | ||
wheels/ | ||
*.egg-info/ | ||
.installed.cfg | ||
*.egg | ||
MANIFEST | ||
|
||
# PyInstaller | ||
# Usually these files are written by a python script from a template | ||
# before PyInstaller builds the exe, so as to inject date/other infos into it. | ||
*.manifest | ||
*.spec | ||
|
||
# Installer logs | ||
pip-log.txt | ||
pip-delete-this-directory.txt | ||
|
||
# Unit test / coverage reports | ||
htmlcov/ | ||
.tox/ | ||
.coverage | ||
.coverage.* | ||
.cache | ||
nosetests.xml | ||
coverage.xml | ||
*.cover | ||
.hypothesis/ | ||
.pytest_cache/ | ||
|
||
# Translations | ||
*.mo | ||
*.pot | ||
|
||
# Django stuff: | ||
*.log | ||
local_settings.py | ||
db.sqlite3 | ||
|
||
# Flask stuff: | ||
instance/ | ||
.webassets-cache | ||
|
||
# Scrapy stuff: | ||
.scrapy | ||
|
||
# Sphinx documentation | ||
docs/_build/ | ||
|
||
# PyBuilder | ||
target/ | ||
|
||
# Jupyter Notebook | ||
.ipynb_checkpoints | ||
|
||
# pyenv | ||
.python-version | ||
|
||
# celery beat schedule file | ||
celerybeat-schedule | ||
|
||
# SageMath parsed files | ||
*.sage.py | ||
|
||
# Environments | ||
.env | ||
.venv | ||
env/ | ||
venv/ | ||
ENV/ | ||
env.bak/ | ||
venv.bak/ | ||
|
||
# Spyder project settings | ||
.spyderproject | ||
.spyproject | ||
|
||
# Rope project settings | ||
.ropeproject | ||
|
||
# mkdocs documentation | ||
/site | ||
|
||
# mypy | ||
.mypy_cache/ | ||
|
||
# custom | ||
/data | ||
.vscode | ||
.idea | ||
*.pkl | ||
*.pkl.json | ||
*.log.json | ||
benchlist.txt | ||
work_dirs/ | ||
|
||
# Pytorch | ||
*.pth | ||
|
||
# Profile | ||
*.prof | ||
|
||
# lmdb | ||
*.mdb | ||
|
||
# unignore some data file in tests/data | ||
!tests/data/**/*.pkl | ||
!tests/data/**/*.pkl.json | ||
!tests/data/**/*.log.json | ||
!tests/data/**/*.pth | ||
|
||
# avoid soft links created by MIM | ||
mmaction/configs/* | ||
mmaction/tools/* |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,99 @@ | ||
# MaskAlign | ||
> This is the official repository for paper "Stare at What You See: Masked Image Modeling without Reconstruction". | ||
|
||
Stay tuned for new details! | ||
<p align="center"> | ||
<img src="figs/framework.png" alt="statistics" width="80%"/> | ||
</p> | ||
|
||
|
||
This is the official PyTorch repository for paper [Stare at What You See: Masked Image Modeling without Reconstruction](https://arxiv.org/abs/2211.08887): | ||
``` | ||
@article{xue2022stare, | ||
title={Stare at What You See: Masked Image Modeling without Reconstruction}, | ||
author={Xue, Hongwei and Gao, Peng and Li, Hongyang and Qiao, Yu and Sun, Hao and Li, Houqiang and Luo, Jiebo}, | ||
journal={arXiv preprint arXiv:2211.08887}, | ||
year={2022} | ||
} | ||
``` | ||
|
||
* This repo is a modification on the [MAE repo](https://github.com/facebookresearch/mae). Installation and preparation follow that repo. | ||
|
||
* The teacher models in this repo are called from [Huggingface](https://huggingface.co/). Please install transformers package by running: <br> `pip install transformers`. | ||
|
||
## Pre-training | ||
|
||
To pre-train ViT-base (recommended default) with **distributed training**, run the following on 8 GPUs: | ||
|
||
``` | ||
python -m torch.distributed.launch --nproc_per_node=8 main_pretrain.py \ | ||
--batch_size 128 \ | ||
--model mae_vit_base_patch16 \ | ||
--blr 1.5e-4 \ | ||
--min_lr 1e-5 \ | ||
--data_path ${IMAGENET_DIR} \ | ||
--output_dir ${OUTPUT_DIR} \ | ||
--target_norm whiten \ | ||
--loss_type smoothl1 \ | ||
--drop_path 0.1 \ | ||
--head_type linear \ | ||
--epochs 200 \ | ||
--warmup_epochs 20 \ | ||
--mask_type attention \ | ||
--mask_ratio 0.7 \ | ||
--loss_weights top5 \ | ||
--fusion_type linear \ | ||
--teacher_model openai/clip-vit-base-patch16 | ||
``` | ||
|
||
- Here the effective batch size is 128 (`batch_size` per gpu) * 8 (gpus) = 1024. If memory or # gpus is limited, use `--accum_iter` to maintain the effective batch size, which is `batch_size` (per gpu) * `nodes` * 8 (gpus) * `accum_iter`. | ||
- `blr` is the base learning rate. The actual `lr` is computed by the [linear scaling rule](https://arxiv.org/abs/1706.02677): `lr` = `blr` * effective batch size / 256. | ||
- This repo will automatically resume the checkpoints by keeping a "latest checkpoint". | ||
|
||
To train ViT-Large, please set `--model mae_vit_large_patch16` and `--drop_path 0.2`. Currently, this repo supports three teacher models: `--teacher_model ${TEACHER}`, where `${TEACHER} in openai/clip-vit-base-patch16, openai/clip-vit-large-patch14 and facebook/dino-vitb16`. | ||
|
||
## Fine-tuning | ||
|
||
Get our pre-trained checkpoints from [here](TODO). | ||
|
||
To fine-tune ViT-base (recommended default) with **distributed training**, run the following on 8 GPUs: | ||
``` | ||
python -m torch.distributed.launch --nproc_per_node=8 main_finetune.py \ | ||
--epochs 100 \ | ||
--batch_size 128 \ | ||
--model vit_base_patch16 \ | ||
--blr 3e-4 \ | ||
--layer_decay 0.55 \ | ||
--weight_decay 0.05 \ | ||
--drop_path 0.2 \ | ||
--reprob 0.25 \ | ||
--mixup 0.8 \ | ||
--cutmix 1.0 \ | ||
--dist_eval \ | ||
--finetune ${PT_CHECKPOINT} \ | ||
--data_path ${IMAGENET_DIR} \ | ||
--output_dir ${OUTPUT_DIR} | ||
``` | ||
|
||
- Here the effective batch size is 128 (`batch_size` per gpu) * 8 (gpus) = 1024. | ||
- `blr` is the base learning rate. The actual `lr` is computed by the [linear scaling rule](https://arxiv.org/abs/1706.02677): `lr` = `blr` * effective batch size / 256. | ||
|
||
To fine-tune ViT-Large, please set `--model vit_large_patch16 --epochs 50 --drop_path 0.4 --layer_decay 0.75 --blr 3e-4`. | ||
|
||
|
||
## Linear Probing | ||
|
||
Run the following on 8 GPUs: | ||
``` | ||
python -m torch.distributed.launch --nproc_per_node=8 main_linprobe.py \ | ||
--epochs 90 \ | ||
--batch_size 2048 \ | ||
--model vit_base_patch16 \ | ||
--blr 0.025 \ | ||
--weight_decay 0.0 \ | ||
--dist_eval \ | ||
--finetune ${PT_CHECKPOINT} \ | ||
--data_path ${IMAGENET_DIR} \ | ||
--output_dir ${OUTPUT_DIR} | ||
``` | ||
- Here the effective batch size is 2048 (`batch_size` per gpu) * 8 (gpus) = 16384. | ||
- `blr` is the base learning rate. The actual `lr` is computed by the [linear scaling rule](https://arxiv.org/abs/1706.02677): `lr` = `blr` * effective batch size / 256. | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,130 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
|
||
# This source code is licensed under the license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
# -------------------------------------------------------- | ||
# References: | ||
# DeiT: https://github.com/facebookresearch/deit | ||
# BEiT: https://github.com/microsoft/unilm/tree/master/beit | ||
# -------------------------------------------------------- | ||
|
||
import math | ||
import sys | ||
from typing import Iterable, Optional | ||
|
||
import torch | ||
|
||
from timm.data import Mixup | ||
from timm.utils import accuracy | ||
|
||
import util.misc as misc | ||
import util.lr_sched as lr_sched | ||
|
||
|
||
def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module, | ||
data_loader: Iterable, optimizer: torch.optim.Optimizer, | ||
device: torch.device, epoch: int, loss_scaler, max_norm: float = 0, | ||
mixup_fn: Optional[Mixup] = None, log_writer=None, | ||
args=None): | ||
model.train(True) | ||
metric_logger = misc.MetricLogger(delimiter=" ") | ||
metric_logger.add_meter('lr', misc.SmoothedValue(window_size=1, fmt='{value:.6f}')) | ||
header = 'Epoch: [{}]'.format(epoch) | ||
print_freq = 20 | ||
|
||
accum_iter = args.accum_iter | ||
|
||
optimizer.zero_grad() | ||
|
||
if log_writer is not None: | ||
print('log_dir: {}'.format(log_writer.log_dir)) | ||
|
||
for data_iter_step, (samples, targets) in enumerate(metric_logger.log_every(data_loader, print_freq, header)): | ||
|
||
# we use a per iteration (instead of per epoch) lr scheduler | ||
if data_iter_step % accum_iter == 0: | ||
lr_sched.adjust_learning_rate(optimizer, data_iter_step / len(data_loader) + epoch, args) | ||
|
||
samples = samples.to(device, non_blocking=True) | ||
targets = targets.to(device, non_blocking=True) | ||
|
||
if mixup_fn is not None: | ||
samples, targets = mixup_fn(samples, targets) | ||
|
||
with torch.cuda.amp.autocast(): | ||
outputs = model(samples) | ||
loss = criterion(outputs, targets) | ||
|
||
loss_value = loss.item() | ||
|
||
if not math.isfinite(loss_value): | ||
print("Loss is {}, stopping training".format(loss_value)) | ||
sys.exit(1) | ||
|
||
loss /= accum_iter | ||
loss_scaler(loss, optimizer, clip_grad=max_norm, | ||
parameters=model.parameters(), create_graph=False, | ||
update_grad=(data_iter_step + 1) % accum_iter == 0) | ||
if (data_iter_step + 1) % accum_iter == 0: | ||
optimizer.zero_grad() | ||
|
||
# torch.cuda.synchronize() | ||
|
||
metric_logger.update(loss=loss_value) | ||
min_lr = 10. | ||
max_lr = 0. | ||
for group in optimizer.param_groups: | ||
min_lr = min(min_lr, group["lr"]) | ||
max_lr = max(max_lr, group["lr"]) | ||
|
||
metric_logger.update(lr=max_lr) | ||
|
||
loss_value_reduce = misc.all_reduce_mean(loss_value) | ||
if log_writer is not None and (data_iter_step + 1) % accum_iter == 0: | ||
""" We use epoch_1000x as the x-axis in tensorboard. | ||
This calibrates different curves when batch size changes. | ||
""" | ||
epoch_1000x = int((data_iter_step / len(data_loader) + epoch) * 1000) | ||
log_writer.add_scalar('loss', loss_value_reduce, epoch_1000x) | ||
log_writer.add_scalar('lr', max_lr, epoch_1000x) | ||
|
||
# gather the stats from all processes | ||
metric_logger.synchronize_between_processes() | ||
print("Averaged stats:", metric_logger) | ||
return {k: meter.global_avg for k, meter in metric_logger.meters.items()} | ||
|
||
|
||
@torch.no_grad() | ||
def evaluate(data_loader, model, device): | ||
criterion = torch.nn.CrossEntropyLoss() | ||
|
||
metric_logger = misc.MetricLogger(delimiter=" ") | ||
header = 'Test:' | ||
|
||
# switch to evaluation mode | ||
model.eval() | ||
|
||
for batch in metric_logger.log_every(data_loader, 10, header): | ||
images = batch[0] | ||
target = batch[-1] | ||
images = images.to(device, non_blocking=True) | ||
target = target.to(device, non_blocking=True) | ||
|
||
# compute output | ||
with torch.cuda.amp.autocast(): | ||
output = model(images) | ||
loss = criterion(output, target) | ||
|
||
acc1, acc5 = accuracy(output, target, topk=(1, 5)) | ||
|
||
batch_size = images.shape[0] | ||
metric_logger.update(loss=loss.item()) | ||
metric_logger.meters['acc1'].update(acc1.item(), n=batch_size) | ||
metric_logger.meters['acc5'].update(acc5.item(), n=batch_size) | ||
# gather the stats from all processes | ||
metric_logger.synchronize_between_processes() | ||
print('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}' | ||
.format(top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss)) | ||
|
||
return {k: meter.global_avg for k, meter in metric_logger.meters.items()} |
Oops, something went wrong.