diff --git a/docs/_book.yaml b/docs/_book.yaml
index 2738d272e45..4afefa3cf55 100644
--- a/docs/_book.yaml
+++ b/docs/_book.yaml
@@ -39,6 +39,8 @@ upper_tabs:
path: /tensorboard/r2/hyperparameter_tuning_with_hparams
- title: "What-If Tool"
path: /tensorboard/r2/what_if_tool
+ - title: "Profiling Tool"
+ path: /tensorboard/r2/tensorboard_profiling_keras
- title: "TensorBoard in notebooks"
path: /tensorboard/r2/tensorboard_in_notebooks
diff --git a/docs/r2/images/profiler-blocking-runtime.png b/docs/r2/images/profiler-blocking-runtime.png
new file mode 100644
index 00000000000..04c3c0dc395
Binary files /dev/null and b/docs/r2/images/profiler-blocking-runtime.png differ
diff --git a/docs/r2/images/profiler-capture.png b/docs/r2/images/profiler-capture.png
new file mode 100644
index 00000000000..fd71c3ed6bd
Binary files /dev/null and b/docs/r2/images/profiler-capture.png differ
diff --git a/docs/r2/images/profiler-download-logdir.png b/docs/r2/images/profiler-download-logdir.png
new file mode 100644
index 00000000000..efeb9f2b44d
Binary files /dev/null and b/docs/r2/images/profiler-download-logdir.png differ
diff --git a/docs/r2/images/profiler-idle-gpu.png b/docs/r2/images/profiler-idle-gpu.png
new file mode 100644
index 00000000000..a6e854686ce
Binary files /dev/null and b/docs/r2/images/profiler-idle-gpu.png differ
diff --git a/docs/r2/images/profiler-input-cpu.png b/docs/r2/images/profiler-input-cpu.png
new file mode 100644
index 00000000000..21062158682
Binary files /dev/null and b/docs/r2/images/profiler-input-cpu.png differ
diff --git a/docs/r2/images/profiler-notebook-settings.png b/docs/r2/images/profiler-notebook-settings.png
new file mode 100644
index 00000000000..e2100d88ea7
Binary files /dev/null and b/docs/r2/images/profiler-notebook-settings.png differ
diff --git a/docs/r2/images/profiler-prefetch-runtime.png b/docs/r2/images/profiler-prefetch-runtime.png
new file mode 100644
index 00000000000..99b8de1724d
Binary files /dev/null and b/docs/r2/images/profiler-prefetch-runtime.png differ
diff --git a/docs/r2/images/profiler-trace-viewer-select.png b/docs/r2/images/profiler-trace-viewer-select.png
new file mode 100644
index 00000000000..6bb7df322f7
Binary files /dev/null and b/docs/r2/images/profiler-trace-viewer-select.png differ
diff --git a/docs/r2/images/profiler-trace-viewer.png b/docs/r2/images/profiler-trace-viewer.png
new file mode 100644
index 00000000000..8a66fed9e12
Binary files /dev/null and b/docs/r2/images/profiler-trace-viewer.png differ
diff --git a/docs/r2/tensorboard_profiling_keras.ipynb b/docs/r2/tensorboard_profiling_keras.ipynb
new file mode 100644
index 00000000000..c8b46b5f4a9
--- /dev/null
+++ b/docs/r2/tensorboard_profiling_keras.ipynb
@@ -0,0 +1,1086 @@
+{
+ "nbformat": 4,
+ "nbformat_minor": 0,
+ "metadata": {
+ "colab": {
+ "name": "external tensorboard_profiling_keras.ipynb",
+ "version": "0.3.2",
+ "provenance": [],
+ "collapsed_sections": [],
+ "toc_visible": true,
+ "include_colab_link": true
+ },
+ "kernelspec": {
+ "name": "python3",
+ "display_name": "Python 3"
+ },
+ "accelerator": "GPU"
+ },
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "view-in-github",
+ "colab_type": "text"
+ },
+ "source": [
+ ""
+ ]
+ },
+ {
+ "metadata": {
+ "id": "djUvWu41mtXa",
+ "colab_type": "text"
+ },
+ "cell_type": "markdown",
+ "source": [
+ "##### Copyright 2019 The TensorFlow Authors."
+ ]
+ },
+ {
+ "metadata": {
+ "id": "su2RaORHpReL",
+ "colab_type": "code",
+ "colab": {}
+ },
+ "cell_type": "code",
+ "source": [
+ "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
+ "# you may not use this file except in compliance with the License.\n",
+ "# You may obtain a copy of the License at\n",
+ "#\n",
+ "# https://www.apache.org/licenses/LICENSE-2.0\n",
+ "#\n",
+ "# Unless required by applicable law or agreed to in writing, software\n",
+ "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
+ "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
+ "# See the License for the specific language governing permissions and\n",
+ "# limitations under the License."
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "metadata": {
+ "id": "NztQK2uFpXT-",
+ "colab_type": "text"
+ },
+ "cell_type": "markdown",
+ "source": [
+ "# TensorBoard Profile: Profiling basic training metrics in Keras\n",
+ "\n",
+ "
"
+ ]
+ },
+ {
+ "metadata": {
+ "id": "eDXRFe_qp5C3",
+ "colab_type": "text"
+ },
+ "cell_type": "markdown",
+ "source": [
+ "\n",
+ "## Overview\n",
+ "Performance is critical for machine learning. TensorFlow has a built-in profiler that allows you to record runtime of each ops with very little effort. Then you can visualize the profile result in TensorBoard's **Profile Plugin**. This tutorial focuses on GPU but the Profile Plugin can also be used with TPUs by following the [Cloud TPU Tools](https://cloud.google.com/tpu/docs/cloud-tpu-tools).\n",
+ "\n",
+ "This tutorial presents very basic examples to help you learn how to enable profiler when developing your Keras model. You will learn how to use the Keras TensorBoard callback to visualize profile result. **Profiler APIs** and **Profiler Server** mentioned in **Other ways for profiling** allow you to profile non-Keras TensorFlow job."
+ ]
+ },
+ {
+ "metadata": {
+ "id": "dG-nnZK9qW9z",
+ "colab_type": "text"
+ },
+ "cell_type": "markdown",
+ "source": [
+ "## Prerequisites\n",
+ "\n",
+ "\n",
+ "* Install latest [TensorBoard](https://www.tensorflow.org/tensorboard) on your local machine.\n",
+ "\n",
+ "* Select **GPU** in the Accelerator drop-down in Notebook Settings (Assuming you run this notebook on Colab).\n",
+ "\n",
+ ">![Notebook Settings](https://github.com/tensorflow/tensorboard/blob/master/docs/r2/images/profiler-notebook-settings.png?raw=1\\)\n",
+ "\n"
+ ]
+ },
+ {
+ "metadata": {
+ "id": "DZhGh-G7KoKL",
+ "colab_type": "text"
+ },
+ "cell_type": "markdown",
+ "source": [
+ "## Setup"
+ ]
+ },
+ {
+ "metadata": {
+ "id": "3U5gdCw_nSG3",
+ "colab_type": "code",
+ "outputId": "ecbc68d1-3e87-42d1-8fb8-a23abdc3d59c",
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 119
+ }
+ },
+ "cell_type": "code",
+ "source": [
+ "# Ensure latest TensorFlow is installed.\n",
+ "!pip install -q tf-nightly-gpu-2.0-preview\n",
+ "# Load the TensorBoard notebook extension.\n",
+ "%load_ext tensorboard\n"
+ ],
+ "execution_count": 0,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "text": [
+ "\u001b[K 100% |████████████████████████████████| 345.7MB 61kB/s \n",
+ "\u001b[K 100% |████████████████████████████████| 3.1MB 6.6MB/s \n",
+ "\u001b[K 100% |████████████████████████████████| 430kB 10.2MB/s \n",
+ "\u001b[K 100% |████████████████████████████████| 61kB 29.0MB/s \n",
+ "\u001b[?25h Building wheel for wrapt (setup.py) ... \u001b[?25ldone\n",
+ "\u001b[31mthinc 6.12.1 has requirement wrapt<1.11.0,>=1.10.0, but you'll have wrapt 1.11.1 which is incompatible.\u001b[0m\n",
+ "\u001b[?25h"
+ ],
+ "name": "stdout"
+ }
+ ]
+ },
+ {
+ "metadata": {
+ "id": "1qIKtOBrqc9Y",
+ "colab_type": "code",
+ "outputId": "cba1a5a5-da4e-4df1-b80c-c83f446f8178",
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 34
+ }
+ },
+ "cell_type": "code",
+ "source": [
+ "from __future__ import absolute_import\n",
+ "from __future__ import division\n",
+ "from __future__ import print_function\n",
+ "\n",
+ "from datetime import datetime\n",
+ "from packaging import version\n",
+ "\n",
+ "import functools\n",
+ "import tensorflow as tf\n",
+ "import tensorflow_datasets as tfds\n",
+ "from tensorflow.python.keras import backend\n",
+ "from tensorflow.python.keras import layers\n",
+ "\n",
+ "import numpy as np\n",
+ "\n",
+ "print(\"TensorFlow version: \", tf.__version__)\n"
+ ],
+ "execution_count": 0,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "text": [
+ "TensorFlow version: 2.0.0-dev20190424\n"
+ ],
+ "name": "stdout"
+ }
+ ]
+ },
+ {
+ "metadata": {
+ "id": "8ZM-6NzYgPRn",
+ "colab_type": "text"
+ },
+ "cell_type": "markdown",
+ "source": [
+ "Confirm TensorFlow can see the GPU."
+ ]
+ },
+ {
+ "metadata": {
+ "id": "gp2p-MemgAIh",
+ "colab_type": "code",
+ "outputId": "d459de3e-c9cb-4336-874a-d544f6323ffe",
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 34
+ }
+ },
+ "cell_type": "code",
+ "source": [
+ "device_name = tf.test.gpu_device_name()\n",
+ "if device_name != '/device:GPU:0':\n",
+ " raise SystemError('GPU device not found')\n",
+ "print('Found GPU at: {}'.format(device_name))"
+ ],
+ "execution_count": 0,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "text": [
+ "Found GPU at: /device:GPU:0\n"
+ ],
+ "name": "stdout"
+ }
+ ]
+ },
+ {
+ "metadata": {
+ "id": "6YDAoNCN3ZNS",
+ "colab_type": "text"
+ },
+ "cell_type": "markdown",
+ "source": [
+ "## Run a simple model with TensorBoard callback\n",
+ "\n",
+ "You're now going to use Keras to build a simple model for classifying [CIFAR-10](https://www.cs.toronto.edu/~kriz/cifar.html) images using ResNet56 (Reference: [Deep Residual Learning for Image Recognition](https://arxiv.org/abs/1512.03385)).\n",
+ "\n",
+ "Following RestNet model code is copied from [TensorFlow models garden](https://github.com/tensorflow/models/blob/master/official/resnet/keras/resnet_cifar_model.py).\n",
+ "\n"
+ ]
+ },
+ {
+ "metadata": {
+ "id": "ImCFrQ74eerE",
+ "colab_type": "code",
+ "colab": {}
+ },
+ "cell_type": "code",
+ "source": [
+ "BATCH_NORM_DECAY = 0.997\n",
+ "BATCH_NORM_EPSILON = 1e-5\n",
+ "L2_WEIGHT_DECAY = 2e-4\n",
+ "\n",
+ "\n",
+ "def identity_building_block(input_tensor,\n",
+ " kernel_size,\n",
+ " filters,\n",
+ " stage,\n",
+ " block,\n",
+ " training=None):\n",
+ " \"\"\"The identity block is the block that has no conv layer at shortcut.\n",
+ "\n",
+ " Arguments:\n",
+ " input_tensor: input tensor\n",
+ " kernel_size: default 3, the kernel size of\n",
+ " middle conv layer at main path\n",
+ " filters: list of integers, the filters of 3 conv layer at main path\n",
+ " stage: integer, current stage label, used for generating layer names\n",
+ " block: current block label, used for generating layer names\n",
+ " training: Only used if training keras model with Estimator. In other\n",
+ " scenarios it is handled automatically.\n",
+ "\n",
+ " Returns:\n",
+ " Output tensor for the block.\n",
+ " \"\"\"\n",
+ " filters1, filters2 = filters\n",
+ " if tf.keras.backend.image_data_format() == 'channels_last':\n",
+ " bn_axis = 3\n",
+ " else:\n",
+ " bn_axis = 1\n",
+ " conv_name_base = 'res' + str(stage) + block + '_branch'\n",
+ " bn_name_base = 'bn' + str(stage) + block + '_branch'\n",
+ "\n",
+ " x = tf.keras.layers.Conv2D(filters1, kernel_size,\n",
+ " padding='same',\n",
+ " kernel_initializer='he_normal',\n",
+ " kernel_regularizer=\n",
+ " tf.keras.regularizers.l2(L2_WEIGHT_DECAY),\n",
+ " bias_regularizer=\n",
+ " tf.keras.regularizers.l2(L2_WEIGHT_DECAY),\n",
+ " name=conv_name_base + '2a')(input_tensor)\n",
+ " x = tf.keras.layers.BatchNormalization(axis=bn_axis,\n",
+ " name=bn_name_base + '2a',\n",
+ " momentum=BATCH_NORM_DECAY,\n",
+ " epsilon=BATCH_NORM_EPSILON)(\n",
+ " x, training=training)\n",
+ " x = tf.keras.layers.Activation('relu')(x)\n",
+ "\n",
+ " x = tf.keras.layers.Conv2D(filters2, kernel_size,\n",
+ " padding='same',\n",
+ " kernel_initializer='he_normal',\n",
+ " kernel_regularizer=\n",
+ " tf.keras.regularizers.l2(L2_WEIGHT_DECAY),\n",
+ " bias_regularizer=\n",
+ " tf.keras.regularizers.l2(L2_WEIGHT_DECAY),\n",
+ " name=conv_name_base + '2b')(x)\n",
+ " x = tf.keras.layers.BatchNormalization(axis=bn_axis,\n",
+ " name=bn_name_base + '2b',\n",
+ " momentum=BATCH_NORM_DECAY,\n",
+ " epsilon=BATCH_NORM_EPSILON)(\n",
+ " x, training=training)\n",
+ "\n",
+ " x = tf.keras.layers.add([x, input_tensor])\n",
+ " x = tf.keras.layers.Activation('relu')(x)\n",
+ " return x\n",
+ "\n",
+ "\n",
+ "def conv_building_block(input_tensor,\n",
+ " kernel_size,\n",
+ " filters,\n",
+ " stage,\n",
+ " block,\n",
+ " strides=(2, 2),\n",
+ " training=None):\n",
+ " \"\"\"A block that has a conv layer at shortcut.\n",
+ "\n",
+ " Arguments:\n",
+ " input_tensor: input tensor\n",
+ " kernel_size: default 3, the kernel size of\n",
+ " middle conv layer at main path\n",
+ " filters: list of integers, the filters of 3 conv layer at main path\n",
+ " stage: integer, current stage label, used for generating layer names\n",
+ " block: current block label, used for generating layer names\n",
+ " strides: Strides for the first conv layer in the block.\n",
+ " training: Only used if training keras model with Estimator. In other\n",
+ " scenarios it is handled automatically.\n",
+ "\n",
+ " Returns:\n",
+ " Output tensor for the block.\n",
+ "\n",
+ " Note that from stage 3,\n",
+ " the first conv layer at main path is with strides=(2, 2)\n",
+ " And the shortcut should have strides=(2, 2) as well\n",
+ " \"\"\"\n",
+ " filters1, filters2 = filters\n",
+ " if tf.keras.backend.image_data_format() == 'channels_last':\n",
+ " bn_axis = 3\n",
+ " else:\n",
+ " bn_axis = 1\n",
+ " conv_name_base = 'res' + str(stage) + block + '_branch'\n",
+ " bn_name_base = 'bn' + str(stage) + block + '_branch'\n",
+ "\n",
+ " x = tf.keras.layers.Conv2D(filters1, kernel_size, strides=strides,\n",
+ " padding='same',\n",
+ " kernel_initializer='he_normal',\n",
+ " kernel_regularizer=\n",
+ " tf.keras.regularizers.l2(L2_WEIGHT_DECAY),\n",
+ " bias_regularizer=\n",
+ " tf.keras.regularizers.l2(L2_WEIGHT_DECAY),\n",
+ " name=conv_name_base + '2a')(input_tensor)\n",
+ " x = tf.keras.layers.BatchNormalization(axis=bn_axis,\n",
+ " name=bn_name_base + '2a',\n",
+ " momentum=BATCH_NORM_DECAY,\n",
+ " epsilon=BATCH_NORM_EPSILON)(\n",
+ " x, training=training)\n",
+ " x = tf.keras.layers.Activation('relu')(x)\n",
+ "\n",
+ " x = tf.keras.layers.Conv2D(filters2, kernel_size, padding='same',\n",
+ " kernel_initializer='he_normal',\n",
+ " kernel_regularizer=\n",
+ " tf.keras.regularizers.l2(L2_WEIGHT_DECAY),\n",
+ " bias_regularizer=\n",
+ " tf.keras.regularizers.l2(L2_WEIGHT_DECAY),\n",
+ " name=conv_name_base + '2b')(x)\n",
+ " x = tf.keras.layers.BatchNormalization(axis=bn_axis,\n",
+ " name=bn_name_base + '2b',\n",
+ " momentum=BATCH_NORM_DECAY,\n",
+ " epsilon=BATCH_NORM_EPSILON)(\n",
+ " x, training=training)\n",
+ "\n",
+ " shortcut = tf.keras.layers.Conv2D(filters2, (1, 1), strides=strides,\n",
+ " kernel_initializer='he_normal',\n",
+ " kernel_regularizer=\n",
+ " tf.keras.regularizers.l2(L2_WEIGHT_DECAY),\n",
+ " bias_regularizer=\n",
+ " tf.keras.regularizers.l2(L2_WEIGHT_DECAY),\n",
+ " name=conv_name_base + '1')(input_tensor)\n",
+ " shortcut = tf.keras.layers.BatchNormalization(\n",
+ " axis=bn_axis, name=bn_name_base + '1',\n",
+ " momentum=BATCH_NORM_DECAY, epsilon=BATCH_NORM_EPSILON)(\n",
+ " shortcut, training=training)\n",
+ "\n",
+ " x = tf.keras.layers.add([x, shortcut])\n",
+ " x = tf.keras.layers.Activation('relu')(x)\n",
+ " return x\n",
+ "\n",
+ "\n",
+ "def resnet_block(input_tensor,\n",
+ " size,\n",
+ " kernel_size,\n",
+ " filters,\n",
+ " stage,\n",
+ " conv_strides=(2, 2),\n",
+ " training=None):\n",
+ " \"\"\"A block which applies conv followed by multiple identity blocks.\n",
+ "\n",
+ " Arguments:\n",
+ " input_tensor: input tensor\n",
+ " size: integer, number of constituent conv/identity building blocks.\n",
+ " A conv block is applied once, followed by (size - 1) identity blocks.\n",
+ " kernel_size: default 3, the kernel size of\n",
+ " middle conv layer at main path\n",
+ " filters: list of integers, the filters of 3 conv layer at main path\n",
+ " stage: integer, current stage label, used for generating layer names\n",
+ " conv_strides: Strides for the first conv layer in the block.\n",
+ " training: Only used if training keras model with Estimator. In other\n",
+ " scenarios it is handled automatically.\n",
+ "\n",
+ " Returns:\n",
+ " Output tensor after applying conv and identity blocks.\n",
+ " \"\"\"\n",
+ "\n",
+ " x = conv_building_block(input_tensor, kernel_size, filters, stage=stage,\n",
+ " strides=conv_strides, block='block_0',\n",
+ " training=training)\n",
+ " for i in range(size - 1):\n",
+ " x = identity_building_block(x, kernel_size, filters, stage=stage,\n",
+ " block='block_%d' % (i + 1), training=training)\n",
+ " return x\n",
+ "\n",
+ "def resnet(num_blocks, classes=10, training=None):\n",
+ " \"\"\"Instantiates the ResNet architecture.\n",
+ "\n",
+ " Arguments:\n",
+ " num_blocks: integer, the number of conv/identity blocks in each block.\n",
+ " The ResNet contains 3 blocks with each block containing one conv block\n",
+ " followed by (layers_per_block - 1) number of idenity blocks. Each\n",
+ " conv/idenity block has 2 convolutional layers. With the input\n",
+ " convolutional layer and the pooling layer towards the end, this brings\n",
+ " the total size of the network to (6*num_blocks + 2)\n",
+ " classes: optional number of classes to classify images into\n",
+ " training: Only used if training keras model with Estimator. In other\n",
+ " scenarios it is handled automatically.\n",
+ "\n",
+ " Returns:\n",
+ " A Keras model instance.\n",
+ " \"\"\"\n",
+ "\n",
+ " input_shape = (32, 32, 3)\n",
+ " img_input = layers.Input(shape=input_shape)\n",
+ "\n",
+ " if backend.image_data_format() == 'channels_first':\n",
+ " x = layers.Lambda(lambda x: backend.permute_dimensions(x, (0, 3, 1, 2)),\n",
+ " name='transpose')(img_input)\n",
+ " bn_axis = 1\n",
+ " else: # channel_last\n",
+ " x = img_input\n",
+ " bn_axis = 3\n",
+ "\n",
+ " x = tf.keras.layers.ZeroPadding2D(padding=(1, 1), name='conv1_pad')(x)\n",
+ " x = tf.keras.layers.Conv2D(16, (3, 3),\n",
+ " strides=(1, 1),\n",
+ " padding='valid',\n",
+ " kernel_initializer='he_normal',\n",
+ " kernel_regularizer=\n",
+ " tf.keras.regularizers.l2(L2_WEIGHT_DECAY),\n",
+ " bias_regularizer=\n",
+ " tf.keras.regularizers.l2(L2_WEIGHT_DECAY),\n",
+ " name='conv1')(x)\n",
+ " x = tf.keras.layers.BatchNormalization(axis=bn_axis, name='bn_conv1',\n",
+ " momentum=BATCH_NORM_DECAY,\n",
+ " epsilon=BATCH_NORM_EPSILON)(\n",
+ " x, training=training)\n",
+ " x = tf.keras.layers.Activation('relu')(x)\n",
+ "\n",
+ " x = resnet_block(x, size=num_blocks, kernel_size=3, filters=[16, 16],\n",
+ " stage=2, conv_strides=(1, 1), training=training)\n",
+ "\n",
+ " x = resnet_block(x, size=num_blocks, kernel_size=3, filters=[32, 32],\n",
+ " stage=3, conv_strides=(2, 2), training=training)\n",
+ "\n",
+ " x = resnet_block(x, size=num_blocks, kernel_size=3, filters=[64, 64],\n",
+ " stage=4, conv_strides=(2, 2), training=training)\n",
+ "\n",
+ " x = tf.keras.layers.GlobalAveragePooling2D(name='avg_pool')(x)\n",
+ " x = tf.keras.layers.Dense(classes, activation='softmax',\n",
+ " kernel_initializer='he_normal',\n",
+ " kernel_regularizer=\n",
+ " tf.keras.regularizers.l2(L2_WEIGHT_DECAY),\n",
+ " bias_regularizer=\n",
+ " tf.keras.regularizers.l2(L2_WEIGHT_DECAY),\n",
+ " name='fc10')(x)\n",
+ "\n",
+ " inputs = img_input\n",
+ " # Create model.\n",
+ " model = tf.keras.models.Model(inputs, x, name='resnet56')\n",
+ "\n",
+ " return model\n",
+ "\n",
+ "\n",
+ "resnet20 = functools.partial(resnet, num_blocks=3)\n",
+ "resnet32 = functools.partial(resnet, num_blocks=5)\n",
+ "resnet56 = functools.partial(resnet, num_blocks=9)\n",
+ "resnet10 = functools.partial(resnet, num_blocks=110)"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "metadata": {
+ "id": "1lAek-Lye8_q",
+ "colab_type": "text"
+ },
+ "cell_type": "markdown",
+ "source": [
+ "Download CIFAR-10 data from [TensorFlow Datasets](https://www.tensorflow.org/datasets)."
+ ]
+ },
+ {
+ "metadata": {
+ "id": "H8A67-bNXzsx",
+ "colab_type": "code",
+ "colab": {}
+ },
+ "cell_type": "code",
+ "source": [
+ "cifar_builder = tfds.builder('cifar10')\n",
+ "cifar_builder.download_and_prepare()\n"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "metadata": {
+ "id": "21jm6LOSq9EN",
+ "colab_type": "text"
+ },
+ "cell_type": "markdown",
+ "source": [
+ "Build data input pipeline and compile ResNet56 model."
+ ]
+ },
+ {
+ "metadata": {
+ "id": "j-ryO6OxnQH_",
+ "colab_type": "code",
+ "colab": {}
+ },
+ "cell_type": "code",
+ "source": [
+ "HEIGHT = 32\n",
+ "WIDTH = 32\n",
+ "NUM_CHANNELS = 3\n",
+ "NUM_CLASSES = 10\n",
+ "BATCH_SIZE = 128\n",
+ "\n",
+ "def preprocess_data(record):\n",
+ " image = record['image']\n",
+ " label = record['label']\n",
+ " \n",
+ " # Resize the image to add four extra pixels on each side.\n",
+ " image = tf.image.resize_with_crop_or_pad(\n",
+ " image, HEIGHT + 8, WIDTH + 8)\n",
+ "\n",
+ " # Randomly crop a [HEIGHT, WIDTH] section of the image.\n",
+ " image = tf.image.random_crop(image, [HEIGHT, WIDTH, NUM_CHANNELS])\n",
+ "\n",
+ " # Randomly flip the image horizontally.\n",
+ " image = tf.image.random_flip_left_right(image)\n",
+ "\n",
+ " # Subtract off the mean and divide by the variance of the pixels.\n",
+ " image = tf.image.per_image_standardization(image)\n",
+ " \n",
+ " label = tf.compat.v1.sparse_to_dense(label, (NUM_CLASSES,), 1)\n",
+ " return image, label\n",
+ "\n",
+ "train_data = cifar_builder.as_dataset(split=tfds.Split.TRAIN)\n",
+ "train_data = train_data.repeat()\n",
+ "train_data = train_data.map(\n",
+ " lambda value: preprocess_data(value))\n",
+ "train_data = train_data.shuffle(1024)\n",
+ "\n",
+ "train_data = train_data.batch(BATCH_SIZE)\n",
+ "\n",
+ "model = resnet56(classes=NUM_CLASSES)\n",
+ "\n",
+ "model.compile(optimizer='SGD',\n",
+ " loss='categorical_crossentropy',\n",
+ " metrics=['categorical_accuracy'])\n"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "metadata": {
+ "id": "_5llFQBKHFmA",
+ "colab_type": "text"
+ },
+ "cell_type": "markdown",
+ "source": [
+ "When creating TensorBoard callback, you can specify the batch num you want to profile. By default, TensorFlow will profile the second batch, because many one time graph optimizations run on the first batch. You can modify it by setting **profile_batch**. You can also turn off profiling by setting it to 0.\n",
+ "\n",
+ "This time, you will profile on the third batch."
+ ]
+ },
+ {
+ "metadata": {
+ "id": "WmY-2znGJxNY",
+ "colab_type": "code",
+ "colab": {}
+ },
+ "cell_type": "code",
+ "source": [
+ "log_dir=\"logs/profile/\" + datetime.now().strftime(\"%Y%m%d-%H%M%S\")\n",
+ "\n",
+ "tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1, profile_batch = 3)"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "metadata": {
+ "id": "ylDhh7zlJ273",
+ "colab_type": "text"
+ },
+ "cell_type": "markdown",
+ "source": [
+ "Start training by calling [Model.fit()](https://https://www.tensorflow.org/api_docs/python/tf/keras/models/Model#fit). "
+ ]
+ },
+ {
+ "metadata": {
+ "id": "LEb_1HETJ_tX",
+ "colab_type": "code",
+ "outputId": "5c22c9f5-6901-4ff6-a5fd-f2700010e85c",
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 275
+ }
+ },
+ "cell_type": "code",
+ "source": [
+ "model.fit(train_data,\n",
+ " steps_per_epoch=20,\n",
+ " epochs=5, \n",
+ " callbacks=[tensorboard_callback])"
+ ],
+ "execution_count": 0,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "text": [
+ "Epoch 1/5\n",
+ " 1/20 [>.............................] - ETA: 14:27 - loss: 5.4251 - categorical_accuracy: 0.0859"
+ ],
+ "name": "stdout"
+ },
+ {
+ "output_type": "stream",
+ "text": [
+ "W0425 21:14:50.396199 140078590396288 callbacks.py:238] Method (on_train_batch_end) is slow compared to the batch update (0.317050). Check your callbacks.\n"
+ ],
+ "name": "stderr"
+ },
+ {
+ "output_type": "stream",
+ "text": [
+ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r 2/20 [==>...........................] - ETA: 6:58 - loss: 5.5955 - categorical_accuracy: 0.0781 "
+ ],
+ "name": "stdout"
+ },
+ {
+ "output_type": "stream",
+ "text": [
+ "W0425 21:14:50.954807 140078590396288 callbacks.py:238] Method (on_train_batch_end) is slow compared to the batch update (0.268180). Check your callbacks.\n"
+ ],
+ "name": "stderr"
+ },
+ {
+ "output_type": "stream",
+ "text": [
+ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r 3/20 [===>..........................] - ETA: 4:26 - loss: 5.7003 - categorical_accuracy: 0.0911"
+ ],
+ "name": "stdout"
+ },
+ {
+ "output_type": "stream",
+ "text": [
+ "W0425 21:14:51.180765 140078590396288 callbacks.py:238] Method (on_train_batch_end) is slow compared to the batch update (0.134130). Check your callbacks.\n"
+ ],
+ "name": "stderr"
+ },
+ {
+ "output_type": "stream",
+ "text": [
+ "20/20 [==============================] - 51s 3s/step - loss: 5.3766 - categorical_accuracy: 0.1004\n",
+ "Epoch 2/5\n",
+ "20/20 [==============================] - 5s 227ms/step - loss: 4.8007 - categorical_accuracy: 0.0988\n",
+ "Epoch 3/5\n",
+ "20/20 [==============================] - 5s 242ms/step - loss: 4.3439 - categorical_accuracy: 0.0980\n",
+ "Epoch 4/5\n",
+ "20/20 [==============================] - 5s 247ms/step - loss: 3.9405 - categorical_accuracy: 0.1074\n",
+ "Epoch 5/5\n",
+ "20/20 [==============================] - 5s 225ms/step - loss: 3.6195 - categorical_accuracy: 0.1176\n"
+ ],
+ "name": "stdout"
+ },
+ {
+ "output_type": "execute_result",
+ "data": {
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {
+ "tags": []
+ },
+ "execution_count": 12
+ }
+ ]
+ },
+ {
+ "metadata": {
+ "id": "042k7GMERVkx",
+ "colab_type": "text"
+ },
+ "cell_type": "markdown",
+ "source": [
+ "## Visualizing profile result using TensorBoard\n",
+ "\n",
+ "Unfortunately, due to [#1913](https://github.com/tensorflow/tensorboard/issues/1913), you cannot use TensorBoard in Colab to visualize profile result. You are going to download the logdir and start TensorBoard on your local machine.\n",
+ "\n",
+ "Compress logdir:\n"
+ ]
+ },
+ {
+ "metadata": {
+ "id": "6pck56gKReON",
+ "colab_type": "code",
+ "colab": {}
+ },
+ "cell_type": "code",
+ "source": [
+ "!tar -zcvf logs.tar.gz logs/profile/"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "metadata": {
+ "id": "TZOf_K4L-Nkv",
+ "colab_type": "text"
+ },
+ "cell_type": "markdown",
+ "source": [
+ "Download **logdir.tar.gz** by right-clicking it in **Files** tab.\n",
+ "\n",
+ "![Download](https://github.com/tensorflow/tensorboard/blob/master/docs/r2/images/profiler-download-logdir.png?raw=1\\)\n",
+ "\n",
+ "Please make sure you have the latest [TensorBoard](https://www.tensorflow.org/tensorboard) installed on you local machine as well. Execute following commands on your local machine:\n",
+ "\n",
+ "```\n",
+ "> cd download/directory\n",
+ "> tar -zxvf logs.tar.gz\n",
+ "> tensorboard --logdir=logs/ --port=6006\n",
+ "\n",
+ "```\n",
+ "\n",
+ "\n"
+ ]
+ },
+ {
+ "metadata": {
+ "id": "ciSIRibhRi6N",
+ "colab_type": "text"
+ },
+ "cell_type": "markdown",
+ "source": [
+ "Open a new tab in your Chrome browser and navigate to [localhost:6006](http://localhost:6006) and then click **Profile** tab. You may see the profile result like this:\n",
+ "\n",
+ "![Trace View](https://github.com/tensorflow/tensorboard/blob/master/docs/r2/images/profiler-trace-viewer.png?raw=1\\)\n",
+ "\n",
+ "\n"
+ ]
+ },
+ {
+ "metadata": {
+ "id": "roE94vH9mJ6k",
+ "colab_type": "text"
+ },
+ "cell_type": "markdown",
+ "source": [
+ "## Trace Viewer\n",
+ "Once you click the profile tab, you will see Trace Viewer. The page displays a timeline of different events that happened on the CPU and the accelerator during the collection period.\n",
+ "\n",
+ "The Trace Viewer shows multiple **event groups** on the vertical axis. Each event group has multiple horizontal **tracks**, filled with trace events. The **track** is basically an event timeline for events executed on a thread or a GPU stream. Individual events are the colored, rectangular blocks on the timeline tracks. Time moves from left to right.\n",
+ "\n",
+ "You can navigate through the result using **w** (zoom in), **s** (zoom out), **a** (scroll left), **d** (scroll right).\n",
+ "\n",
+ "A single rectangle represents a **trace event**: when it began, and when it ended. To study an individual rectangle, you can click on it after selecting the mouse cursor icon in the floating tool bar. This will display information about the rectangle, such as its Start time and Duration.\n",
+ "\n",
+ "In addition to clicking, you can drag the mouse to select a rectangle covering a group of trace events. This will give you a list of events that intersect that rectangle and summarize them for you. The **m** key can be used to measure the time duration of the selected events.\n",
+ "\n",
+ "![List of Events](https://github.com/tensorflow/tensorboard/blob/master/docs/r2/images/profiler-trace-viewer-select.png?raw=1\\)\n",
+ "\n",
+ "The trace events are collected from three sources:\n",
+ "\n",
+ "\n",
+ "* **CPU**: CPU events are under event group named **/host:CPU**. Each track represents a thread on CPU. E.g. input pipeline events, GPU op scheduling events, CPU ops execution events, etc.\n",
+ "* **GPU**: GPU events are under event groups prefixed by **/device:GPU:***. Except **stream:all**, each event group represents one stream on GPU. **stream::all** aggregates all events on one GPU. E.g. Memory copy events, Kernel execution events, etc.\n",
+ "* **TensorFlow Runtime**: Runtime events are under event groups prefixed by **/job:***. Runtime events represent the TensorFlow ops invoked by python program. E.g. tf.function execution events, etc.\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n"
+ ]
+ },
+ {
+ "metadata": {
+ "id": "XAcO9sj4B2DK",
+ "colab_type": "text"
+ },
+ "cell_type": "markdown",
+ "source": [
+ "## Debug Performance\n",
+ "Now, you're going to use the Trace Viewer to improve your model's performance. \n",
+ "\n",
+ "Let's go back to the profile result you have just captured.\n",
+ "\n",
+ "![GPU kernel](https://github.com/tensorflow/tensorboard/blob/master/docs/r2/images/profiler-idle-gpu.png?raw=1\\)\n",
+ "\n",
+ "GPU events show that GPU has nothing to do at all in the first harf of the step.\n",
+ "\n",
+ "![CPU events](https://github.com/tensorflow/tensorboard/blob/master/docs/r2/images/profiler-input-cpu.png?raw=1\\)\n",
+ "\n",
+ "CPU events show that CPU is occupied by data input pipeline in the beginning of this step.\n",
+ "\n",
+ "![Runtime](https://github.com/tensorflow/tensorboard/blob/master/docs/r2/images/profiler-blocking-runtime.png?raw=1\\)\n",
+ "\n",
+ "In TensorFlow runtime, there is a big block named **Iterator::GetNextSync**, which is a blocking call to get the next batch from data input pipeline. And it blocks the training step. So if you could prepare the input data for step **s** in **s-1** step, you can probably train this model faster.\n",
+ "\n",
+ "You can achieve it by using [tf.data.prefetch](https://www.tensorflow.org/api_docs/python/tf/data/Dataset#prefetch)."
+ ]
+ },
+ {
+ "metadata": {
+ "id": "JZ6UeYx9TT2T",
+ "colab_type": "code",
+ "colab": {}
+ },
+ "cell_type": "code",
+ "source": [
+ "train_data = cifar_builder.as_dataset(split=tfds.Split.TRAIN)\n",
+ "train_data = train_data.repeat()\n",
+ "train_data = train_data.map(\n",
+ " lambda value: preprocess_data(value))\n",
+ "train_data = train_data.shuffle(1024)\n",
+ "train_data = train_data.batch(BATCH_SIZE)\n",
+ "\n",
+ "# It will prefetch the data in (s-1) step\n",
+ "train_data = train_data.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "metadata": {
+ "id": "EfD6pnhgT7q3",
+ "colab_type": "text"
+ },
+ "cell_type": "markdown",
+ "source": [
+ "Re-run the model."
+ ]
+ },
+ {
+ "metadata": {
+ "id": "tgFqaHYBUADP",
+ "colab_type": "code",
+ "outputId": "a2e5ab1a-1390-4637-a630-c84035de3879",
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 204
+ }
+ },
+ "cell_type": "code",
+ "source": [
+ "log_dir=\"logs/profile/\" + datetime.now().strftime(\"%Y%m%d-%H%M%S\")\n",
+ "\n",
+ "tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1, profile_batch = 3)\n",
+ "\n",
+ "model.fit(train_data,\n",
+ " steps_per_epoch=20,\n",
+ " epochs=5, \n",
+ " callbacks=[tensorboard_callback])"
+ ],
+ "execution_count": 0,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "text": [
+ "Epoch 1/5\n",
+ "20/20 [==============================] - 5s 265ms/step - loss: 3.4081 - categorical_accuracy: 0.1055\n",
+ "Epoch 2/5\n",
+ "20/20 [==============================] - 4s 205ms/step - loss: 3.3122 - categorical_accuracy: 0.1141\n",
+ "Epoch 3/5\n",
+ "20/20 [==============================] - 4s 200ms/step - loss: 3.2795 - categorical_accuracy: 0.1199\n",
+ "Epoch 4/5\n",
+ "20/20 [==============================] - 4s 204ms/step - loss: 3.2237 - categorical_accuracy: 0.1469\n",
+ "Epoch 5/5\n",
+ "20/20 [==============================] - 4s 201ms/step - loss: 3.1888 - categorical_accuracy: 0.1465\n"
+ ],
+ "name": "stdout"
+ },
+ {
+ "output_type": "execute_result",
+ "data": {
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {
+ "tags": []
+ },
+ "execution_count": 14
+ }
+ ]
+ },
+ {
+ "metadata": {
+ "id": "LFtVDt-9UVkn",
+ "colab_type": "text"
+ },
+ "cell_type": "markdown",
+ "source": [
+ "Woohoo! You have just improvd training performance from **~235ms/step** to **~200ms/step**. "
+ ]
+ },
+ {
+ "metadata": {
+ "id": "if5LuLl_pgna",
+ "colab_type": "code",
+ "colab": {}
+ },
+ "cell_type": "code",
+ "source": [
+ "!tar -zcvf logs.tar.gz logs/profile/"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "metadata": {
+ "id": "aBBKSVJVp4yk",
+ "colab_type": "text"
+ },
+ "cell_type": "markdown",
+ "source": [
+ "\n",
+ "Download **logs** directory again to see the new profile result in TensorBoard.\n",
+ "\n",
+ "![TF Runtime](https://github.com/tensorflow/tensorboard/blob/master/docs/r2/images/profiler-prefetch-runtime.png?raw=1\\)\n",
+ "\n",
+ "The big **Iterator::GetNextSync** block is not there anymore.\n",
+ "\n",
+ "Good job!\n",
+ "\n",
+ "Apparently, this is still not the best performance yet. Please try by yourself to see if you can have further improvements.\n",
+ "\n",
+ "Some useful references for performance tuning:\n",
+ "\n",
+ "\n",
+ "* [TensorFlow Performance](https://www.tensorflow.org/guide/performance/overview)\n",
+ "* [Data input pipeline](https://www.tensorflow.org/guide/performance/datasets)\n",
+ "* [Training Performance: A user’s guide to converge faster (TensorFlow Dev Summit 2018)](https://www.youtube.com/watch?v=SxOsJPaxHME)\n",
+ "\n",
+ "\n"
+ ]
+ },
+ {
+ "metadata": {
+ "id": "pLfa4vMn626q",
+ "colab_type": "text"
+ },
+ "cell_type": "markdown",
+ "source": [
+ "## Other ways for profiling\n",
+ "In addition to TensorBoard callback, TensorFlow also provides two additional way to trigger profiler manually: **Profiler APIs** and **Profiler Service**.\n",
+ "\n",
+ "**NOTE**: Please don't run multiple profilers at the same time. If you want to use either Profiler APIs or Profiler Service with TensorBoard callback, ensure the **profile_batch** parameter is set to 0.\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n"
+ ]
+ },
+ {
+ "metadata": {
+ "id": "gt9Dm8PkL1FI",
+ "colab_type": "text"
+ },
+ "cell_type": "markdown",
+ "source": [
+ "### Profiler APIs"
+ ]
+ },
+ {
+ "metadata": {
+ "id": "VYywGzC2GQ8w",
+ "colab_type": "code",
+ "colab": {}
+ },
+ "cell_type": "code",
+ "source": [
+ "# Context manager APIs\n",
+ "with tf.python.eager.profiler.Profiler('logdir_path'):\n",
+ " # do your training here\n",
+ " pass\n",
+ "\n",
+ "\n",
+ "# Function APIs\n",
+ "tf.python.eager.profiler.start()\n",
+ "# do your training here\n",
+ "profiler_result = tf.python.eager.profiler.stop()\n",
+ "tf.python.eager.profiler.save('logdir_path', profiler_result)"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "metadata": {
+ "id": "NSHEq0rIHHBs",
+ "colab_type": "text"
+ },
+ "cell_type": "markdown",
+ "source": [
+ "### Profiler Service\n",
+ "\n"
+ ]
+ },
+ {
+ "metadata": {
+ "id": "USTAe02KHcql",
+ "colab_type": "code",
+ "colab": {}
+ },
+ "cell_type": "code",
+ "source": [
+ "# This API will start a gRPC server with your TensorFlow job which can receive\n",
+ "# on-demand profiling request.\n",
+ "tf.python.eager.profiler.start_profiler_server(6009)\n",
+ "\n",
+ "# Your TensorFlow program here"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "metadata": {
+ "id": "AgIro3xQIXUa",
+ "colab_type": "text"
+ },
+ "cell_type": "markdown",
+ "source": [
+ "Then you can send profiling request to profiler server to perform on-demand profiling on TensorBoard by clicking **CAPTURE PROFILE** button:\n",
+ "\n",
+ "![CAPTURE PROFILE](https://github.com/tensorflow/tensorboard/blob/master/docs/r2/images/profiler-capture.png?raw=1\\)\n",
+ "\n",
+ "A message will show up after successfully captured. Then you can refresh TensorBoard to visualize the result."
+ ]
+ }
+ ]
+}
\ No newline at end of file