Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add test_cpu.py for CPU-based testing #287

Open
wants to merge 12 commits into
base: master
Choose a base branch
from
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,8 @@ The code is developed under the following configurations.
- Hardware: >=4 GPUs for training, >=1 GPU for testing (set ```[--gpus GPUS]``` accordingly)
- Software: Ubuntu 16.04.3 LTS, ***CUDA>=8.0, Python>=3.5, PyTorch>=0.4.0***
- Dependencies: numpy, scipy, opencv, yacs, tqdm
- gdown https://drive.google.com/file/d/1Il1Pcb13syeHi9LA9KjXz8KqMFN9izgo -O ckpt.zip
- unzip -j ckpt.zip

## Quick start: Test on an image using our trained model
1. Here is a simple demo to do inference on a single image:
Expand Down
2 changes: 2 additions & 0 deletions demo_test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -35,3 +35,5 @@ python3 -u test.py \
TEST.checkpoint epoch_20.pth

fi
# MODEL_NAME=ade20k-hrnetv2-c1
# python3 -u test_cpu.py --imgs ADE_val_00001519.jpg --cfg config/ade20k-hrnetv2.yaml
12 changes: 6 additions & 6 deletions mit_semseg/lib/utils/th.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,17 @@
def as_variable(obj):
if isinstance(obj, Variable):
return obj
if isinstance(obj, collections.Sequence):
if isinstance(obj, collections.abc.Sequence):
return [as_variable(v) for v in obj]
elif isinstance(obj, collections.Mapping):
elif isinstance(obj, collections.abc.Mapping):
return {k: as_variable(v) for k, v in obj.items()}
else:
return Variable(obj)

def as_numpy(obj):
if isinstance(obj, collections.Sequence):
if isinstance(obj, collections.abc.Sequence):
return [as_numpy(v) for v in obj]
elif isinstance(obj, collections.Mapping):
elif isinstance(obj, collections.abc.Mapping):
return {k: as_numpy(v) for k, v in obj.items()}
elif isinstance(obj, Variable):
return obj.data.cpu().numpy()
Expand All @@ -33,9 +33,9 @@ def mark_volatile(obj):
if isinstance(obj, Variable):
obj.no_grad = True
return obj
elif isinstance(obj, collections.Mapping):
elif isinstance(obj, collections.abc.Mapping):
return {k: mark_volatile(o) for k, o in obj.items()}
elif isinstance(obj, collections.Sequence):
elif isinstance(obj, collections.abc.Sequence):
return [mark_volatile(o) for o in obj]
else:
return obj
15 changes: 12 additions & 3 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,16 @@
--extra-index-url https://download.pytorch.org/whl/cpu
numpy
scipy
pytorch==0.4.1
torchvision
opencv3
torch==1.12.1+cpu
torchvision==0.13.1+cpu
opencv-python
yacs
tqdm
aiohttp
aiofiles
aiohttp_cors
pillow
gdown
uvicorn[standard]
fastapi
python-multipart
191 changes: 191 additions & 0 deletions test_cpu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
# System libs
import os
import argparse
from distutils.version import LooseVersion
# Numerical libs
import numpy as np
import torch
import torch.nn as nn
from scipy.io import loadmat
import csv
# Our libs
from mit_semseg.dataset import TestDataset
from mit_semseg.models import ModelBuilder, SegmentationModule
from mit_semseg.utils import colorEncode, find_recursive, setup_logger
from mit_semseg.lib.nn import user_scattered_collate, async_copy_to
from mit_semseg.lib.utils import as_numpy
from PIL import Image
from tqdm import tqdm
from mit_semseg.config import cfg

colors = loadmat('data/color150.mat')['colors']
names = {}
with open('data/object150_info.csv') as f:
reader = csv.reader(f)
next(reader)
for row in reader:
names[int(row[0])] = row[5].split(";")[0]


def visualize_result(data, pred, cfg):
(img, info) = data

# print predictions in descending order
pred = np.int32(pred)
pixs = pred.size
uniques, counts = np.unique(pred, return_counts=True)
print("Predictions in [{}]:".format(info))
for idx in np.argsort(counts)[::-1]:
name = names[uniques[idx] + 1]
ratio = counts[idx] / pixs * 100
if ratio > 0.1:
print(" {}: {:.2f}%".format(name, ratio))

# colorize prediction
pred_color = colorEncode(pred, colors).astype(np.uint8)

