diff --git a/docs/frontend/tensorflow.md b/docs/frontend/tensorflow.md index acafbb5bb93e..d47923bdd938 100644 --- a/docs/frontend/tensorflow.md +++ b/docs/frontend/tensorflow.md @@ -21,7 +21,7 @@ instructions to generate protobuf from checkpoint. ### Add Shapes: While freezing of protobuf add additional option ```add_shapes=True``` to embed output shapes of each node into graph. -You may use ```nnvm.testing.tf.AddShapesToGraphDef``` from nnvm for the same. +You may use ```tvm.relay.testing.tf.AddShapesToGraphDef``` from nnvm for the same. Please refer to [tensorflow tutorial](https://github.com/dmlc/tvm/blob/master/tutorials/nnvm/from_tensorflow.py). ### Explicit Shape: diff --git a/nnvm/tests/python/frontend/tensorflow/test_forward.py b/nnvm/tests/python/frontend/tensorflow/test_forward.py index 0ea92248f0f5..f4ec61979527 100644 --- a/nnvm/tests/python/frontend/tensorflow/test_forward.py +++ b/nnvm/tests/python/frontend/tensorflow/test_forward.py @@ -21,7 +21,7 @@ from tensorflow.python.ops import init_ops from tensorflow.core.framework import graph_pb2 -import nnvm.testing.tf +import tvm.relay.testing.tf as tf_testing ####################################################################### # Generic run functions for TVM & tensorflow @@ -784,9 +784,9 @@ def test_forward_pad(): def test_forward_inception_v3(): '''test inception V3 model''' with tf.Graph().as_default(): - graph_def = nnvm.testing.tf.get_workload('InceptionV3/inception_v3_2016_08_28_frozen-with_shapes.pb') + graph_def = tf_testing.get_workload('InceptionV3/inception_v3_2016_08_28_frozen-with_shapes.pb') # Call the utility to import the graph definition into default graph. - graph_def = nnvm.testing.tf.ProcessGraphDefParam(graph_def) + graph_def = tf_testing.ProcessGraphDefParam(graph_def) data = np.random.uniform(size=(1, 299, 299, 3)).astype('float32') @@ -801,9 +801,9 @@ def test_forward_inception_v3(): def test_forward_inception_v1(): '''test inception V1 model''' with tf.Graph().as_default(): - graph_def = nnvm.testing.tf.get_workload("InceptionV1/classify_image_graph_def-with_shapes.pb") + graph_def = tf_testing.get_workload("InceptionV1/classify_image_graph_def-with_shapes.pb") # Call the utility to import the graph definition into default graph. - graph_def = nnvm.testing.tf.ProcessGraphDefParam(graph_def) + graph_def = tf_testing.ProcessGraphDefParam(graph_def) # Build an image from random data. from PIL import Image @@ -838,18 +838,18 @@ def test_forward_mobilenet(): '''test mobilenet model''' # MobilenetV2 with tf.Graph().as_default(): - graph_def = nnvm.testing.tf.get_workload( + graph_def = tf_testing.get_workload( "https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_1.4_224.tgz", "mobilenet_v2_1.4_224_frozen.pb") # Call the utility to import the graph definition into default graph. - graph_def = nnvm.testing.tf.ProcessGraphDefParam(graph_def) + graph_def = tf_testing.ProcessGraphDefParam(graph_def) data = np.random.uniform(size=(1, 224, 224, 3)).astype('float32') out_node = 'MobilenetV2/Predictions/Reshape_1' with tf.Session() as sess: # Add shapes to the graph. - graph_def = nnvm.testing.tf.AddShapesToGraphDef(sess, out_node) + graph_def = tf_testing.AddShapesToGraphDef(sess, out_node) tf_output = run_tf_graph(sess, data, 'input:0', out_node + ':0') tvm_output = run_tvm_graph(graph_def, data, 'input') tvm.testing.assert_allclose(np.squeeze(tvm_output[0]), np.squeeze(tf_output[0]), rtol=1e-5, atol=1e-5) @@ -861,9 +861,9 @@ def test_forward_resnetv2(): '''test resnet model''' if is_gpu_available(): with tf.Graph().as_default(): - graph_def = nnvm.testing.tf.get_workload("ResnetV2/resnet-20180601_resnet_v2_imagenet-shapes.pb") + graph_def = tf_testing.get_workload("ResnetV2/resnet-20180601_resnet_v2_imagenet-shapes.pb") # Call the utility to import the graph definition into default graph. - graph_def = nnvm.testing.tf.ProcessGraphDefParam(graph_def) + graph_def = tf_testing.ProcessGraphDefParam(graph_def) data = np.random.uniform(size=(128, 224, 224, 3)).astype('float32') out_node = 'ArgMax' @@ -879,7 +879,7 @@ def test_forward_resnetv2(): dir(tf.contrib) def test_forward_ptb(): '''test ptb model''' - config = nnvm.testing.tf.get_config() + config = tf_testing.get_config() num_steps = config.num_steps num_hidden = config.hidden_size num_layers = config.num_layers @@ -936,7 +936,7 @@ def _get_sample(data, state): "float32")).asnumpy() state_output = model.get_output(1, tvm.nd.empty(out_state_shape, "float32")).asnumpy() - sample = nnvm.testing.tf.pick_from_weight(tvm_output[0]) + sample = tf_testing.pick_from_weight(tvm_output[0]) return sample, state_output @@ -956,10 +956,10 @@ def _get_sample(data, state): return samples, state with tf.Graph().as_default(): - word_to_id, id_to_word, graph_def = nnvm.testing.tf.get_workload_ptb() + word_to_id, id_to_word, graph_def = tf_testing.get_workload_ptb() vocab_size = len(word_to_id) # Call the utility to import the graph definition into default graph. - graph_def = nnvm.testing.tf.ProcessGraphDefParam(graph_def) + graph_def = tf_testing.ProcessGraphDefParam(graph_def) sess = tf.Session() #TVM graph module creation @@ -975,7 +975,7 @@ def _get_sample(data, state): for word in seed_for_sample], in_state, params, cnt_sample) tvm_sample_str = _pretty_print(tvm_samples, False, id_to_word) - tf_samples, tf_state = nnvm.testing.tf.do_tf_sample(sess, + tf_samples, tf_state = tf_testing.do_tf_sample(sess, [word_to_id[word] for word in seed_for_sample], in_state, cnt_sample) tf_sample_str = _pretty_print(tf_samples, False, id_to_word) diff --git a/python/tvm/relay/frontend/__init__.py b/python/tvm/relay/frontend/__init__.py index d582e02e5cc7..dee3999ad3f1 100644 --- a/python/tvm/relay/frontend/__init__.py +++ b/python/tvm/relay/frontend/__init__.py @@ -13,3 +13,4 @@ from .tflite import from_tflite from .coreml import from_coreml from .caffe2 import from_caffe2 +from .tensorflow import from_tensorflow diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py new file mode 100644 index 000000000000..82b4c5b9ca37 --- /dev/null +++ b/python/tvm/relay/frontend/tensorflow.py @@ -0,0 +1,1540 @@ +# pylint: disable=import-self, invalid-name, unused-argument, too-many-lines, len-as-condition +"""TF: Tensorflow frontend.""" +from __future__ import absolute_import as _abs +from __future__ import print_function + +import logging +# Numpy support +import numpy as np + +import tvm +from topi.util import get_const_tuple +from .. import ir_pass +from .. import expr as _expr +from .. import op as _op + +__all__ = ['from_tensorflow'] + +def _get_relay_op(op_name): + try: + op = getattr(_op, op_name) + except AttributeError: + try: + op = getattr(_op.nn, op_name) + except AttributeError: + op = getattr(_op.image, op_name) + + if not op: + raise RuntimeError("Unable to map op_name {} to relay".format(op_name)) + return op + +class AttrCvt(object): + """Common attribute conveter. An AttrConverter instance is a callable: + ``` + attr_converter = AttrConverter(op_name, transforms={'a':'b', 'c':('d', 1)}) + new_op_name, new_attr = attr_converter(attrs) + ``` + + Parameters + ---------- + op_name : str or callable + If set as str, returned operator name is the str. + If set as callable, returned operator is the str returned by calling: + `op_name = func(attr)` + transforms : dict of `new_name, or (new_name, default_value, transform function)` + If only a new_name is provided, it's like renaming the attribute name. + If default_value if provded, then the attribute is considered as optional. + If transform function is provided, the original attribute value is handled + by transform function. + excludes : list + A list of excluded attributes that should `NOT` appear. + Raise NotImplementedError if occured. + disables : list + A list of attributes that is disabled in relay. Log warnings. + ignores : list + A list of attributes that is ignored in relay. Debug level logging. + extras : dict + A series of additional attributes should be added anyway to the returned + attribute dict. + custom_check : callable + A custom function takes attribute, and return True/False. + Raise RuntimeError if not bool(True) returned. + """ + + def __init__(self, op_name, transforms=None, + excludes=None, disables=None, ignores=None, + extras=None, custom_check=None): + self._op_name = op_name + self._transforms = transforms if transforms else {} + self._excludes = excludes if excludes else [] + self._disables = disables if disables else [] + self._ignores = ignores if ignores else [] + self._extras = extras if extras else {} + self._custom_check = custom_check + + def __call__(self, inputs, attrs, *args): + self._ignores.append('_output_shapes') + self._ignores.append('_input_shapes') + self._ignores.append('T') + self._ignores.append('use_cudnn_on_gpu') + self._ignores.append('_node_name') + self._ignores.append('is_training') + self._ignores.append('_target_layout') + + # apply custom check + if self._custom_check: + func, msg = self._custom_check + if not func(attrs): + raise RuntimeError("Check failed: {}".format(msg)) + # get new op_name + if isinstance(self._op_name, str): + op_name = self._op_name + else: + assert callable(self._op_name), "op_name can either be string or callable" + op_name = self._op_name(attrs) + # convert attributes + new_attrs = {} + for k in attrs.keys(): + if k in self._excludes: + raise NotImplementedError("Attribute {} not supported yet.".format(k)) + elif k in self._disables: + logging.warning("Attribute %s is disabled in relay.%s", k, op_name) + elif k in self._ignores: + logging.debug("Attribute %s is ignored in relay.%s", k, op_name) + elif k in self._transforms: + new_name, defaults, transform = self._parse_default(self._transforms[k]) + if defaults is None: + new_attr = self._required_attr(attrs, k) + else: + new_attr = attrs.get(k, None) + if new_attr is None: + new_attrs[new_name] = defaults + else: + new_attrs[new_name] = transform(new_attr) + else: + # copy + new_attrs[k] = attrs[k] + # add extras + new_attrs.update(self._extras) + return _get_relay_op(op_name)(*inputs, **new_attrs) + + def _parse_default(self, target): + """Helper function to parse default values.""" + if not isinstance(target, (list, tuple)): + k, v, t = target, None, lambda x: x + elif len(target) == 1: + k, v, t = target[0], None, lambda x: x + elif len(target) == 2: + k, v, t = target[0], target[1], lambda x: x + elif len(target) > 2: + k, v, t = target[0], target[1], target[2] + else: + k = None # should raise + if not isinstance(k, str): + msg = "{} is not a valid target, (name, default) expected.".format(target) + raise ValueError(msg) + return k, v, t + + def _parse_bool(self, value): + """Helper function to parse default boolean values.""" + if isinstance(value, str): + return value.strip().lower() in ['true', '1', 't', 'y', 'yes'] + return bool(value) + + def _required_attr(self, attr, key): + """Wrapper for getting required attributes.""" + assert isinstance(attr, dict) + if key not in attr: + raise AttributeError("Required attribute {} not found.".format(key)) + return attr[key] + +def _get_pad_pair(input1d, kernel1d, stride1d): + if input1d % stride1d == 0: + pad = max(kernel1d - stride1d, 0) + else: + pad = max(kernel1d - (input1d % stride1d), 0) + + pad_before = pad // 2 + pad_after = pad - pad_before + + return [pad_before, pad_after] + +def _get_name_hint(node): + name = '' + if hasattr(node, "name_hint"): + name = node.name_hint + return name + +def _math_name_picker(surfix): + def _impl(attr): + return 'broadcast_' + surfix + return _impl + +def _dimension_picker(prefix, surfix=''): + def _impl(attr): + kernel = attr['kernel_shape'] + if len(kernel) == 2: + return prefix + '2d' + surfix + else: + raise NotImplementedError("Only 2d kernel supported.") + return _impl + +def _dimension_constraint(): + def _dim_check(attrs): + if len(attrs['kernel_shape']) == 2: + return True + return False + return _dim_check, "Only 2d kernel supported." + +def _infer_channels(inputs, params, transpose=False): + """A hack for getting 'channles' or 'units' since tensorflow don't provide + these attributes. We check the shape of weights provided to get the number. + """ + out_type = ir_pass.infer_type(inputs) + out_shapes = [get_const_tuple(out_type.checked_type.shape)] + channels = out_shapes[0][0] if not transpose else out_shapes[0][1] + return channels + +def _rsqrt(): + def _impl(inputs, attr, *args): + inputs.append(tvm.relay.const(-0.5, attr['T'].name)) + return AttrCvt(op_name="power")(inputs, attr) + return _impl + +def _argx(func, func_name): + """ A common wrapper for argmin and argmax operations """ + def _impl(inputs, attr, params): + try: + # In Tensorflow, `axis` argument is a Tensor, not attribute. We + # support the case where it inputs from a scalar constant. + axis_input_name = inputs[1].name_hint + axis_input_vlaue = [params[axis_input_name].asnumpy()[0]] + except (IndexError, KeyError): + raise TypeError( \ + "Unsupported argument for `{}` : `axis` should be a constant".format(func_name)) + return func(inputs[0], axis=axis_input_vlaue, keepdims=False) + return _impl + +def _elemwise(name): + def _impl(inputs, attr, *args): + assert len(inputs) == 2, "Math op take 2 inputs, {} given".format(len(inputs)) + return _get_relay_op(name)(*inputs) + return _impl + +def _pooling(name): + def _impl(inputs, attr, params): + + attr['data_format'] = attr['data_format'].decode("utf-8") + flip_layout = False + + input_shape = attr['_input_shapes'][inputs[0]][0] + + if attr['data_format'] == 'NHWC': + attr['kernel_shape'] = (attr['ksize'][1], attr['ksize'][2]) + attr['strides'] = (attr['strides'][1], attr['strides'][2]) + elif attr['data_format'] == 'NCHW': + attr['kernel_shape'] = (attr['ksize'][2], attr['ksize'][3]) + attr['strides'] = (attr['strides'][2], attr['strides'][3]) + else: + raise TypeError("Unsupported data_format type : {}".format(attr['data_format'])) + + if attr['_target_layout'] == "NCHW" and attr['data_format'] == "NHWC": + tmp_shape = attr['_input_shapes'][inputs[0]][0] + input_shape = [tmp_shape[ii] for ii in (0, 3, 1, 2)] + inputs[0] = _op.transpose(inputs[0], axes=(0, 3, 1, 2)) + attr['data_format'] = "NCHW" + flip_layout = True + + # Fix padding + attr['padding'] = attr['padding'].decode("utf-8") + + if attr['padding'] == 'VALID': + attr['padding'] = [0, 0] + elif attr['padding'] == 'SAME': + stride_h, stride_w = attr['strides'] + kernel_h, kernel_w = attr['kernel_shape'] + if attr['data_format'] == 'NHWC': + in_h = input_shape[1] + in_w = input_shape[2] + else: + in_h = input_shape[2] + in_w = input_shape[3] + + pad_v = _get_pad_pair(in_h, kernel_h, stride_h) + pad_h = _get_pad_pair(in_w, kernel_w, stride_w) + + attr['padding'] = [pad_v[0], pad_h[0], pad_v[1], pad_h[1]] + else: + raise TypeError("Unsupported padding type : {}".format(attr['padding'])) + + if name == "avg_pool": + attr['count_include_pad'] = False + + out = AttrCvt( + op_name=_dimension_picker(name), + transforms={ + 'kernel_shape':'pool_size', + 'data_format':'layout'}, + ignores=['ksize'], + extras={'ceil_mode': False}, + custom_check=_dimension_constraint())(inputs, attr) + + if flip_layout: + out = _op.transpose(out, axes=(0, 2, 3, 1)) + + return out + return _impl + +def _conv(opname): + def _impl(inputs, attr, params): + attr['data_format'] = attr['data_format'].decode("utf-8") + flip_layout = False + + # NCHW Layout require weights transpose + if attr['data_format'] == 'NCHW': + tmp_shape = attr['_input_shapes'][inputs[1]][0] + tmp_shape = [tmp_shape[ii] for ii in (3, 2, 0, 1)] + inputs[1] = _op.transpose(inputs[1], axes=(3, 2, 0, 1)) + attr['_input_shapes'][inputs[1]] = [tmp_shape] + + input_shape = attr['_input_shapes'][inputs[0]][0] + weights_shape = attr['_input_shapes'][inputs[1]][0] + + if attr['_target_layout'] == "NCHW" and attr['data_format'] == "NHWC": + input_shape = [input_shape[ii] for ii in (0, 3, 1, 2)] + inputs[0] = _op.transpose(inputs[0], axes=(0, 3, 1, 2)) + if opname == 'conv': + weights_shape = [weights_shape[ii] for ii in (3, 2, 0, 1)] + inputs[1] = _op.transpose(inputs[1], axes=(3, 2, 0, 1)) + else: + weights_shape = [weights_shape[ii] for ii in (2, 3, 0, 1)] + inputs[1] = _op.transpose(inputs[1], axes=(2, 3, 0, 1)) + + attr['data_format'] = "NCHW" + attr['strides'] = [attr['strides'][ii] for ii in (0, 3, 1, 2)] + flip_layout = True + + if attr['data_format'] == 'NHWC': + kernel_h, kernel_w, _, depth_mult = weights_shape + attr['kernel_shape'] = (weights_shape[0], weights_shape[1]) + if opname == 'conv': + attr['channels'] = weights_shape[3] + else: + attr['channels'] = input_shape[3] * depth_mult + + if 'dilations' in attr: + attr['dilations'] = (attr['dilations'][0], attr['dilations'][1]) + attr['strides'] = (attr['strides'][1], attr['strides'][2]) + elif attr['data_format'] == 'NCHW': + depth_mult, _, kernel_h, kernel_w = weights_shape + attr['kernel_shape'] = (weights_shape[2], weights_shape[3]) + if opname == 'conv': + attr['channels'] = weights_shape[0] + else: + attr['channels'] = input_shape[0] * depth_mult + if attr['channels'] < 0: + attr['channels'] *= -1 + + if 'dilations' in attr: + attr['dilations'] = (attr['dilations'][2], attr['dilations'][3]) + attr['strides'] = (attr['strides'][2], attr['strides'][3]) + else: + raise TypeError("Unsupported data format type : {}".format(attr['data_format'])) + + + if opname == 'depthwise': + attr['groups'] = attr['channels'] + + # Fix padding + attr['padding'] = attr['padding'].decode("utf-8") + + if attr['padding'] == 'VALID': + attr['padding'] = [0, 0] + elif attr['padding'] == 'SAME': + stride_h, stride_w = attr['strides'] + kernel_h, kernel_w = attr['kernel_shape'] + if attr['data_format'] == 'NHWC': + in_h = input_shape[1] + in_w = input_shape[2] + else: + in_h = input_shape[2] + in_w = input_shape[3] + + pad_v = _get_pad_pair(in_h, kernel_h, stride_h) + pad_h = _get_pad_pair(in_w, kernel_w, stride_w) + + if attr['data_format'] == 'NHWC': + inputs[0] = _op.nn.pad(data=inputs[0], + pad_width=((0, 0), + (pad_v[0], pad_v[1]), + (pad_h[0], pad_h[1]), + (0, 0))) + else: + inputs[0] = _op.nn.pad(data=inputs[0], + pad_width=((0, 0), + (0, 0), + (pad_v[0], pad_v[1]), + (pad_h[0], pad_h[1]))) + + attr['padding'] = [0, 0] + + else: + raise TypeError("Unsupported padding type : {}".format(attr['padding'])) + + if 'kernel_layout' not in attr: + if opname == 'conv': + attr['kernel_layout'] = 'HWIO' if attr['data_format'] == 'NHWC' else 'OIHW' + else: + attr['kernel_layout'] = 'HWOI' if attr['data_format'] == 'NHWC' else 'OIHW' + + use_bias = len(inputs) == 3 + channel_axis = 1 if attr['data_format'] == "NCHW" else 3 + + out = AttrCvt( + op_name=_dimension_picker('conv'), + transforms={ + 'kernel_shape': 'kernel_size', + 'data_format': 'data_layout', + 'dilations': ('dilation', (0, 0)), + 'group': ('groups', 1)}, + custom_check=_dimension_constraint())([inputs[0], inputs[1]], attr) + + if use_bias: + out = _op.nn.bias_add(out, inputs[2], axis=channel_axis) + + if flip_layout: + out = _op.transpose(out, axes=(0, 2, 3, 1)) + + return out + return _impl + +def _decode_image(): + def _impl(inputs, attr, params): + # Image decode wrapper: Expecting user to feed decoded input to next layer drop this layer. + print("DecodeJpeg: It's a pass through, please handle preprocessing before input") + return inputs[0] + return _impl + +def _cast(): + def _impl(inputs, attr, params): + return inputs[0].astype(attr['DstT'].name) + return _impl + +def _expand_dims(): + def _impl(inputs, attr, params): + dim_input = inputs.pop(1) + axis = params[dim_input.name_hint] + params.pop(dim_input.name_hint) + return AttrCvt(op_name="expand_dims", ignores=['Tdim'], + extras={'axis': int(axis.asnumpy()[0])})(inputs, attr) + return _impl + +def _resize_bilinear(): + def _impl(inputs, attr, params): + attr['size'] = attr['_output_shapes'][0][1:3] + inputs.pop(1) + # NHWC + attr['layout'] = 'NHWC' + + return AttrCvt(op_name="resize", + ignores=['Tdim'], + extras={'method': "BILINEAR"})(inputs, attr) + return _impl + +def _check_numerics(): + def _impl(inputs, attr, params): + # Making a copy node assuming no need to verify + return AttrCvt(op_name="copy", ignores=['message'])(inputs, attr) + return _impl + + +def _matmul(): + def _impl(inputs, attr, params): + channels = _infer_channels(inputs[1], params, not attr['transpose_b']) + if attr['transpose_a']: + inputs[0] = _op.transpose(inputs[0], axes=(1, 0)) + if not attr['transpose_b']: + inputs[1] = _op.transpose(inputs[1], axes=(1, 0)) + return AttrCvt(op_name="dense", + extras={'units': channels}, + ignores=['transpose_a', 'transpose_b', 'T'])(inputs, attr) + + return _impl + +def _identity(): + def _impl(inputs, attr, params): + return inputs[0] + return _impl + +def _concatV2(): + def _impl(inputs, attr, params): + pop_node = inputs.pop(len(inputs)-1) + axis = params[pop_node.name_hint] + params.pop(pop_node.name_hint) + return AttrCvt( + op_name="concatenate", ignores=['T', 'N', 'Tidx'], + extras={'axis': int(axis.asnumpy()[0])})([inputs], attr) + return _impl + +def _concat(): + def _impl(inputs, attr, params): + pop_node = inputs.pop(0) + axis = params[pop_node.name_hint] + params.pop(pop_node.name_hint) + return AttrCvt( + op_name="concatenate", ignores=['N'], + extras={'axis': int(axis.asnumpy()[0])})([inputs], attr) + return _impl + +def _pack(): + def _impl(inputs, attr, params): + axis = int(attr["axis"]) + inputs_reshaped = [_op.expand_dims(i, axis=axis, num_newaxis=1) for i in inputs] + return _op.concatenate(inputs_reshaped, axis) + return _impl + +def _reshape(): + def _impl(inputs, attr, params): + try: + pop_node = inputs[1] + shape_arg = params.pop(pop_node.name_hint) + inputs.pop(1) + + return AttrCvt( + op_name="reshape", + extras={'newshape':tuple(shape_arg.asnumpy())}, + ignores=['Tshape'])(inputs, attr) + except KeyError: + # Shape operator is already pruned, hence + # try to infer shape by precompute prune if possible. + if all(in_node in params for in_node in inputs[1].list_input_names()): + func = _expr.Function(ir_pass.free_vars(inputs[1]), inputs[1]) + with tvm.relay.build_config(opt_level=0): + graph, lib, params = tvm.relay.build(func, target="llvm", params=params) + ctx = tvm.context("llvm", 0) + from tvm.contrib import graph_runtime + m = graph_runtime.create(graph, lib, ctx) + m.set_input(**params) + m.run() + params_new = m.get_output(0) + inputs.pop(1) + return AttrCvt( + op_name="reshape", + extras={'newshape':tuple(params_new.asnumpy().flatten())}, + ignores=['Tshape'])(inputs, attr) + else: + raise RuntimeError("Reshape with dynamic shape input not supported yet.") + return _impl + +def _bias_add(): + def _impl(inputs, attr, params): + return _op.add(inputs[0], inputs[1]) + return _impl + +def _squeeze(): + def _impl(inputs, attr, params): + if len(attr['squeeze_dims']) == 0: + attr['squeeze_dims'] = None + return AttrCvt( + op_name="squeeze", + transforms={'squeeze_dims':'axis'}, + ignores=['T'])(inputs, attr) + return _impl + +def _fused_batch_norm(): + def _impl(inputs, attr, params): + # Tensorflow: (data, gamma, beta, moving_mean, moving_variance) + # Relay: (data, gamma, beta, moving_mean, moving_varience) + axis = 3 + need_cast = False + + if 'data_format' in attr: + attr['data_format'] = attr['data_format'].decode("utf-8") + if attr['data_format'] == 'NCHW': + axis = 1 + if 'U' in attr: + need_cast = True + inputs[0] = _op.cast(inputs[0], dtype=attr['U'].name) + + out = AttrCvt(op_name='batch_norm', + transforms={'scale_after_normalization':'scale', + 'variance_epsilon':'epsilon'}, + extras={'axis': axis}, + ignores=['data_format', 'U'], + disables=['momentum'])(inputs, attr) + + if need_cast: + out = _op.cast(out, dtype=attr['T'].name) + return out + return _impl + +def _batch_norm(): + def _impl(inputs, attr, params): + # Rearrange inputs from + # (data, moving_mean, moving_variance, beta, gamma) + # to + # (data, gamma, beta, moving_mean, moving_var) + new_inputs = [inputs[0], inputs[4], inputs[3], inputs[1], inputs[2]] + + axis = 3 + if 'data_format' in attr: + attr['data_format'] = attr['data_format'].decode("utf-8") + if attr['data_format'] == 'NCHW': + axis = 1 + + return AttrCvt( + op_name='batch_norm', + transforms={'scale_after_normalization':'scale', 'variance_epsilon':'epsilon'}, + extras={'axis': axis}, + ignores=['data_format'], + disables=['momentum'])(new_inputs, attr) + return _impl + +def _relu6(): + def _impl(inputs, attr, params): + return _op.clip(inputs[0], a_min=0, a_max=6) + return _impl + +def _shape(): + def _impl(inputs, attr, params): + return np.array(attr['_input_shapes'][inputs[0]][0], dtype='int32') + return _impl + +def _fill(): + def _impl(inputs, attr, params): + fill_arg = params.pop(inputs.pop(1).name_hint) + return _op.full(tvm.relay.const(fill_arg.asnumpy()[0], attr['T'].name), + attr['_output_shapes'][0], attr['T'].name) + return _impl + +def _lrn(): + def _impl(inputs, attr, params): + attr_new = {} + depth_radius = attr.get('depth_radius', 5) + size = (depth_radius * 2) + 1 + attr_new['axis'] = 3 # Fix axis, NHWC format + attr_new['size'] = size + attr_new['bias'] = attr.get('bias', 1) + attr_new['alpha'] = attr.get('alpha', 1) * size + attr_new['beta'] = attr.get('beta', 0.5) + return AttrCvt(op_name='lrn')(inputs, attr_new) + return _impl + +def _sum(): + def _impl(inputs, attr, params): + axis = params.pop(inputs[1].name_hint).asnumpy() + # convert to tuple for preventing invalid parameter format error + axis = tuple(axis) + return AttrCvt( + op_name='sum', + extras={'axis': axis}, + transforms={'keep_dims':'keepdims'}, + ignores=['name', 'Tidx'])([inputs[0]], attr) + return _impl + +def _square(): + def _impl(inputs, attr, params): + return _op.multiply(inputs[0], inputs[0]) + return _impl + +def _gather_v2(): + "Tensorflow now support only gatherv2" + def _impl(inputs, attr, params): + axis = params[inputs.pop(2).name_hint].asnumpy()[0] + new_input = [] + new_input.append(inputs.pop(0)) + new_input.append(inputs.pop(0)) + return AttrCvt(op_name="take", + extras={'axis': tvm.const(axis, 'int32')}, + ignores=['Tindices', 'Tparams', 'validate_indices', \ + 'Taxis', '_class'])(new_input, attr) + return _impl + +def _infer_out_shapes(inputs, params): + """A method to get the output shape of an intermediate node in the relay graph.""" + out_type = ir_pass.infer_type(inputs) + out_shapes = [get_const_tuple(out_type.checked_type.shape)] + return out_shapes + +def _stridedSlice(): + def _impl(inputs, attr, params): + """Strided Slice. + Operator description: https://www.tensorflow.org/api_docs/python/tf/strided_slice + Tensorflow mask validation: https://github.com/tensorflow/tensorflow/blob/master/ + tensorflow/core/util/strided_slice_op.cc#L147-L368 + """ + begin = params.pop(inputs[1].name_hint).asnumpy().tolist() + end = params.pop(inputs[2].name_hint).asnumpy().tolist() + stride = params.pop(inputs[3].name_hint).asnumpy().tolist() + begin_mask = int(attr.get('begin_mask', 0)) + end_mask = int(attr.get('end_mask', 0)) + ellipsis_mask = int(attr.get('ellipsis_mask', 0)) + new_axis_mask = int(attr.get('new_axis_mask', 0)) + shrink_axis_mask = int(attr.get('shrink_axis_mask', 0)) + data_shape = attr['_input_shapes'][inputs[0]] + data_dim = len(data_shape[0]) + stride_dim = len(stride) + + def _transform_mask(stride_dim, ellipsis_mask): + """Handle mask inputs to create new begin, end, stride and output shape""" + m_begin = [0] * data_dim + m_end = [0] * data_dim + m_stride = [0] * data_dim + fshape_indices = [] + #Count new axis after ellipsis_mask, consider while applying ellipsis_mask. + ellipsis_seen = False + new_axes_after_ellipsis = 0 + for i in range(stride_dim): + mask = 1 << i + if ellipsis_seen and (mask & new_axis_mask) != 0: + new_axes_after_ellipsis += 1 + if (mask & ellipsis_mask) != 0: + ellipsis_seen = True + if not ellipsis_seen: + #Used later for extending the stride attributes in the below loop. + ellipsis_mask |= (1 << stride_dim) + stride_dim += 1 + final_index = 0 + for index in range(stride_dim): + mask = 1 << index + if mask & ellipsis_mask: + #Identify the end index for applying ellipsis_mask + to_index = min(((data_dim - (stride_dim-index)) + 1 \ + + new_axes_after_ellipsis), data_dim) + for i in range(final_index, to_index): + m_begin[final_index] = 0 + m_end[final_index] = data_shape[0][final_index] + m_stride[final_index] = 1 + fshape_indices.append(final_index) + final_index += 1 + elif mask &new_axis_mask: + fshape_indices.append(-1) + elif not mask & new_axis_mask: + if final_index == len(m_begin): + break + if mask & begin_mask: + m_begin[final_index] = data_shape[0][final_index] \ + if stride[index] < 0 else 0 + elif begin[index]: + m_begin[final_index] = begin[index] + if mask & end_mask: + m_end[final_index] = 0 if stride[index] < 0 \ + else data_shape[0][final_index] + elif end[index]: + m_end[final_index] = end[index] + m_stride[final_index] = stride[index] + if mask & shrink_axis_mask: + #Tensorflow make axis with shrink_axis_mask as dimension 1 + m_begin[final_index] = data_shape[0][final_index] + begin[index] \ + if begin[index] < 0 else begin[index] + m_end[final_index] = begin[index] + 1 + m_stride[final_index] = 1 + fshape_indices.append(-2) + else: + fshape_indices.append(final_index) + + final_index += 1 + return m_begin, m_end, m_stride, fshape_indices + + fshape_indices = None + if begin_mask or end_mask or ellipsis_mask or new_axis_mask or shrink_axis_mask: + begin, end, stride, fshape_indices = _transform_mask(stride_dim, ellipsis_mask) + out = _op.strided_slice(inputs[0], begin=begin, end=end, strides=stride) + out_shape = _infer_out_shapes(out, params)[0] + if not fshape_indices: + fshape_indices = range(len(out_shape)) + + #Create final output shape. + final_output = [] + for gather_index in fshape_indices: + if gather_index == -1: + final_output.append(1) + elif gather_index == -2: + pass + else: + final_output.append(out_shape[gather_index]) + return _op.reshape(out, newshape=tuple(final_output)) + return _impl + +def _pad(name): + def _impl(inputs, attr, params): + padlist_key = inputs[1].name_hint + if padlist_key in params: + padlist = params.pop(padlist_key).asnumpy() + else: + raise RuntimeError("Required parameter {} not fount.".format(padlist_key)) + paddings = tuple([tuple(l) for l in padlist]) + attr['pad_width'] = paddings + attr['pad_value'] = 0 + new_inputs = [inputs[0]] + if name == 'PadV2': + constant_values = params.pop(inputs[2].name_hint).asnumpy() + attr['pad_value'] = constant_values[0] + return AttrCvt( + op_name='pad', + ignores=['Tpaddings'],)(new_inputs, attr) + return _impl + + +def _transpose(): + def _impl(inputs, attr, params): + # If perm is not specified, axes is left empty, + # otherwise its value is get from params + param_name = _get_name_hint(inputs[1]) + if param_name in params: + axes = tuple(params.get(param_name).asnumpy()) + else: + axes = None + return _op.transpose(inputs[0], axes=axes) + return _impl + +def _rank(): + def _impl(inputs, attr, params): + input_shapes = attr['_input_shapes'][inputs[0]] + assert len(inputs) == 1 + + name = attr["_node_name"] + params[name] = tvm.nd.array([len(input_shapes[0])]) + return [_expr.var(name, + shape=params[name].shape, + dtype='int32')] + + return _impl + +def _range(): + def _impl(inputs, attr, params): + start = params.pop(inputs[0].name_hint).asnumpy()[0] + limit = params.pop(inputs[1].name_hint).asnumpy()[0] + delta = params.pop(inputs[2].name_hint).asnumpy()[0] + + name = attr["_node_name"] + params[name] = tvm.nd.array([start, limit, delta]) + return [_expr.var(name, + shape=params[name].shape, + dtype='int32')] + return _impl + +def _elu(): + def _impl(inputs, attr, params): + alpha = tvm.relay.const(-1.0, attr['T'].name) + return alpha * _op.nn.relu(tvm.relay.const(1, attr['T'].name) \ + - _op.exp(inputs[0])) + _op.nn.relu(inputs[0]) + return _impl + +def _selu(): + def _impl(inputs, attr, params): + alpha = tvm.relay.const(-1.6732632423543772848170429916717, attr['T'].name) + gamma = tvm.relay.const(1.0507009873554804934193349852946, attr['T'].name) + return gamma * (alpha * _op.nn.relu(tvm.relay.const(1, attr['T'].name) \ + - _op.exp(inputs[0])) + _op.nn.relu(inputs[0])) + return _impl + +def _mean(): + def _impl(inputs, attr, params): + axis = params.pop(inputs[1].name_hint) + return AttrCvt(op_name="mean", ignores=['Tdim', 'Tidx'], + transforms={'keep_dims': 'keepdims'}, + extras={'axis': tuple(axis.asnumpy())})([inputs[0]], attr) + return _impl + +def _broadcast(name): + def _impl(inputs, attr, params): + return AttrCvt( + op_name=name, + ignores=['name', 'Tidx'] + )(inputs, attr) + return _impl + +def _softmax(): + def _impl(inputs, attr, params): + return AttrCvt(op_name='softmax', + transforms={'axis': ('axis', 1)})([inputs[0]], attr) + return _impl + +# compatible operators that do NOT require any conversion. +_identity_list = [] + +# _convert_map defines maps of name to converter functor(callable) +# for 1 to 1 mapping, use Renamer if nothing but name is different +# use AttrCvt if attributes need to be converted +# for 1 to N mapping(composed), use custom callable functions +# for N to 1 mapping, currently not supported(?) +_convert_map = { + 'ArgMax' : _argx(_op.argmax, 'argmax'), + 'ArgMin' : _argx(_op.argmin, 'argmin'), + 'AvgPool' : _pooling('avg_pool'), + 'BatchNormWithGlobalNormalization' : _batch_norm(), + 'BiasAdd' : _bias_add(), + 'Cast' : _cast(), + 'Ceil' : AttrCvt('ceil'), + 'CheckNumerics' : _check_numerics(), + 'Concat' : _concat(), + 'ConcatV2' : _concatV2(), + 'Conv2D' : _conv('conv'), + 'DecodeJpeg' : _decode_image(), + 'Elu' : _elu(), + 'ExpandDims' : _expand_dims(), + 'Floor' : AttrCvt('floor'), + 'Identity' : _identity(), + 'MatMul' : _matmul(), + 'MaxPool' : _pooling('max_pool'), + 'Add' : _elemwise('add'), + 'Sub' : _elemwise('subtract'), + 'Mul' : _elemwise('multiply'), + 'Maximum' : _elemwise('maximum'), + 'Minimum' : _elemwise('minimum'), + 'Sum' : _sum(), + 'Square' : _square(), + 'Pack' : _pack(), + 'LeakyRelu' : AttrCvt('leaky_relu'), + 'Relu' : AttrCvt('relu'), + 'Reshape' : _reshape(), + 'ResizeBilinear' : _resize_bilinear(), + 'Selu' : _selu(), + 'Softmax' : _softmax(), + 'Rsqrt' : _rsqrt(), + 'Squeeze' : _squeeze(), + 'FusedBatchNorm' : _fused_batch_norm(), + 'FusedBatchNormV2' : _fused_batch_norm(), + 'Relu6' : _relu6(), + 'DepthwiseConv2dNative' : _conv('depthwise'), + 'Shape' : _shape(), + 'Sigmoid' : AttrCvt('sigmoid'), + 'Fill' : _fill(), + 'GatherV2' : _gather_v2(), + 'StridedSlice' : _stridedSlice(), + 'LRN' : _lrn(), + 'Pad' : _pad('Pad'), + 'PadV2' : _pad('PadV2'), + 'Range' : _range(), + 'Rank' : _rank(), + 'Transpose' : _transpose(), + 'Tanh' : AttrCvt('tanh'), + 'Mean' : _mean(), + 'Less' : _broadcast('less'), + 'Greater' : _broadcast('greater'), + 'LessEqual' : _broadcast('less_equal'), + 'GreaterEqual' : _broadcast('greater_equal'), + 'Equal' : _broadcast('equal'), + 'NotEqual' : _broadcast('not_equal'), +} + +def _LSTMBlockCell(): + def _impl(inputs, in_state_c, in_state_h, attr, params): + """LSTM Block cell. + Calculations are described in: https://github.com/tensorflow/tensorflow/blob/ + r1.8/tensorflow/contrib/rnn/python/ops/lstm_ops.py#L41-L114 + + Parameters + ---------- + inputs : relay.Expr + Input data + in_state_c: list of relay.Expr + Cell state input values for all the layers + in_state_h: list of relay.Expr + Hidden state input values for all the layers + attrs : dict + Dict of operator attributes + params : dict + List of pretrained weights and bias + + Returns + ------- + sym : relay.Expr + Converted relay.Expr + output: relay.Expr + Output state value. + """ + in_data = inputs[0] + in_weight = inputs[3] + in_bias = inputs[7] + forget_bias = attr.pop('forget_bias') + input_shape = attr['_input_shapes'][inputs[0]] + weight_shape = attr['_input_shapes'][inputs[3]] + batch_size, input_size = input_shape[0][0], input_shape[0][1] + num_hidden_layers = weight_shape[0][1] + num_hidden = num_hidden_layers // 4 + + in_data = _op.reshape(in_data, + newshape=(batch_size, input_size)) + ixh = _op.concatenate([in_data, in_state_h], axis=1) + in_weight = _op.transpose(in_weight, axes=None) + gates = _op.nn.dense(ixh, in_weight, + units=num_hidden_layers) + gates_bias = _op.add(gates, in_bias) + gate_list = _op.split(gates_bias, indices_or_sections=4, axis=1) + in_gate = _op.sigmoid(gate_list[0]) + in_transform = _op.tanh(gate_list[1]) + forget_gate = _op.sigmoid(gate_list[2]) + forget_gate = _op.add(forget_gate, + tvm.relay.const(forget_bias, attr['T'].name)) + out_gate = _op.sigmoid(gate_list[3]) + next_c = _op.add(_op.multiply(forget_gate, in_state_c), + _op.multiply(in_gate, in_transform)) + next_h = out_gate * _op.tanh(next_c) + out_state = _op.concatenate([next_c, next_h], axis=1) + out_state = _op.reshape(out_state, + newshape=(2, batch_size, num_hidden)) + return next_h, out_state + return _impl + +# _convert_map_rnn defines maps of rnn operator name to +# converter functor(callable) for 1 to 1 mapping. +_convert_map_rnn = { + 'LSTMBlockCell' : _LSTMBlockCell(), +} + +class RecurrentNetworks(object): + """Recurrent network layer handlers. + + Handle Layer operations. + ToDo: Operators like RNN/GRU layer concepts also can be handled here + + Parameters + ---------- + nodes : list + list of graph nodes used for tensorflow parsing. + + out_rnn : list + List of RecurrentNetwork outputs. This output will be appended to the + 'head' nodes of the graph. + + graph : tensorflow graph definition object + The loaded tensorflow GraphDef + + convert_map : dict + Dict of name : callable, where name is the op's name that + require conversion to relay, callable are functions which + take attrs and return (new_op_name, new_attrs) + """ + def __init__(self, nodes, out_rnn, graph, convert_map): + self._graph = graph + self._convert_map = convert_map + self._nodes = nodes + self._out_rnn = out_rnn + self._cur_lstm_layer = 0 + self._layer_name_list = [] + self._recurrent_ops_layer_map = { + 'LSTMBlockCell' : self._LSTMBlockCellLayer(), + } + + def _LSTMBlockCellLayer(self): + """LSTMBlockCell layer handler. + + Parameters + ---------- + op_name : str + Operator name, eg:LSTMBlockCell + + layer_name : str list + Layer name is used for creating the state input placeholder. + + inputs : relay.Expr + Input data + + attrs : dict + Dict of operator attributes + + params : dict + List of pretrained weights and bias + + num_layers : int + Total number of LSTM layer presented in the graph + + Returns + ------- + sym : relay.Expr + The returned relay Expr + """ + def _impl(op_name, layer_name, inputs, attrs, params, num_layers): + in_state_c_name = layer_name+'_c' + in_state_h_name = layer_name+'_h' + + def _init_state(num_layers, batch_size, num_hidden): + """Create the initial states for the first layer in the graph.""" + in_state_c = [_expr.var(in_state_c_name, + shape=(num_layers, batch_size, num_hidden), + dtype='float32')] + + in_state_h = [_expr.var(in_state_h_name, + shape=(num_layers, batch_size, num_hidden), + dtype='float32')] + return in_state_c, in_state_h + + def _get_cur_input_state(in_state_c, in_state_h, num_layers, + layer, batch_size, num_hidden): + """Select the appropriate states for the current layer""" + in_state_c_tup = _op.split(in_state_c[0], + indices_or_sections=num_layers, axis=0) + in_state_h_tup = _op.split(in_state_h[0], + indices_or_sections=num_layers, axis=0) + cur_in_state_c = _op.reshape(in_state_c_tup[layer], + newshape=(batch_size, num_hidden)) + cur_in_state_h = _op.reshape(in_state_h_tup[layer], + newshape=(batch_size, num_hidden)) + return cur_in_state_c, cur_in_state_h + + def _LSTMBlockCellWrapper(inputs, attr, params, + num_layers, layer): + """LSTM cell warapper to prepare the inputs""" + input_shape = attr['_input_shapes'][inputs[0]] + weight_shape = attr['_input_shapes'][inputs[3]] + + batch_size = input_shape[0][0] + num_hidden = weight_shape[0][1] // 4 + + if layer == 0: + #Create initial states placeholder in case of first layer + in_state_c, in_state_h = _init_state(num_layers, + batch_size, num_hidden) + else: + in_state_c = self._nodes[in_state_c_name] + in_state_h = self._nodes[in_state_h_name] + + cur_in_state_c, cur_in_state_h = _get_cur_input_state( \ + in_state_c, in_state_h, + num_layers, layer, + batch_size, num_hidden) + output, out_state = self._convert_map[op_name](inputs, cur_in_state_c, + cur_in_state_h, + attr, params) + return output, out_state, in_state_c, in_state_h + + sym, cur_out_state, in_state_c, in_state_h = \ + _LSTMBlockCellWrapper(inputs, attrs, params, + num_layers, self._cur_lstm_layer) + self._nodes[in_state_c_name] = in_state_c + self._nodes[in_state_h_name] = in_state_h + cur_out_state = _op.expand_dims(cur_out_state, axis=0, num_newaxis=1) + self._out_rnn.append(cur_out_state) + self._cur_lstm_layer += 1 + return sym + return _impl + + def process_op(self, op_name, inputs, attrs, params): + """Process recurrent layer operators. + + List '_recurrent_ops_layer_map' map each Layer based operators with its + layer handlers. Total number of layers are calculated to form the input + data shapes. + + Parameters + ---------- + op_name : str + Operator name, such as LSTMBlockCell + + inputs : relay.Expr + Input data + + attrs : dict + Dict of operator attributes + + params : dict + List of pretrained weights and bias + + Returns + ------- + sym : relay.Expr + Returns relay.Expr + """ + def _get_abs_layer_name(node): + """Identify the layer name is already handled. Return the absolute name + """ + if not self._layer_name_list: + self._layer_name_list.append(node.name) + return node.name + + for _name in self._layer_name_list: + if _name in node.name: + abs_name = _name + else: + self._layer_name_list.append(node.name) + abs_name = node.name + return abs_name + + #Find number of layers of this same operator node in the graph + #and also read the inputs name for the current op. + num_layers = 0 + for _, node in enumerate(self._graph.node): + if node.op == op_name: + layer_name = _get_abs_layer_name(node) + num_layers += 1 + + sym = self._recurrent_ops_layer_map[op_name](op_name, layer_name, inputs, attrs, + params, num_layers) + return sym + +class GraphProto(object): + """ A helper class for handling relay graph copying from Tensorflow GraphDef. + Definition: + https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/framework/graph.proto + """ + def __init__(self): + self._nodes = {} + self._params = {} + self._output_shapes = {} + self._num_param = 0 + self._num_rnn_layer = False + + def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None): + """Construct relay nodes from tensorflow graph definition - GraphDef. + + Follow the tensorflow graph definition to parse and convert it to Relay. + Some of the assumptions listed below. + + -> All Placeholders are considered as graph input. + -> All Const nodes are params. + -> Last node is assumed as graph output. + -> _output_shapes : Graph should be frozen with add_shapes=True. + Or user can pass input shape dictionaly optionally. + -> DecodeJpeg, ResizeBilinear: These are dummy operators. + Hence user should handle preprocessing outside. + -> CheckNumerics: No implementation as of now for this. + Just copies input to output. + + Parameters + ---------- + graph : tensorflow graph definition object + The loaded tensorflow GraphDef + + layout : target layout to be used (Optional) + NCHW only supported now to enable NHWC models on GPU. + + shape : Dictionary of input dimensions (Optional) + Graph level input shape dictionary. + + Returns + ------- + sym : relay.op + The returned relay operator + params : dict + A dict of name: tvm.nd.array pairs, used as pretrained weights + """ + + try: + from tensorflow.python.framework import tensor_util + except ImportError as e: + raise ImportError( + "Unable to import tensorflow which is required {}".format(e)) + + missing_operators = self._parse_import_prerequisites(graph) + + if missing_operators: + raise NotImplementedError( \ + "The following operators are not implemented: {}".format(missing_operators)) + + # Parse the nodes to re-create TF graph using Relay operators. + for node in graph.node: + # Tensorflow doesn't have seperate list for params extraction. + # Operator name 'Const' is treated as a parameter to build params dict. + + input_shapes = {} + attr = self._parse_attr(node.attr) + + #Variable converted to Const will not have only value attr + if 'value' in attr and node.op == 'Const': + tensor_value = attr['value'] + self._output_shapes[node.name] = \ + [tensor_util.TensorShapeProtoToList( \ + tensor_value.tensor_shape)] + elif '_output_shapes' in attr: + self._output_shapes[node.name] = \ + [tensor_util.TensorShapeProtoToList(tshape) \ + for tshape in attr['_output_shapes']] + elif shape: + # Keep the list indexable to avoid key error. + # Actual value will be filled after node creation. + self._output_shapes[node.name] = [None] + else: + raise NotImplementedError( \ + "Please freeze the graph with add_shapes=True") + + if node.op == "Placeholder": + self._output_shapes[node.name] = [shape[node.name]] + self._nodes[node.name] = [_expr.var(node.name, + shape=self._output_shapes[node.name][0], + dtype=attr['dtype'].name)] + + elif node.op == "Const": + # All Const nodes are Param nodes, lets parse + self._num_param += 1 + for key, value in node.attr.items(): + self._parse_param(key, value, node.name, shape) + if node.name not in self._nodes: + raise NotImplementedError( \ + "Const {} couldn't be converted to Param.".format(node.name)) + + attr = self._parse_attr(node.attr) + + else: + # Pass the parsed shapes instead + attr["_output_shapes"] = self._output_shapes[node.name] + + # Pass the node name too in attr + attr["_node_name"] = node.name + + # Pass the target layout + attr["_target_layout"] = layout + + #ToDo: Some of the tensorflow operators internaly maintain + #execution layers and its output name will the layer number along with + #graph node name.eg: Node name:- 'Model/RNN/cell_0/RnnCell', but the + #output name will be 'Model/RNN/cell_0/RnnCell:0'. In this case, + #the digit has to be ignored. + if ":" in node.input[0]: + in_name, _ = node.input[0].split(':') + node.input[0] = in_name + + # Fill shapes for all inputs in a list + inputs = [] + for i in node.input: + if i in self._nodes: + inputs.append(self._nodes[i][0]) + input_shapes[self._nodes[i][0]] = self._output_shapes[i] + attr['_input_shapes'] = input_shapes + + op = self._convert_operator(node.op, inputs, attr, graph) + + # Check is op is converted to param + if isinstance(op, np.ndarray): + self._params[node.name] = tvm.nd.array(op) + op = [_expr.var(node.name, + shape=self._params[node.name].shape, + dtype=self._params[node.name].dtype)] + + elif isinstance(op, (_expr.TupleWrapper, tuple, list)): + pass + elif isinstance(op, _expr.Expr): + op = [op] + else: + raise RuntimeError("unexpected type %s" % type(op)) + + self._nodes[node.name] = op + + # Infer shapes if passed explicitely + node_output = self._nodes[node.name] + out_type = ir_pass.infer_type(node_output[0]) + self._output_shapes[node.name] = [get_const_tuple(out_type.checked_type.shape)] + + + out = [] + if outputs is None: + out = op + else: + out = [self._nodes[out_name][0] for out_name in outputs] + + #Add the RNN outputs also with 'head' nodes of the relay graph + if self._num_rnn_layer: + if len(self._out_rnn) == 1: + out.append(self._out_rnn[0]) + else: + out_rnn = _op.concatenate(self._out_rnn, axis=0) + out.append(out_rnn) + + out = out[0] if len(out) == 1 else _expr.Tuple(out) + func = _expr.Function(ir_pass.free_vars(out), out) + + return func, self._params + + def _parse_import_prerequisites(self, graph): + """ Calculate the named preconditions from TensorFlow `graph`. + Return prerequisites for parsing: + a. Set of operator names which don't have their mapping in TVM, i.e. + which are not supported + """ + missing_operators = set() + for node in graph.node: + if node.op == "Placeholder": + pass + elif node.op == "Const": + pass + else: + if any([node.op in t for t in [_identity_list, _convert_map, _convert_map_rnn]]): + pass + else: + missing_operators.add(node.op) + + return missing_operators + + def _parse_param(self, key, value, name, shape): + try: + from tensorflow.python.framework import tensor_util + except ImportError as e: + raise ImportError( + "Unable to import tensorflow which is required {}".format(e)) + + if key == 'value': + np_array = tensor_util.MakeNdarray(value.tensor) + + if np_array.dtype == np.dtype(object): + # Object types are generally tensorflow DT_STRING (DecodeJpeg op). + # Just leave it as placeholder. + self._nodes[name] = [_expr.var(name, shape=shape[name], dtype='uint8')] + + return + + array_ndim = len(np_array.shape) + if array_ndim == 0: + new_array = np.empty([1], dtype=np_array.dtype) + new_array[0] = np_array + self._params[name] = tvm.nd.array(new_array) + else: + self._params[name] = tvm.nd.array(np_array) + + self._nodes[name] = [_expr.var(name, + shape=self._params[name].shape, + dtype=self._params[name].dtype)] + else: + if key != 'dtype' and key != '_output_shapes' and key != '_class': + raise NotImplementedError \ + ("Other attributes for a Const(param) Node {} ? .".format(key)) + + def _get_attr(self, buf): + """Returns the value of the attr of this buf with the given `name`. + + Args: + buf: attrvalue protobuf. + + Returns: + The value of the attr, as a Python object. + + Raises: + ValueError: If this op does not have an attr with the given `name`. + """ + fields = ["s", "i", "f", "b", "type", "shape", "tensor", "func"] + + x = buf + + ret = [] + + try: + from tensorflow.python.framework import dtypes + except ImportError as e: + raise ImportError( + "Unable to import tensorflow which is required {}".format(e)) + + # Treat an empty oneof value as an empty list. + if not x.WhichOneof("value"): + return ret + if x.HasField("list"): + for f in fields: + if getattr(x.list, f): + if f == "type": + ret += [dtypes.as_dtype(x) for x in list(getattr(x.list, f))] + else: + ret += list(getattr(x.list, f)) + else: + for f in fields: + if x.HasField(f): + if f == "type": + ret = dtypes.as_dtype(getattr(x, f)) + else: + ret = getattr(x, f) + return ret + + def _parse_attr(self, attr_proto): + """Convert a list of AttributeProto to a dict, with names as keys.""" + attrs = {} + for key, value in attr_proto.items(): + attrs[key] = self._get_attr(value) + + return attrs + + def _convert_rnn_operator(self, op_name, inputs, + attrs, params, graph, convert_map): + """Convert RNN and its variant operators to Relay operators. + This converter read the input states of each layers and + also maintain the output states of each layer in a list. + + Parameters + ---------- + op_name : str + Operator name, such as LSTMBlockCell + inputs : list of relay.Expr + List of input symbols. + attrs : dict + Dict of operator attributes + params : dict + List of pretrained weights and bias + graph : Tensorflow graph object + Graph is to find the number of upcoming same operator to + calculate the number of layers. + convert_map : dict + Dict of name : callable, where name is the op's name that + require conversion to relay, callable are functions which + take attrs and return (new_op_name, new_attrs) + + Returns + ------- + sym : relay.Expr + Converted relay.Expr + """ + if not self._num_rnn_layer: + self._out_rnn = [] + self.rnn = RecurrentNetworks(self._nodes, self._out_rnn, graph, convert_map) + self._num_rnn_layer = True + sym = self.rnn.process_op(op_name, inputs, attrs, params) + return sym + + def _convert_operator(self, op_name, inputs, attrs, + graph, identity_list=None, convert_map=None): + """Convert from Tensorflow operator to relay operator. + The converter must specify conversions explicity for incompatible name, and + apply handlers to operator attributes. + + Parameters + ---------- + op_name : str + Operator name, such as Conv2D, AvgPool + inputs : list of relay.op + List of input symbols. + attrs : dict + Dict of operator attributes + identity_list : list + List of operators that don't require conversion + convert_map : dict + Dict of name : callable, where name is the op's name that + require conversion to relay, callable are functions which + take attrs and return (new_op_name, new_attrs) + + Returns + ------- + sym : relay.op + Converted relay operator + """ + identity_list = identity_list if identity_list else _identity_list + convert_map = convert_map if convert_map else _convert_map + convert_map_rnn = _convert_map_rnn + if op_name in identity_list: + sym = _get_relay_op(op_name)(*inputs, **attrs) + elif op_name in convert_map: + sym = convert_map[op_name](inputs, attrs, self._params) + elif op_name in convert_map_rnn: + sym = self._convert_rnn_operator(op_name, inputs, attrs, + self._params, graph, + convert_map_rnn) + else: + raise NotImplementedError("Operator {} not implemented.".format(op_name)) + return sym + + +def from_tensorflow(graph, layout="NHWC", shape=None, outputs=None): + """ Load tensorflow graph which is a python tensorflow graph object into relay. + The companion parameters will be handled automatically. + + Parameters + ---------- + graph : GraphDef object + Tensorflow GraphDef + + Returns + ------- + sym : relay.op + Compatible relay operator + + params : dict of str to tvm.ndarray + Dict of converted parameters stored in tvm.ndarray format + """ + g = GraphProto() + sym, params = g.from_tensorflow(graph, layout, shape, outputs) + return sym, params diff --git a/nnvm/python/nnvm/testing/tf.py b/python/tvm/relay/testing/tf.py similarity index 100% rename from nnvm/python/nnvm/testing/tf.py rename to python/tvm/relay/testing/tf.py diff --git a/tests/python/frontend/tensorflow/test_forward.py b/tests/python/frontend/tensorflow/test_forward.py new file mode 100644 index 000000000000..0db6952d837d --- /dev/null +++ b/tests/python/frontend/tensorflow/test_forward.py @@ -0,0 +1,1113 @@ +# pylint: disable=import-self, invalid-name, unused-argument +""" +Tensorflow testcases +==================== +This article is a test script to test tensorflow operator with Relay. +""" +from __future__ import print_function +import numpy as np +import tvm +from tvm import relay +import tensorflow as tf +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import graph_util +from tensorflow.python.ops import nn_ops +from tensorflow.python.ops import nn +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gen_array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import variable_scope +from tensorflow.python.ops import variables +from tensorflow.python.ops import init_ops +from tensorflow.core.framework import graph_pb2 + +import tvm.relay.testing.tf as tf_testing + +####################################################################### +# Generic run functions for TVM & tensorflow +# ------------------------------------------ +def convert_to_list(x): + if not isinstance(x, list): + x = [x] + return x + +def run_tvm_graph(graph_def, input_data, input_node, num_output=1, target='llvm', out_names=None): + """ Generic function to compile on relay and execute on tvm """ + input_data = convert_to_list(input_data) + input_node = convert_to_list(input_node) + + layout = None + if target == "cuda": + layout = "NCHW" + target_host = 'llvm' + + if isinstance(input_data, list): + shape_dict = {} + dtype_dict = {} + for i, e in enumerate(input_node): + shape_dict[e] = input_data[i].shape + dtype_dict[e] = input_data[i].dtype + else: + shape_dict = {input_node: input_data.shape} + dtype_dict = {input_node: input_data.dtype} + + sym, params = relay.frontend.from_tensorflow(graph_def, + layout=layout, + shape=shape_dict, + outputs=out_names) + with relay.build_config(opt_level=3): + graph, lib, params = relay.build(sym, target, params=params) + + ctx = tvm.context(target, 0) + from tvm.contrib import graph_runtime + m = graph_runtime.create(graph, lib, ctx) + # set inputs + for i, e in enumerate(input_node): + m.set_input(e, tvm.nd.array(input_data[i].astype(input_data[i].dtype))) + + m.set_input(**params) + # execute + m.run() + # get outputs + assert out_names is None or num_output == len(out_names),"out_names: {} num_output: {}".format( + out_names, num_output) + tvm_output_list = [] + for i in range(0, num_output): + tvm_output = m.get_output(i) + tvm_output_list.append(tvm_output.asnumpy()) + return tvm_output_list + +def run_tf_graph(sess, input_data, input_node, output_node): + """ Generic function to execute tensorflow """ + input_data = convert_to_list(input_data) + input_node = convert_to_list(input_node) + output_node = convert_to_list(output_node) + + tensor = [0] * len(output_node) + for i in range(len(output_node)): + tensor[i] = sess.graph.get_tensor_by_name(output_node[i]) + + input_dict = {} + for i, e in enumerate(input_node): + input_dict[e] = input_data[i] + + output_data = sess.run(tensor, input_dict) + return output_data + + +def compare_tf_with_tvm(in_data, in_name, out_name, init_global_variables=False, no_gpu=False): + """Generic function to generate and compare tensorflow and TVM output""" + + out_name = convert_to_list(out_name) + out_node = [0]*len(out_name) + for i in range(len(out_name)): + out_node[i] = out_name[i].split(':')[0] if ":" in out_name[i] else out_name[i] + + in_data = convert_to_list(in_data) + in_name = convert_to_list(in_name) + in_node = [0]*len(in_name) + for i in range(len(in_name)): + in_node[i] = in_name[i].split(':')[0] if ":" in in_name[i] else in_name[i] + + with tf.Session() as sess: + if init_global_variables: + sess.run(variables.global_variables_initializer()) + final_graph_def = tf.graph_util.convert_variables_to_constants( + sess, + sess.graph.as_graph_def(add_shapes=True), + out_node, + ) + tf_output = run_tf_graph(sess, in_data, in_name, out_name) + + for device in ["llvm", "cuda"]: + ctx = tvm.context(device, 0) + if not ctx.exist: + print("Skip because %s is not enabled" % device) + continue + if no_gpu and device == 'cuda': + continue + + tvm_output = run_tvm_graph(final_graph_def, in_data, in_node, target=device) + # since the names from tensorflow and relay runs are not exactly same, + # first len(tf_output) will be compared + for i in range(len(tf_output)): + tvm.testing.assert_allclose(tf_output[i], tvm_output[i], atol=1e-5, rtol=1e-5) + + sess.close() + +def is_gpu_available(): + from tensorflow.python.client import device_lib + local_device_protos = device_lib.list_local_devices() + gpu_list = [x.name for x in local_device_protos if x.device_type == 'GPU'] + if len(gpu_list) > 0: + print("Tensorflow GPU:", gpu_list) + return True + else: + return False + +####################################################################### +# Pooling +# ------- +def _test_pooling_iteration(input_shape, **kwargs): + """ One iteration of pool operation with given shapes and attributes """ + + x = -np.arange( + np.prod(input_shape), dtype=np.float32).reshape(input_shape) - 1 + + with tf.Graph().as_default(): + in_data = array_ops.placeholder(shape=input_shape, dtype='float32') + nn_ops.pool(in_data, **kwargs) + + if kwargs['pooling_type'] == 'MAX': + out_name = 'max_pool:0' + else: + out_name = 'avg_pool:0' + + compare_tf_with_tvm(x, 'Placeholder:0', out_name) + +def _test_pooling(input_shape, **kwargs): + _test_pooling_iteration(input_shape, **kwargs) + + if is_gpu_available(): + input_shape = [input_shape[ii] for ii in (0, 3, 1, 2)] + kwargs['data_layout'] = 'NCHW' + _test_pooling_iteration(input_shape, **kwargs) + +def test_forward_pooling(): + """ Pooling """ + + for pool_type in ['AVG', 'MAX']: + _test_pooling(input_shape=[2, 9, 10, 2], + window_shape=[1, 1], + padding='SAME', + pooling_type=pool_type, + dilation_rate=[1, 1], + strides=[1, 1]) + + _test_pooling(input_shape=[2, 10, 9, 2], + window_shape=[1, 1], + padding='SAME', + pooling_type=pool_type, + dilation_rate=[1, 1], + strides=[1, 1]) + + _test_pooling(input_shape=[2, 9, 10, 2], + window_shape=[2, 1], + padding='SAME', + pooling_type=pool_type, + dilation_rate=[1, 1], + strides=[1, 1]) + + _test_pooling(input_shape=[2, 10, 9, 2], + window_shape=[2, 3], + padding='SAME', + pooling_type=pool_type, + dilation_rate=[1, 1], + strides=[2, 1]) + +####################################################################### +# Convolution +# ----------- + +def _test_convolution(tensor_in_sizes, filter_in_sizes, + dilations, strides, padding, data_format): + """ One iteration of convolution with given shapes and attributes """ + + total_size_1 = 1 + total_size_2 = 1 + for s in tensor_in_sizes: + total_size_1 *= s + for s in filter_in_sizes: + total_size_2 *= s + # Initializes the input tensor with array containing incrementing + # numbers from 1. + data_array = [f * 1.0 for f in range(1, total_size_1 + 1)] + filter_array = [f * 1.0 for f in range(1, total_size_2 + 1)] + + with tf.Graph().as_default(): + in_data = array_ops.placeholder(shape=tensor_in_sizes, dtype='float32') + in_filter = constant_op.constant(filter_array, shape=filter_in_sizes, dtype='float32') + strides = [1] + strides + [1] + dilations = [1] + dilations + [1] + + nn_ops.conv2d(in_data, + in_filter, + strides=strides, + padding=padding, + data_format=data_format) + + compare_tf_with_tvm(np.reshape(data_array, tensor_in_sizes).astype('float32'), + 'Placeholder:0', 'Conv2D:0') + +def test_forward_convolution(): + if is_gpu_available(): + _test_convolution([4, 176, 8, 8], [1, 1, 176, 32], [1, 1], [1, 1], 'SAME', 'NCHW') + _test_convolution([4, 19, 17, 17], [3, 3, 19, 19], [1, 1], [2, 2], 'VALID', 'NCHW') + _test_convolution([4, 124, 17, 17], [1, 1, 124, 19], [1, 1], [1, 1], 'SAME', 'NCHW') + _test_convolution([4, 12, 17, 17], [3, 3, 12, 32], [1, 1], [2, 2], 'VALID', 'NCHW') + + _test_convolution([4, 8, 8, 176], [1, 1, 176, 32], [1, 1], [1, 1], 'SAME', 'NHWC') + _test_convolution([4, 17, 17, 19], [3, 3, 19, 19], [1, 1], [2, 2], 'VALID', 'NHWC') + _test_convolution([4, 17, 17, 124], [1, 1, 124, 19], [1, 1], [1, 1], 'SAME', 'NHWC') + _test_convolution([4, 17, 17, 12], [3, 3, 12, 32], [1, 1], [2, 2], 'VALID', 'NHWC') + +####################################################################### +# Reshape +# ------- + +def _test_reshape(data, out_shape): + """ One iteration of reshape operation with given data and out shape """ + + with tf.Graph().as_default(): + in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype) + array_ops.reshape(in_data, out_shape) + + compare_tf_with_tvm(data, 'Placeholder:0', 'Reshape:0') + +def test_forward_reshape(): + _test_reshape(np.arange(6.0), [2, 3]) + _test_reshape(np.arange(6), [-1, 2]) + _test_reshape(np.arange(6), [3, -1]) + _test_reshape(np.arange(6), [-1]) + +####################################################################### +####################################################################### +# Squeeze +# ------- + +def _test_squeeze(data, squeeze_dims=None): + """ One iteration of squeeze """ + + if squeeze_dims is None: + squeeze_dims = [] + + with tf.Graph().as_default(): + in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype) + + if squeeze_dims: + array_ops.squeeze(in_data, squeeze_dims) + else: + array_ops.squeeze(in_data) + + compare_tf_with_tvm(data, 'Placeholder:0', 'Squeeze:0') + +def test_forward_squeeze(): + """ Squeeze """ + + # Nothing to squeeze. + _test_squeeze(np.arange(2).reshape((2))) + _test_squeeze(np.arange(6).reshape((2, 3))) + + # Squeeze the middle element away. + _test_squeeze(np.arange(4).reshape((2, 1, 2))) + + # Squeeze on both ends. + _test_squeeze(np.arange(6).reshape((1, 2, 1, 3, 1))) + + # Positive squeeze dim index. + _test_squeeze(np.arange(6).reshape((1, 2, 1, 3, 1)), [0]) + _test_squeeze(np.arange(6).reshape((1, 2, 1, 3, 1)), [2, 4]) + _test_squeeze(np.arange(6).reshape((1, 2, 1, 3, 1)), [0, 4, 2]) + + # Negative squeeze dim index. + _test_squeeze(np.arange(6).reshape((1, 2, 1, 3, 1)), [-1]) + _test_squeeze(np.arange(6).reshape((1, 2, 1, 3, 1)), [-3, -5]) + _test_squeeze(np.arange(6).reshape((1, 2, 1, 3, 1)), [-3, -5, -1]) + +####################################################################### +# ConcatV2 +# -------- + +def _test_concat_v2(data, dim): + """ One iteration of ConcatV2 """ + + with tf.Graph().as_default(): + gen_array_ops._concat_v2(data, dim) + + compare_tf_with_tvm(data, ['ConcatV2/values_0:0', 'ConcatV2/values_1:0'], + 'ConcatV2:0') + +def _test_forward_concat_v2(): + t1 = np.array([]) + t2 = np.array([]) + _test_concat_v2([t1, t2], 0) + + t1 = np.array([[1, 2, 3], [4, 5, 6]]) + t2 = np.array([[7, 8, 9], [10, 11, 12]]) + + _test_concat_v2([t1, t2], 1) + +####################################################################### +# Sigmoid +# ------- + +def _test_sigmoid(data): + """ One iteration of sigmoid """ + + with tf.Graph().as_default(): + in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype) + sigmoid_out = math_ops.sigmoid(in_data) + + compare_tf_with_tvm(data, 'Placeholder:0', 'Sigmoid:0') + +def test_forward_sigmoid(): + """ Sigmoid """ + + _test_sigmoid(np.random.uniform(size=(3, 4, 4, 3)).astype('float32')) + +####################################################################### +# Argmin/Argmax +# ------------- + +def _test_argx(func, data, **kwargs): + + with tf.Graph().as_default(): + inp = array_ops.placeholder(shape=data.shape, dtype=data.dtype, name="c0") + func(inp, name="argx0", output_type=tf.int32, **kwargs) + + compare_tf_with_tvm(data, 'c0:0', 'argx0:0') + +def test_forward_argminmax(): + for axis in [None,0,1,2]: + data = np.random.uniform(size=(8,4,9)).astype('float32') + _test_argx(tf.argmax, data=data, axis=axis) + _test_argx(tf.argmin, data=data, axis=axis) + +####################################################################### +# Reduce +# ------ + +def _test_reduce(func, data, **kwargs): + """ One iteration of a reduce operation""" + + with tf.Graph().as_default(): + inp = array_ops.placeholder(shape=data.shape, dtype=data.dtype, name="c0") + func(inp, name="reducex0", **kwargs) + + compare_tf_with_tvm(data, 'c0:0', 'reducex0:0') + +def test_forward_reduce(): + data = np.random.uniform(size=(8,4,9)).astype('float32') + _test_reduce(tf.reduce_sum, data=data) + _test_reduce(tf.reduce_sum, data=data, axis=0) + _test_reduce(tf.reduce_sum, data=data, axis=(0,1)) + + +####################################################################### +# Variable +# -------- + +def _test_variable(data): + """ One iteration of a variable """ + + tf.reset_default_graph() + input_op = array_ops.placeholder(shape=data.shape, dtype=data.dtype) + input_tensor = array_ops.reshape(input_op, data.shape) + + size = input_tensor.shape.dims[1] + with variable_scope.variable_scope("linear", reuse=None): + w = variable_scope.get_variable( + "w", shape=[size, size], dtype=input_tensor.dtype) + math_ops.matmul(input_tensor, w) + + compare_tf_with_tvm(data, 'Placeholder:0', 'MatMul:0', init_global_variables=True) + +def test_forward_variable(): + """Variable type op test""" + _test_variable(np.random.uniform(size=(32, 100)).astype('float32')) + + +####################################################################### +# StridedSlice +# ------------ + +def _test_stridedslice(ip_shape, begin, end, stride, dtype, + begin_mask=0, end_mask=0, new_axis_mask=0, + shrink_axis_mask=0, ellipsis_mask=0): + """ One iteration of a Stridedslice """ + + tf.reset_default_graph() + in_data = tf.placeholder(dtype, ip_shape, name="in_data") + tf.strided_slice(in_data, begin, end, stride, begin_mask=begin_mask, + end_mask=end_mask, new_axis_mask=new_axis_mask, + shrink_axis_mask=shrink_axis_mask, + ellipsis_mask=ellipsis_mask, name="strided_slice") + np_data = np.random.uniform(size=ip_shape).astype(dtype) + + compare_tf_with_tvm(np_data, 'in_data:0', 'strided_slice:0') + +def test_forward_stridedslice(): + '''test StridedSlice''' + + _test_stridedslice((3, 4, 3), [1, -1, 0], [4, -5, 3], [2, -1, 1], 'float32') + _test_stridedslice((3, 4, 3), [1, 0], [4, 3], [2, 1], 'float32', ellipsis_mask=8) + _test_stridedslice((3, 4, 3), [1, 0], [4, 2], [2, 1], 'float32', ellipsis_mask=2) + _test_stridedslice((3, 4, 5, 3), [1, 0], [4, 2], [2, 1], 'float32', ellipsis_mask=2) + _test_stridedslice((3, 4, 5, 3), [1, 0, 1], [4, 2, 2], [2, 1, 1], 'float32', ellipsis_mask=2) + _test_stridedslice((3, 4, 3), [1, 1, 0], [4, 4, 2], [2, 1, 1], 'float32', new_axis_mask=5) + _test_stridedslice((3, 4, 3), [1, 1, 1], [4, 4, 1], [2, 1, 1], 'float32', ellipsis_mask=2, new_axis_mask=4) + _test_stridedslice((6, 4, 5), [1, 1, 1], [6, 3, 4], [2, 1, 1], 'float32', ellipsis_mask=2, new_axis_mask=5) + _test_stridedslice((3, 4, 3), [1, 1, 2], [4, 4, 3], [2, 1, 1], 'float32', ellipsis_mask=4, new_axis_mask=2) + _test_stridedslice((3, 4, 3), [1, 1, 2], [4, 4, 3], [2, 1, 1], 'float32', ellipsis_mask=2, new_axis_mask=3) + _test_stridedslice((3, 4, 3), [1, 1, 0], [4, 4, 1], [2, 1, 1], 'float32', ellipsis_mask=2, new_axis_mask=3) + _test_stridedslice((3, 4, 3), [1, 1, 2], [4, 4, 3], [2, 1, 1], 'float32', ellipsis_mask=2, new_axis_mask=2) + _test_stridedslice((3,4), [1, 0], [4, 4], [1, 1], 'float32', shrink_axis_mask=2) + _test_stridedslice((3, 4, 3), [1, 1, 0], [4, 4, 3], [2, 1, 1], 'float32', shrink_axis_mask=2, new_axis_mask=2) + _test_stridedslice((3, 4, 3), [1, 1, 0], [4, 4, 3], [2, 1, 1], 'float32', shrink_axis_mask=1, new_axis_mask=2) + _test_stridedslice((3, 4, 3), [1, 1, 0], [4, 4, 3], [2, 1, 1], 'float32', shrink_axis_mask=2, new_axis_mask=1) + _test_stridedslice((3, 4, 5, 4, 5, 6), [0, 0], [2, 3], [1, 1], 'float32', shrink_axis_mask=5, new_axis_mask=1) + _test_stridedslice((3, 4, 5, 4, 5, 6), [0, 0, 1, 2, 1], [2, 3, 4, 5, 3], [1, 1, 2, 2, 1], + 'float32', shrink_axis_mask=5, new_axis_mask=1, ellipsis_mask=2, begin_mask=8, end_mask=8) + _test_stridedslice((3, 4, 5, 4, 5, 6), [0, 0, 1, 2, 1], [2, 3, 4, 5, 3], [1, 1, 2, 2, 1], + 'float32', shrink_axis_mask=8, new_axis_mask=1, ellipsis_mask=2, begin_mask=5, end_mask=5) + _test_stridedslice((3, 4, 5, 4, 5, 6), [0, 0, 1, 2, 1], [2, 3, 4, 5, 3], [1, 1, 2, 2, 1], + 'float32', shrink_axis_mask=16, new_axis_mask=1, ellipsis_mask=2, begin_mask=5, end_mask=5) + _test_stridedslice((3, 4, 5, 4, 5, 6), [1, 2, 0, -3], [4, 5, 3, 3], [2, 2, 1, 1], + 'float32', shrink_axis_mask=8, new_axis_mask=1, ellipsis_mask=2, begin_mask=5, + end_mask=8) + + +####################################################################### +# Gather +# ------ + +def _test_gather(ip_shape, indice_shape, indice_value, axis, dtype): + """ One iteration of a Gather """ + + tf.reset_default_graph() + in_data = tf.placeholder(dtype, ip_shape, name="in_data") + indices = tf.placeholder("int32", indice_shape, name="indices") + tf.gather(in_data, indices, axis=axis) + np_data = np.random.uniform(size=ip_shape).astype(dtype) + + def _fill_indices(indice_value): + indices = np.array(ip_shape, dtype=dtype) + if isinstance(indice_value, int): + indices = np.array([indice_value], dtype='int32') + else: + indices = np.asarray(indice_value, dtype='int32') + return indices + np_indices = _fill_indices(indice_value) + + compare_tf_with_tvm([np_data, np_indices], ['in_data:0', 'indices:0'], 'GatherV2:0') + +def test_forward_gather(): + '''test gather layer''' + _test_gather((4,), (1,), 1, 0, 'int32') + _test_gather((4,), (1,), 1, 0, 'float32') + _test_gather((1,4), (1,), [0], 0, 'int32') + _test_gather((4,), (1,2,2), [[[1,0],[0,1]]], 0, 'float32') + _test_gather((2,2), (1,2,2), [[[1,0],[0,1]]], 0, 'int32') + _test_gather((2,2), (1,2,2), [[[1,0],[0,1]]], 1, 'int32') + _test_gather((2,2), (1,2,2), [[[1,0],[0,1]]], 0, 'float32') + _test_gather((3,3,3), (1,1,2), [[[1,0]]], 0, 'int32') + _test_gather((3,3,3), (1,1,2), [[[1,0]]], 2, 'int32') + _test_gather((4,3,5,6), (1,4), [[2,1,0,0]], 0, 'float32') + + +####################################################################### +# Multi Input to graph +# -------------------- + +def test_forward_multi_input(): + with tf.Graph().as_default(): + in1 = tf.placeholder(tf.int32, shape=[3, 3], name='in1') + in2 = tf.placeholder(tf.int32, shape=[3, 3], name='in2') + in3 = tf.placeholder(tf.int32, shape=[3, 3], name='in3') + in4 = tf.placeholder(tf.int32, shape=[3, 3], name='in4') + + out1 = tf.add(in1, in2, name='out1') + out2 = tf.subtract(in3, in4, name='out2') + out = tf.multiply(out1, out2, name='out') + in_data = np.arange(9, dtype='int32').reshape([3, 3]) + + compare_tf_with_tvm([in_data, in_data, in_data, in_data], + ['in1:0', 'in2:0', 'in3:0', 'in4:0'], 'out:0') + +####################################################################### +# Multi Output to Graph +# --------------------- + +def test_forward_multi_output(): + with tf.Graph().as_default(): + in1 = tf.placeholder(tf.int32, shape=[3, 3], name='in1') + in2 = tf.placeholder(tf.int32, shape=[3, 3], name='in2') + in3 = tf.placeholder(tf.int32, shape=[3, 3], name='in3') + in4 = tf.placeholder(tf.int32, shape=[3, 3], name='in4') + + out1 = tf.add(in1, in2, name='out1') + out2 = tf.subtract(in3, in4, name='out2') + in_data = np.arange(9, dtype='int32').reshape([3, 3]) + in_data = [in_data] * 4 + in_name = ['in1:0', 'in2:0', 'in3:0', 'in4:0'] + out_name = ['out1:0', 'out2:0'] + out_node = [out.strip(':0') for out in out_name] + in_node = [inp.strip(':0') for inp in in_name] + + with tf.Session() as sess: + final_graph_def = tf.graph_util.convert_variables_to_constants( + sess, sess.graph.as_graph_def(add_shapes=True), out_node,) + tf_output = run_tf_graph(sess, in_data, in_name, out_name) + tvm_output = run_tvm_graph(final_graph_def, in_data, in_node, target='llvm', + out_names=out_node, num_output=2) + for i in range(len(tf_output)): + tvm.testing.assert_allclose(tf_output[i], tvm_output[i], atol=1e-5, rtol=1e-5) + +####################################################################### +# Resize Bilinear +# --------------- + +def _test_resize_bilinear(in_shape, to_shape, align_corners): + """ One iteration of resize bilinear """ + + data = np.random.uniform(size=in_shape).astype('float32') + shape_data = np.array(to_shape).astype('int32') + + with tf.Graph().as_default(): + in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype) + shape_data = constant_op.constant(shape_data, shape=shape_data.shape, dtype=shape_data.dtype) + tf.image.resize_bilinear(in_data, shape_data, align_corners=align_corners) + + compare_tf_with_tvm(data, 'Placeholder:0', 'ResizeBilinear:0') + +def test_forward_resize_bilinear(): + """ Resize Bilinear """ + + _test_resize_bilinear((4, 16, 32, 32), [50, 50], False) + _test_resize_bilinear((6, 32, 64, 64), [20, 20], True) + + +####################################################################### +# LSTM +# ---- + +def _test_lstm_cell(batch_size, num_hidden, num_layers, forget_bias, dtype): + """ One iteration of a LSTM cell """ + + tf.reset_default_graph() + input_size = num_hidden + input_data = np.full((batch_size, input_size), 1., dtype=dtype) + in_state_c = np.full((num_layers, batch_size, num_hidden), 0.1, dtype=dtype) + in_state_h = np.full((num_layers, batch_size, num_hidden), 0.1, dtype=dtype) + + def _get_tensorflow_output(): + with tf.Session() as sess: + with variable_scope.variable_scope( + "root", initializer=init_ops.constant_initializer(0.5)): + m0 = array_ops.zeros([batch_size, num_hidden]) + m1 = array_ops.zeros([batch_size, num_hidden]) + x=tf.placeholder(shape=(batch_size, input_size), dtype=dtype) + g, ((out_m0, out_m1)) = \ + tf.contrib.rnn.LSTMBlockCell(num_hidden, + forget_bias=forget_bias)(x, ((m0, m1))) + sess.run([variables.global_variables_initializer()]) + res = sess.run([g, out_m0, out_m1], { + x.name: np.array([[1., 1.]]), + m0.name: 0.1 * np.ones([batch_size, num_hidden]), + m1.name: 0.1 * np.ones([batch_size, num_hidden]), + }) + graph_def = sess.graph.as_graph_def(add_shapes=True) + final_graph_def = graph_util.convert_variables_to_constants( + sess, + graph_def, + ['root/lstm_cell/LSTMBlockCell']) + return final_graph_def, res + + graph_def, tf_out = _get_tensorflow_output() + tvm_output = run_tvm_graph(graph_def, [input_data, in_state_c, in_state_h], + ['root/Placeholder', 'root/lstm_cell/LSTMBlockCell_c', + 'root/lstm_cell/LSTMBlockCell_h'], num_output=2) + assert isinstance(tvm_output, list) + + out = tvm_output[0] + out_state = tvm_output[1] + out_state_tup = np.split(out_state, indices_or_sections=2, axis=1) + out_state_c = np.reshape(out_state_tup[0], (batch_size, num_hidden)) + out_state_h = np.reshape(out_state_tup[1], (batch_size, num_hidden)) + tvm_out = [out, out_state_c, out_state_h] + tvm.testing.assert_allclose(tf_out[0], tvm_out[0], rtol=1e-3, atol=1e-3) + +def test_forward_lstm(): + '''test LSTM block cell''' + _test_lstm_cell(1, 2, 1, 0.0, 'float32') + + + +####################################################################### +# Pack +# --- +def _test_pack(axis, shape, **kwargs): + + a = np.arange(np.prod(shape), dtype=np.float32).reshape(shape) + b = np.arange(np.prod(shape), dtype=np.float32).reshape(shape) + + with tf.Graph().as_default(): + tf_a = array_ops.placeholder(shape=shape, dtype='float32', name='pl_a') + tf_b = array_ops.placeholder(shape=shape, dtype='float32', name='pl_b') + tf_c = tf.stack([tf_a,tf_b], axis=axis, **kwargs) + assert tf_c.op.op_def.name == 'Pack', "tf.stack() is expected to produce 'Pack' operation" + + compare_tf_with_tvm([a,b], ['pl_a:0','pl_b:0'], 'stack:0') + +def test_forward_pack(): + for axis in range(-3,3): + _test_pack(axis, [3,2,1]) + for axis in range(-1,1): + _test_pack(axis, [3]) + _test_pack(0, []) + +####################################################################### +# Pad +# --- +def _test_pad(input_shape, paddings, mode, **kwargs): + """ One iteration of pad operation with given shape""" + + x = np.arange(np.prod(input_shape), dtype=np.float32).reshape(input_shape) + + with tf.Graph().as_default(): + in_data = array_ops.placeholder(shape=input_shape, dtype='float32') + pad_values = constant_op.constant(paddings) + pad = tf.pad(in_data, paddings=pad_values, mode=mode, **kwargs) + + if mode == 'CONSTANT': + if 'constant_values' in kwargs: + out_name = 'PadV2:0' + else: + out_name = 'Pad:0' + + compare_tf_with_tvm(x, 'Placeholder:0', out_name) + +def test_forward_pad(): + """ Pad """ + _test_pad((2, 3), [[1,1], [2,2]], mode="CONSTANT") + _test_pad((2, 3), [[1,1], [2,2]], mode="CONSTANT", constant_values=1.0) + + +####################################################################### +# Inception V3 +# ------------ +def test_forward_inception_v3(): + '''test inception V3 model''' + with tf.Graph().as_default(): + graph_def = tf_testing.get_workload('InceptionV3/inception_v3_2016_08_28_frozen-with_shapes.pb') + # Call the utility to import the graph definition into default graph. + graph_def = tf_testing.ProcessGraphDefParam(graph_def) + + data = np.random.uniform(size=(1, 299, 299, 3)).astype('float32') + + with tf.Session() as sess: + tf_output = run_tf_graph(sess, data, 'input:0', 'InceptionV3/Predictions/Reshape_1:0') + tvm_output = run_tvm_graph(graph_def, data, 'input') + tvm.testing.assert_allclose(tf_output[0], tvm_output[0], rtol=1e-5, atol=1e-5) + +####################################################################### +# Inception V1 +# ------------ +def test_forward_inception_v1(): + '''test inception V1 model''' + with tf.Graph().as_default(): + graph_def = tf_testing.get_workload("InceptionV1/classify_image_graph_def-with_shapes.pb") + # Call the utility to import the graph definition into default graph. + graph_def = tf_testing.ProcessGraphDefParam(graph_def) + + # Build an image from random data. + from PIL import Image + from tvm.contrib import util + + img_array = np.random.uniform(size=(1, 600, 600, 3)).astype("uint8") + img = Image.frombuffer('RGB', (600, 600), img_array.tostring(), 'raw', 'RGB', 0, 1) + temp = util.tempdir() + img_path = temp.relpath("tf-test.jpg") + img.save(img_path); + + import os.path + if not tf.gfile.Exists(os.path.join(img_path)): + tf.logging.fatal('File does not exist %s', img_path) + data = tf.gfile.FastGFile(os.path.join(img_path), 'rb').read() + + temp.remove() + + # Extract tensorflow decoded image frame for tvm input + with tf.Session() as sess: + tvm_data = run_tf_graph(sess, data, 'DecodeJpeg/contents:0', 'DecodeJpeg:0') + + with tf.Session() as sess: + tf_output = run_tf_graph(sess, data, 'DecodeJpeg/contents:0', 'softmax:0') + tvm_output = run_tvm_graph(graph_def, tvm_data, 'DecodeJpeg/contents') + tvm.testing.assert_allclose(tf_output[0], tvm_output[0], rtol=1e-5, atol=1e-5) + +####################################################################### +# Mobilenet +# --------- +def test_forward_mobilenet(): + '''test mobilenet model''' + # MobilenetV2 + with tf.Graph().as_default(): + graph_def = tf_testing.get_workload( + "https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_1.4_224.tgz", + "mobilenet_v2_1.4_224_frozen.pb") + # Call the utility to import the graph definition into default graph. + graph_def = tf_testing.ProcessGraphDefParam(graph_def) + + data = np.random.uniform(size=(1, 224, 224, 3)).astype('float32') + out_node = 'MobilenetV2/Predictions/Reshape_1' + + with tf.Session() as sess: + # Add shapes to the graph. + graph_def = tf_testing.AddShapesToGraphDef(sess, out_node) + tf_output = run_tf_graph(sess, data, 'input:0', out_node + ':0') + tvm_output = run_tvm_graph(graph_def, data, 'input') + tvm.testing.assert_allclose(np.squeeze(tvm_output[0]), np.squeeze(tf_output[0]), rtol=1e-5, atol=1e-5) + +####################################################################### +# ResnetV2 +# --------- +def test_forward_resnetv2(): + '''test resnet model''' + if is_gpu_available(): + with tf.Graph().as_default(): + graph_def = tf_testing.get_workload("ResnetV2/resnet-20180601_resnet_v2_imagenet-shapes.pb") + # Call the utility to import the graph definition into default graph. + graph_def = tf_testing.ProcessGraphDefParam(graph_def) + + data = np.random.uniform(size=(128, 224, 224, 3)).astype('float32') + out_node = 'ArgMax' + + with tf.Session() as sess: + tf_output = run_tf_graph(sess, data, 'input_tensor:0', out_node + ':0') + tvm_output = run_tvm_graph(graph_def, data, 'input_tensor', tf_output.shape, 'float32') + tvm.testing.assert_allclose(np.squeeze(tvm_output[0]), np.squeeze(tf_output[0]), rtol=1e-5, atol=1e-5) + +####################################################################### +# PTB +# --- +dir(tf.contrib) +def test_forward_ptb(): + '''test ptb model''' + config = tf_testing.get_config() + num_steps = config.num_steps + num_hidden = config.hidden_size + num_layers = config.num_layers + batch_size = config.batch_size + vocab_size = config.vocab_size + out_sample_shape = (batch_size, vocab_size) + out_state_shape = (num_layers, 2, batch_size, num_hidden) + #Sample input + inpt = "we have no useful information on" + cnt_sample = 20 + + def _pretty_print(items, is_char_model, id2word): + if not is_char_model: + return ' '.join([id2word[x] for x in items]) + else: + return ''.join([id2word[x] for x in items]).replace('_', ' ') + + def _get_tvm_graph_module(graph_def): + #Cell inputs 'c and 'h' consist of all layers values + shape_dict = {'Model/Placeholder': (batch_size, num_steps), + 'Model/RNN/RNN/multi_rnn_cell/cell_0/lstm_cell/LSTMBlockCell_c':(num_layers, batch_size, num_hidden), + 'Model/RNN/RNN/multi_rnn_cell/cell_0/lstm_cell/LSTMBlockCell_h':(num_layers, batch_size, num_hidden)} + + sym, params = relay.frontend.from_tensorflow(graph_def, shape=shape_dict) + + dtype_dict = {'Model/Placeholder': 'int32', + 'Model/RNN/RNN/multi_rnn_cell/cell_0/lstm_cell/LSTMBlockCell_c':'float32', + 'Model/RNN/RNN/multi_rnn_cell/cell_0/lstm_cell/LSTMBlockCell_h':'float32'} + target = 'llvm' + with relay.build_config(opt_level=0): + graph, lib, params = relay.build(sym, target, params=params) + from tvm.contrib import graph_runtime + ctx = tvm.cpu(0) + return params, graph_runtime.create(graph, lib, ctx) + + def _do_tvm_sample(model, data, in_states, params, num_samples): + """Sampled from the model""" + samples = [] + state = in_states + sample = None + def _get_sample(data, state): + input_data = np.full((batch_size, num_steps), data, dtype="int32") + in_state_tup = np.split(state, indices_or_sections=2, axis=1) + in_state_c = np.reshape(in_state_tup[0], (num_layers, batch_size, num_hidden)) + in_state_h = np.reshape(in_state_tup[1], (num_layers, batch_size, num_hidden)) + + model.set_input('Model/Placeholder', tvm.nd.array(input_data.astype("int32"))) + model.set_input('Model/RNN/RNN/multi_rnn_cell/cell_0/lstm_cell/LSTMBlockCell_c', + tvm.nd.array(in_state_c.astype("float32"))) + model.set_input('Model/RNN/RNN/multi_rnn_cell/cell_0/lstm_cell/LSTMBlockCell_h', + tvm.nd.array(in_state_h.astype("float32"))) + model.set_input(**params) + model.run() + tvm_output = model.get_output(0, tvm.nd.empty(out_sample_shape, + "float32")).asnumpy() + state_output = model.get_output(1, tvm.nd.empty(out_state_shape, + "float32")).asnumpy() + sample = tf_testing.pick_from_weight(tvm_output[0]) + + return sample, state_output + + for x in data: + sample, state = _get_sample(x, state) + + if sample is not None: + samples.append(sample) + else: + samples.append(0) + + k = 1 + while k < num_samples: + sample, state = _get_sample(samples[-1], state) + samples.append(sample) + k += 1 + return samples, state + + with tf.Graph().as_default(): + word_to_id, id_to_word, graph_def = tf_testing.get_workload_ptb() + vocab_size = len(word_to_id) + # Call the utility to import the graph definition into default graph. + graph_def = tf_testing.ProcessGraphDefParam(graph_def) + sess = tf.Session() + + #TVM graph module creation + params, m = _get_tvm_graph_module(graph_def) + + # Create 10 predicted statments of 20 words + cnt_stm = 0 + while cnt_stm < 10: + cnt_stm += 1 + in_state = np.full((num_layers, 2, batch_size, num_hidden), 0, dtype="float32") + seed_for_sample = inpt.split() + tvm_samples, tvm_state = _do_tvm_sample(m, [word_to_id[word] \ + for word in seed_for_sample], + in_state, params, cnt_sample) + tvm_sample_str = _pretty_print(tvm_samples, False, id_to_word) + tf_samples, tf_state = tf_testing.do_tf_sample(sess, + [word_to_id[word] for word in seed_for_sample], + in_state, cnt_sample) + tf_sample_str = _pretty_print(tf_samples, False, id_to_word) + inpt = tvm_sample_str + tvm.testing.assert_allclose(tf_samples, tvm_samples, rtol=1e-5, atol=1e-5) + assert(tvm_sample_str == tf_sample_str) + +####################################################################### +# LRN (Local Response Normalization) +# ---------------------------------- + +def _test_lrn(ishape, size, axis, bias, alpha, beta): + """ testing local response normalization """ + lrn_depth_radius = size / 2 + + inp_array = np.random.uniform(size=ishape).astype(np.float32) + + with tf.Graph().as_default(): + in1 = tf.placeholder(shape=inp_array.shape, dtype=inp_array.dtype, name="lrn0_data") + nn_ops.local_response_normalization(in1, + name="lrn", + depth_radius=lrn_depth_radius, + bias=bias, + alpha=alpha, + beta=beta) + + compare_tf_with_tvm(inp_array, 'lrn0_data:0', 'lrn:0') + +def test_forward_lrn(): + _test_lrn((1, 3, 20, 20), 3, 1, 1.0, 1.0, 0.5) + +####################################################################### +# l2_normalize +# ------------ + +def _test_l2_normalize(ishape, eps, axis): + """ testing l2 normalize (uses max, sum, square, sqrt frontend operators)""" + + inp_array = np.random.uniform(size=ishape).astype(np.float32) + + with tf.Graph().as_default(): + in1 = tf.placeholder(shape=inp_array.shape, dtype=inp_array.dtype) + nn.l2_normalize(in1, + axis=axis, + epsilon=eps, + name=None, + dim=None) + + compare_tf_with_tvm(inp_array, 'Placeholder:0', 'l2_normalize:0') + +def test_forward_l2_normalize(): + _test_l2_normalize((1, 3, 20, 20), 0.001, (0,)) + +####################################################################### +# transpose +# --------- +def _test_forward_transpose(ishape, axes=None): + data = np.random.uniform(size=ishape).astype(np.float32) + + with tf.Graph().as_default(): + in1 = tf.placeholder(shape=data.shape, dtype=data.dtype, name="transpose_data") + + if axes is None: + tf.transpose(in1) + else: + tf.transpose(in1, perm=axes) + + compare_tf_with_tvm(data, 'transpose_data:0', 'transpose:0') + +def test_forward_transpose(): + _test_forward_transpose((2, 3, 4), (1, 2, 0)) + _test_forward_transpose((2, 3, 4)) + _test_forward_transpose((7, 8, 8, 10)) + _test_forward_transpose((2, 3, 4), (1, 2, 0)) + _test_forward_transpose((2, 3, 4), (0, 1, 2)) + _test_forward_transpose((2, 3, 4, 5), (3, 0, 1, 2)) + + +def test_forward_ceil(): + ishape = (1, 3, 10, 10) + inp_array = np.random.uniform(size=ishape).astype(np.float32) + with tf.Graph().as_default(): + in1 = tf.placeholder(shape=inp_array.shape, dtype=inp_array.dtype) + tf.ceil(in1) + compare_tf_with_tvm(inp_array, 'Placeholder:0', 'Ceil:0') + +def test_forward_floor(): + ishape = (1, 3, 10, 10) + inp_array = np.random.uniform(size=ishape).astype(np.float32) + with tf.Graph().as_default(): + in1 = tf.placeholder(shape=inp_array.shape, dtype=inp_array.dtype) + tf.floor(in1) + compare_tf_with_tvm(inp_array, 'Placeholder:0', 'Floor:0') + +def test_forward_relu(): + ishape = (1, 3, 10, 10) + inp_array = np.random.uniform(-5, 5, size=ishape).astype(np.float32) + with tf.Graph().as_default(): + in1 = tf.placeholder(shape=inp_array.shape, dtype=inp_array.dtype) + tf.nn.relu(in1) + compare_tf_with_tvm(inp_array, 'Placeholder:0', 'Relu:0') + +def test_forward_leaky_relu(): + ishape = (1, 3, 10, 10) + inp_array = np.random.uniform(-5, 5, size=ishape).astype(np.float32) + with tf.Graph().as_default(): + in1 = tf.placeholder(shape=inp_array.shape, dtype=inp_array.dtype) + tf.nn.leaky_relu(in1, alpha=0.4) + compare_tf_with_tvm(inp_array, 'Placeholder:0', 'LeakyRelu/mul:0') + +def test_forward_elu(): + ishape = (1, 3, 10, 10) + inp_array = np.random.uniform(-5, 5, size=ishape).astype(np.float32) + with tf.Graph().as_default(): + in1 = tf.placeholder(shape=inp_array.shape, dtype=inp_array.dtype) + tf.nn.elu(in1) + compare_tf_with_tvm(inp_array, 'Placeholder:0', 'Elu:0') + +def test_forward_selu(): + ishape = (1, 3, 10, 10) + inp_array = np.random.uniform(-5, 5, size=ishape).astype(np.float32) + with tf.Graph().as_default(): + in1 = tf.placeholder(shape=inp_array.shape, dtype=inp_array.dtype) + tf.nn.selu(in1) + compare_tf_with_tvm(inp_array, 'Placeholder:0', 'Selu:0') + +def test_forward_tanh(): + ishape = (1, 3, 10, 10) + inp_array = np.random.uniform(-5, 5, size=ishape).astype(np.float32) + with tf.Graph().as_default(): + in1 = tf.placeholder(shape=inp_array.shape, dtype=inp_array.dtype) + tf.nn.tanh(in1) + compare_tf_with_tvm(inp_array, 'Placeholder:0', 'Tanh:0') + +####################################################################### +# Mean +# ---- +def test_forward_mean(): + def check_mean(ishape, **kwargs): + inp_array = np.random.uniform(size=ishape).astype(np.float32) + with tf.Graph().as_default(): + in1 = tf.placeholder(shape=inp_array.shape, dtype=inp_array.dtype) + tf.keras.backend.mean(in1, **kwargs) + compare_tf_with_tvm(inp_array, 'Placeholder:0', 'Mean:0', no_gpu=True) + + check_mean((10, 8, 16, 32)) + check_mean((10, 8, 16, 32), axis=(2,3)) + check_mean((10, 8, 16, 32), axis=(1,2), keepdims=True) + +####################################################################### +# Relational operators +# -------------------- +def _test_forward_rel_op(data, func): + with tf.Graph().as_default(): + in1 = tf.placeholder(shape=data[0].shape, dtype=data[0].dtype, name='in1') + in2 = tf.placeholder(shape=data[1].shape, dtype=data[1].dtype, name='in2') + op = func(in1, in2, name='op') + out = tf.cast(op, tf.int32, name='out1') + compare_tf_with_tvm([data[0], data[1]], ['in1:0', 'in2:0'], 'out1:0') + +def test_forward_rel_ops(): + t1 = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) + t2 = np.array([[9, 8, 7], [6, 5, 4], [3, 2, 1]]) + _test_forward_rel_op([t1, t2], math_ops.less) + _test_forward_rel_op([t1, t2], math_ops.greater) + _test_forward_rel_op([t1, t2], math_ops.less_equal) + _test_forward_rel_op([t1, t2], math_ops.greater_equal) + _test_forward_rel_op([t1, t2], math_ops.equal) + _test_forward_rel_op([t1, t2], math_ops.not_equal) + + +####################################################################### +# Main +# ---- +if __name__ == '__main__': + # Transforms + test_forward_transpose() + test_forward_reshape() + test_forward_squeeze() + test_forward_pack() + test_forward_resize_bilinear() + test_forward_pad() + test_forward_gather() + test_forward_stridedslice() + + # Activations + test_forward_sigmoid() + test_forward_relu() + test_forward_leaky_relu() + test_forward_elu() + test_forward_selu() + test_forward_tanh() + + # Reductions + test_forward_argminmax() + test_forward_reduce() + test_forward_mean() + + # General + test_forward_multi_input() + test_forward_multi_output() + test_forward_variable() + + # NN + test_forward_convolution() + test_forward_pooling() + if tf.__version__ == '1.4.1': + _test_forward_concat_v2() + test_forward_lrn() + test_forward_l2_normalize() + + # End to End + test_forward_inception_v3() + test_forward_inception_v1() + test_forward_mobilenet() + test_forward_resnetv2() + test_forward_ptb() + + # RNN + test_forward_lstm() + + # Elementwise + test_forward_ceil() + test_forward_floor() + + # Relational ops + test_forward_rel_ops() + + diff --git a/tests/python/frontend/tflite/test_forward.py b/tests/python/frontend/tflite/test_forward.py index a929d4e33905..3c048435fba8 100644 --- a/tests/python/frontend/tflite/test_forward.py +++ b/tests/python/frontend/tflite/test_forward.py @@ -16,7 +16,7 @@ from tensorflow.python.ops import variables from tensorflow.contrib.lite.python import interpreter as interpreter_wrapper -import nnvm.testing.tf +import tvm.relay.testing.tf as tf_testing ####################################################################### # Generic run functions for TVM & TFLite @@ -344,7 +344,7 @@ def test_forward_mobilenet(): '''test mobilenet v1 tflite model''' # MobilenetV1 temp = util.tempdir() - tflite_model_file = nnvm.testing.tf.get_workload_official( + tflite_model_file = tf_testing.get_workload_official( "http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_224.tgz", "mobilenet_v1_1.0_224.tflite", temp) tflite_model_buf = open(tflite_model_file, "rb").read() diff --git a/tests/scripts/task_python_frontend.sh b/tests/scripts/task_python_frontend.sh index 880c35ee42e0..b4802da1c42a 100755 --- a/tests/scripts/task_python_frontend.sh +++ b/tests/scripts/task_python_frontend.sh @@ -42,6 +42,9 @@ python3 -m nose -v tests/python/frontend/onnx || exit -1 echo "Running relay CoreML frondend test..." python3 -m nose -v tests/python/frontend/coreml || exit -1 +echo "Running relay Tensorflow frontend test..." +python3 -m nose -v tests/python/frontend/tensorflow || exit -1 + echo "Running nnvm to relay frontend test..." python3 -m nose -v tests/python/frontend/nnvm_to_relay || exit -1 @@ -50,4 +53,3 @@ python3 -m nose -v tests/python/frontend/tflite || exit -1 echo "Running relay caffe2 frondend test..." python3 -m nose -v tests/python/frontend/caffe2 || exit -1 - diff --git a/tutorials/frontend/from_tensorflow.py b/tutorials/frontend/from_tensorflow.py new file mode 100644 index 000000000000..1f76db890ade --- /dev/null +++ b/tutorials/frontend/from_tensorflow.py @@ -0,0 +1,214 @@ +""" +Compile Tensorflow Models +========================= +This article is an introductory tutorial to deploy tensorflow models with TVM. + +For us to begin with, tensorflow python module is required to be installed. + +Please refer to https://www.tensorflow.org/install +""" + +# tvm, relay +import tvm +from tvm import relay + +# os and numpy +import numpy as np +import os.path + +# Tensorflow imports +import tensorflow as tf + +# Tensorflow utility functions +import tvm.relay.testing.tf as tf_testing + +# Base location for model related files. +repo_base = 'https://github.com/dmlc/web-data/raw/master/tensorflow/models/InceptionV1/' + +# Test image +img_name = 'elephant-299.jpg' +image_url = os.path.join(repo_base, img_name) + +###################################################################### +# Tutorials +# --------- +# Please refer docs/frontend/tensorflow.md for more details for various models +# from tensorflow. + +model_name = 'classify_image_graph_def-with_shapes.pb' +model_url = os.path.join(repo_base, model_name) + +# Image label map +map_proto = 'imagenet_2012_challenge_label_map_proto.pbtxt' +map_proto_url = os.path.join(repo_base, map_proto) + +# Human readable text for labels +lable_map = 'imagenet_synset_to_human_label_map.txt' +lable_map_url = os.path.join(repo_base, lable_map) + +# Target settings +# Use these commented settings to build for cuda. +#target = 'cuda' +#target_host = 'llvm' +#layout = "NCHW" +#ctx = tvm.gpu(0) +target = 'llvm' +target_host = 'llvm' +layout = None +ctx = tvm.cpu(0) + +###################################################################### +# Download required files +# ----------------------- +# Download files listed above. +from mxnet.gluon.utils import download + +download(image_url, img_name) +download(model_url, model_name) +download(map_proto_url, map_proto) +download(lable_map_url, lable_map) + +###################################################################### +# Import model +# ------------ +# Creates tensorflow graph definition from protobuf file. + +with tf.gfile.FastGFile(os.path.join("./", model_name), 'rb') as f: + graph_def = tf.GraphDef() + graph_def.ParseFromString(f.read()) + graph = tf.import_graph_def(graph_def, name='') + # Call the utility to import the graph definition into default graph. + graph_def = tf_testing.ProcessGraphDefParam(graph_def) + # Add shapes to the graph. + with tf.Session() as sess: + graph_def = tf_testing.AddShapesToGraphDef(sess, 'softmax') + +###################################################################### +# Decode image +# ------------ +# .. note:: +# +# tensorflow frontend import doesn't support preprocessing ops like JpegDecode. +# JpegDecode is bypassed (just return source node). +# Hence we supply decoded frame to TVM instead. +# + +from PIL import Image +image = Image.open(img_name).resize((299, 299)) + +x = np.array(image) + +###################################################################### +# Import the graph to Relay +# ------------------------- +# Import tensorflow graph definition to relay frontend. +# +# Results: +# sym: relay expr for given tensorflow protobuf. +# params: params converted from tensorflow params (tensor protobuf). +shape_dict = {'DecodeJpeg/contents': x.shape} +dtype_dict = {'DecodeJpeg/contents': 'uint8'} +sym, params = relay.frontend.from_tensorflow(graph_def, layout=layout, shape=shape_dict) + +print ("Tensorflow protobuf imported to relay frontend.") +###################################################################### +# Relay Build +# ----------- +# Compile the graph to llvm target with given input specification. +# +# Results: +# graph: Final graph after compilation. +# params: final params after compilation. +# lib: target library which can be deployed on target with tvm runtime. + +with relay.build_config(opt_level=3): + graph, lib, params = relay.build(sym, target=target, target_host=target_host, params=params) + +###################################################################### +# Execute the portable graph on TVM +# --------------------------------- +# Now we can try deploying the compiled model on target. + +from tvm.contrib import graph_runtime +dtype = 'uint8' +m = graph_runtime.create(graph, lib, ctx) +# set inputs +m.set_input('DecodeJpeg/contents', tvm.nd.array(x.astype(dtype))) +m.set_input(**params) +# execute +m.run() +# get outputs +tvm_output = m.get_output(0, tvm.nd.empty(((1, 1008)), 'float32')) + +###################################################################### +# Process the output +# ------------------ +# Process the model output to human readable text for InceptionV1. +predictions = tvm_output.asnumpy() +predictions = np.squeeze(predictions) + +# Creates node ID --> English string lookup. +node_lookup = tf_testing.NodeLookup(label_lookup_path=os.path.join("./", map_proto), + uid_lookup_path=os.path.join("./", lable_map)) + +# Print top 5 predictions from TVM output. +top_k = predictions.argsort()[-5:][::-1] +for node_id in top_k: + human_string = node_lookup.id_to_string(node_id) + score = predictions[node_id] + print('%s (score = %.5f)' % (human_string, score)) + +###################################################################### +# Inference on tensorflow +# ----------------------- +# Run the corresponding model on tensorflow + +def create_graph(): + """Creates a graph from saved GraphDef file and returns a saver.""" + # Creates graph from saved graph_def.pb. + with tf.gfile.FastGFile(model_name, 'rb') as f: + graph_def = tf.GraphDef() + graph_def.ParseFromString(f.read()) + graph = tf.import_graph_def(graph_def, name='') + # Call the utility to import the graph definition into default graph. + graph_def = tf_testing.ProcessGraphDefParam(graph_def) + +def run_inference_on_image(image): + """Runs inference on an image. + + Parameters + ---------- + image: String + Image file name. + + Returns + ------- + Nothing + """ + if not tf.gfile.Exists(image): + tf.logging.fatal('File does not exist %s', image) + image_data = tf.gfile.FastGFile(image, 'rb').read() + + # Creates graph from saved GraphDef. + create_graph() + + with tf.Session() as sess: + softmax_tensor = sess.graph.get_tensor_by_name('softmax:0') + predictions = sess.run(softmax_tensor, + {'DecodeJpeg/contents:0': image_data}) + + predictions = np.squeeze(predictions) + + # Creates node ID --> English string lookup. + node_lookup = tf_testing.NodeLookup(label_lookup_path=os.path.join("./", map_proto), + uid_lookup_path=os.path.join("./", lable_map)) + + # Print top 5 predictions from tensorflow. + top_k = predictions.argsort()[-5:][::-1] + print ("===== TENSORFLOW RESULTS =======") + for node_id in top_k: + human_string = node_lookup.id_to_string(node_id) + score = predictions[node_id] + print('%s (score = %.5f)' % (human_string, score)) + +run_inference_on_image (img_name) diff --git a/tutorials/nnvm/from_tensorflow.py b/tutorials/nnvm/from_tensorflow.py index 92c287e4ade7..ac632b122e76 100644 --- a/tutorials/nnvm/from_tensorflow.py +++ b/tutorials/nnvm/from_tensorflow.py @@ -23,7 +23,7 @@ from tensorflow.python.framework import tensor_util # Tensorflow utility functions -import nnvm.testing.tf +import tvm.relay.testing.tf as tf_testing # Base location for model related files. repo_base = 'https://github.com/dmlc/web-data/raw/master/tensorflow/models/InceptionV1/' @@ -87,10 +87,10 @@ graph_def.ParseFromString(f.read()) graph = tf.import_graph_def(graph_def, name='') # Call the utility to import the graph definition into default graph. - graph_def = nnvm.testing.tf.ProcessGraphDefParam(graph_def) + graph_def = tf_testing.ProcessGraphDefParam(graph_def) # Add shapes to the graph. with tf.Session() as sess: - graph_def = nnvm.testing.tf.AddShapesToGraphDef(sess, 'softmax') + graph_def = tf_testing.AddShapesToGraphDef(sess, 'softmax') ###################################################################### # Decode image @@ -157,7 +157,7 @@ predictions = np.squeeze(predictions) # Creates node ID --> English string lookup. -node_lookup = nnvm.testing.tf.NodeLookup(label_lookup_path=os.path.join("./", map_proto), +node_lookup = tf_testing.NodeLookup(label_lookup_path=os.path.join("./", map_proto), uid_lookup_path=os.path.join("./", lable_map)) # Print top 5 predictions from TVM output. @@ -180,7 +180,7 @@ def create_graph(): graph_def.ParseFromString(f.read()) graph = tf.import_graph_def(graph_def, name='') # Call the utility to import the graph definition into default graph. - graph_def = nnvm.testing.tf.ProcessGraphDefParam(graph_def) + graph_def = tf_testing.ProcessGraphDefParam(graph_def) def run_inference_on_image(image): """Runs inference on an image. @@ -209,7 +209,7 @@ def run_inference_on_image(image): predictions = np.squeeze(predictions) # Creates node ID --> English string lookup. - node_lookup = nnvm.testing.tf.NodeLookup(label_lookup_path=os.path.join("./", map_proto), + node_lookup = tf_testing.NodeLookup(label_lookup_path=os.path.join("./", map_proto), uid_lookup_path=os.path.join("./", lable_map)) # Print top 5 predictions from tensorflow.