From 67915e659d2097a96c82ba7740b9e43a8c69858d Mon Sep 17 00:00:00 2001 From: Taehoon Lee Date: Mon, 26 Mar 2018 21:16:33 +0900 Subject: [PATCH] Add generic object detection models --- tensornets/__init__.py | 9 +++ tensornets/darknets.py | 95 ++++++++++++++++++++++ tensornets/detections.py | 167 +++++++++++++++++++++++++++++++++++++++ tensornets/preprocess.py | 3 + tensornets/pretrained.py | 22 +++++- tensornets/utils.py | 16 ++++ tensornets/zf.py | 65 +++++++++++++++ 7 files changed, 375 insertions(+), 2 deletions(-) create mode 100644 tensornets/darknets.py create mode 100644 tensornets/detections.py create mode 100644 tensornets/zf.py diff --git a/tensornets/__init__.py b/tensornets/__init__.py index 74ff19b..8aae690 100644 --- a/tensornets/__init__.py +++ b/tensornets/__init__.py @@ -48,6 +48,15 @@ from .references import FasterRCNN_ZF_VOC from .references import FasterRCNN_VGG16_VOC +from .darknets import Darknet19 +from .darknets import TinyDarknet19 + +from .zf import ZF + +from .detections import YOLOv2 +from .detections import TinyYOLOv2 +from .detections import FasterRCNN + from .preprocess import preprocess from .pretrained import assign as pretrained diff --git a/tensornets/darknets.py b/tensornets/darknets.py new file mode 100644 index 0000000..90f990b --- /dev/null +++ b/tensornets/darknets.py @@ -0,0 +1,95 @@ +"""Darknet19 embedded in YOLO + +The reference paper: + + - YOLO9000: Better, Faster, Stronger, CVPR 2017 (Best Paper Honorable Mention) + - Joseph Redmon, Ali Farhadi + - https://arxiv.org/abs/1612.08242 + +The reference implementation: + +1. Darknet + - https://pjreddie.com/darknet/yolo/ +""" +from __future__ import absolute_import +from __future__ import division + +import tensorflow as tf + +from .layers import batch_norm +from .layers import bias_add +from .layers import conv2d +from .layers import darkconv as conv +from .layers import max_pool2d as pool + +from .ops import * +from .utils import set_args +from .utils import var_scope + + +def __args__(is_training): + return [([batch_norm], {'is_training': is_training}), + ([bias_add, conv2d], {}), + ([pool], {'padding': 'SAME'})] + + +@var_scope('stack') +def _stack(x, filters, blocks, scope=None): + for i in range(1, blocks+1): + if i % 2 > 0: + x = conv(x, filters, 3, scope=str(i)) + else: + x = conv(x, filters // 2, 1, scope=str(i)) + return x + + +@var_scope('darknet19') +@set_args(__args__) +def darknet19(x, is_training=False, classes=1000, + stem=False, scope=None, reuse=None): + x = _stack(x, 32, 1, scope='conv1') + x = pool(x, 2, stride=2, scope='pool1') + x = _stack(x, 64, 1, scope='conv2') + x = pool(x, 2, stride=2, scope='pool2') + x = _stack(x, 128, 3, scope='conv3') + x = pool(x, 2, stride=2, scope='pool3') + x = _stack(x, 256, 3, scope='conv4') + x = pool(x, 2, stride=2, scope='pool4') + x = p = _stack(x, 512, 5, scope='conv5') + x = pool(x, 2, stride=2, scope='pool5') + x = _stack(x, 1024, 5, scope='conv6') + x.p = p + if stem: return x + + x = reduce_mean(x, [1, 2], name='avgpool') + x = fc(x, classes, scope='logits') + x = softmax(x, name='probs') + return x + + +@var_scope('tinydarknet19') +@set_args(__args__) +def tinydarknet19(x, is_training=False, classes=1000, + stem=False, scope=None, reuse=None): + x = conv(x, 16, 3, scope='conv1') + x = pool(x, 2, stride=2, scope='pool1') + x = conv(x, 32, 3, scope='conv2') + x = pool(x, 2, stride=2, scope='pool2') + x = conv(x, 64, 3, scope='conv3') + x = pool(x, 2, stride=2, scope='pool3') + x = conv(x, 128, 3, scope='conv4') + x = pool(x, 2, stride=2, scope='pool4') + x = conv(x, 256, 3, scope='conv5') + x = pool(x, 2, stride=2, scope='pool5') + x = conv(x, 512, 3, scope='conv6') + if stem: return x + + x = reduce_mean(x, [1, 2], name='avgpool') + x = fc(x, classes, scope='logits') + x = softmax(x, name='probs') + return x + + +# Simple alias. +Darknet19 = darknet19 +TinyDarknet19 = tinydarknet19 diff --git a/tensornets/detections.py b/tensornets/detections.py new file mode 100644 index 0000000..f7618ec --- /dev/null +++ b/tensornets/detections.py @@ -0,0 +1,167 @@ +"""Collection of generic object detection models + +The reference papers: + +1. YOLOv2 + - YOLO9000: Better, Faster, Stronger, CVPR 2017 (Best Paper Honorable Mention) + - Joseph Redmon, Ali Farhadi + - https://arxiv.org/abs/1612.08242 +2. Faster R-CNN + - Faster R-CNN: Towards Real-Time Object Detection + with Region Proposal Networks, NIPS 2015 + - Shaoqing Ren, Kaiming He, Ross Girshick, Jian Sun + - https://arxiv.org/abs/1506.01497 + +The reference implementations: + +1. Darknet + - https://pjreddie.com/darknet/yolo/ +2. darkflow + - https://github.com/thtrieu/darkflow +3. Caffe and Python utils + - https://github.com/rbgirshick/py-faster-rcnn +4. RoI pooling in TensorFlow + - https://github.com/deepsense-ai/roi-pooling +""" +from __future__ import absolute_import +from __future__ import division + +import tensorflow as tf + +from .layers import batch_norm +from .layers import bias_add +from .layers import conv2d +from .layers import darkconv +from .layers import dropout +from .layers import flatten +from .layers import fc +from .layers import max_pool2d + +from .ops import * +from .utils import remove_head +from .utils import set_args +from .utils import var_scope + +from .references.yolos import get_boxes as yolo_boxes +from .references.yolos import local_flatten +from .references.yolos import opts +from .references.rcnns import get_boxes as rcnn_boxes +from .references.rcnns import roi_pool2d +from .references.rcnns import rp_net + + +def __args_yolo__(is_training): + return [([batch_norm], {'is_training': is_training}), + ([bias_add, conv2d], {}), + ([max_pool2d], {'padding': 'SAME'})] + + +def __args_rcnn__(is_training): + return [([conv2d], {'activation_fn': None, 'scope': 'conv'}), + ([dropout], {'is_training': is_training}), + ([fc], {'activation_fn': None, 'scope': 'fc'})] + + +@var_scope('genYOLOv2') +@set_args(__args_yolo__) +def yolov2(x, stem_fn, stem_out=None, is_training=False, classes=21, + scope=None, reuse=None): + def get_boxes(*args, **kwargs): + return yolo_boxes(opts('yolov2' + data_name(classes)), *args, **kwargs) + + x = stem_fn(x, is_training, stem=True, scope='stem') + p = x.p + stem_name = x.model_name + + if stem_out is not None: + x = remove_head(stem_out) + + x = darkconv(x, 1024, 3, scope='conv7') + x = darkconv(x, 1024, 3, scope='conv8') + + p = darkconv(p, 64, 1, scope='conv5a') + p = local_flatten(p, scope='flat5a') + + x = concat([p, x], axis=3, name='concat') + x = darkconv(x, 1024, 3, scope='conv9') + x = darkconv(x, 125 if classes == 21 else 425, 1, + onlyconv=True, scope='linear') + x.aliases = [] + x.get_boxes = get_boxes + x.stem_name = stem_name + return x + + +def data_name(classes): + return 'voc' if classes == 21 else '' + + +@var_scope('genTinyYOLOv2') +@set_args(__args_yolo__) +def tinyyolov2(x, stem_fn, stem_out=None, is_training=False, classes=21, + scope=None, reuse=None): + def get_boxes(*args, **kwargs): + return yolo_boxes(opts('tinyyolov2' + data_name(classes)), + *args, **kwargs) + + x = stem_fn(x, is_training, stem=True, scope='stem') + stem_name = x.model_name + + if stem_out is not None: + x = remove_head(stem_out) + + x = max_pool2d(x, 2, stride=1, scope='pool6') + x = darkconv(x, 1024, 3, scope='conv7') + x = darkconv(x, 1024 if classes == 21 else 512, 3, scope='conv8') + x = darkconv(x, 125 if classes == 21 else 425, 1, + onlyconv=True, scope='linear') + x.aliases = [] + x.get_boxes = get_boxes + x.stem_name = stem_name + return x + + +@var_scope('genFasterRCNN') +@set_args(__args_rcnn__) +def fasterrcnn(x, stem_fn, stem_out=None, is_training=False, classes=21, + scope=None, reuse=None): + def roi_pool_fn(x, filters, kernel_size): + rois = rp_net(x, filters, height, width, scales) + x = roi_pool2d(x, kernel_size, rois) + return x, rois[0] / scales + + scales = tf.placeholder(tf.float32, [None]) + height = tf.cast(tf.shape(x)[1], dtype=tf.float32) + width = tf.cast(tf.shape(x)[2], dtype=tf.float32) + + x = stem_fn(x, is_training, stem=True, scope='stem') + stem_name = x.model_name + + if stem_out is not None: + x = remove_head(stem_out) + + if 'zf' in stem_name: + x, rois = roi_pool_fn(x, 256, 6) + else: + x, rois = roi_pool_fn(x, 512, 7) + + x = flatten(x) + x = fc(x, 4096, scope='fc6') + x = relu(x, name='relu6') + x = dropout(x, keep_prob=0.5, scope='drop6') + x = fc(x, 4096, scope='fc7') + x = relu(x, name='relu7') + x = dropout(x, keep_prob=0.5, scope='drop7') + x = concat([softmax(fc(x, classes, scope='logits'), name='probs'), + fc(x, 4 * classes, scope='boxes'), + rois], axis=1, name='out') + x.get_boxes = rcnn_boxes + x.stem_name = stem_name + x.scales = scales + return x + + +# Simple alias. +YOLOv2 = yolov2 +TinyYOLOv2 = tinyyolov2 +FasterRCNN = fasterrcnn diff --git a/tensornets/preprocess.py b/tensornets/preprocess.py index 14e1489..90eca54 100644 --- a/tensornets/preprocess.py +++ b/tensornets/preprocess.py @@ -171,4 +171,7 @@ def faster_rcnn_preprocess(x): 'REFtinyyolov2voc': darknet_preprocess, 'REFfasterrcnnZFvoc': faster_rcnn_preprocess, 'REFfasterrcnnVGG16voc': faster_rcnn_preprocess, + 'genYOLOv2': darknet_preprocess, + 'genTinyYOLOv2': darknet_preprocess, + 'genFasterRCNN': faster_rcnn_preprocess, } diff --git a/tensornets/pretrained.py b/tensornets/pretrained.py index 09a1b33..e2d63e7 100644 --- a/tensornets/pretrained.py +++ b/tensornets/pretrained.py @@ -74,9 +74,24 @@ def assign(scopes): def direct(model_name, scope): + if model_name.startswith('gen'): + fun = load_nothing + if 'FasterRCNN' in model_name: + if 'vgg16' in scope.stem_name: + fun = load_ref_faster_rcnn_vgg16_voc + elif 'zf' in scope.stem_name: + fun = load_ref_faster_rcnn_zf_voc + elif 'TinyYOLOv2' in model_name: + if 'tinydarknet19' in scope.stem_name: + fun = load_ref_tiny_yolo_v2_voc + elif 'YOLOv2' in model_name: + if 'darknet19' in scope.stem_name: + fun = load_ref_yolo_v2_voc + else: + fun = __load_dict__[model_name] + def _direct(): - return __load_dict__[model_name](scope, - return_fn=pretrained_initializer) + return fun(scope, return_fn=pretrained_initializer) return _direct @@ -652,6 +667,9 @@ def load_ref_faster_rcnn_vgg16_voc(scopes, return_fn=_assign): 'mobilenet75': load_mobilenet75, 'mobilenet100': load_mobilenet100, 'squeezenet': load_squeezenet, + 'zf': load_nothing, + 'darknet19': load_nothing, + 'tinydarknet19': load_nothing, 'REFyolov2': load_ref_yolo_v2, 'REFyolov2voc': load_ref_yolo_v2_voc, 'REFtinyyolov2voc': load_ref_tiny_yolo_v2_voc, diff --git a/tensornets/utils.py b/tensornets/utils.py index 727e8aa..f3d8582 100644 --- a/tensornets/utils.py +++ b/tensornets/utils.py @@ -382,6 +382,22 @@ def parse_torch_weights(weights_path, move_rules=None): return values +def remove_head(name): + _scope = "%s/stem" % tf.get_variable_scope().name + g = tf.get_default_graph() + for x in g.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, + scope=_scope)[::-1]: + if name in x.name: + break + g.get_collection_ref(tf.GraphKeys.GLOBAL_VARIABLES).pop() + + for x in g.get_collection(__outputs__, scope=_scope)[::-1]: + if name in x.name: + break + g.get_collection_ref(__outputs__).pop() + return x + + def remove_utils(module_name, exceptions): import sys from . import utils diff --git a/tensornets/zf.py b/tensornets/zf.py new file mode 100644 index 0000000..fb38a7d --- /dev/null +++ b/tensornets/zf.py @@ -0,0 +1,65 @@ +"""ZF net embedded in Faster RCNN + +The reference paper: + + - Faster R-CNN: Towards Real-Time Object Detection + with Region Proposal Networks, NIPS 2015 + - Shaoqing Ren, Kaiming He, Ross Girshick, Jian Sun + - https://arxiv.org/abs/1506.01497 + +The reference implementation: + +1. Caffe and Python utils + - https://github.com/rbgirshick/py-faster-rcnn +""" +from __future__ import absolute_import +from __future__ import division + +import tensorflow as tf + +from .layers import conv2d +from .layers import fc +from .layers import max_pool2d +from .layers import convrelu as conv + +from .ops import * +from .utils import pad_info +from .utils import set_args +from .utils import var_scope + + +def __args__(is_training): + return [([conv2d], {'padding': 'SAME', 'activation_fn': None, + 'scope': 'conv'}), + ([fc], {'activation_fn': None, 'scope': 'fc'}), + ([max_pool2d], {'scope': 'pool'})] + + +@var_scope('zf') +@set_args(__args__) +def zf(x, is_training=False, classes=1000, stem=False, scope=None, reuse=None): + x = pad(x, pad_info(7), name='pad1') + x = conv(x, 96, 7, stride=2, padding='VALID', scope='conv1') + x = srn(x, depth_radius=3, alpha=0.00005, beta=0.75, name='srn1') + x = pad(x, pad_info(3, symmetry=False), name='pad2') + x = max_pool2d(x, 3, stride=2, padding='VALID', scope='pool1') + + x = pad(x, pad_info(5), name='pad3') + x = conv(x, 256, 5, stride=2, padding='VALID', scope='conv2') + x = srn(x, depth_radius=3, alpha=0.00005, beta=0.75, name='srn2') + x = pad(x, pad_info(3, symmetry=False), name='pad4') + x = max_pool2d(x, 3, stride=2, padding='VALID', scope='pool2') + + x = conv(x, 384, 3, scope='conv3') + x = conv(x, 384, 3, scope='conv4') + x = conv(x, 256, 3, scope='conv5') + if stem: return x + + x = reduce_mean(x, [1, 2], name='avgpool') + x = fc(x, classes, scope='logits') + x = softmax(x, name='probs') + return x + + +# Simple alias. +ZF = zf