Skip to content

Commit

Permalink
Add generic object detection models
Browse files Browse the repository at this point in the history
  • Loading branch information
taehoonlee committed Mar 26, 2018
1 parent 0f6f3ec commit 67915e6
Show file tree
Hide file tree
Showing 7 changed files with 375 additions and 2 deletions.
9 changes: 9 additions & 0 deletions tensornets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,15 @@
from .references import FasterRCNN_ZF_VOC
from .references import FasterRCNN_VGG16_VOC

from .darknets import Darknet19
from .darknets import TinyDarknet19

from .zf import ZF

from .detections import YOLOv2
from .detections import TinyYOLOv2
from .detections import FasterRCNN

from .preprocess import preprocess
from .pretrained import assign as pretrained

Expand Down
95 changes: 95 additions & 0 deletions tensornets/darknets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
"""Darknet19 embedded in YOLO
The reference paper:
- YOLO9000: Better, Faster, Stronger, CVPR 2017 (Best Paper Honorable Mention)
- Joseph Redmon, Ali Farhadi
- https://arxiv.org/abs/1612.08242
The reference implementation:
1. Darknet
- https://pjreddie.com/darknet/yolo/
"""
from __future__ import absolute_import
from __future__ import division

import tensorflow as tf

from .layers import batch_norm
from .layers import bias_add
from .layers import conv2d
from .layers import darkconv as conv
from .layers import max_pool2d as pool

from .ops import *
from .utils import set_args
from .utils import var_scope


def __args__(is_training):
return [([batch_norm], {'is_training': is_training}),
([bias_add, conv2d], {}),
([pool], {'padding': 'SAME'})]


