Skip to content

Commit

Permalink
release code
Browse files Browse the repository at this point in the history
  • Loading branch information
HellwayXue committed Feb 7, 2023
1 parent 7914f98 commit 4cb577f
Show file tree
Hide file tree
Showing 17 changed files with 2,639 additions and 2 deletions.
138 changes: 138 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
output_dir/
outputs/
selected/

# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
**/*.pyc

# C extensions
*.so

# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST

# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
.hypothesis/
.pytest_cache/

# Translations
*.mo
*.pot

# Django stuff:
*.log
local_settings.py
db.sqlite3

# Flask stuff:
instance/
.webassets-cache

# Scrapy stuff:
.scrapy

# Sphinx documentation
docs/_build/

# PyBuilder
target/

# Jupyter Notebook
.ipynb_checkpoints

# pyenv
.python-version

# celery beat schedule file
celerybeat-schedule

# SageMath parsed files
*.sage.py

# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/

# Spyder project settings
.spyderproject
.spyproject

# Rope project settings
.ropeproject

# mkdocs documentation
/site

# mypy
.mypy_cache/

# custom
/data
.vscode
.idea
*.pkl
*.pkl.json
*.log.json
benchlist.txt
work_dirs/

# Pytorch
*.pth

# Profile
*.prof

# lmdb
*.mdb

# unignore some data file in tests/data
!tests/data/**/*.pkl
!tests/data/**/*.pkl.json
!tests/data/**/*.log.json
!tests/data/**/*.pth

