diff --git a/papers/squeezedet/LICENSE b/papers/squeezedet/LICENSE new file mode 100644 index 0000000..36f6b4a --- /dev/null +++ b/papers/squeezedet/LICENSE @@ -0,0 +1,39 @@ +Parts of the code in Squeezedet_Theano are derived from the +following version of SqueezeDet by Bichen Wu: + + Author : Bichen Wu + Date : 22 November 2017 + Source : https://github.com/BichenWuUCB/squeezeDet/ + Commit : e7c0860 (Wed Nov 22 15:47:02 2017 -0800) + +Copyright for portions of the derived code in Squeezedet_Theano are +held by Bichen Wu, 2016-2107. All other code in Squeezedet_Theano is +copyright Corvidim, 2017. This code is published in accordance with +and retains the same license as SqueezeDet by Bichen Wu. + +BSD 2-Clause License + +Copyright (c) 2016-2017, Bichen Wu +Copyright (c) 2017, Corvidim +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +* Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +* Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/papers/squeezedet/README.md b/papers/squeezedet/README.md new file mode 100644 index 0000000..25d59f2 --- /dev/null +++ b/papers/squeezedet/README.md @@ -0,0 +1,82 @@ +# SqueezeDet_Theano + +This is a reimplementation of Bichen Wu's SqueezeDet [1],[2] for +Theano [3] and Lasagne [4] by +[@corvidim](https://github.com/corvidim). The version contained in +this repo corresponds to a snapshot of +[squeezedet_theano](https://github.com/corvidim/squeezedet_theano) +0.1.0. + +SqueezeDet is particularly well-suited to low-power, +low-memory-consumption applications. In particular, the trained +weights take up less than 8Mb! + +This release includes: + +* Network weights trained on the KITTI dataset (`data/squeezedet_kitti.pkl.gz`, <8MB). + +* The detector ("prediction graph" and "interpretation graph" +of the original TensorFlow implementation); training is not yet +supported. + + +References: + +- [1] https://arxiv.org/abs/1612.01051 +- [2] https://github.com/BichenWuUCB/squeezeDet +- [3] http://deeplearning.net/software/theano/ +- [4] http://lasagne.readthedocs.io/en/latest/ + +## Setup + +Requirements: + +- Python 2.7+ (Python 3 experimental support; tested under Python 3.6.3) +- Theano 0.9+ +- Lasagne 0.2+ +- Numpy, PIL, Matplotlib + +Sample docker files for Ubuntu installation: + +- `./setup/sqzdet_theano_dockerfile.txt` (Python 2, Theano 0.10) +- `./setup/sqzdet_theano_0.9_dockerfile.txt` (Python 2, Theano 0.9) +- `./setup/sqzdet_theano_py3k_dockerfile.txt` (Python 3, Theano 0.10) + + +To run detection on a default image (`data/sample.png` from the original SqueezeDet demo): + +``` +THEANO_FLAGS='floatX=float32' python src/sqz_det_thn.py --network_weights=data/squeezedet_kitti.pkl.gz +``` + +Use the `--img_in` flag to specify an image file, for example: + +``` +THEANO_FLAGS='floatX=float32' python src/sqz_det_thn.py --network_weights=data/squeezedet_kitti.pkl.gz --img_in=data/Boston_01.png +``` + +or specify a comma-separated list of images (all of which must be the same size), for example: + +``` +THEANO_FLAGS='floatX=float32' python src/sqz_det_thn.py --network_weights=data/squeezedet_kitti.pkl.gz --img_in=data/sample.png,data/Boston_00.png +``` + +Output visualization(s) will be saved in PDF format to the current +directory (e.g., `./out_sample.pdf`) and displayed (if running +interactively) as shown below. To specify a different directory for +output use the `--out_dir` flag. + +## Output visualizations of sample input images +--- +![output from sample.png](data/output/out_sample.png) + +--- + +![output from Boston_00.png](data/output/out_Boston_00.png) + +--- + +![output from Boston_01.png](data/output/out_Boston_01.png) + +--- + diff --git a/papers/squeezedet/data/Boston_00.png b/papers/squeezedet/data/Boston_00.png new file mode 100644 index 0000000..3e366af Binary files /dev/null and b/papers/squeezedet/data/Boston_00.png differ diff --git a/papers/squeezedet/data/Boston_01.png b/papers/squeezedet/data/Boston_01.png new file mode 100644 index 0000000..000eb56 Binary files /dev/null and b/papers/squeezedet/data/Boston_01.png differ diff --git a/papers/squeezedet/data/output/out_Boston_00.png b/papers/squeezedet/data/output/out_Boston_00.png new file mode 100644 index 0000000..c42b25e Binary files /dev/null and b/papers/squeezedet/data/output/out_Boston_00.png differ diff --git a/papers/squeezedet/data/output/out_Boston_01.png b/papers/squeezedet/data/output/out_Boston_01.png new file mode 100644 index 0000000..9149abf Binary files /dev/null and b/papers/squeezedet/data/output/out_Boston_01.png differ diff --git a/papers/squeezedet/data/output/out_sample.png b/papers/squeezedet/data/output/out_sample.png new file mode 100644 index 0000000..e541283 Binary files /dev/null and b/papers/squeezedet/data/output/out_sample.png differ diff --git a/papers/squeezedet/data/sample.png b/papers/squeezedet/data/sample.png new file mode 100644 index 0000000..2afc926 Binary files /dev/null and b/papers/squeezedet/data/sample.png differ diff --git a/papers/squeezedet/data/squeezedet_kitti.pkl.gz b/papers/squeezedet/data/squeezedet_kitti.pkl.gz new file mode 100644 index 0000000..ac1bb28 Binary files /dev/null and b/papers/squeezedet/data/squeezedet_kitti.pkl.gz differ diff --git a/papers/squeezedet/setup/sqzdet_theano_0.9_dockerfile.txt b/papers/squeezedet/setup/sqzdet_theano_0.9_dockerfile.txt new file mode 100644 index 0000000..699297e --- /dev/null +++ b/papers/squeezedet/setup/sqzdet_theano_0.9_dockerfile.txt @@ -0,0 +1,37 @@ +#Ubuntu CUDA Py2 docker file for sqzdet_theano under Theano 0.9 +# nvidia-docker build -t sqzdet_theano_0.9 . -- file sqzdet_theano_0.9.txt + +FROM nvidia/cuda:8.0-cudnn5-devel-ubuntu16.04 + +RUN apt-get update && apt-get install -y \ + apt-utils \ + git \ + wget + +RUN wget --no-check-certificate https://repo.continuum.io/miniconda/Miniconda2-latest-Linux-x86_64.sh -O miniconda2.sh + +RUN chmod +x miniconda2.sh && bash ./miniconda2.sh -b -p /opt/miniconda2 + +ENV PATH /opt/miniconda2/bin:$PATH + +RUN conda update conda + +# as of 20171226: matplotlib=2.1.1 numpy=1.13.3 pillow=4.3.0 +RUN conda install \ + matplotlib \ + numpy \ + pillow \ + mkl-service + +#NB: we require Lasagne>=0.2, still "bleeding edge" (0.2.dev1 as of 20171225) +RUN pip install \ + numpy \ + easydict \ + future \ + Theano==0.9 \ + --upgrade https://github.com/Lasagne/Lasagne/archive/master.zip + +# to avoid RuntimeError('To use MKL 2018 with Theano...') +ENV MKL_THREADING_LAYER GNU + +# nvidia-docker run --rm -it --name st_0 -v /home/ubuntu:/home/ubuntu sqzdet_theano_0.9 diff --git a/papers/squeezedet/setup/sqzdet_theano_dockerfile.txt b/papers/squeezedet/setup/sqzdet_theano_dockerfile.txt new file mode 100644 index 0000000..be77d90 --- /dev/null +++ b/papers/squeezedet/setup/sqzdet_theano_dockerfile.txt @@ -0,0 +1,36 @@ +#Ubuntu CUDA Python2 docker file for sqzdet_theano under Theano 1.0 +# nvidia-docker build -t sqzdet_theano . --file sqzdet_theano.txt + +FROM nvidia/cuda:8.0-cudnn6-devel-ubuntu16.04 + +RUN apt-get update && apt-get install -y \ + apt-utils \ + git \ + wget + +RUN wget --no-check-certificate https://repo.continuum.io/miniconda/Miniconda2-latest-Linux-x86_64.sh -O miniconda2.sh + +RUN chmod +x miniconda2.sh && bash ./miniconda2.sh -b -p /opt/miniconda2 + +ENV PATH /opt/miniconda2/bin:$PATH + +RUN conda update conda + +# as of 20171226: matplotlib=2.1.1 numpy=1.13.3 pillow=4.3.0 +RUN conda install \ + matplotlib \ + numpy \ + pillow \ + mkl-service + +# as of 20171226, bleeding-edge installs below: theano=1.0.1+unknown, lasagne=0.2.dev1 +RUN pip install \ + easydict \ + future \ + --upgrade https://github.com/Theano/Theano/archive/master.zip \ + --upgrade https://github.com/Lasagne/Lasagne/archive/master.zip + +#to avoid RuntimeError('To use MKL 2018 with Theano...') +ENV MKL_THREADING_LAYER GNU + +# nvidia-docker run --rm -it --name st_1 -v /home/ubuntu:/home/ubuntu sqzdet_theano diff --git a/papers/squeezedet/setup/sqzdet_theano_py3k_dockerfile.txt b/papers/squeezedet/setup/sqzdet_theano_py3k_dockerfile.txt new file mode 100644 index 0000000..9309677 --- /dev/null +++ b/papers/squeezedet/setup/sqzdet_theano_py3k_dockerfile.txt @@ -0,0 +1,36 @@ +#Ubuntu CUDA Python3 docker file for sqzdet_theano under Theano 1.0 +# nvidia-docker build -t sqzdet_theano_py3k . --file sqzdet_theano_py3k.txt + +FROM nvidia/cuda:8.0-cudnn6-devel-ubuntu16.04 + +RUN apt-get update && apt-get install -y \ + apt-utils \ + git \ + wget + +RUN wget --no-check-certificate https://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh -O miniconda3.sh + +RUN chmod +x miniconda3.sh && bash ./miniconda3.sh -b -p /opt/miniconda3 + +ENV PATH /opt/miniconda3/bin:$PATH + +RUN conda update conda + +# as of 20171226: matplotlib=2.1.1 numpy=1.13.3 pillow=4.3.0 +RUN conda install \ + matplotlib \ + numpy \ + pillow \ + mkl-service + +# as of 20171226, bleeding-edge installs below: theano=1.0.1+unknown, lasagne=0.2.dev1 +RUN pip install \ + easydict \ + future \ + --upgrade https://github.com/Theano/Theano/archive/master.zip \ + --upgrade https://github.com/Lasagne/Lasagne/archive/master.zip + +#to avoid RuntimeError('To use MKL 2018 with Theano...') +ENV MKL_THREADING_LAYER GNU + +# nvidia-docker run --rm -it --name st_3k -v /home/ubuntu:/home/ubuntu sqzdet_theano_py3k diff --git a/papers/squeezedet/src/config.py b/papers/squeezedet/src/config.py new file mode 100644 index 0000000..b2f9569 --- /dev/null +++ b/papers/squeezedet/src/config.py @@ -0,0 +1,142 @@ +# Author: Bichen Wu (bichen@berkeley.edu) 08/25/2016 + +"""Base Model configurations""" + +import os +import os.path as osp +import numpy as np +from easydict import EasyDict as edict + +def base_model_config(dataset='PASCAL_VOC'): + assert dataset.upper()=='PASCAL_VOC' or dataset.upper()=='KITTI', \ + 'Currently only support PASCAL_VOC or KITTI dataset' + + cfg = edict() + + # Dataset used to train/val/test model. Now support PASCAL_VOC or KITTI + cfg.DATASET = dataset.upper() + + if cfg.DATASET == 'PASCAL_VOC': + # object categories to classify + cfg.CLASS_NAMES = ('aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', + 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', + 'horse', 'motorbike', 'person', 'pottedplant', 'sheep', + 'sofa', 'train', 'tvmonitor') + elif cfg.DATASET == 'KITTI': + cfg.CLASS_NAMES = ('car', 'pedestrian', 'cyclist') + + # number of categories to classify + cfg.CLASSES = len(cfg.CLASS_NAMES) + + # ROI pooling output width + cfg.GRID_POOL_WIDTH = 7 + + # ROI pooling output height + cfg.GRID_POOL_HEIGHT = 7 + + # parameter used in leaky ReLU + cfg.LEAKY_COEF = 0.1 + + # Probability to keep a node in dropout + cfg.KEEP_PROB = 0.5 + + # image width + cfg.IMAGE_WIDTH = 224 + + # image height + cfg.IMAGE_HEIGHT = 224 + + # anchor box, array of [cx, cy, w, h]. To be defined later + cfg.ANCHOR_BOX = [] + + # number of anchor boxes + cfg.ANCHORS = len(cfg.ANCHOR_BOX) + + # number of anchor boxes per grid + cfg.ANCHOR_PER_GRID = -1 + + # batch size + cfg.BATCH_SIZE = 20 + + # Only keep boxes with probability higher than this threshold + cfg.PROB_THRESH = 0.005 + + # Only plot boxes with probability higher than this threshold + cfg.PLOT_PROB_THRESH = 0.5 + + # Bounding boxes with IOU larger than this are going to be removed + cfg.NMS_THRESH = 0.2 + + # Pixel mean values (BGR order) as a (1, 1, 3) array. Below is the BGR mean + # of VGG16 + cfg.BGR_MEANS = np.array([[[103.939, 116.779, 123.68]]]) + + # loss coefficient for confidence regression + cfg.LOSS_COEF_CONF = 1.0 + + # loss coefficient for classification regression + cfg.LOSS_COEF_CLASS = 1.0 + + # loss coefficient for bounding box regression + cfg.LOSS_COEF_BBOX = 10.0 + + # reduce step size after this many steps + cfg.DECAY_STEPS = 10000 + + # multiply the learning rate by this factor + cfg.LR_DECAY_FACTOR = 0.1 + + # learning rate + cfg.LEARNING_RATE = 0.005 + + # momentum + cfg.MOMENTUM = 0.9 + + # weight decay + cfg.WEIGHT_DECAY = 0.0005 + + # wether to load pre-trained model + cfg.LOAD_PRETRAINED_MODEL = True + + # path to load the pre-trained model + cfg.PRETRAINED_MODEL_PATH = '' + + # print log to console in debug mode + cfg.DEBUG_MODE = False + + # a small value used to prevent numerical instability + cfg.EPSILON = 1e-16 + + # threshold for safe exponential operation + cfg.EXP_THRESH=1.0 + + # gradients with norm larger than this is going to be clipped. + cfg.MAX_GRAD_NORM = 10.0 + + # Whether to do data augmentation + cfg.DATA_AUGMENTATION = False + + # The range to randomly shift the image widht + cfg.DRIFT_X = 0 + + # The range to randomly shift the image height + cfg.DRIFT_Y = 0 + + # Whether to exclude images harder than hard-category. Only useful for KITTI + # dataset. + cfg.EXCLUDE_HARD_EXAMPLES = True + + # small value used in batch normalization to prevent dividing by 0. The + # default value here is the same with caffe's default value. + cfg.BATCH_NORM_EPSILON = 1e-5 + + # number of threads to fetch data + cfg.NUM_THREAD = 4 + + # capacity for FIFOQueue + cfg.QUEUE_CAPACITY = 100 + + # indicate if the model is in training mode + cfg.IS_TRAINING = False + + return cfg diff --git a/papers/squeezedet/src/kitti_squeezeDet_config.py b/papers/squeezedet/src/kitti_squeezeDet_config.py new file mode 100644 index 0000000..ed82d8a --- /dev/null +++ b/papers/squeezedet/src/kitti_squeezeDet_config.py @@ -0,0 +1,79 @@ +# Author: Bichen Wu (bichen@berkeley.edu) 08/25/2016 + +"""Model configuration for pascal dataset""" + +import numpy as np + +from config import base_model_config + +def kitti_squeezeDet_config(): + """Specify the parameters to tune below.""" + mc = base_model_config('KITTI') + + mc.IMAGE_WIDTH = 1248 + mc.IMAGE_HEIGHT = 384 + mc.BATCH_SIZE = 20 + + mc.WEIGHT_DECAY = 0.0001 + mc.LEARNING_RATE = 0.01 + mc.DECAY_STEPS = 10000 + mc.MAX_GRAD_NORM = 1.0 + mc.MOMENTUM = 0.9 + mc.LR_DECAY_FACTOR = 0.5 + + mc.LOSS_COEF_BBOX = 5.0 + mc.LOSS_COEF_CONF_POS = 75.0 + mc.LOSS_COEF_CONF_NEG = 100.0 + mc.LOSS_COEF_CLASS = 1.0 + + mc.PLOT_PROB_THRESH = 0.4 + mc.NMS_THRESH = 0.4 + mc.PROB_THRESH = 0.005 + mc.TOP_N_DETECTION = 64 + + mc.DATA_AUGMENTATION = True + mc.DRIFT_X = 150 + mc.DRIFT_Y = 100 + mc.EXCLUDE_HARD_EXAMPLES = False + + mc.ANCHOR_BOX = set_anchors(mc) + mc.ANCHORS = len(mc.ANCHOR_BOX) + mc.ANCHOR_PER_GRID = 9 + + return mc + +def set_anchors(mc): + H, W, B = 24, 78, 9 + anchor_shapes = np.reshape( + [np.array( + [[ 36., 37.], [ 366., 174.], [ 115., 59.], + [ 162., 87.], [ 38., 90.], [ 258., 173.], + [ 224., 108.], [ 78., 170.], [ 72., 43.]])] * H * W, + (H, W, B, 2) + ) + center_x = np.reshape( + np.transpose( + np.reshape( + np.array([np.arange(1, W+1)*float(mc.IMAGE_WIDTH)/(W+1)]*H*B), + (B, H, W) + ), + (1, 2, 0) + ), + (H, W, B, 1) + ) + center_y = np.reshape( + np.transpose( + np.reshape( + np.array([np.arange(1, H+1)*float(mc.IMAGE_HEIGHT)/(H+1)]*W*B), + (B, W, H) + ), + (2, 1, 0) + ), + (H, W, B, 1) + ) + anchors = np.reshape( + np.concatenate((center_x, center_y, anchor_shapes), axis=3), + (-1, 4) + ) + + return anchors diff --git a/papers/squeezedet/src/sqz_det_thn.py b/papers/squeezedet/src/sqz_det_thn.py new file mode 100644 index 0000000..05bc073 --- /dev/null +++ b/papers/squeezedet/src/sqz_det_thn.py @@ -0,0 +1,311 @@ +""" +Copyright (c) 2017 Corvidim (corvidim.net) +Licensed under the BSD 2-Clause License (see LICENSE for details) +Authors: V. Ablavsky, A. J. Fox +""" + +from __future__ import (absolute_import, division, + print_function, unicode_literals) + +import argparse, os, sys, gzip +from pdb import set_trace as keyboard + +import utils + +import numpy as np +import PIL.Image + +python_ver = sys.version_info[0] +import pickle + +import lasagne +import theano, theano.tensor as T +import collections + +import kitti_squeezeDet_config + +import matplotlib +# Python 2 vs. 3 backends +if sys.platform=='darwin' and sys.version_info[0] == 3: + matplotlib.use('qt5agg') +# if headless, we need this *before* importing plt +if utils.is_headless(): + orig_backend = matplotlib.rcParams['backend'] + matplotlib.use('Agg') + print('Headless; resetting backend from {} to {}' + ' and assuming no_gui'.format(orig_backend, matplotlib.rcParams['backend'])) +import matplotlib.pyplot as plt + +###################################################################################### +# load_network_weights() +###################################################################################### +def load_network_weights(par): + print('********* load_network_weights()') + try: + path_ = par['network_weights'] + dir_, file_ = os.path.split(path_) + base_file, base_ext = os.path.splitext(file_) + assert base_ext == '.gz' and os.path.splitext(base_file)[1] == '.pkl' + + with gzip.open(path_, 'rb') as ar: + # This .pkl was produced with Py2K; if loading in Py3K we need to specify encoding + if python_ver < 3: + network_weights = pickle.load(ar) + else: + network_weights = pickle.load(ar, encoding='latin-1') + except: + raise ValueError('Expected --network_weights=/PATH/TO/FILE.pkl.gz') + return network_weights + +###################################################################################### +# viz_det_roi +###################################################################################### +def viz_det_roi(img, det_roi, det_label, par, out_file_name, plot_title=''): + print('********* viz_det_roi()') + """ + img: RGB order, normalized to [0,1] + """ + + if utils.is_headless(): + par['no_gui'] = True + + cls2clr = { + 'car': (0, 0.75, 1), + 'cyclist': (1, 0.75, 0), + 'ped':(1, 0, 0.75) + } + + if not par['no_gui']: + plt.ion() + + plt.figure() + plt.suptitle('viz_det_roi(): {} bounding box(es)'.format(len(det_roi))) + plt.title(plot_title) + + plt.imshow(img) + plt.axis('off') + ax=plt.gca() + + for bbox, label in zip(det_roi, det_label): + # bbox of form [cx, cy, w, h] + w = bbox[2] + h = bbox[3] + [xmin, ymin, xmax, ymax] = utils.bbox_transform(bbox) + class_str = label.split(':')[0] + color_val = cls2clr[class_str] + + print('adding box: {} [~{}x{}: xmin:{:.2f}, ymin:{:.2f}),' + ' (xmax:{:.2f}, ymax:{:.2f}] '.format(label, + int(ymax-ymin), + int(xmax-xmin), + xmin, + ymin, + xmax, + ymax)) + + rect = matplotlib.patches.Rectangle((xmin,ymin),w,h,edgecolor=color_val, + facecolor='none',linewidth=0.5) + ax.add_patch(rect) + ax.text(xmin,ymin+h,label,fontdict=dict(color=color_val,fontsize=6)) + + if not par['no_gui']: + plt.show() + + plt.savefig(out_file_name,bbox_inches='tight',pad_inches=0.0) + print ('Image detection output saved to {}'.format(out_file_name)) + +###################################################################################### +# load_img() +###################################################################################### +def load_img(img_path,mean_img_BGR,par): + print('********* load_img()') + raw_PIL = PIL.Image.open(img_path) + img_RGB = np.array(raw_PIL).astype('float32') # shape (375, 1242, 3), pixel values 0..255 + img_BGR = img_RGB[:,:,::-1] # RGB -> BGR order, to match sqzdet + img_ = img_BGR - mean_img_BGR # subtract mean image, to match sqzdet + + img_lasagne = np.expand_dims(np.rollaxis(img_,2),axis=0) # shape (1, 3, 375, 1242) + + if par['verbose']: + print('incoming image shape: {} ({})'.format(img_RGB.shape, img_path)) + return (img_RGB, img_lasagne) + + +###################################################################################### +# make_sqz_det_net() +###################################################################################### +def make_sqz_det_net(thn_x, par): + """ + Input: thn_x a 4-d tensor (batch_idx,channel_idx,row_idx,col_idx) + par a dictionary of network specification + + Output: Lasagne network as an ordered dictionary + """ + print('-'*5 + 'make_sqz_det_net()') + input_shape = (None,3,None,None) # assume 3 channels, but not input image dimensions + net = collections.OrderedDict() + + net['input'] = lasagne.layers.InputLayer(shape=input_shape, input_var=thn_x) + net['conv1'] = lasagne.layers.Conv2DLayer(incoming=net['input'], num_filters=64, filter_size=3, stride=2, flip_filters=False, pad='valid') + net['pool1'] = lasagne.layers.MaxPool2DLayer(incoming=net['conv1'], stride=2, pool_size=3) + net['conv2_fire1_sqz'] = lasagne.layers.Conv2DLayer(incoming=net['pool1'], num_filters=16, filter_size=1, stride=1, flip_filters=False, pad='same') + net['conv3_fire1_exp'] = lasagne.layers.Conv2DLayer(incoming=net['conv2_fire1_sqz'], num_filters=64, filter_size=1, stride=1, flip_filters=False, pad='same') + net['conv4_fire1_exp'] = lasagne.layers.Conv2DLayer(incoming=net['conv2_fire1_sqz'], num_filters=64, filter_size=3, stride=1, flip_filters=False, pad='same') + net['concat1'] = lasagne.layers.ConcatLayer((net['conv3_fire1_exp'],net['conv4_fire1_exp']),axis=1) + net['conv5_fire2_sqz'] = lasagne.layers.Conv2DLayer(incoming=net['concat1'], num_filters=16, filter_size=1, stride=1, flip_filters=False, pad='same') + net['conv6_fire2_exp'] = lasagne.layers.Conv2DLayer(incoming=net['conv5_fire2_sqz'], num_filters=64, filter_size=1, stride=1, flip_filters=False, pad='same') + net['conv7_fire2_exp'] = lasagne.layers.Conv2DLayer(incoming=net['conv5_fire2_sqz'], num_filters=64, filter_size=3, stride=1, flip_filters=False, pad='same') + net['concat2'] = lasagne.layers.ConcatLayer((net['conv6_fire2_exp'],net['conv7_fire2_exp']),axis=1) + net['pool2'] = lasagne.layers.MaxPool2DLayer(incoming=net['concat2'], stride=2, pool_size=3) + net['conv8_fire3_sqz'] = lasagne.layers.Conv2DLayer(incoming=net['pool2'], num_filters=32, filter_size=1, stride=1, flip_filters=False, pad='same') + net['conv9_fire3_exp'] = lasagne.layers.Conv2DLayer(incoming=net['conv8_fire3_sqz'], num_filters=128, filter_size=1, stride=1, flip_filters=False, pad='same') + net['conv10_fire3_exp'] = lasagne.layers.Conv2DLayer(incoming=net['conv8_fire3_sqz'], num_filters=128, filter_size=3, stride=1, flip_filters=False, pad='same') + net['concat3'] = lasagne.layers.ConcatLayer((net['conv9_fire3_exp'],net['conv10_fire3_exp']),axis=1) + net['conv11_fire4_sqz'] = lasagne.layers.Conv2DLayer(incoming=net['concat3'], num_filters=32, filter_size=1, stride=1, flip_filters=False, pad='same') + net['conv12_fire4_exp'] = lasagne.layers.Conv2DLayer(incoming=net['conv11_fire4_sqz'], num_filters=128, filter_size=1, stride=1, flip_filters=False, pad='same') + net['conv13_fire4_exp'] = lasagne.layers.Conv2DLayer(incoming=net['conv11_fire4_sqz'], num_filters=128, filter_size=3, stride=1, flip_filters=False, pad='same') + net['concat4'] = lasagne.layers.ConcatLayer((net['conv12_fire4_exp'],net['conv13_fire4_exp']),axis=1) + net['pool3'] = lasagne.layers.MaxPool2DLayer(incoming=net['concat4'], stride=2, pool_size=3) + net['conv14_fire5_sqz'] = lasagne.layers.Conv2DLayer(incoming=net['pool3'], num_filters=48, filter_size=1, stride=1, flip_filters=False, pad='same') + net['conv15_fire5_exp'] = lasagne.layers.Conv2DLayer(incoming=net['conv14_fire5_sqz'], num_filters=192, filter_size=1, stride=1, flip_filters=False, pad='same') + net['conv16_fire5_exp'] = lasagne.layers.Conv2DLayer(incoming=net['conv14_fire5_sqz'], num_filters=192, filter_size=3, stride=1, flip_filters=False, pad='same') + net['concat5'] = lasagne.layers.ConcatLayer((net['conv15_fire5_exp'],net['conv16_fire5_exp']),axis=1) + net['conv17_fire6_sqz'] = lasagne.layers.Conv2DLayer(incoming=net['concat5'], num_filters=48, filter_size=1, stride=1, flip_filters=False, pad='same') + net['conv18_fire6_exp'] = lasagne.layers.Conv2DLayer(incoming=net['conv17_fire6_sqz'], num_filters=192, filter_size=1, stride=1, flip_filters=False, pad='same') + net['conv19_fire6_exp'] = lasagne.layers.Conv2DLayer(incoming=net['conv17_fire6_sqz'], num_filters=192, filter_size=3, stride=1, flip_filters=False, pad='same') + net['concat6'] = lasagne.layers.ConcatLayer((net['conv18_fire6_exp'],net['conv19_fire6_exp']),axis=1) + net['conv20_fire7_sqz'] = lasagne.layers.Conv2DLayer(incoming=net['concat6'], num_filters=64, filter_size=1, stride=1, flip_filters=False, pad='same') + net['conv21_fire7_exp'] = lasagne.layers.Conv2DLayer(incoming=net['conv20_fire7_sqz'], num_filters=256, filter_size=1, stride=1, flip_filters=False, pad='same') + net['conv22_fire7_exp'] = lasagne.layers.Conv2DLayer(incoming=net['conv20_fire7_sqz'], num_filters=256, filter_size=3, stride=1, flip_filters=False, pad='same') + net['concat7'] = lasagne.layers.ConcatLayer((net['conv21_fire7_exp'],net['conv22_fire7_exp']),axis=1) + net['conv23_fire8_sqz'] = lasagne.layers.Conv2DLayer(incoming=net['concat7'], num_filters=64, filter_size=1, stride=1, flip_filters=False, pad='same') + net['conv24_fire8_exp'] = lasagne.layers.Conv2DLayer(incoming=net['conv23_fire8_sqz'], num_filters=256, filter_size=1, stride=1, flip_filters=False, pad='same') + net['conv25_fire8_exp'] = lasagne.layers.Conv2DLayer(incoming=net['conv23_fire8_sqz'], num_filters=256, filter_size=3, stride=1, flip_filters=False, pad='same') + net['concat8'] = lasagne.layers.ConcatLayer((net['conv24_fire8_exp'],net['conv25_fire8_exp']),axis=1) + net['conv26_fire9_sqz'] = lasagne.layers.Conv2DLayer(incoming=net['concat8'], num_filters=96, filter_size=1, stride=1, flip_filters=False, pad='same') + net['conv27_fire9_exp'] = lasagne.layers.Conv2DLayer(incoming=net['conv26_fire9_sqz'], num_filters=384, filter_size=1, stride=1, flip_filters=False, pad='same') + net['conv28_fire9_exp'] = lasagne.layers.Conv2DLayer(incoming=net['conv26_fire9_sqz'], num_filters=384, filter_size=3, stride=1, flip_filters=False, pad='same') + net['concat9'] = lasagne.layers.ConcatLayer((net['conv27_fire9_exp'],net['conv28_fire9_exp']),axis=1) + net['conv29_fire10_sqz'] = lasagne.layers.Conv2DLayer(incoming=net['concat9'], num_filters=96, filter_size=1, stride=1, flip_filters=False, pad='same') + net['conv30_fire10_exp'] = lasagne.layers.Conv2DLayer(incoming=net['conv29_fire10_sqz'], num_filters=384, filter_size=1, stride=1, flip_filters=False, pad='same') + net['conv31_fire10_exp'] = lasagne.layers.Conv2DLayer(incoming=net['conv29_fire10_sqz'], num_filters=384, filter_size=3, stride=1, flip_filters=False, pad='same') + net['concat10'] = lasagne.layers.ConcatLayer((net['conv30_fire10_exp'],net['conv31_fire10_exp']),axis=1) + net['conv32'] = lasagne.layers.Conv2DLayer(incoming=net['concat10'], num_filters=72, filter_size=3, stride=1, flip_filters=False, pad='same', nonlinearity=None) + + return net + +###################################################################################### +# run_sqz_det_net() +###################################################################################### +def run_sqz_det_net(par): + """ + Create a Lasagne network corresponding to SqueezeDet; load its weights from a .pkl file; + run it end-to-end and visualize bounding boxes. + """ + print('-'*5 + 'run_sqz_det_net()') + thn_x = T.tensor4('thn_x') # (batch, single-channel,h,w) + + verbose = par['verbose'] + + cfg_mc = kitti_squeezeDet_config.kitti_squeezeDet_config() # SqueezeDet: "Model config for pascal dataset" + + # for convenience/legibility, shorten 'pedestrian' -> 'ped' + cfg_mc.CLASS_NAMES = ('car', 'ped', 'cyclist') + + net = make_sqz_det_net(thn_x, par) + + network_weights = load_network_weights(par) + + weights_list = [] + param_idx = 0 # at end: 64 + for net_idx,layer_name in enumerate(net): + if 'conv' in layer_name: + if verbose: + print('net_idx:{}\tW:{} \tb:{} \t{}'.format(net_idx,network_weights[param_idx].shape,network_weights[param_idx+1].shape,layer_name)) + weights_list.append([network_weights[param_idx],network_weights[param_idx+1]]) # W,b + param_idx += 2 + else: + if verbose: print(' {} [skipped] {}'.format(net_idx,layer_name)) + final_layer_name = layer_name + + weights_list = [num for elem in weights_list for num in elem] # flatten list of lists + + lasagne.layers.set_all_param_values(net[final_layer_name], weights_list) + + # load image(s); we assume they are all the same size (but do not yet bother checking that) + img_list = par['img_in'] + mean_img_BGR = cfg_mc['BGR_MEANS'].astype(np.float32) + for (img_idx, img_str) in enumerate(img_list): + img = img_str.strip() + (img_RGB, img_lasagne) = load_img(img,mean_img_BGR,par) # shape of $SQDT_ROOT/data/sample.png (1, 3, 375, 1242) + if img_idx == 0: + img_RGB_list = [img_RGB] + img_lasagne_stack = img_lasagne + else: + img_RGB_list = img_RGB_list + [img_RGB] + img_lasagne_stack = np.concatenate((img_lasagne_stack,img_lasagne)) + + # run image(s) through network + net_out = np.array(lasagne.layers.get_output(net[final_layer_name], + img_lasagne_stack, + deterministic=True).eval()) + + # visualize bounding boxes, relative to original image + for (img_idx, img_str) in enumerate(img_list): + img_shape = img_RGB_list[img_idx].shape + (det_roi, det_label) = utils.get_det_roi(cfg_mc, img_shape, np.expand_dims(net_out[img_idx],axis=0), par) + + file_name = os.path.split(img_list[img_idx])[1] + base, ext = os.path.splitext(file_name) + out_file_name = os.path.join(par['out_dir'], 'out_'.format(img_idx+1)+base+'.pdf') + + viz_det_roi(img_RGB_list[img_idx]/255., det_roi, det_label, par, out_file_name, plot_title='img {} of {} ({})'.format(img_idx+1, len(img_list), file_name)) + + if not par['no_gui']: + print('pausing for examination of visualization(s)...') + keyboard() + + + +###################################################################################### +# get_valid_modes() +###################################################################################### +def get_valid_modes(): + valid_modes = ['run_sqz_det_net'] + return valid_modes + + + +###################################################################################### +# main() +###################################################################################### +def main(par): + mode=par['mode'] + if par['out_dir'] and not os.path.exists(par['out_dir']): + raise IOError('Specified --out_dir does not exist') + if mode == 'run_sqz_det_net': + try: + run_sqz_det_net(par) + except ValueError as e: + print('ValueError: {}'.format(e)) + except IOError as e: + print('IOError: {}: {}'.format(e.strerror,e.filename)) + +###################################################################################### +# __main__ +###################################################################################### +if __name__ == '__main__': + + def csv_list(string): + return string.split(',') + + parser = argparse.ArgumentParser() + parser.add_argument('--mode', default='run_sqz_det_net', choices=get_valid_modes()) + parser.add_argument('--network_weights', required=True, help='e.g., data/squeezedet_kitti.pkl.gz') + parser.add_argument('--img_in', type=csv_list, default='data/sample.png', help='/PATH/TO/INPUT/IMG or comma-separated list of paths') + parser.add_argument('--out_dir', default='.', help='dir path for output images') + parser.add_argument('--no_gui', action='store_true') + parser.add_argument('--verbose', action='store_true') + par=vars(parser.parse_args(sys.argv[1:])) + + main(par) diff --git a/papers/squeezedet/src/utils.py b/papers/squeezedet/src/utils.py new file mode 100644 index 0000000..bc7a719 --- /dev/null +++ b/papers/squeezedet/src/utils.py @@ -0,0 +1,395 @@ +""" +Copyright (c) 2017 Corvidim (corvidim.net) +Licensed under the BSD 2-Clause License (see LICENSE for details) +Authors: V. Ablavsky, A. J. Fox +""" + +from __future__ import (absolute_import, division, + print_function, unicode_literals) + +from pdb import set_trace as keyboard + +import os, platform + +import numpy as np + +import theano, theano.tensor as T + +###################################################################################### +# is_headless() +###################################################################################### +def is_headless(): + sysname = platform.system() + if sysname == 'Linux': + if os.getenv('DISPLAY') is None: + return True + else: + return False + +###################################################################################### +# av_elu +###################################################################################### +def av_elu(thn_x,t): + """ + analog of SqueezDet src/utils/util.py:safe_exp() + and inspired by lasagne.layers.elu() + https://github.com/Lasagne/Lasagne/commit/c4e3f81d6b1e6f7518b3efa4681e548f87b2fd72 + which, in turn, is a one-liner + theano.tensor.switch(x > 0, x, theano.tensor.exp(x) - 1) + + x (theano tensor): tensor to be transformed + t (floatX_t): threshold + """ + return T.switch(thn_x > t, T.exp(t)*(thn_x - t + 1),T.exp(thn_x)) + +###################################################################################### +# bbox_transform() +###################################################################################### +def bbox_transform(bbox): + """ + [analog of SqueezeDet src/utils/util.py:bbox_transform()] + convert a bbox of form [cx, cy, w, h] to [xmin, ymin, xmax, ymax]. + Works for numpy array or list of tensors. + """ + cx, cy, w, h = bbox + out_box = [[]]*4 + out_box[0] = cx-w/2 + out_box[1] = cy-h/2 + out_box[2] = cx+w/2 + out_box[3] = cy+h/2 + + return out_box + +###################################################################################### +# bbox_transform_inv() +###################################################################################### +def bbox_transform_inv(bbox): + """ + [analog of SqueezeDet src/utils/util.py:bbox_transform()] + convert a bbox of form [xmin, ymin, xmax, ymax] to [cx, cy, w, h]. + Works for numpy array or list of tensors. + """ + xmin, ymin, xmax, ymax = bbox + out_box = [[]]*4 + + #NB SqueezeDet assumed OpenCV ==> origin pixel (0,0) at center of upper-left pixel; PIL uses upper-left + width = xmax - xmin + 1.0 + height = ymax - ymin + 1.0 + out_box[0] = xmin + 0.5*width + out_box[1] = ymin + 0.5*height + out_box[2] = width + out_box[3] = height + + return out_box + +###################################################################################### +# filter_prediction() +###################################################################################### +def filter_prediction(mc, boxes, probs, cls_idx): + """ + [from SqueezeDet src/nn_skeleton.py] + + Filter bounding box predictions with probability threshold and + non-maximum supression. + + Args: + boxes: array of [cx, cy, w, h]. + probs: array of probabilities + cls_idx: array of class indices + Returns: + final_boxes: array of filtered bounding boxes. + final_probs: array of filtered probabilities + final_cls_idx: array of filtered class indices + """ + if mc.TOP_N_DETECTION < len(probs) and mc.TOP_N_DETECTION > 0: + order = probs.argsort()[:-mc.TOP_N_DETECTION-1:-1] + probs = probs[order] + boxes = boxes[order] + cls_idx = cls_idx[order] + else: + filtered_idx = np.nonzero(probs>mc.PROB_THRESH)[0] + probs = probs[filtered_idx] + boxes = boxes[filtered_idx] + cls_idx = cls_idx[filtered_idx] + + final_boxes = [] + final_probs = [] + final_cls_idx = [] + + for c in range(mc.CLASSES): + idx_per_class = [i for i in range(len(probs)) if cls_idx[i] == c] + keep = sqzdet_nms(boxes[idx_per_class], probs[idx_per_class], mc.NMS_THRESH) + for i in range(len(keep)): + if keep[i]: + final_boxes.append(boxes[idx_per_class[i]]) + final_probs.append(probs[idx_per_class[i]]) + final_cls_idx.append(c) + return final_boxes, final_probs, final_cls_idx + +###################################################################################### +# sqzdet_nms() +###################################################################################### +def sqzdet_nms(boxes, probs, threshold): + """ + [from SqueezeDet src/utils/util.py] + + Non-Maximum supression. + Args: + boxes: array of [cx, cy, w, h] (center format) + probs: array of probabilities + threshold: two boxes are considered overlapping if their IOU is largher than + this threshold + form: 'center' or 'diagonal' + Returns: + keep: array of True or False. + """ + + order = probs.argsort()[::-1] + keep = [True]*len(order) + + for i in range(len(order)-1): + ovps = batch_iou(boxes[order[i+1:]], boxes[order[i]]) + for j, ov in enumerate(ovps): + if ov > threshold: + keep[order[j+i+1]] = False + return keep + +###################################################################################### +# batch_iou() +###################################################################################### +def batch_iou(boxes, box): + """ + [from SqueezeDet src/utils/util.py] + + Compute the Intersection-Over-Union of a batch of boxes with another + box. + + Args: + box1: 2D array of [cx, cy, width, height]. + box2: a single array of [cx, cy, width, height] + Returns: + ious: array of a float number in range [0, 1]. + """ + lr = np.maximum( + np.minimum(boxes[:,0]+0.5*boxes[:,2], box[0]+0.5*box[2]) - \ + np.maximum(boxes[:,0]-0.5*boxes[:,2], box[0]-0.5*box[2]), + 0 + ) + tb = np.maximum( + np.minimum(boxes[:,1]+0.5*boxes[:,3], box[1]+0.5*box[3]) - \ + np.maximum(boxes[:,1]-0.5*boxes[:,3], box[1]-0.5*box[3]), + 0 + ) + inter = lr*tb + union = boxes[:,2]*boxes[:,3] + box[2]*box[3] - inter + return inter/union + +###################################################################################### +# set_anchors() +###################################################################################### +def set_anchors(img_h,img_w,H=22,W=76): + """ + generalized from kitti_squeezeDet_config, to support arbitrary-size input image + """ + B = 9 + anchor_shapes = np.reshape( + [np.array( + [[ 36., 37.], [ 366., 174.], [ 115., 59.], + [ 162., 87.], [ 38., 90.], [ 258., 173.], + [ 224., 108.], [ 78., 170.], [ 72., 43.]])] * H * W, + (H, W, B, 2) + ) + center_x = np.reshape( + np.transpose( + np.reshape( + np.array([np.arange(1, W+1)*float(img_w)/(W+1)]*H*B), + (B, H, W) + ), + (1, 2, 0) + ), + (H, W, B, 1) + ) + center_y = np.reshape( + np.transpose( + np.reshape( + np.array([np.arange(1, H+1)*float(img_h)/(H+1)]*W*B), + (B, W, H) + ), + (2, 1, 0) + ), + (H, W, B, 1) + ) + anchors = np.reshape( + np.concatenate((center_x, center_y, anchor_shapes), axis=3), + (-1, 4) + ) + + return anchors + + +###################################################################################### +# get_det_roi() +###################################################################################### +def get_det_roi(cfg_mc, img_shape, net_out, par): + """ + net_out: a numpy tensor from lasagne-net final layer, + e.g., net_out.shape == (1, 72, 22, 76) + + The second axis, i.e., net_out[0,:,r,c] + contains (a) class-conditional and marginal probabilities, + (b) confidence scores, and (c) anchor-deformations for all the + anchors at anchor-site (i,j). + + The code in this function parses the second axis in order (a), (b), and (c) + to extract these probabilities and confidence scores, and compute + the ROIs given anchor-deformations. It then runs non-max-suppression, + and then returns the most confidence ROIs that are left. + + + NB: much of the code below follows SqueezeDet src/nn_skeleton.py + and retains tf (tensorflow) order, as opposed to th (Theano/Lasagne): + + tf: (rows, columns, channels, filters) aka (h, w, channels, filters) + thn: (filters, channels, rows, columns) aka (filters, channels, h, w) + """ + + n_class_probs = cfg_mc.ANCHOR_PER_GRID * cfg_mc.CLASSES # 27 + + # for each anchor, the network predicts "confidence score" (a scalar) + # n_conf_scores = n_anchor_per_grid * (1+n_class) + # = n_anchor_per_grid + n_class_prob + n_conf_scores = cfg_mc.ANCHOR_PER_GRID + + # analog of pred_class_probs from nn_skeleton.py:147 + """ + self.pred_class_probs = tf.reshape( + tf.nn.softmax( + tf.reshape( + preds[:, :, :, :num_class_probs], + [-1, mc.CLASSES] + ) + ), + [mc.BATCH_SIZE, mc.ANCHORS, mc.CLASSES], + name='pred_class_probs' + ) + """ + + idx_begin_probs = 0 + idx_end_probs = n_class_probs + idx_begin_conf = idx_end_probs + idx_end_conf = idx_begin_conf + n_conf_scores + idx_begin_pred_box_delta = idx_end_conf + # end index will equal net_out.shape[1] + idx_end_pred_box_delta = idx_begin_pred_box_delta + 4*cfg_mc.ANCHOR_PER_GRID + + net_out_tf_order = np.transpose(net_out,(2,3,0,1)) # (22, 76, 1, 72 + class_probs0 = net_out_tf_order[:, :, :, idx_begin_probs:idx_end_probs] # (22, 76, 1, 27) + class_probs1 = np.reshape(class_probs0, (-1, cfg_mc.CLASSES)) # (15048, 3) + + thn_x = T.dmatrix('thn_x') + thn_y1 = T.nnet.softmax(thn_x) + f_softmax = theano.function([thn_x],thn_y1) + class_probs2 = f_softmax(class_probs1) # (15048, 3) + n_anchors = net_out.shape[2] * net_out.shape[3] * n_conf_scores # net_out.shape (1, 72, 22, 76) + class_probs3 = np.reshape(class_probs2, + (1, n_anchors, cfg_mc.CLASSES)) + + #analog of 158: #confidence + """ + # confidence + num_confidence_scores = mc.ANCHOR_PER_GRID+num_class_probs + self.pred_conf = tf.sigmoid( + tf.reshape( + preds[:, :, :, num_class_probs:num_confidence_scores], + [mc.BATCH_SIZE, mc.ANCHORS] + ), + name='pred_confidence_score' + ) + """ + + conf_scores0 = net_out_tf_order[:, :, :, idx_begin_conf:idx_end_conf] # (22, 76, 1, 9) + conf_scores1 = np.reshape(conf_scores0, (1, n_anchors)) # (1, 15048) + thn_y2 = T.nnet.sigmoid(thn_x) + f_sigmoid = theano.function([thn_x],thn_y2) + conf_scores2 = f_sigmoid(conf_scores1) # (1, 15048) + + # analog of 267: with tf.variable_scope('probability') as scope: + """ + probs = tf.multiply( + self.pred_class_probs, + tf.reshape(self.pred_conf, [mc.BATCH_SIZE, mc.ANCHORS, 1]), + name='final_class_prob' + ) + + 278: self.det_probs = tf.reduce_max(probs, 2, name='score') + + 279: self.det_class = tf.argmax(probs, 2, name='class_idx') + """ + + + conf_scores3 = np.reshape(conf_scores2, (1, n_anchors, 1)) # (1, 15048, 1) + class_probs4 = np.multiply(class_probs3, conf_scores3) # (1, 15048, 3) + class_probs_max = np.max(class_probs4, axis=2).squeeze() # (15048,) + class_decisions = np.argmax(class_probs4, axis=2).squeeze() # (1, 15048) + + #analog of pred_box_delta from nn_skeleton.py:169 ("bbox_delta") + net_out_slice_box = net_out_tf_order[:, :, :, idx_end_conf:] # (22, 76, 1, 36) + pred_box_delta = np.reshape(net_out_slice_box, (1, n_anchors, 4)) # (1, 15048, 4) + + # analog of box_center_* from nn_skeleton.py:180ff + # conceptually, unpack pred_box_delta of shape (1, 15048, 4) + # into 4-tuple, each of shape (1,15048) delta_x, delta_y, delta_w, delta_h + + img_h = img_shape[0] + img_w = img_shape[1] + anchor_box = set_anchors(img_h,img_w,H=net_out.shape[2],W=net_out.shape[3]) + + anchor_x = anchor_box[:, 0] # (15048,) + anchor_y = anchor_box[:, 1] + anchor_w = anchor_box[:, 2] + anchor_h = anchor_box[:, 3] + + # analog of 179: with tf.variable_scope('stretching'): + delta_x = pred_box_delta.squeeze()[:,0] + delta_y = pred_box_delta.squeeze()[:,1] + delta_w = pred_box_delta.squeeze()[:,2] + delta_h = pred_box_delta.squeeze()[:,3] + box_center_x = anchor_x + delta_x * anchor_w + box_center_y = anchor_y + delta_y * anchor_h + box_width = anchor_w * av_elu(delta_w, cfg_mc.EXP_THRESH).eval() + box_height = anchor_h * av_elu(delta_h, cfg_mc.EXP_THRESH).eval() + + # analog of 209: with tf.variable_scope('trimming'): + xmins, ymins, xmaxs, ymaxs = bbox_transform([box_center_x, box_center_y, box_width, box_height]) + + xmins = np.minimum(np.maximum(0.0, xmins), img_w-1.0) + ymins = np.minimum(np.maximum(0.0, ymins), img_h-1.0) + xmax = np.maximum(np.minimum(img_w-1.0, xmaxs), 0.0) + ymax = np.maximum(np.minimum(img_w-1.0, ymaxs), 0.0) + + box_centers_ = bbox_transform_inv([xmins, ymins, xmaxs, ymaxs]) # tf: each is TensorShape([Dimension(1), Dimension(15048)]) + box_centers_stacked = np.stack(box_centers_) # (4, 15048) tf: TensorShape([Dimension(4), Dimension(1), Dimension(15048)]) + det_roi_ = np.transpose(box_centers_stacked) # (15048, 4) tf: transpose(...,(1, 2, 0)) => tf: (1, 15048, 4) + + + t_ = filter_prediction(cfg_mc, det_roi_, class_probs_max, class_decisions) + (filtered_roi, filtered_probs, filtered_class) = t_ + + keep_idx = [idx for idx in range(len(filtered_probs)) \ + if filtered_probs[idx] > cfg_mc.PLOT_PROB_THRESH] # list len 16 + filtered_roi = [filtered_roi[idx] for idx in keep_idx] # list len 16 + filtered_probs = [filtered_probs[idx] for idx in keep_idx] + filtered_class = [filtered_class[idx] for idx in keep_idx] + + detection_info = [cfg_mc.CLASS_NAMES[idx]+': {:.2f}'.format(prob) \ + for idx, prob in zip(filtered_class, filtered_probs)] + + if par['verbose']: + print('len(filtered_roi): {0} len(filtered_probs): ' + '{1} cfg_mc.PLOT_PROB_THRESH: {2} cfg_mc.NMS_THRESH: {3}'.format( + len(filtered_roi), + len(filtered_probs), + cfg_mc.PLOT_PROB_THRESH, + cfg_mc.NMS_THRESH)) + + return (filtered_roi, detection_info)