@var_scope('stack')
def _stack(x, filters, blocks, scope=None):
for i in range(1, blocks+1):
if i % 2 > 0:
x = conv(x, filters, 3, scope=str(i))
else:
x = conv(x, filters // 2, 1, scope=str(i))
return x


@var_scope('darknet19')
@set_args(__args__)
def darknet19(x, is_training=False, classes=1000,
stem=False, scope=None, reuse=None):
x = _stack(x, 32, 1, scope='conv1')
x = pool(x, 2, stride=2, scope='pool1')
x = _stack(x, 64, 1, scope='conv2')
x = pool(x, 2, stride=2, scope='pool2')
x = _stack(x, 128, 3, scope='conv3')
x = pool(x, 2, stride=2, scope='pool3')
x = _stack(x, 256, 3, scope='conv4')
x = pool(x, 2, stride=2, scope='pool4')
x = p = _stack(x, 512, 5, scope='conv5')
x = pool(x, 2, stride=2, scope='pool5')
x = _stack(x, 1024, 5, scope='conv6')
x.p = p
if stem: return x

x = reduce_mean(x, [1, 2], name='avgpool')
x = fc(x, classes, scope='logits')
x = softmax(x, name='probs')
return x


@var_scope('tinydarknet19')
@set_args(__args__)
def tinydarknet19(x, is_training=False, classes=1000,
stem=False, scope=None, reuse=None):
x = conv(x, 16, 3, scope='conv1')
x = pool(x, 2, stride=2, scope='pool1')
x = conv(x, 32, 3, scope='conv2')
x = pool(x, 2, stride=2, scope='pool2')
x = conv(x, 64, 3, scope='conv3')
x = pool(x, 2, stride=2, scope='pool3')
x = conv(x, 128, 3, scope='conv4')
x = pool(x, 2, stride=2, scope='pool4')
x = conv(x, 256, 3, scope='conv5')
x = pool(x, 2, stride=2, scope='pool5')
x = conv(x, 512, 3, scope='conv6')
if stem: return x

x = reduce_mean(x, [1, 2], name='avgpool')
x = fc(x, classes, scope='logits')
x = softmax(x, name='probs')
return x


# Simple alias.
Darknet19 = darknet19
TinyDarknet19 = tinydarknet19
167 changes: 167 additions & 0 deletions tensornets/detections.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
"""Collection of generic object detection models
The reference papers:
1. YOLOv2
- YOLO9000: Better, Faster, Stronger, CVPR 2017 (Best Paper Honorable Mention)
- Joseph Redmon, Ali Farhadi
- https://arxiv.org/abs/1612.08242
2. Faster R-CNN
- Faster R-CNN: Towards Real-Time Object Detection
with Region Proposal Networks, NIPS 2015
- Shaoqing Ren, Kaiming He, Ross Girshick, Jian Sun
- https://arxiv.org/abs/1506.01497
The reference implementations:
1. Darknet
- https://pjreddie.com/darknet/yolo/
2. darkflow
- https://github.com/thtrieu/darkflow
3. Caffe and Python utils
- https://github.com/rbgirshick/py-faster-rcnn
4. RoI pooling in TensorFlow
- https://github.com/deepsense-ai/roi-pooling
"""
from __future__ import absolute_import
from __future__ import division

import tensorflow as tf

from .layers import batch_norm
from .layers import bias_add
from .layers import conv2d
from .layers import darkconv
from .layers import dropout
from .layers import flatten
from .layers import fc
from .layers import max_pool2d

from .ops import *
from .utils import remove_head
from .utils import set_args
from .utils import var_scope

from .references.yolos import get_boxes as yolo_boxes
from .references.yolos import local_flatten
from .references.yolos import opts
from .references.rcnns import get_boxes as rcnn_boxes
from .references.rcnns import roi_pool2d
from .references.rcnns import rp_net


def __args_yolo__(is_training):
return [([batch_norm], {'is_training': is_training}),
([bias_add, conv2d], {}),
([max_pool2d], {'padding': 'SAME'})]


def __args_rcnn__(is_training):
return [([conv2d], {'activation_fn': None, 'scope': 'conv'}),
([dropout], {'is_training': is_training}),
([fc], {'activation_fn': None, 'scope': 'fc'})]


@var_scope('genYOLOv2')
@set_args(__args_yolo__)
def yolov2(x, stem_fn, stem_out=None, is_training=False, classes=21,
scope=None, reuse=None):
def get_boxes(*args, **kwargs):
return yolo_boxes(opts('yolov2' + data_name(classes)), *args, **kwargs)

x = stem_fn(x, is_training, stem=True, scope='stem')
p = x.p
stem_name = x.model_name

if stem_out is not None:
x = remove_head(stem_out)

x = darkconv(x, 1024, 3, scope='conv7')
x = darkconv(x, 1024, 3, scope='conv8')

p = darkconv(p, 64, 1, scope='conv5a')
p = local_flatten(p, scope='flat5a')

x = concat([p, x], axis=3, name='concat')
x = darkconv(x, 1024, 3, scope='conv9')
x = darkconv(x, 125 if classes == 21 else 425, 1,
onlyconv=True, scope='linear')
x.aliases = []
x.get_boxes = get_boxes
x.stem_name = stem_name
return x


def data_name(classes):
return 'voc' if classes == 21 else ''


@var_scope('genTinyYOLOv2')
@set_args(__args_yolo__)
def tinyyolov2(x, stem_fn, stem_out=None, is_training=False, classes=21,
scope=None, reuse=None):
def get_boxes(*args, **kwargs):
return yolo_boxes(opts('tinyyolov2' + data_name(classes)),
*args, **kwargs)

x = stem_fn(x, is_training, stem=True, scope='stem')
stem_name = x.model_name

if stem_out is not None:
x = remove_head(stem_out)

x = max_pool2d(x, 2, stride=1, scope='pool6')
x = darkconv(x, 1024, 3, scope='conv7')
x = darkconv(x, 1024 if classes == 21 else 512, 3, scope='conv8')
x = darkconv(x, 125 if classes == 21 else 425, 1,
onlyconv=True, scope='linear')
x.aliases = []
x.get_boxes = get_boxes
x.stem_name = stem_name
return x


@var_scope('genFasterRCNN')
@set_args(__args_rcnn__)
def fasterrcnn(x, stem_fn, stem_out=None, is_training=False, classes=21,
scope=None, reuse=None):
def roi_pool_fn(x, filters, kernel_size):
rois = rp_net(x, filters, height, width, scales)
x = roi_pool2d(x, kernel_size, rois)
return x, rois[0] / scales

scales = tf.placeholder(tf.float32, [None])
height = tf.cast(tf.shape(x)[1], dtype=tf.float32)
width = tf.cast(tf.shape(x)[2], dtype=tf.float32)

x = stem_fn(x, is_training, stem=True, scope='stem')
stem_name = x.model_name

if stem_out is not None:
x = remove_head(stem_out)

if 'zf' in stem_name:
x, rois = roi_pool_fn(x, 256, 6)
else:
x, rois = roi_pool_fn(x, 512, 7)

x = flatten(x)
x = fc(x, 4096, scope='fc6')
x = relu(x, name='relu6')
x = dropout(x, keep_prob=0.5, scope='drop6')
x = fc(x, 4096, scope='fc7')
x = relu(x, name='relu7')
x = dropout(x, keep_prob=0.5, scope='drop7')
x = concat([softmax(fc(x, classes, scope='logits'), name='probs'),
fc(x, 4 * classes, scope='boxes'),
rois], axis=1, name='out')
x.get_boxes = rcnn_boxes
x.stem_name = stem_name
x.scales = scales
return x


# Simple alias.
YOLOv2 = yolov2
TinyYOLOv2 = tinyyolov2
FasterRCNN = fasterrcnn
3 changes: 3 additions & 0 deletions tensornets/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,4 +171,7 @@ def faster_rcnn_preprocess(x):
'REFtinyyolov2voc': darknet_preprocess,
'REFfasterrcnnZFvoc': faster_rcnn_preprocess,
'REFfasterrcnnVGG16voc': faster_rcnn_preprocess,
'genYOLOv2': darknet_preprocess,
'genTinyYOLOv2': darknet_preprocess,
'genFasterRCNN': faster_rcnn_preprocess,
}
22 changes: 20 additions & 2 deletions tensornets/pretrained.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,24 @@ def assign(scopes):


