|
| 1 | +# Copyright (c) Facebook, Inc. and its affiliates. |
| 2 | + |
| 3 | +# This source code is licensed under the MIT license found in the |
| 4 | +# LICENSE file in the root directory of this source tree. |
| 5 | + |
| 6 | +import copy |
| 7 | +import logging |
| 8 | +from typing import List |
| 9 | + |
| 10 | +import torch |
| 11 | +import torch.nn as nn |
| 12 | +from classy_vision.models import VisionTransformer as ClassyVisionTransformer |
| 13 | +from vissl.config import AttrDict |
| 14 | +from vissl.models.trunks import register_model_trunk |
| 15 | + |
| 16 | + |
| 17 | +@register_model_trunk("classy_vit") |
| 18 | +class ClassyViT(nn.Module): |
| 19 | + """ |
| 20 | + Simple wrapper for ClassyVision Vision Transformer model. |
| 21 | + This model is defined on the fly from a Vision Transformer base class and |
| 22 | + a configuration file. |
| 23 | + """ |
| 24 | + |
| 25 | + def __init__(self, model_config: AttrDict, model_name: str): |
| 26 | + super().__init__() |
| 27 | + self.model_config = model_config |
| 28 | + |
| 29 | + assert model_config.INPUT_TYPE in ["rgb", "bgr"], "Input type not supported" |
| 30 | + trunk_config = copy.deepcopy(model_config.TRUNK.VISION_TRANSFORMERS) |
| 31 | + |
| 32 | + logging.info("Building model: Vision Transformer from yaml config") |
| 33 | + trunk_config = AttrDict({k.lower(): v for k, v in trunk_config.items()}) |
| 34 | + |
| 35 | + self.model = ClassyVisionTransformer( |
| 36 | + image_size=trunk_config.image_size, |
| 37 | + patch_size=trunk_config.patch_size, |
| 38 | + num_layers=trunk_config.num_layers, |
| 39 | + num_heads=trunk_config.num_heads, |
| 40 | + hidden_dim=trunk_config.hidden_dim, |
| 41 | + mlp_dim=trunk_config.mlp_dim, |
| 42 | + dropout_rate=trunk_config.dropout_rate, |
| 43 | + attention_dropout_rate=trunk_config.attention_dropout_rate, |
| 44 | + classifier=trunk_config.classifier, |
| 45 | + ) |
| 46 | + |
| 47 | + def forward( |
| 48 | + self, x: torch.Tensor, out_feat_keys: List[str] = None |
| 49 | + ) -> List[torch.Tensor]: |
| 50 | + x = self.model(x) |
| 51 | + x = x.unsqueeze(0) |
| 52 | + return x |
0 commit comments