diff --git a/example/multi-task/README.md b/example/multi-task/README.md index 9034814c3b50..b7756fe378a7 100644 --- a/example/multi-task/README.md +++ b/example/multi-task/README.md @@ -1,10 +1,13 @@ # Mulit-task learning example -This is a simple example to show how to use mxnet for multi-task learning. It uses MNIST as an example and mocks up the multi-label task. +This is a simple example to show how to use mxnet for multi-task learning. It uses MNIST as an example, trying to predict jointly the digit and whether this digit is odd or even. -## Usage -First, you need to write a multi-task iterator on your own. The iterator needs to generate multiple labels according to your applications, and the label names should be specified in the `provide_label` function, which needs to be consist with the names of output layers. +For example: -Then, if you want to show metrics of different tasks separately, you need to write your own metric class and specify the `num` parameter. In the `update` function of metric, calculate the metrics separately for different tasks. +![](https://camo.githubusercontent.com/ed3cf256f47713335dc288f32f9b0b60bf1028b7/68747470733a2f2f7777772e636c61737365732e63732e756368696361676f2e6564752f617263686976652f323031332f737072696e672f31323330302d312f70612f7061312f64696769742e706e67) -The example script uses gpu as device by default, if gpu is not available for your environment, you can change `device` to be `mx.cpu()`. +Should be jointly classified as 4, and Even. + +In this example we don't expect the tasks to contribute to each other much, but for example multi-task learning has been successfully applied to the domain of image captioning. In [A Multi-task Learning Approach for Image Captioning](https://www.ijcai.org/proceedings/2018/0168.pdf) by Wei Zhao, Benyou Wang, Jianbo Ye, Min Yang, Zhou Zhao, Ruotian Luo, Yu Qiao, they train a network to jointly classify images and generate text captions + +Please refer to the notebook for a fully worked example. diff --git a/example/multi-task/example_multi_task.py b/example/multi-task/example_multi_task.py deleted file mode 100644 index 9e898494a14b..000000000000 --- a/example/multi-task/example_multi_task.py +++ /dev/null @@ -1,159 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -# pylint: skip-file -import mxnet as mx -from mxnet.test_utils import get_mnist_iterator -import numpy as np -import logging -import time - -logging.basicConfig(level=logging.DEBUG) - -def build_network(): - data = mx.symbol.Variable('data') - fc1 = mx.symbol.FullyConnected(data = data, name='fc1', num_hidden=128) - act1 = mx.symbol.Activation(data = fc1, name='relu1', act_type="relu") - fc2 = mx.symbol.FullyConnected(data = act1, name = 'fc2', num_hidden = 64) - act2 = mx.symbol.Activation(data = fc2, name='relu2', act_type="relu") - fc3 = mx.symbol.FullyConnected(data = act2, name='fc3', num_hidden=10) - sm1 = mx.symbol.SoftmaxOutput(data = fc3, name = 'softmax1') - sm2 = mx.symbol.SoftmaxOutput(data = fc3, name = 'softmax2') - - softmax = mx.symbol.Group([sm1, sm2]) - - return softmax - -class Multi_mnist_iterator(mx.io.DataIter): - '''multi label mnist iterator''' - - def __init__(self, data_iter): - super(Multi_mnist_iterator, self).__init__() - self.data_iter = data_iter - self.batch_size = self.data_iter.batch_size - - @property - def provide_data(self): - return self.data_iter.provide_data - - @property - def provide_label(self): - provide_label = self.data_iter.provide_label[0] - # Different labels should be used here for actual application - return [('softmax1_label', provide_label[1]), \ - ('softmax2_label', provide_label[1])] - - def hard_reset(self): - self.data_iter.hard_reset() - - def reset(self): - self.data_iter.reset() - - def next(self): - batch = self.data_iter.next() - label = batch.label[0] - - return mx.io.DataBatch(data=batch.data, label=[label, label], \ - pad=batch.pad, index=batch.index) - -class Multi_Accuracy(mx.metric.EvalMetric): - """Calculate accuracies of multi label""" - - def __init__(self, num=None): - self.num = num - super(Multi_Accuracy, self).__init__('multi-accuracy') - - def reset(self): - """Resets the internal evaluation result to initial state.""" - self.num_inst = 0 if self.num is None else [0] * self.num - self.sum_metric = 0.0 if self.num is None else [0.0] * self.num - - def update(self, labels, preds): - mx.metric.check_label_shapes(labels, preds) - - if self.num is not None: - assert len(labels) == self.num - - for i in range(len(labels)): - pred_label = mx.nd.argmax_channel(preds[i]).asnumpy().astype('int32') - label = labels[i].asnumpy().astype('int32') - - mx.metric.check_label_shapes(label, pred_label) - - if self.num is None: - self.sum_metric += (pred_label.flat == label.flat).sum() - self.num_inst += len(pred_label.flat) - else: - self.sum_metric[i] += (pred_label.flat == label.flat).sum() - self.num_inst[i] += len(pred_label.flat) - - def get(self): - """Gets the current evaluation result. - - Returns - ------- - names : list of str - Name of the metrics. - values : list of float - Value of the evaluations. - """ - if self.num is None: - return super(Multi_Accuracy, self).get() - else: - return zip(*(('%s-task%d'%(self.name, i), float('nan') if self.num_inst[i] == 0 - else self.sum_metric[i] / self.num_inst[i]) - for i in range(self.num))) - - def get_name_value(self): - """Returns zipped name and value pairs. - - Returns - ------- - list of tuples - A (name, value) tuple list. - """ - if self.num is None: - return super(Multi_Accuracy, self).get_name_value() - name, value = self.get() - return list(zip(name, value)) - - -batch_size=100 -num_epochs=100 -device = mx.gpu(0) -lr = 0.01 - -network = build_network() -train, val = get_mnist_iterator(batch_size=batch_size, input_shape = (784,)) -train = Multi_mnist_iterator(train) -val = Multi_mnist_iterator(val) - - -model = mx.mod.Module( - context = device, - symbol = network, - label_names = ('softmax1_label', 'softmax2_label')) - -model.fit( - train_data = train, - eval_data = val, - eval_metric = Multi_Accuracy(num=2), - num_epoch = num_epochs, - optimizer_params = (('learning_rate', lr), ('momentum', 0.9), ('wd', 0.00001)), - initializer = mx.init.Xavier(factor_type="in", magnitude=2.34), - batch_end_callback = mx.callback.Speedometer(batch_size, 50)) - diff --git a/example/multi-task/multi-task-learning.ipynb b/example/multi-task/multi-task-learning.ipynb new file mode 100644 index 000000000000..6e03e2b61f8c --- /dev/null +++ b/example/multi-task/multi-task-learning.ipynb @@ -0,0 +1,454 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Multi-Task Learning Example" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This is a simple example to show how to use mxnet for multi-task learning.\n", + "\n", + "The network is jointly going to learn whether a number is odd or even and to actually recognize the digit.\n", + "\n", + "\n", + "For example\n", + "\n", + "- 1 : 1 and odd\n", + "- 2 : 2 and even\n", + "- 3 : 3 and odd\n", + "\n", + "etc\n", + "\n", + "In this example we don't expect the tasks to contribute to each other much, but for example multi-task learning has been successfully applied to the domain of image captioning. In [A Multi-task Learning Approach for Image Captioning](https://www.ijcai.org/proceedings/2018/0168.pdf) by Wei Zhao, Benyou Wang, Jianbo Ye, Min Yang, Zhou Zhao, Ruotian Luo, Yu Qiao, they train a network to jointly classify images and generate text captions" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [], + "source": [ + "import logging\n", + "import random\n", + "import time\n", + "\n", + "import matplotlib.pyplot as plt\n", + "import mxnet as mx\n", + "from mxnet import gluon, nd, autograd\n", + "import numpy as np" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Parameters" + ] + }, + { + "cell_type": "code", + "execution_count": 99, + "metadata": {}, + "outputs": [], + "source": [ + "batch_size = 128\n", + "epochs = 5\n", + "ctx = mx.gpu() if len(mx.test_utils.list_gpus()) > 0 else mx.cpu()\n", + "lr = 0.01" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Data\n", + "\n", + "We get the traditionnal MNIST dataset and add a new label to the existing one. For each digit we return a new label that stands for Odd or Even" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "![](https://upload.wikimedia.org/wikipedia/commons/2/27/MnistExamples.png)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "train_dataset = gluon.data.vision.MNIST(train=True)\n", + "test_dataset = gluon.data.vision.MNIST(train=False)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "def transform(x,y):\n", + " x = x.transpose((2,0,1)).astype('float32')/255.\n", + " y1 = y\n", + " y2 = y % 2 #odd or even\n", + " return x, np.float32(y1), np.float32(y2)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We assign the transform to the original dataset" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "train_dataset_t = train_dataset.transform(transform)\n", + "test_dataset_t = test_dataset.transform(transform)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We load the datasets DataLoaders" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "train_data = gluon.data.DataLoader(train_dataset_t, shuffle=True, last_batch='rollover', batch_size=batch_size, num_workers=5)\n", + "test_data = gluon.data.DataLoader(test_dataset_t, shuffle=False, last_batch='rollover', batch_size=batch_size, num_workers=5)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Input shape: (28, 28, 1), Target Labels: (5.0, 1.0)\n" + ] + } + ], + "source": [ + "print(\"Input shape: {}, Target Labels: {}\".format(train_dataset[0][0].shape, train_dataset_t[0][1:]))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Multi-task Network\n", + "\n", + "The output of the featurization is passed to two different outputs layers" + ] + }, + { + "cell_type": "code", + "execution_count": 135, + "metadata": {}, + "outputs": [], + "source": [ + "class MultiTaskNetwork(gluon.HybridBlock):\n", + " \n", + " def __init__(self):\n", + " super(MultiTaskNetwork, self).__init__()\n", + " \n", + " self.shared = gluon.nn.HybridSequential()\n", + " with self.shared.name_scope():\n", + " self.shared.add(\n", + " gluon.nn.Dense(128, activation='relu'),\n", + " gluon.nn.Dense(64, activation='relu'),\n", + " gluon.nn.Dense(10, activation='relu')\n", + " )\n", + " self.output1 = gluon.nn.Dense(10) # Digist recognition\n", + " self.output2 = gluon.nn.Dense(1) # odd or even\n", + "\n", + " \n", + " def hybrid_forward(self, F, x):\n", + " y = self.shared(x)\n", + " output1 = self.output1(y)\n", + " output2 = self.output2(y)\n", + " return output1, output2" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can use two different losses, one for each output" + ] + }, + { + "cell_type": "code", + "execution_count": 136, + "metadata": {}, + "outputs": [], + "source": [ + "loss_digits = gluon.loss.SoftmaxCELoss()\n", + "loss_odd_even = gluon.loss.SigmoidBCELoss()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We create and initialize the network" + ] + }, + { + "cell_type": "code", + "execution_count": 137, + "metadata": {}, + "outputs": [], + "source": [ + "mx.random.seed(42)\n", + "random.seed(42)" + ] + }, + { + "cell_type": "code", + "execution_count": 138, + "metadata": {}, + "outputs": [], + "source": [ + "net = MultiTaskNetwork()" + ] + }, + { + "cell_type": "code", + "execution_count": 139, + "metadata": {}, + "outputs": [], + "source": [ + "net.initialize(mx.init.Xavier(), ctx=ctx)\n", + "net.hybridize() # hybridize for speed" + ] + }, + { + "cell_type": "code", + "execution_count": 140, + "metadata": {}, + "outputs": [], + "source": [ + "trainer = gluon.Trainer(net.collect_params(), 'adam', {'learning_rate':lr})" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Evaluate Accuracy\n", + "We need to evaluate the accuracy of each task separately" + ] + }, + { + "cell_type": "code", + "execution_count": 141, + "metadata": {}, + "outputs": [], + "source": [ + "def evaluate_accuracy(net, data_iterator):\n", + " acc_digits = mx.metric.Accuracy(name='digits')\n", + " acc_odd_even = mx.metric.Accuracy(name='odd_even')\n", + " \n", + " for i, (data, label_digit, label_odd_even) in enumerate(data_iterator):\n", + " data = data.as_in_context(ctx)\n", + " label_digit = label_digit.as_in_context(ctx)\n", + " label_odd_even = label_odd_even.as_in_context(ctx).reshape(-1,1)\n", + "\n", + " output_digit, output_odd_even = net(data)\n", + " \n", + " acc_digits.update(label_digit, output_digit.softmax())\n", + " acc_odd_even.update(label_odd_even, output_odd_even.sigmoid() > 0.5)\n", + " return acc_digits.get(), acc_odd_even.get()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Training Loop" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We need to balance the contribution of each loss to the overall training and do so by tuning this alpha parameter within [0,1]." + ] + }, + { + "cell_type": "code", + "execution_count": 142, + "metadata": {}, + "outputs": [], + "source": [ + "alpha = 0.5 # Combine losses factor" + ] + }, + { + "cell_type": "code", + "execution_count": 143, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch [0], Acc Digits 0.8945 Loss Digits 0.3409\n", + "Epoch [0], Acc Odd/Even 0.9561 Loss Odd/Even 0.1152\n", + "Epoch [0], Testing Accuracies (('digits', 0.9487179487179487), ('odd_even', 0.9770633012820513))\n", + "Epoch [1], Acc Digits 0.9576 Loss Digits 0.1475\n", + "Epoch [1], Acc Odd/Even 0.9804 Loss Odd/Even 0.0559\n", + "Epoch [1], Testing Accuracies (('digits', 0.9642427884615384), ('odd_even', 0.9826722756410257))\n", + "Epoch [2], Acc Digits 0.9681 Loss Digits 0.1124\n", + "Epoch [2], Acc Odd/Even 0.9852 Loss Odd/Even 0.0418\n", + "Epoch [2], Testing Accuracies (('digits', 0.9580328525641025), ('odd_even', 0.9846754807692307))\n", + "Epoch [3], Acc Digits 0.9734 Loss Digits 0.0961\n", + "Epoch [3], Acc Odd/Even 0.9884 Loss Odd/Even 0.0340\n", + "Epoch [3], Testing Accuracies (('digits', 0.9670472756410257), ('odd_even', 0.9839743589743589))\n", + "Epoch [4], Acc Digits 0.9762 Loss Digits 0.0848\n", + "Epoch [4], Acc Odd/Even 0.9894 Loss Odd/Even 0.0310\n", + "Epoch [4], Testing Accuracies (('digits', 0.9652887658227848), ('odd_even', 0.9858583860759493))\n" + ] + } + ], + "source": [ + "for e in range(epochs):\n", + " # Accuracies for each task\n", + " acc_digits = mx.metric.Accuracy(name='digits')\n", + " acc_odd_even = mx.metric.Accuracy(name='odd_even')\n", + " # Accumulative losses\n", + " l_digits_ = 0.\n", + " l_odd_even_ = 0. \n", + " \n", + " for i, (data, label_digit, label_odd_even) in enumerate(train_data):\n", + " data = data.as_in_context(ctx)\n", + " label_digit = label_digit.as_in_context(ctx)\n", + " label_odd_even = label_odd_even.as_in_context(ctx).reshape(-1,1)\n", + " \n", + " with autograd.record():\n", + " output_digit, output_odd_even = net(data)\n", + " l_digits = loss_digits(output_digit, label_digit)\n", + " l_odd_even = loss_odd_even(output_odd_even, label_odd_even)\n", + "\n", + " # Combine the loss of each task\n", + " l_combined = (1-alpha)*l_digits + alpha*l_odd_even\n", + " \n", + " l_combined.backward()\n", + " trainer.step(data.shape[0])\n", + " \n", + " l_digits_ += l_digits.mean()\n", + " l_odd_even_ += l_odd_even.mean()\n", + " acc_digits.update(label_digit, output_digit.softmax())\n", + " acc_odd_even.update(label_odd_even, output_odd_even.sigmoid() > 0.5)\n", + " \n", + " print(\"Epoch [{}], Acc Digits {:.4f} Loss Digits {:.4f}\".format(\n", + " e, acc_digits.get()[1], l_digits_.asscalar()/(i+1)))\n", + " print(\"Epoch [{}], Acc Odd/Even {:.4f} Loss Odd/Even {:.4f}\".format(\n", + " e, acc_odd_even.get()[1], l_odd_even_.asscalar()/(i+1)))\n", + " print(\"Epoch [{}], Testing Accuracies {}\".format(e, evaluate_accuracy(net, test_data)))\n", + " " + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Testing" + ] + }, + { + "cell_type": "code", + "execution_count": 144, + "metadata": {}, + "outputs": [], + "source": [ + "def get_random_data():\n", + " idx = random.randint(0, len(test_dataset))\n", + "\n", + " img = test_dataset[idx][0]\n", + " data, _, _ = test_dataset_t[idx]\n", + " data = data.as_in_context(ctx).expand_dims(axis=0)\n", + "\n", + " plt.imshow(img.squeeze().asnumpy(), cmap='gray')\n", + " \n", + " return data" + ] + }, + { + "cell_type": "code", + "execution_count": 152, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Predicted digit: [9.], odd: [1.]\n" + ] + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAP8AAAD8CAYAAAC4nHJkAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAADeVJREFUeJzt3X+MFPX9x/HXG6QGAQ3aiBdLpd9Ga6pBak5joqk01caaRuAfUhMbjE2viTUpEVFCNT31Dxu1rdWYJldLCk2/QhUb+KPWWuKP1jQNIKiotFJC00OEkjNBEiNyvPvHzdlTbz6zzs7uzPF+PpLL7e57Z+ad5V7M7H5m9mPuLgDxTKq7AQD1IPxAUIQfCIrwA0ERfiAowg8ERfiBoAg/EBThB4I6oZsbMzNOJwQ6zN2tlee1tec3s6vM7O9mtsvMVrSzLgDdZWXP7TezyZL+IelKSYOSNku61t1fSyzDnh/osG7s+S+WtMvdd7v7EUlrJS1oY30Auqid8J8p6d9j7g9mj32ImfWZ2RYz29LGtgBUrOMf+Ln7gKQBicN+oEna2fPvlTR7zP3PZI8BmADaCf9mSWeb2efM7FOSvilpYzVtAei00of97n7UzG6S9JSkyZJWufurlXUGoKNKD/WV2hjv+YGO68pJPgAmLsIPBEX4gaAIPxAU4QeCIvxAUIQfCIrwA0ERfiAowg8ERfiBoAg/EBThB4Ii/EBQhB8IivADQRF+ICjCDwRF+IGgCD8QFOEHgiL8QFCEHwiK8ANBEX4gKMIPBEX4gaAIPxAU4QeCKj1FtySZ2R5J70galnTU3XuraApA57UV/sxX3P1gBesB0EUc9gNBtRt+l/RHM9tqZn1VNASgO9o97L/M3fea2emSnjazne7+/NgnZP8p8B8D0DDm7tWsyKxf0mF3vz/xnGo2BiCXu1srzyt92G9m08xsxuhtSV+TtKPs+gB0VzuH/bMk/c7MRtfz/+7+h0q6AtBxlR32t7QxDvuBjuv4YT+AiY3wA0ERfiAowg8ERfiBoAg/EFQVV/WhwaZPn56sL1++vK3lb7755mT97bffzq3deeedyWUffvjhZP3o0aPJOtLY8wNBEX4gKMIPBEX4gaAIPxAU4QeCIvxAUFzSOwFMnTo1WV+xYkVurWgcftq0acl69n0NuTr591M0zr9s2bJk/ciRI1W2M2FwSS+AJMIPBEX4gaAIPxAU4QeCIvxAUIQfCIpx/i4oGqe//PLLk/Vbb701WZ8/f/4nballQ0NDbdWnTJmSWzvrrLNK9TTqySefTNafe+653NoDDzyQXHYinyPAOD+AJMIPBEX4gaAIPxAU4QeCIvxAUIQfCKpwnN/MVkn6hqQD7n5+9tipktZJmiNpj6TF7p7/Be3/W9dxOc5/0kknJesPPvhgsn7DDTdU2c6H7NixI1m/5557kvVt27Yl6zt37kzWZ8yYkVt76qmnkstecsklyXo7zjnnnGR9165dHdt2p1U5zv8rSVd95LEVkja5+9mSNmX3AUwgheF39+clffQ0rgWSVme3V0taWHFfADqs7Hv+We6+L7v9lqRZFfUDoEvanqvP3T31Xt7M+iT1tbsdANUqu+ffb2Y9kpT9PpD3RHcfcPded+8tuS0AHVA2/BslLcluL5G0oZp2AHRLYfjN7FFJf5X0BTMbNLNvS/qRpCvN7A1JV2T3AUwghe/53f3anNJXK+5lwrriiiuS9XbH8Q8ePJisr1u3Lrd2yy23JJd97733SvXUqp6entq2jTTO8AOCIvxAUIQfCIrwA0ERfiAowg8E1fbpvVGkprJevnx5R7f9yCOPJOsrV67s2LZPOCH9J7Jo0aJk/aGHHsqtnX766aV6atUzzzyTW9u7d29Htz0RsOcHgiL8QFCEHwiK8ANBEX4gKMIPBEX4gaAY52/RHXfckVu79NJL21p30Tj+3Xff3db6U84999xkfenSpcl6X19zv6Ht3nvvza29++67XeykmdjzA0ERfiAowg8ERfiBoAg/EBThB4Ii/EBQjPO3qJPXnq9ZsyZZLxqTTk03XTROv3jx4mT9tNNOS9aLpnjvpNR3BUjSs88+251GJij2/EBQhB8IivADQRF+ICjCDwRF+IGgCD8QVOE4v5mtkvQNSQfc/fzssX5J35H0n+xpK939951qsgk2b96cW7v++uvbWveGDRuS9SNHjiTrU6dOza2dfPLJpXoa9f777yfr1113XbKemlNg7ty5pXoa9dhjjyXrTAGe1sqe/1eSrhrn8Z+6+7zs57gOPnA8Kgy/uz8vaagLvQDoonbe899kZi+b2Sozm1lZRwC6omz4fy7p85LmSdon6cd5TzSzPjPbYmZbSm4LQAeUCr+773f3YXc/JukXki5OPHfA3XvdvbdskwCqVyr8ZtYz5u4iSTuqaQdAt7Qy1PeopPmSPm1mg5J+KGm+mc2T5JL2SPpuB3sE0AHWzeuxzay+i7/bNGlS/kHS448/nlx24cKFVbdTmRdeeCFZv+uuu5L1ovMIisbiU4p6mz9/frI+PDxcetsTmbtbK8/jDD8gKMIPBEX4gaAIPxAU4QeCIvxAUHx1d4uOHTuWW7vxxhuTy+7fvz9ZL7osdufOncn6E088kVsr+nrrw4cPJ+snnnhisl40HGeWP+qUek0ladOmTcl61KG8qrDnB4Ii/EBQhB8IivADQRF+ICjCDwRF+IGguKQXSWeccUay/uabb5Ze9/bt25P1Cy+8sPS6I+OSXgBJhB8IivADQRF+ICjCDwRF+IGgCD8QFNfzI6m/v7+t5VNTfK9du7atdaM97PmBoAg/EBThB4Ii/EBQhB8IivADQRF+IKjC6/nNbLakNZJmSXJJA+7+MzM7VdI6SXMk7ZG02N3fLlgX1/M3zKJFi5L11JwAklT093Pffffl1m677bbksiinyuv5j0pa5u5flHSJpO+Z2RclrZC0yd3PlrQpuw9ggigMv7vvc/cXs9vvSHpd0pmSFkhanT1ttaSFnWoSQPU+0Xt+M5sj6UuS/iZplrvvy0pvaeRtAYAJouVz+81suqT1kpa6+6Gxc7C5u+e9nzezPkl97TYKoFot7fnNbIpGgv8bdx/9BGi/mfVk9R5JB8Zb1t0H3L3X3XuraBhANQrDbyO7+F9Ket3dfzKmtFHSkuz2Ekkbqm8PQKe0MtR3maQ/S3pF0uicyis18r7/t5I+K+lfGhnqGypYF0N9DfPSSy8l63Pnzk3Wh4aS/+S64IILcmuDg4PJZVFOq0N9he/53f0vkvJW9tVP0hSA5uAMPyAowg8ERfiBoAg/EBThB4Ii/EBQfHX3ca7ostnzzjsvWR8eHk7Wb7/99mSdsfzmYs8PBEX4gaAIPxAU4QeCIvxAUIQfCIrwA0EVXs9f6ca4nr8j5syZk1vbtm1bctlTTjklWd+6dWuyftFFFyXr6L4qv7obwHGI8ANBEX4gKMIPBEX4gaAIPxAU4QeC4nr+48DSpUtza0Xj+EX6+/vbWh7NxZ4fCIrwA0ERfiAowg8ERfiBoAg/EBThB4IqvJ7fzGZLWiNpliSXNODuPzOzfknfkfSf7Kkr3f33Beviev4SrrnmmmR9/fr1ubXJkye3te1Jk9g/TDStXs/fykk+RyUtc/cXzWyGpK1m9nRW+6m731+2SQD1KQy/u++TtC+7/Y6ZvS7pzE43BqCzPtExnZnNkfQlSX/LHrrJzF42s1VmNjNnmT4z22JmW9rqFEClWg6/mU2XtF7SUnc/JOnnkj4vaZ5Gjgx+PN5y7j7g7r3u3ltBvwAq0lL4zWyKRoL/G3d/QpLcfb+7D7v7MUm/kHRx59oEULXC8JuZSfqlpNfd/SdjHu8Z87RFknZU3x6ATmnl0/5LJX1L0itmtj17bKWka81snkaG//ZI+m5HOoR2796drB86dCi3NnPmuB/FfOD++xmsiaqVT/v/Imm8ccPkmD6AZuMMDiAowg8ERfiBoAg/EBThB4Ii/EBQTNENHGeYohtAEuEHgiL8QFCEHwiK8ANBEX4gKMIPBNXtKboPSvrXmPufzh5roqb21tS+JHorq8rezmr1iV09yedjGzfb0tTv9mtqb03tS6K3surqjcN+ICjCDwRVd/gHat5+SlN7a2pfEr2VVUtvtb7nB1Cfuvf8AGpSS/jN7Coz+7uZ7TKzFXX0kMfM9pjZK2a2ve4pxrJp0A6Y2Y4xj51qZk+b2RvZ7/R3c3e3t34z25u9dtvN7OqaepttZs+Y2Wtm9qqZfT97vNbXLtFXLa9b1w/7zWyypH9IulLSoKTNkq5199e62kgOM9sjqdfdax8TNrMvSzosaY27n589dq+kIXf/UfYf50x3v60hvfVLOlz3zM3ZhDI9Y2eWlrRQ0vWq8bVL9LVYNbxudez5L5a0y913u/sRSWslLaihj8Zz9+clDX3k4QWSVme3V2vkj6frcnprBHff5+4vZrffkTQ6s3Str12ir1rUEf4zJf17zP1BNWvKb5f0RzPbamZ9dTczjlnZtOmS9JakWXU2M47CmZu76SMzSzfmtSsz43XV+MDv4y5z9wslfV3S97LD20bykfdsTRquaWnm5m4ZZ2bpD9T52pWd8bpqdYR/r6TZY+5/JnusEdx9b/b7gKTfqXmzD+8fnSQ1+32g5n4+0KSZm8ebWVoNeO2aNON1HeHfLOlsM/ucmX1K0jclbayhj48xs2nZBzEys2mSvqbmzT68UdKS7PYSSRtq7OVDmjJzc97M0qr5tWvcjNfu3vUfSVdr5BP/f0r6QR095PT1f5Jeyn5erbs3SY9q5DDwfY18NvJtSadJ2iTpDUl/knRqg3r7taRXJL2skaD11NTbZRo5pH9Z0vbs5+q6X7tEX7W8bpzhBwTFB35AUIQfCIrwA0ERfiAowg8ERfiBoAg/EBThB4L6L4bahh5ke9v1AAAAAElFTkSuQmCC\n", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "data = get_random_data()\n", + "\n", + "digit, odd_even = net(data)\n", + "\n", + "digit = digit.argmax(axis=1)[0].asnumpy()\n", + "odd_even = (odd_even.sigmoid()[0] > 0.5).asnumpy()\n", + "\n", + "print(\"Predicted digit: {}, odd: {}\".format(digit, odd_even))" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.6.4" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +}