Skip to content

Commit

Permalink
Intro colab with Transformer.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 289761952
  • Loading branch information
Lukasz Kaiser authored and copybara-github committed Jan 15, 2020
1 parent c061226 commit d2c5b84
Showing 1 changed file with 333 additions and 0 deletions.
333 changes: 333 additions & 0 deletions trax/intro.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,333 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "7yuytuIllsv1"
},
"source": [
"# Trax Quick Intro\n",
"\n",
"We train **Trax Transformer** on a simple copy problem and run inference.\n",
"* See how to create your inputs from python.\n",
"* Learn how to run the trainer.\n",
"* Run fast inference with Transformer.\n",
"\n",
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "BIl27504La0G"
},
"source": [
"## General Setup\n",
"Execute the following few cells (once) before running any of the code samples in this notebook."
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"cellView": "both",
"colab": {},
"colab_type": "code",
"id": "oILRLCWN_16u"
},
"outputs": [],
"source": [
"#@title\n",
"# Copyright 2020 Google LLC.\n",
"\n",
"# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
"# you may not use this file except in compliance with the License.\n",
"# You may obtain a copy of the License at\n",
"\n",
"# https://www.apache.org/licenses/LICENSE-2.0\n",
"\n",
"# Unless required by applicable law or agreed to in writing, software\n",
"# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
"# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
"# See the License for the specific language governing permissions and\n",
"# limitations under the License.\n",
"\n",
"import os\n",
"import numpy as np\n",
"\n",
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"cellView": "both",
"colab": {},
"colab_type": "code",
"id": "vlGjGoGMTt-D"
},
"outputs": [],
"source": [
"#@title\n",
"# Import Trax\n",
"\n",
"! pip install -q -U trax\n",
"! pip install -q tensorflow\n",
"\n",
"import trax"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "-LQ89rFFsEdk"
},
"source": [
"# Transformer"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"colab": {
"height": 68
},
"colab_type": "code",
"executionInfo": {
"elapsed": 318,
"status": "ok",
"timestamp": 1578963024402,
"user": {
"displayName": "Lukasz Kaiser",
"photoUrl": "https://lh3.googleusercontent.com/a-/AAuE7mC8pChl87HbK_eOtVhtNPwUVx8btvfyYzH9UHn3=s64",
"userId": "13267693649565518272"
},
"user_tz": 480
},
"id": "djTiSLcaNFGa",
"outputId": "610b5f32-e47e-4afd-971f-3e33b03adc0f"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Inputs[0]: [ 0 6 13 29 22 0 6 13 29 22]\n",
"Targets[0]: [ 0 6 13 29 22 0 6 13 29 22]\n",
"Mask[0]: [0. 0. 0. 0. 0. 1. 1. 1. 1. 1.]\n"
]
}
],
"source": [
"# Construct inputs, see one batch\n",
"def copy_task(batch_size, vocab_size, length):\n",
" \"\"\"This task is to copy a random string w, so the input is 0w0w.\"\"\"\n",
" while True:\n",
" assert length % 2 == 0\n",
" w_length = (length // 2) - 1\n",
" w = np.random.randint(low=1, high=vocab_size-1,\n",
" size=(batch_size, w_length))\n",
" zero = np.zeros([batch_size, 1], np.int32)\n",
" loss_weights = np.concatenate([np.zeros((batch_size, w_length)),\n",
" np.ones((batch_size, w_length+2))], axis=1)\n",
" x = np.concatenate([zero, w, zero, w], axis=1)\n",
" yield (x, x, loss_weights) # Here inputs and targets are the same.\n",
"copy_inputs = trax.supervised.Inputs(lambda _: copy_task(16, 32, 10))\n",
"\n",
"# Peek into the inputs.\n",
"data_stream = copy_inputs.train_stream(1)\n",
"inputs, targets, mask = next(data_stream)\n",
"print(\"Inputs[0]: %s\" % str(inputs[0]))\n",
"print(\"Targets[0]: %s\" % str(targets[0]))\n",
"print(\"Mask[0]: %s\" % str(mask[0]))"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"colab": {
"height": 629
},
"colab_type": "code",
"executionInfo": {
"elapsed": 28950,
"status": "ok",
"timestamp": 1578963053368,
"user": {
"displayName": "Lukasz Kaiser",
"photoUrl": "https://lh3.googleusercontent.com/a-/AAuE7mC8pChl87HbK_eOtVhtNPwUVx8btvfyYzH9UHn3=s64",
"userId": "13267693649565518272"
},
"user_tz": 480
},
"id": "kSauPt0NUl_o",
"outputId": "e5cf66a1-5deb-43f3-fb0a-a5fa479c7cbf"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"Step 500: Ran 500 train steps in 16.51 secs\n",
"Step 500: Evaluation\n",
"Step 500: train accuracy | 0.53125000\n",
"Step 500: train loss | 1.83887446\n",
"Step 500: train neg_log_perplexity | -1.83887446\n",
"Step 500: train weights_per_batch_per_core | 80.00000000\n",
"Step 500: eval accuracy | 0.52500004\n",
"Step 500: eval loss | 1.92791247\n",
"Step 500: eval neg_log_perplexity | -1.92791247\n",
"Step 500: eval weights_per_batch_per_core | 80.00000000\n",
"Step 500: Finished evaluation\n",
"\n",
"Step 1000: Ran 500 train steps in 2.54 secs\n",
"Step 1000: Evaluation\n",
"Step 1000: train accuracy | 1.00000000\n",
"Step 1000: train loss | 0.00707983\n",
"Step 1000: train neg_log_perplexity | -0.00707983\n",
"Step 1000: train weights_per_batch_per_core | 80.00000000\n",
"Step 1000: eval accuracy | 1.00000000\n",
"Step 1000: eval loss | 0.01029818\n",
"Step 1000: eval neg_log_perplexity | -0.01029818\n",
"Step 1000: eval weights_per_batch_per_core | 80.00000000\n",
"Step 1000: Finished evaluation\n",
"\n",
"Step 1500: Ran 500 train steps in 2.46 secs\n",
"Step 1500: Evaluation\n",
"Step 1500: train accuracy | 1.00000000\n",
"Step 1500: train loss | 0.00037777\n",
"Step 1500: train neg_log_perplexity | -0.00037777\n",
"Step 1500: train weights_per_batch_per_core | 80.00000000\n",
"Step 1500: eval accuracy | 1.00000000\n",
"Step 1500: eval loss | 0.00037660\n",
"Step 1500: eval neg_log_perplexity | -0.00037660\n",
"Step 1500: eval weights_per_batch_per_core | 80.00000000\n",
"Step 1500: Finished evaluation\n"
]
}
],
"source": [
"# Transformer LM\n",
"def tiny_transformer_lm(mode):\n",
" return trax.models.TransformerLM( # You can try trax_models.ReformerLM too.\n",
" d_model=32, d_ff=128, n_layers=2, vocab_size=32, mode=mode)\n",
"\n",
"# Train tiny model with Trainer.\n",
"output_dir = os.path.expanduser('~/train_dir/')\n",
"!rm -f ~/train_dir/model.pkl # Remove old model.\n",
"trainer = trax.supervised.Trainer(\n",
" model=tiny_transformer_lm,\n",
" loss_fn=trax.layers.CrossEntropyLossScalar,\n",
" optimizer=trax.optimizers.Adafactor, # Change optimizer params here.\n",
" lr_schedule=trax.lr.MultifactorSchedule, # Change lr schedule here.\n",
" inputs=copy_inputs,\n",
" output_dir=output_dir,\n",
" has_weights=True) # Because we have loss mask, this API may change.\n",
"\n",
"# Train for 3 epochs each consisting of 500 train batches, eval on 2 batches.\n",
"n_epochs = 3\n",
"train_steps = 500\n",
"eval_steps = 2\n",
"for _ in range(n_epochs):\n",
" trainer.train_epoch(train_steps, eval_steps)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"colab": {
"height": 34
},
"colab_type": "code",
"executionInfo": {
"elapsed": 1190,
"status": "ok",
"timestamp": 1578963686769,
"user": {
"displayName": "Lukasz Kaiser",
"photoUrl": "https://lh3.googleusercontent.com/a-/AAuE7mC8pChl87HbK_eOtVhtNPwUVx8btvfyYzH9UHn3=s64",
"userId": "13267693649565518272"
},
"user_tz": 480
},
"id": "cqjYoxPEu8PG",
"outputId": "d5a99b97-843b-4761-a711-ec337098e30e"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[0, 1, 2, 3, 4, 0, 1, 2, 3, 4]\n"
]
}
],
"source": [
"# Initialize model for inference.\n",
"predict_model = tiny_transformer_lm(mode='predict')\n",
"predict_signature = trax.shapes.ShapeDtype((1,1), dtype=np.int32)\n",
"predict_model.init(predict_signature)\n",
"predict_model.init_from_file(os.path.join(output_dir, \"model.pkl\"),\n",
" weights_only=True)\n",
"# You can also do: predict_model.weights = trainer.model_weights\n",
"\n",
"# Run inference\n",
"prefix = [0, 1, 2, 3, 4, 0] # Change non-0 digits to see if it's copying\n",
"cur_input = np.array([[0]])\n",
"result = []\n",
"for i in range(10):\n",
" logits = predict_model(cur_input)\n",
" next_input = np.argmax(logits[0, 0, :], axis=-1)\n",
" if i \u003c len(prefix) - 1:\n",
" next_input = prefix[i]\n",
" cur_input = np.array([[next_input]])\n",
" result.append(int(next_input)) # Append to the result\n",
"print(result)"
]
}
],
"metadata": {
"colab": {
"collapsed_sections": [],
"last_runtime": {
"build_target": "//learning/deepmind/dm_python:dm_notebook3",
"kind": "private"
},
"name": "Trax Quick Intro",
"provenance": [
{
"file_id": "1v1GvTkEFjMH_1c-bdS7JzNS70u9RUEHV",
"timestamp": 1578964243645
},
{
"file_id": "1SplqILjJr_ZqXcIUkNIk0tSbthfhYm07",
"timestamp": 1572044421118
},
{
"file_id": "intro.ipynb",
"timestamp": 1571858674399
},
{
"file_id": "1sF8QbqJ19ZU6oy5z4GUTt4lgUCjqO6kt",
"timestamp": 1569980697572
},
{
"file_id": "1EH76AWQ_pvT4i8ZXfkv-SCV4MrmllEl5",
"timestamp": 1563927451951
}
]
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
}
},
"nbformat": 4,
"nbformat_minor": 0
}

0 comments on commit d2c5b84

Please sign in to comment.