Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding ViT to torchvision/models #4594

Merged
merged 39 commits into from
Nov 27, 2021
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
fbd0024
[vit] Adding ViT to torchvision/models
yiwen-song Oct 12, 2021
7521ffe
adding pre-logits layer + resolving comments
yiwen-song Oct 20, 2021
7e63685
Merge branch 'pytorch:main' into main
yiwen-song Oct 22, 2021
2dd878a
Merge branch 'pytorch:main' into main
yiwen-song Oct 25, 2021
53b6967
Fix the model attribute bug
yiwen-song Oct 26, 2021
fe248f0
Merge branch 'main' of https://github.com/sallysyw/vision into main
yiwen-song Oct 26, 2021
a84361a
Change version to arch
yiwen-song Oct 26, 2021
f981519
Merge branch 'pytorch:main' into main
yiwen-song Oct 26, 2021
9d2ef95
Merge branch 'main' into main
datumbox Oct 27, 2021
0aaac5b
Merge branch 'pytorch:main' into main
yiwen-song Nov 1, 2021
1cf8b92
Merge branch 'pytorch:main' into main
yiwen-song Nov 5, 2021
c2f3826
fix failing unittests
yiwen-song Nov 6, 2021
35c1d22
remove useless prints
yiwen-song Nov 6, 2021
1aff5cd
Merge branch 'pytorch:main' into main
yiwen-song Nov 13, 2021
568c560
reduce input size to fix unittests
yiwen-song Nov 15, 2021
8e71e4b
Increase windows-cpu executor to 2xlarge
yiwen-song Nov 16, 2021
f9860ec
Use `batch_first=True` and remove classifier
yiwen-song Nov 17, 2021
4d7d7fe
Merge branch 'pytorch:main' into main
yiwen-song Nov 17, 2021
b795e85
Change resource_class back to xlarge
yiwen-song Nov 17, 2021
ff64591
Remove vit_h_14
yiwen-song Nov 17, 2021
bd3a747
Remove vit_h_14 from __all__
yiwen-song Nov 17, 2021
8f88592
Move vision_transformer.py into prototype
yiwen-song Nov 19, 2021
22025ac
Fix formatting issue
yiwen-song Nov 19, 2021
26bc529
remove arch in builder
yiwen-song Nov 19, 2021
cc22238
Fix type err in model builder
yiwen-song Nov 19, 2021
1d4e2aa
Merge branch 'main' into main
yiwen-song Nov 19, 2021
091bf6b
Merge branch 'pytorch:main' into main
yiwen-song Nov 23, 2021
41edd15
address comments and trigger unittests
yiwen-song Nov 24, 2021
48ce69e
remove the prototype import in torchvision.models
yiwen-song Nov 24, 2021
0caf745
Merge branch 'main' into main
yiwen-song Nov 24, 2021
3a6b445
Adding vit back to models to trigger CircleCI test
yiwen-song Nov 24, 2021
72c5af7
fix test_jit_forward_backward
yiwen-song Nov 24, 2021
aae308c
Move all to prototype.
datumbox Nov 25, 2021
7b1e59e
Merge branch 'main' into main
datumbox Nov 25, 2021
717b6af
Merge branch 'main' into main
datumbox Nov 25, 2021
f0df7f8
Adopt new helper methods and fix prototype tests.
datumbox Nov 25, 2021
3807b23
Remove unused import.
datumbox Nov 25, 2021
eabec95
Merge branch 'main' into main
yiwen-song Nov 26, 2021
40b566b
Merge branch 'main' into main
yiwen-song Nov 27, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added test/expect/ModelTester.test_vit_b_16_expect.pkl
Binary file not shown.
Binary file added test/expect/ModelTester.test_vit_b_32_expect.pkl
Binary file not shown.
Binary file added test/expect/ModelTester.test_vit_h_14_expect.pkl
Binary file not shown.
Binary file added test/expect/ModelTester.test_vit_l_16_expect.pkl
Binary file not shown.
Binary file added test/expect/ModelTester.test_vit_l_32_expect.pkl
Binary file not shown.
1 change: 1 addition & 0 deletions torchvision/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from .shufflenetv2 import *
from .efficientnet import *
from .regnet import *
from .vision_transformer import *
from . import detection
from . import feature_extraction
from . import quantization
Expand Down
353 changes: 353 additions & 0 deletions torchvision/models/vision_transformer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,353 @@
# References:
# https://github.com/google-research/vision_transformer
# https://github.com/facebookresearch/ClassyVision/blob/main/classy_vision/models/vision_transformer.py

