Table of Contents
timm
is a deep-learning library created by Ross Wightman and is a collection of SOTA computer vision models, layers, utilities, optimizers, schedulers, data-loaders, augmentations and also training/validating scripts with ability to reproduce ImageNet training results.
pip install timm
Or for an editable install,
git clone https://github.com/rwightman/pytorch-image-models
cd pytorch-image-models && pip install -e .
import timm
import torch
model = timm.create_model('resnet34')
x = torch.randn(1, 3, 224, 224)
model(x).shape
torch.Size([1, 1000])
It is that simple to create a model using timm
. The create_model
function is a factory method that can be used to create over 300 models that are part of the timm
library.
To create a pretrained model, simply pass in pretrained=True
.
pretrained_resnet_34 = timm.create_model('resnet34', pretrained=True)
Downloading: "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet34-43635321.pth" to /home/tmabraham/.cache/torch/hub/checkpoints/resnet34-43635321.pth
To create a model with a custom number of classes, simply pass in num_classes=<number_of_classes>
.
import timm
import torch
model = timm.create_model('resnet34', num_classes=10)
x = torch.randn(1, 3, 224, 224)
model(x).shape
torch.Size([1, 10])
timm.list_models()
returns a complete list of available models in timm
. To have a look at a complete list of pretrained models, pass in pretrained=True
in list_models
.
avail_pretrained_models = timm.list_models(pretrained=True)
len(avail_pretrained_models), avail_pretrained_models[:5]
(592,
['adv_inception_v3',
'bat_resnext26ts',
'beit_base_patch16_224',
'beit_base_patch16_224_in22k',
'beit_base_patch16_384'])
There are a total of 271 models with pretrained weights currently available in timm
!
It is also possible to search for model architectures using Wildcard as below:
all_densenet_models = timm.list_models('*densenet*')
all_densenet_models
['densenet121',
'densenet121d',
'densenet161',
'densenet169',
'densenet201',
'densenet264',
'densenet264d_iabn',
'densenetblur121d',
'tv_densenet121']
The fastai library has support for fine-tuning models from timm:
from fastai.vision.all import *
path = untar_data(URLs.PETS)/'images'
dls = ImageDataLoaders.from_name_func(
path, get_image_files(path), valid_pct=0.2,
label_func=lambda x: x[0].isupper(), item_tfms=Resize(224))
# if a string is passed into the model argument, it will now use timm (if it is installed)
learn = vision_learner(dls, 'vit_tiny_patch16_224', metrics=error_rate)
learn.fine_tune(1)
epoch | train_loss | valid_loss | error_rate | time |
---|---|---|---|---|
0 | 0.201583 | 0.024980 | 0.006766 | 00:08 |
epoch | train_loss | valid_loss | error_rate | time |
---|---|---|---|---|
0 | 0.040622 | 0.024036 | 0.005413 | 00:10 |