Skip to content
This repository was archived by the owner on Mar 19, 2024. It is now read-only.

Commit dc59f07

Browse files
prigoyalfacebook-github-bot
authored andcommitted
Add a simple classy transformer wrapper to load the ViT models trained in classy vision (#505)
Summary: Pull Request resolved: #505 as title Reviewed By: iseessel, QuentinDuval Differential Revision: D33795085 fbshipit-source-id: f40c9a5c92bf44a5377a361b254eef326ee97ef8
1 parent 722a7cc commit dc59f07

File tree

1 file changed

+52
-0
lines changed

1 file changed

+52
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
# Copyright (c) Facebook, Inc. and its affiliates.
2+
3+
# This source code is licensed under the MIT license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import copy
7+
import logging
8+
from typing import List
9+
10+
import torch
11+
import torch.nn as nn
12+
from classy_vision.models import VisionTransformer as ClassyVisionTransformer
13+
from vissl.config import AttrDict
14+
from vissl.models.trunks import register_model_trunk
15+
16+
17+
@register_model_trunk("classy_vit")
18+
class ClassyViT(nn.Module):
19+
"""
20+
Simple wrapper for ClassyVision Vision Transformer model.
21+
This model is defined on the fly from a Vision Transformer base class and
22+
a configuration file.
23+
"""
24+
25+
def __init__(self, model_config: AttrDict, model_name: str):
26+
super().__init__()
27+
self.model_config = model_config
28+
29+
assert model_config.INPUT_TYPE in ["rgb", "bgr"], "Input type not supported"
30+
trunk_config = copy.deepcopy(model_config.TRUNK.VISION_TRANSFORMERS)
31+
32+
logging.info("Building model: Vision Transformer from yaml config")
33+
trunk_config = AttrDict({k.lower(): v for k, v in trunk_config.items()})
34+
35+
self.model = ClassyVisionTransformer(
36+
image_size=trunk_config.image_size,
37+
patch_size=trunk_config.patch_size,
38+
num_layers=trunk_config.num_layers,
39+
num_heads=trunk_config.num_heads,
40+
hidden_dim=trunk_config.hidden_dim,
41+
mlp_dim=trunk_config.mlp_dim,
42+
dropout_rate=trunk_config.dropout_rate,
43+
attention_dropout_rate=trunk_config.attention_dropout_rate,
44+
classifier=trunk_config.classifier,
45+
)
46+
47+
def forward(
48+
self, x: torch.Tensor, out_feat_keys: List[str] = None
49+
) -> List[torch.Tensor]:
50+
x = self.model(x)
51+
x = x.unsqueeze(0)
52+
return x

0 commit comments

Comments
 (0)