import math
from collections import OrderedDict
from functools import partial
from typing import Any, Callable, Optional

import torch
import torch.nn as nn
from torch import Tensor

__all__ = [
"VisionTransformer",
"vit_b_16",
"vit_b_32",
"vit_l_16",
"vit_l_32",
yiwen-song marked this conversation as resolved.
Show resolved Hide resolved
"vit_h_14",
]


class MLPBlock(nn.Sequential):
yiwen-song marked this conversation as resolved.
Show resolved Hide resolved
"""Transformer MLP block."""

def __init__(self, in_dim: int, mlp_dim: int, dropout: float):
super().__init__()
self.linear_1 = nn.Linear(in_dim, mlp_dim)
self.act = nn.GELU()
self.dropout_1 = nn.Dropout(dropout)
self.linear_2 = nn.Linear(mlp_dim, in_dim)
self.dropout_2 = nn.Dropout(dropout)
self._init_weights()

def _init_weights(self):
nn.init.xavier_uniform_(self.linear_1.weight)
nn.init.xavier_uniform_(self.linear_2.weight)
nn.init.normal_(self.linear_1.bias, std=1e-6)
nn.init.normal_(self.linear_2.bias, std=1e-6)


class EncoderBlock(nn.Module):
"""Transformer encoder block."""

def __init__(
self,
num_heads: int,
hidden_dim: int,
mlp_dim: int,
dropout: float,
attention_dropout: float,
norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6),
datumbox marked this conversation as resolved.
Show resolved Hide resolved
):
super().__init__()
self.num_heads = num_heads

# Attention block
self.ln_1 = norm_layer(hidden_dim)
self.self_attention = nn.MultiheadAttention(hidden_dim, num_heads, dropout=attention_dropout)
self.dropout = nn.Dropout(dropout)

# MLP block
self.ln_2 = norm_layer(hidden_dim)
self.mlp = MLPBlock(hidden_dim, mlp_dim, dropout)

def forward(self, input: Tensor):
torch._assert(input.dim() == 3, f"Expected (seq_length, batch_size, hidden_dim) got {input.shape}")
yiwen-song marked this conversation as resolved.
Show resolved Hide resolved
x = self.ln_1(input)
x, _ = self.self_attention(query=x, key=x, value=x, need_weights=False)
x = self.dropout(x)
x = x + input

y = self.ln_2(x)
y = self.mlp(y)
return x + y


class Encoder(nn.Module):
"""Transformer Model Encoder for sequence to sequence translation."""

def __init__(
self,
seq_length: int,
num_layers: int,
num_heads: int,
hidden_dim: int,
mlp_dim: int,
dropout: float,
attention_dropout: float,
norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6),
):
super().__init__()
# Note that batch_size is on the second dim because
# we have batch_first=False in nn.MultiAttention() by default
self.pos_embedding = nn.Parameter(torch.empty(seq_length, 1, hidden_dim).normal_(std=0.02)) # from BERT
yiwen-song marked this conversation as resolved.
Show resolved Hide resolved
self.dropout = nn.Dropout(dropout)
layers: OrderedDict[str, nn.Module] = OrderedDict()
for i in range(num_layers):
layers[f"encoder_layer_{i}"] = EncoderBlock(
num_heads,
hidden_dim,
mlp_dim,
dropout,
attention_dropout,
norm_layer,
)
self.layers = nn.Sequential(layers)
self.ln = norm_layer(hidden_dim)