def direct(model_name, scope):
if model_name.startswith('gen'):
fun = load_nothing
if 'FasterRCNN' in model_name:
if 'vgg16' in scope.stem_name:
fun = load_ref_faster_rcnn_vgg16_voc
elif 'zf' in scope.stem_name:
fun = load_ref_faster_rcnn_zf_voc
elif 'TinyYOLOv2' in model_name:
if 'tinydarknet19' in scope.stem_name:
fun = load_ref_tiny_yolo_v2_voc
elif 'YOLOv2' in model_name:
if 'darknet19' in scope.stem_name:
fun = load_ref_yolo_v2_voc
else:
fun = __load_dict__[model_name]

def _direct():
return __load_dict__[model_name](scope,
return_fn=pretrained_initializer)
return fun(scope, return_fn=pretrained_initializer)
return _direct


Expand Down Expand Up @@ -652,6 +667,9 @@ def load_ref_faster_rcnn_vgg16_voc(scopes, return_fn=_assign):
'mobilenet75': load_mobilenet75,
'mobilenet100': load_mobilenet100,
'squeezenet': load_squeezenet,
'zf': load_nothing,
'darknet19': load_nothing,
'tinydarknet19': load_nothing,
'REFyolov2': load_ref_yolo_v2,
'REFyolov2voc': load_ref_yolo_v2_voc,
'REFtinyyolov2voc': load_ref_tiny_yolo_v2_voc,
Expand Down
16 changes: 16 additions & 0 deletions tensornets/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,22 @@ def parse_torch_weights(weights_path, move_rules=None):
return values


def remove_head(name):
_scope = "%s/stem" % tf.get_variable_scope().name
g = tf.get_default_graph()
for x in g.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,
scope=_scope)[::-1]:
if name in x.name:
break
g.get_collection_ref(tf.GraphKeys.GLOBAL_VARIABLES).pop()

for x in g.get_collection(__outputs__, scope=_scope)[::-1]:
if name in x.name:
break
g.get_collection_ref(__outputs__).pop()
return x


def remove_utils(module_name, exceptions):
import sys
from . import utils
Expand Down
Loading

0 comments on commit 67915e6

Please sign in to comment.