-
Notifications
You must be signed in to change notification settings - Fork 823
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Showing
1 changed file
with
333 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,333 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": { | ||
"colab_type": "text", | ||
"id": "7yuytuIllsv1" | ||
}, | ||
"source": [ | ||
"# Trax Quick Intro\n", | ||
"\n", | ||
"We train **Trax Transformer** on a simple copy problem and run inference.\n", | ||
"* See how to create your inputs from python.\n", | ||
"* Learn how to run the trainer.\n", | ||
"* Run fast inference with Transformer.\n", | ||
"\n", | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": { | ||
"colab_type": "text", | ||
"id": "BIl27504La0G" | ||
}, | ||
"source": [ | ||
"## General Setup\n", | ||
"Execute the following few cells (once) before running any of the code samples in this notebook." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 0, | ||
"metadata": { | ||
"cellView": "both", | ||
"colab": {}, | ||
"colab_type": "code", | ||
"id": "oILRLCWN_16u" | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"#@title\n", | ||
"# Copyright 2020 Google LLC.\n", | ||
"\n", | ||
"# Licensed under the Apache License, Version 2.0 (the \"License\");\n", | ||
"# you may not use this file except in compliance with the License.\n", | ||
"# You may obtain a copy of the License at\n", | ||
"\n", | ||
"# https://www.apache.org/licenses/LICENSE-2.0\n", | ||
"\n", | ||
"# Unless required by applicable law or agreed to in writing, software\n", | ||
"# distributed under the License is distributed on an \"AS IS\" BASIS,\n", | ||
"# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", | ||
"# See the License for the specific language governing permissions and\n", | ||
"# limitations under the License.\n", | ||
"\n", | ||
"import os\n", | ||
"import numpy as np\n", | ||
"\n", | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 0, | ||
"metadata": { | ||
"cellView": "both", | ||
"colab": {}, | ||
"colab_type": "code", | ||
"id": "vlGjGoGMTt-D" | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"#@title\n", | ||
"# Import Trax\n", | ||
"\n", | ||
"! pip install -q -U trax\n", | ||
"! pip install -q tensorflow\n", | ||
"\n", | ||
"import trax" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": { | ||
"colab_type": "text", | ||
"id": "-LQ89rFFsEdk" | ||
}, | ||
"source": [ | ||
"# Transformer" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 3, | ||
"metadata": { | ||
"colab": { | ||
"height": 68 | ||
}, | ||
"colab_type": "code", | ||
"executionInfo": { | ||
"elapsed": 318, | ||
"status": "ok", | ||
"timestamp": 1578963024402, | ||
"user": { | ||
"displayName": "Lukasz Kaiser", | ||
"photoUrl": "https://lh3.googleusercontent.com/a-/AAuE7mC8pChl87HbK_eOtVhtNPwUVx8btvfyYzH9UHn3=s64", | ||
"userId": "13267693649565518272" | ||
}, | ||
"user_tz": 480 | ||
}, | ||
"id": "djTiSLcaNFGa", | ||
"outputId": "610b5f32-e47e-4afd-971f-3e33b03adc0f" | ||
}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"Inputs[0]: [ 0 6 13 29 22 0 6 13 29 22]\n", | ||
"Targets[0]: [ 0 6 13 29 22 0 6 13 29 22]\n", | ||
"Mask[0]: [0. 0. 0. 0. 0. 1. 1. 1. 1. 1.]\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"# Construct inputs, see one batch\n", | ||
"def copy_task(batch_size, vocab_size, length):\n", | ||
" \"\"\"This task is to copy a random string w, so the input is 0w0w.\"\"\"\n", | ||
" while True:\n", | ||
" assert length % 2 == 0\n", | ||
" w_length = (length // 2) - 1\n", | ||
" w = np.random.randint(low=1, high=vocab_size-1,\n", | ||
" size=(batch_size, w_length))\n", | ||
" zero = np.zeros([batch_size, 1], np.int32)\n", | ||
" loss_weights = np.concatenate([np.zeros((batch_size, w_length)),\n", | ||
" np.ones((batch_size, w_length+2))], axis=1)\n", | ||
" x = np.concatenate([zero, w, zero, w], axis=1)\n", | ||
" yield (x, x, loss_weights) # Here inputs and targets are the same.\n", | ||
"copy_inputs = trax.supervised.Inputs(lambda _: copy_task(16, 32, 10))\n", | ||
"\n", | ||
"# Peek into the inputs.\n", | ||
"data_stream = copy_inputs.train_stream(1)\n", | ||
"inputs, targets, mask = next(data_stream)\n", | ||
"print(\"Inputs[0]: %s\" % str(inputs[0]))\n", | ||
"print(\"Targets[0]: %s\" % str(targets[0]))\n", | ||
"print(\"Mask[0]: %s\" % str(mask[0]))" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 4, | ||
"metadata": { | ||
"colab": { | ||
"height": 629 | ||
}, | ||
"colab_type": "code", | ||
"executionInfo": { | ||
"elapsed": 28950, | ||
"status": "ok", | ||
"timestamp": 1578963053368, | ||
"user": { | ||
"displayName": "Lukasz Kaiser", | ||
"photoUrl": "https://lh3.googleusercontent.com/a-/AAuE7mC8pChl87HbK_eOtVhtNPwUVx8btvfyYzH9UHn3=s64", | ||
"userId": "13267693649565518272" | ||
}, | ||
"user_tz": 480 | ||
}, | ||
"id": "kSauPt0NUl_o", | ||
"outputId": "e5cf66a1-5deb-43f3-fb0a-a5fa479c7cbf" | ||
}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"\n", | ||
"Step 500: Ran 500 train steps in 16.51 secs\n", | ||
"Step 500: Evaluation\n", | ||
"Step 500: train accuracy | 0.53125000\n", | ||
"Step 500: train loss | 1.83887446\n", | ||
"Step 500: train neg_log_perplexity | -1.83887446\n", | ||
"Step 500: train weights_per_batch_per_core | 80.00000000\n", | ||
"Step 500: eval accuracy | 0.52500004\n", | ||
"Step 500: eval loss | 1.92791247\n", | ||
"Step 500: eval neg_log_perplexity | -1.92791247\n", | ||
"Step 500: eval weights_per_batch_per_core | 80.00000000\n", | ||
"Step 500: Finished evaluation\n", | ||
"\n", | ||
"Step 1000: Ran 500 train steps in 2.54 secs\n", | ||
"Step 1000: Evaluation\n", | ||
"Step 1000: train accuracy | 1.00000000\n", | ||
"Step 1000: train loss | 0.00707983\n", | ||
"Step 1000: train neg_log_perplexity | -0.00707983\n", | ||
"Step 1000: train weights_per_batch_per_core | 80.00000000\n", | ||
"Step 1000: eval accuracy | 1.00000000\n", | ||
"Step 1000: eval loss | 0.01029818\n", | ||
"Step 1000: eval neg_log_perplexity | -0.01029818\n", | ||
"Step 1000: eval weights_per_batch_per_core | 80.00000000\n", | ||
"Step 1000: Finished evaluation\n", | ||
"\n", | ||
"Step 1500: Ran 500 train steps in 2.46 secs\n", | ||
"Step 1500: Evaluation\n", | ||
"Step 1500: train accuracy | 1.00000000\n", | ||
"Step 1500: train loss | 0.00037777\n", | ||
"Step 1500: train neg_log_perplexity | -0.00037777\n", | ||
"Step 1500: train weights_per_batch_per_core | 80.00000000\n", | ||
"Step 1500: eval accuracy | 1.00000000\n", | ||
"Step 1500: eval loss | 0.00037660\n", | ||
"Step 1500: eval neg_log_perplexity | -0.00037660\n", | ||
"Step 1500: eval weights_per_batch_per_core | 80.00000000\n", | ||
"Step 1500: Finished evaluation\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"# Transformer LM\n", | ||
"def tiny_transformer_lm(mode):\n", | ||
" return trax.models.TransformerLM( # You can try trax_models.ReformerLM too.\n", | ||
" d_model=32, d_ff=128, n_layers=2, vocab_size=32, mode=mode)\n", | ||
"\n", | ||
"# Train tiny model with Trainer.\n", | ||
"output_dir = os.path.expanduser('~/train_dir/')\n", | ||
"!rm -f ~/train_dir/model.pkl # Remove old model.\n", | ||
"trainer = trax.supervised.Trainer(\n", | ||
" model=tiny_transformer_lm,\n", | ||
" loss_fn=trax.layers.CrossEntropyLossScalar,\n", | ||
" optimizer=trax.optimizers.Adafactor, # Change optimizer params here.\n", | ||
" lr_schedule=trax.lr.MultifactorSchedule, # Change lr schedule here.\n", | ||
" inputs=copy_inputs,\n", | ||
" output_dir=output_dir,\n", | ||
" has_weights=True) # Because we have loss mask, this API may change.\n", | ||
"\n", | ||
"# Train for 3 epochs each consisting of 500 train batches, eval on 2 batches.\n", | ||
"n_epochs = 3\n", | ||
"train_steps = 500\n", | ||
"eval_steps = 2\n", | ||
"for _ in range(n_epochs):\n", | ||
" trainer.train_epoch(train_steps, eval_steps)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 9, | ||
"metadata": { | ||
"colab": { | ||
"height": 34 | ||
}, | ||
"colab_type": "code", | ||
"executionInfo": { | ||
"elapsed": 1190, | ||
"status": "ok", | ||
"timestamp": 1578963686769, | ||
"user": { | ||
"displayName": "Lukasz Kaiser", | ||
"photoUrl": "https://lh3.googleusercontent.com/a-/AAuE7mC8pChl87HbK_eOtVhtNPwUVx8btvfyYzH9UHn3=s64", | ||
"userId": "13267693649565518272" | ||
}, | ||
"user_tz": 480 | ||
}, | ||
"id": "cqjYoxPEu8PG", | ||
"outputId": "d5a99b97-843b-4761-a711-ec337098e30e" | ||
}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"[0, 1, 2, 3, 4, 0, 1, 2, 3, 4]\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"# Initialize model for inference.\n", | ||
"predict_model = tiny_transformer_lm(mode='predict')\n", | ||
"predict_signature = trax.shapes.ShapeDtype((1,1), dtype=np.int32)\n", | ||
"predict_model.init(predict_signature)\n", | ||
"predict_model.init_from_file(os.path.join(output_dir, \"model.pkl\"),\n", | ||
" weights_only=True)\n", | ||
"# You can also do: predict_model.weights = trainer.model_weights\n", | ||
"\n", | ||
"# Run inference\n", | ||
"prefix = [0, 1, 2, 3, 4, 0] # Change non-0 digits to see if it's copying\n", | ||
"cur_input = np.array([[0]])\n", | ||
"result = []\n", | ||
"for i in range(10):\n", | ||
" logits = predict_model(cur_input)\n", | ||
" next_input = np.argmax(logits[0, 0, :], axis=-1)\n", | ||
" if i \u003c len(prefix) - 1:\n", | ||
" next_input = prefix[i]\n", | ||
" cur_input = np.array([[next_input]])\n", | ||
" result.append(int(next_input)) # Append to the result\n", | ||
"print(result)" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"colab": { | ||
"collapsed_sections": [], | ||
"last_runtime": { | ||
"build_target": "//learning/deepmind/dm_python:dm_notebook3", | ||
"kind": "private" | ||
}, | ||
"name": "Trax Quick Intro", | ||
"provenance": [ | ||
{ | ||
"file_id": "1v1GvTkEFjMH_1c-bdS7JzNS70u9RUEHV", | ||
"timestamp": 1578964243645 | ||
}, | ||
{ | ||
"file_id": "1SplqILjJr_ZqXcIUkNIk0tSbthfhYm07", | ||
"timestamp": 1572044421118 | ||
}, | ||
{ | ||
"file_id": "intro.ipynb", | ||
"timestamp": 1571858674399 | ||
}, | ||
{ | ||
"file_id": "1sF8QbqJ19ZU6oy5z4GUTt4lgUCjqO6kt", | ||
"timestamp": 1569980697572 | ||
}, | ||
{ | ||
"file_id": "1EH76AWQ_pvT4i8ZXfkv-SCV4MrmllEl5", | ||
"timestamp": 1563927451951 | ||
} | ||
] | ||
}, | ||
"kernelspec": { | ||
"display_name": "Python 3", | ||
"name": "python3" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 0 | ||
} |