def forward(self, input: Tensor):
torch._assert(input.dim() == 3, f"Expected (seq_length, batch_size, hidden_dim) got {input.shape}")
input = input + self.pos_embedding
return self.ln(self.layers(self.dropout(input)))


class VisionTransformer(nn.Module):
"""Vision Transformer as per https://arxiv.org/abs/2010.11929."""

def __init__(
self,
image_size: int,
yiwen-song marked this conversation as resolved.
Show resolved Hide resolved
patch_size: int,
num_layers: int,
num_heads: int,
hidden_dim: int,
mlp_dim: int,
yiwen-song marked this conversation as resolved.
Show resolved Hide resolved
dropout: float = 0.0,
attention_dropout: float = 0.0,
classifier: str = "token",
num_classes: int = 1000,
yiwen-song marked this conversation as resolved.
Show resolved Hide resolved
representation_size: Optional[int] = None,
norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6),
):
super().__init__()
torch._assert(image_size % patch_size == 0, "Input shape indivisible by patch size!")
torch._assert(classifier in ["token", "gap"], "Unexpected classifier mode!")
self.image_size = image_size
self.patch_size = patch_size
self.hidden_dim = hidden_dim
self.mlp_dim = mlp_dim
self.attention_dropout = attention_dropout
self.dropout = dropout
self.classifier = classifier
self.num_classes = num_classes
self.representation_size = representation_size
self.norm_layer = norm_layer

input_channels = 3

# The conv_proj is a more efficient version of reshaping, permuting
# and projecting the input
self.conv_proj = nn.Conv2d(input_channels, hidden_dim, kernel_size=patch_size, stride=patch_size)