# avoid soft links created by MIM
mmaction/configs/*
mmaction/tools/*
99 changes: 97 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,99 @@
# MaskAlign
> This is the official repository for paper "Stare at What You See: Masked Image Modeling without Reconstruction".

Stay tuned for new details!
<p align="center">
<img src="figs/framework.png" alt="statistics" width="80%"/>
</p>


This is the official PyTorch repository for paper [Stare at What You See: Masked Image Modeling without Reconstruction](https://arxiv.org/abs/2211.08887):
```
@article{xue2022stare,
title={Stare at What You See: Masked Image Modeling without Reconstruction},
author={Xue, Hongwei and Gao, Peng and Li, Hongyang and Qiao, Yu and Sun, Hao and Li, Houqiang and Luo, Jiebo},
journal={arXiv preprint arXiv:2211.08887},
year={2022}
}
```

* This repo is a modification on the [MAE repo](https://github.com/facebookresearch/mae). Installation and preparation follow that repo.

* The teacher models in this repo are called from [Huggingface](https://huggingface.co/). Please install transformers package by running: <br> `pip install transformers`.

## Pre-training

To pre-train ViT-base (recommended default) with **distributed training**, run the following on 8 GPUs:

```
python -m torch.distributed.launch --nproc_per_node=8 main_pretrain.py \
--batch_size 128 \
--model mae_vit_base_patch16 \
--blr 1.5e-4 \
--min_lr 1e-5 \
--data_path ${IMAGENET_DIR} \
--output_dir ${OUTPUT_DIR} \
--target_norm whiten \
--loss_type smoothl1 \
--drop_path 0.1 \
--head_type linear \
--epochs 200 \
--warmup_epochs 20 \
--mask_type attention \
--mask_ratio 0.7 \
--loss_weights top5 \
--fusion_type linear \
--teacher_model openai/clip-vit-base-patch16
```

- Here the effective batch size is 128 (`batch_size` per gpu) * 8 (gpus) = 1024. If memory or # gpus is limited, use `--accum_iter` to maintain the effective batch size, which is `batch_size` (per gpu) * `nodes` * 8 (gpus) * `accum_iter`.
- `blr` is the base learning rate. The actual `lr` is computed by the [linear scaling rule](https://arxiv.org/abs/1706.02677): `lr` = `blr` * effective batch size / 256.
- This repo will automatically resume the checkpoints by keeping a "latest checkpoint".

To train ViT-Large, please set `--model mae_vit_large_patch16` and `--drop_path 0.2`. Currently, this repo supports three teacher models: `--teacher_model ${TEACHER}`, where `${TEACHER} in openai/clip-vit-base-patch16, openai/clip-vit-large-patch14 and facebook/dino-vitb16`.

## Fine-tuning

Get our pre-trained checkpoints from [here](TODO).

To fine-tune ViT-base (recommended default) with **distributed training**, run the following on 8 GPUs:
```
python -m torch.distributed.launch --nproc_per_node=8 main_finetune.py \
--epochs 100 \
--batch_size 128 \
--model vit_base_patch16 \
--blr 3e-4 \
--layer_decay 0.55 \
--weight_decay 0.05 \
--drop_path 0.2 \
--reprob 0.25 \
--mixup 0.8 \
--cutmix 1.0 \
--dist_eval \
--finetune ${PT_CHECKPOINT} \
--data_path ${IMAGENET_DIR} \
--output_dir ${OUTPUT_DIR}
```

- Here the effective batch size is 128 (`batch_size` per gpu) * 8 (gpus) = 1024.
- `blr` is the base learning rate. The actual `lr` is computed by the [linear scaling rule](https://arxiv.org/abs/1706.02677): `lr` = `blr` * effective batch size / 256.

To fine-tune ViT-Large, please set `--model vit_large_patch16 --epochs 50 --drop_path 0.4 --layer_decay 0.75 --blr 3e-4`.


## Linear Probing

Run the following on 8 GPUs:
```
python -m torch.distributed.launch --nproc_per_node=8 main_linprobe.py \
--epochs 90 \
--batch_size 2048 \
--model vit_base_patch16 \
--blr 0.025 \
--weight_decay 0.0 \
--dist_eval \
--finetune ${PT_CHECKPOINT} \
--data_path ${IMAGENET_DIR} \
--output_dir ${OUTPUT_DIR}
```
- Here the effective batch size is 2048 (`batch_size` per gpu) * 8 (gpus) = 16384.
- `blr` is the base learning rate. The actual `lr` is computed by the [linear scaling rule](https://arxiv.org/abs/1706.02677): `lr` = `blr` * effective batch size / 256.

130 changes: 130 additions & 0 deletions engine_finetune.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# --------------------------------------------------------
# References:
# DeiT: https://github.com/facebookresearch/deit
# BEiT: https://github.com/microsoft/unilm/tree/master/beit
# --------------------------------------------------------

import math
import sys
from typing import Iterable, Optional

import torch

from timm.data import Mixup
from timm.utils import accuracy

import util.misc as misc
import util.lr_sched as lr_sched


def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module,
data_loader: Iterable, optimizer: torch.optim.Optimizer,
device: torch.device, epoch: int, loss_scaler, max_norm: float = 0,
mixup_fn: Optional[Mixup] = None, log_writer=None,
args=None):
model.train(True)
metric_logger = misc.MetricLogger(delimiter=" ")
metric_logger.add_meter('lr', misc.SmoothedValue(window_size=1, fmt='{value:.6f}'))
header = 'Epoch: [{}]'.format(epoch)
print_freq = 20

accum_iter = args.accum_iter

optimizer.zero_grad()

if log_writer is not None:
print('log_dir: {}'.format(log_writer.log_dir))

for data_iter_step, (samples, targets) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):

# we use a per iteration (instead of per epoch) lr scheduler
if data_iter_step % accum_iter == 0:
lr_sched.adjust_learning_rate(optimizer, data_iter_step / len(data_loader) + epoch, args)

samples = samples.to(device, non_blocking=True)
targets = targets.to(device, non_blocking=True)

if mixup_fn is not None:
samples, targets = mixup_fn(samples, targets)

with torch.cuda.amp.autocast():
outputs = model(samples)
loss = criterion(outputs, targets)

loss_value = loss.item()

if not math.isfinite(loss_value):
print("Loss is {}, stopping training".format(loss_value))
sys.exit(1)

loss /= accum_iter
loss_scaler(loss, optimizer, clip_grad=max_norm,
parameters=model.parameters(), create_graph=False,
update_grad=(data_iter_step + 1) % accum_iter == 0)
if (data_iter_step + 1) % accum_iter == 0:
optimizer.zero_grad()

# torch.cuda.synchronize()

metric_logger.update(loss=loss_value)
min_lr = 10.
max_lr = 0.
for group in optimizer.param_groups:
min_lr = min(min_lr, group["lr"])
max_lr = max(max_lr, group["lr"])

metric_logger.update(lr=max_lr)

loss_value_reduce = misc.all_reduce_mean(loss_value)
if log_writer is not None and (data_iter_step + 1) % accum_iter == 0:
""" We use epoch_1000x as the x-axis in tensorboard.
This calibrates different curves when batch size changes.
"""
epoch_1000x = int((data_iter_step / len(data_loader) + epoch) * 1000)
log_writer.add_scalar('loss', loss_value_reduce, epoch_1000x)
log_writer.add_scalar('lr', max_lr, epoch_1000x)

# gather the stats from all processes
metric_logger.synchronize_between_processes()
print("Averaged stats:", metric_logger)
return {k: meter.global_avg for k, meter in metric_logger.meters.items()}


@torch.no_grad()
def evaluate(data_loader, model, device):
criterion = torch.nn.CrossEntropyLoss()

metric_logger = misc.MetricLogger(delimiter=" ")
header = 'Test:'

# switch to evaluation mode
model.eval()

for batch in metric_logger.log_every(data_loader, 10, header):
images = batch[0]
target = batch[-1]
images = images.to(device, non_blocking=True)
target = target.to(device, non_blocking=True)

# compute output
with torch.cuda.amp.autocast():
output = model(images)
loss = criterion(output, target)

acc1, acc5 = accuracy(output, target, topk=(1, 5))

batch_size = images.shape[0]
metric_logger.update(loss=loss.item())
metric_logger.meters['acc1'].update(acc1.item(), n=batch_size)
metric_logger.meters['acc5'].update(acc5.item(), n=batch_size)
# gather the stats from all processes
metric_logger.synchronize_between_processes()
print('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}'
.format(top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss))

return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
Loading

0 comments on commit 4cb577f

Please sign in to comment.