How to translate flash (ImageClassifier) inference code into vanilla PyTorch? #1040
-
Hi, I am loading a pretrained from flash.image import ImageClassifier
flash_model = ImageClassifier(backbone="resnet50", num_classes=2)
flash_model = flash_model.load_from_checkpoint('image_classification_model.pt')
flash_model.predict(filename) By looking at the source code, I thought that I should get equivalent results with the syntax import torch
from collections import OrderedDict
from torch.nn import Linear
from PIL import Image
from torchvision import transforms
import numpy as np
# trying to get a resnet50 from the hub
model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet50', pretrained=False)
model.fc = Linear(in_features=2048, out_features=2, bias=True)
# not so great way to "translate" backbone keys, but should be legit...
flash_state_dict = OrderedDict({(k.replace("backbone.", "").replace("head.0", "fc"),v) for (k, v) in flash_model.state_dict().items()})
model.load_state_dict(flash_state_dict)
model.eval()
preprocess = transforms.Compose([
transforms.PILToTensor(),
transforms.Resize(196),
transforms.ConvertImageDtype(torch.float),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
input_image = Image.open(filename)
input_tensor = preprocess(input_image)
input_batch = input_tensor.unsqueeze(0)
with torch.no_grad():
output = model(input_batch)
torch.nn.functional.softmax(output[0], dim=0) The "translated" code runs fine, but results differ. Is it something that should be tweaked in the |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment
-
Well... yes, the problem was in the preprocessing/resizing :/ I get correct results with preprocess = transforms.Compose([
transforms.ToTensor(),
transforms.Resize((196, 196)), # This line fixed things
transforms.ConvertImageDtype(torch.float),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
]) also upon opening, the following line should take care of channels:
|
Beta Was this translation helpful? Give feedback.
Well... yes, the problem was in the preprocessing/resizing :/
I get correct results with
also upon opening, the following line should take care of channels:
input_image = Image.open(filename).convert('RGB')