# aggregate images and save
im_vis = np.concatenate((img, pred_color), axis=1)

img_name = info.split('/')[-1]
Image.fromarray(im_vis).save(
os.path.join(cfg.TEST.result, img_name.replace('.jpg', '.png')))


def test(segmentation_module, loader):
segmentation_module.eval()

pbar = tqdm(total=len(loader))
for batch_data in loader:
# process data
batch_data = batch_data[0]
segSize = (batch_data['img_ori'].shape[0],
batch_data['img_ori'].shape[1])
img_resized_list = batch_data['img_data']

with torch.no_grad():
scores = torch.zeros(1, cfg.DATASET.num_class, segSize[0], segSize[1])

for img in img_resized_list:
feed_dict = batch_data.copy()
feed_dict['img_data'] = img
del feed_dict['img_ori']
del feed_dict['info']

# forward pass
pred_tmp = segmentation_module(feed_dict, segSize=segSize)
scores = scores + pred_tmp / len(cfg.DATASET.imgSizes)

_, pred = torch.max(scores, dim=1)
pred = as_numpy(pred.squeeze(0))

# visualization
visualize_result(
(batch_data['img_ori'], batch_data['info']),
pred,
cfg
)

pbar.update(1)



def main(cfg):
# Network Builders
net_encoder = ModelBuilder.build_encoder(
arch=cfg.MODEL.arch_encoder,
fc_dim=cfg.MODEL.fc_dim,
weights=cfg.MODEL.weights_encoder)
net_decoder = ModelBuilder.build_decoder(
arch=cfg.MODEL.arch_decoder,
fc_dim=cfg.MODEL.fc_dim,
num_class=cfg.DATASET.num_class,
weights=cfg.MODEL.weights_decoder,
use_softmax=True)

crit = nn.NLLLoss(ignore_index=-1)

segmentation_module = SegmentationModule(net_encoder, net_decoder, crit).cpu()

# Dataset and Loader
dataset_test = TestDataset(
cfg.list_test,
cfg.DATASET)
loader_test = torch.utils.data.DataLoader(
dataset_test,
batch_size=cfg.TEST.batch_size,
shuffle=False,
collate_fn=user_scattered_collate,
num_workers=5,
drop_last=True)

# Main loop
test(segmentation_module, loader_test)

print('Inference done!')


if __name__ == '__main__':
assert LooseVersion(torch.__version__) >= LooseVersion('0.4.0'), \
'PyTorch>=0.4.0 is required'

parser = argparse.ArgumentParser(
description="PyTorch Semantic Segmentation Testing"
)
parser.add_argument(
"--imgs",
required=True,
type=str,
help="an image path, or a directory name"
)
parser.add_argument(
"--cfg",
default="config/ade20k-resnet50dilated-ppm_deepsup.yaml",
metavar="FILE",
help="path to config file",
type=str,
)
parser.add_argument(
"opts",
help="Modify config options using the command-line",
default=None,
nargs=argparse.REMAINDER,
)
args = parser.parse_args()

cfg.merge_from_file(args.cfg)
cfg.merge_from_list(args.opts)
# cfg.freeze()

logger = setup_logger(distributed_rank=0) # TODO
logger.info("Loaded configuration file {}".format(args.cfg))
logger.info("Running with config:\n{}".format(cfg))

cfg.MODEL.arch_encoder = cfg.MODEL.arch_encoder.lower()
cfg.MODEL.arch_decoder = cfg.MODEL.arch_decoder.lower()

# absolute paths of model weights
cfg.MODEL.weights_encoder = os.path.join(
cfg.DIR, 'encoder_' + cfg.TEST.checkpoint)
cfg.MODEL.weights_decoder = os.path.join(
cfg.DIR, 'decoder_' + cfg.TEST.checkpoint)


print(cfg.MODEL.weights_encoder)
print(cfg.MODEL.weights_decoder)

assert os.path.exists(cfg.MODEL.weights_encoder) and \
os.path.exists(cfg.MODEL.weights_decoder), "checkpoint does not exist!"

# generate testing image list
if os.path.isdir(args.imgs):
imgs = find_recursive(args.imgs)
else:
imgs = [args.imgs]
assert len(imgs), "imgs should be a path to image (.jpg) or directory."
cfg.list_test = [{'fpath_img': x} for x in imgs]

if not os.path.isdir(cfg.TEST.result):
os.makedirs(cfg.TEST.result)

main(cfg)