seq_length = (image_size // patch_size) ** 2
if self.classifier == "token":
# Add a class token
self.class_token = nn.Parameter(torch.zeros(1, 1, hidden_dim))
seq_length += 1

self.encoder = Encoder(
seq_length,
num_layers,
num_heads,
hidden_dim,
mlp_dim,
dropout,
attention_dropout,
norm_layer,
)
self.seq_length = seq_length

heads_layers: OrderedDict[str, nn.Module] = OrderedDict()
if representation_size is None:
heads_layers["head"] = nn.Linear(hidden_dim, num_classes)
else:
heads_layers["pre_logits"] = nn.Linear(hidden_dim, representation_size)
heads_layers["act"] = nn.Tanh()
heads_layers["head"] = nn.Linear(representation_size, num_classes)

self.heads = nn.Sequential(heads_layers)
self._init_weights()

def _init_weights(self):
fan_in = self.conv_proj.in_channels * self.conv_proj.kernel_size[0] * self.conv_proj.kernel_size[1]
nn.init.trunc_normal_(self.conv_proj.weight, std=math.sqrt(1 / fan_in))
nn.init.zeros_(self.conv_proj.bias)

if hasattr(self.heads, "pre_logits"):
fan_in = self.heads.pre_logits.in_features
nn.init.trunc_normal_(self.heads.pre_logits.weight, std=math.sqrt(1 / fan_in))
nn.init.zeros_(self.heads.pre_logits.bias)

nn.init.zeros_(self.heads.head.weight)
nn.init.zeros_(self.heads.head.bias)

def forward(self, x: torch.Tensor):
n, c, h, w = x.shape
p = self.patch_size
torch._assert(h == self.image_size, "Wrong image height!")
torch._assert(w == self.image_size, "Wrong image width!")
n_h = h // p
n_w = w // p

# (n, c, h, w) -> (n, hidden_dim, n_h, n_w)
x = self.conv_proj(x)
# (n, hidden_dim, n_h, n_w) -> (n, hidden_dim, (n_h * n_w))
x = x.reshape(n, self.hidden_dim, n_h * n_w)

# (n, hidden_dim, (n_h * n_w)) -> ((n_h * n_w), n, hidden_dim)
# The self attention layer expects inputs in the format (S, N, E)
# where S is the source sequence length, N is the batch size, E is the
# embedding dimension
x = x.permute(2, 0, 1)
yiwen-song marked this conversation as resolved.
Show resolved Hide resolved

if self.classifier == "token":
# Expand the class token to the full batch.
batch_class_token = self.class_token.expand(-1, n, -1)
x = torch.cat([batch_class_token, x], dim=0)

x = self.encoder(x)

if self.classifier == "token":
# Classifier as used by standard language architectures
x = x[0, :, :]
elif self.classifier == "gap":
# Classifier as used by standard vision architectures
x = x.mean(dim=0)
else:
raise ValueError(f"Invalid classifier={self.classifier}")
yiwen-song marked this conversation as resolved.
Show resolved Hide resolved

x = self.heads(x)

return x


def _vision_transformer(arch: str, pretrained: bool, progress: bool, **kwargs: Any) -> VisionTransformer:
image_size = kwargs.get("image_size", 224)
if "image_size" in kwargs:
kwargs.pop("image_size")
model = VisionTransformer(image_size=image_size, **kwargs)
# TODO: Adding pre-trained models
return model


def vit_b_16(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VisionTransformer:
"""
Constructs a ViT_b_16 architecture from
`"An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale" <https://arxiv.org/abs/2010.11929>`_.

Args:
pretrained (bool, optional): If True, returns a model pre-trained on ImageNet. Default: False.
progress (bool, optional): If True, displays a progress bar of the download to stderr. Default: True.
"""
return _vision_transformer(
arch="b_16",
pretrained=pretrained,
progress=progress,
patch_size=16,
num_layers=12,
num_heads=12,
hidden_dim=768,
mlp_dim=3072,
**kwargs,
)


def vit_b_32(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VisionTransformer:
"""
Constructs a ViT_b_32 architecture from
`"An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale" <https://arxiv.org/abs/2010.11929>`_.

Args:
pretrained (bool, optional): If True, returns a model pre-trained on ImageNet. Default: False.
progress (bool, optional): If True, displays a progress bar of the download to stderr. Default: True.
"""
return _vision_transformer(
arch="b_32",
pretrained=pretrained,
progress=progress,
patch_size=32,
num_layers=12,
num_heads=12,
hidden_dim=768,
mlp_dim=3072,
**kwargs,
)


def vit_l_16(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VisionTransformer:
"""
Constructs a ViT_l_16 architecture from
`"An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale" <https://arxiv.org/abs/2010.11929>`_.

Args:
pretrained (bool, optional): If True, returns a model pre-trained on ImageNet. Default: False.
progress (bool, optional): If True, displays a progress bar of the download to stderr. Default: True.
"""
return _vision_transformer(
arch="l_16",
pretrained=pretrained,
progress=progress,
patch_size=16,
num_layers=24,
num_heads=16,
hidden_dim=1024,
mlp_dim=4096,
**kwargs,
)


def vit_l_32(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VisionTransformer:
"""
Constructs a ViT_l_32 architecture from
`"An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale" <https://arxiv.org/abs/2010.11929>`_.

Args:
pretrained (bool, optional): If True, returns a model pre-trained on ImageNet. Default: False.
progress (bool, optional): If True, displays a progress bar of the download to stderr. Default: True.
"""
return _vision_transformer(
arch="l_32",
pretrained=pretrained,
progress=progress,
patch_size=32,
num_layers=24,
num_heads=16,
hidden_dim=1024,
mlp_dim=4096,
**kwargs,
)


def vit_h_14(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VisionTransformer:
"""
Constructs a ViT_h_14 architecture from
`"An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale" <https://arxiv.org/abs/2010.11929>`_.

Args:
pretrained (bool, optional): If True, returns a model pre-trained on ImageNet. Default: False.
progress (bool, optional): If True, displays a progress bar of the download to stderr. Default: True.
"""
return _vision_transformer(
arch="h_14",
pretrained=pretrained,
progress=progress,
patch_size=14,
num_layers=32,
num_heads=16,
hidden_dim=1280,
mlp_dim=5120,
**kwargs,
)