From 8356d9a6e15d7b50adf874d4dff9a708190af139 Mon Sep 17 00:00:00 2001 From: Carin Meier Date: Sun, 5 May 2019 21:34:01 -0400 Subject: [PATCH] [Clojure] Add Fine Tuning Sentence Pair Classification BERT Example (#14769) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Rework Bert examples to include QA infer and finetuning * update notebook example and exported markdown * add integration test for the classification * fix tests * add RAT * add another RAT * fix all the typos * Clojure BERT finetuning example: fix CSV parsing * update readme and gitignore * add fix from @davliepmann’s notebook parsing * feedback from @daveliepmann * fix running of example and don’t show very first batch on callback speedometer * rerun the notebook and save results * remove bert stuff from main .gitignore * re-putting the license back on after regen * fix integration test --- .gitignore | 2 +- contrib/clojure-package/.gitignore | 2 + .../examples/{bert-qa => bert}/.gitignore | 8 +- .../examples/{bert-qa => bert}/README.md | 80 ++- .../examples/bert/fine-tune-bert.ipynb | 510 ++++++++++++++++++ .../examples/bert/fine-tune-bert.md | 371 +++++++++++++ .../{bert-qa => bert}/get_bert_data.sh | 5 +- .../examples/{bert-qa => bert}/project.clj | 16 +- .../{bert-qa => bert}/squad-samples.edn | 0 .../src/bert/bert_sentence_classification.clj | 160 ++++++ .../src/bert_qa => bert/src/bert}/infer.clj | 67 +-- .../examples/bert/src/bert/util.clj | 52 ++ .../bert_sentence_classification_test.clj | 86 +++ .../bert_qa => bert/test/bert}/infer_test.clj | 9 +- .../src/org/apache/clojure_mxnet/callback.clj | 9 +- 15 files changed, 1310 insertions(+), 67 deletions(-) rename contrib/clojure-package/examples/{bert-qa => bert}/.gitignore (69%) rename contrib/clojure-package/examples/{bert-qa => bert}/README.md (50%) create mode 100644 contrib/clojure-package/examples/bert/fine-tune-bert.ipynb create mode 100644 contrib/clojure-package/examples/bert/fine-tune-bert.md rename contrib/clojure-package/examples/{bert-qa => bert}/get_bert_data.sh (73%) rename contrib/clojure-package/examples/{bert-qa => bert}/project.clj (63%) rename contrib/clojure-package/examples/{bert-qa => bert}/squad-samples.edn (100%) create mode 100644 contrib/clojure-package/examples/bert/src/bert/bert_sentence_classification.clj rename contrib/clojure-package/examples/{bert-qa/src/bert_qa => bert/src/bert}/infer.clj (73%) create mode 100644 contrib/clojure-package/examples/bert/src/bert/util.clj create mode 100644 contrib/clojure-package/examples/bert/test/bert/bert_sentence_classification_test.clj rename contrib/clojure-package/examples/{bert-qa/test/bert_qa => bert/test/bert}/infer_test.clj (91%) diff --git a/.gitignore b/.gitignore index 705ef92da0e8..59ca0d434e57 100644 --- a/.gitignore +++ b/.gitignore @@ -161,4 +161,4 @@ tests/mxnet_unit_tests coverage.xml # Local CMake build config -cmake_options.yml +cmake_options.yml \ No newline at end of file diff --git a/contrib/clojure-package/.gitignore b/contrib/clojure-package/.gitignore index 8efd090255a5..884834e3843e 100644 --- a/contrib/clojure-package/.gitignore +++ b/contrib/clojure-package/.gitignore @@ -40,9 +40,11 @@ src/.DS_Store src/org/.DS_Store test/test-ndarray.clj test/test-ndarray-random.clj +test/test-ndarray-random-api.clj test/test-ndarray-api.clj test/test-symbol.clj test/test-symbol-random.clj +test/test-symbol-random-api.clj test/test-symbol-api.clj src/org/apache/clojure_mxnet/gen/* diff --git a/contrib/clojure-package/examples/bert-qa/.gitignore b/contrib/clojure-package/examples/bert/.gitignore similarity index 69% rename from contrib/clojure-package/examples/bert-qa/.gitignore rename to contrib/clojure-package/examples/bert/.gitignore index d18f225992a9..70c55267e7ab 100644 --- a/contrib/clojure-package/examples/bert-qa/.gitignore +++ b/contrib/clojure-package/examples/bert/.gitignore @@ -1,7 +1,6 @@ /target /classes /checkouts -profiles.clj pom.xml pom.xml.asc *.jar @@ -10,3 +9,10 @@ pom.xml.asc /.nrepl-port .hgignore .hg/ +data/* +model/* +*~ +*.params +*.states +*.json + diff --git a/contrib/clojure-package/examples/bert-qa/README.md b/contrib/clojure-package/examples/bert/README.md similarity index 50% rename from contrib/clojure-package/examples/bert-qa/README.md rename to contrib/clojure-package/examples/bert/README.md index 55f13e671c00..7285773e9706 100644 --- a/contrib/clojure-package/examples/bert-qa/README.md +++ b/contrib/clojure-package/examples/bert/README.md @@ -16,7 +16,11 @@ -# bert-qa +# BERT + +There are two examples showcasing the power of BERT. One is BERT-QA for inference and the other is BERT Sentence Pair Classification which uses fine tuning of the BERT base model. For more information about BERT please read [http://jalammar.github.io/illustrated-bert/](http://jalammar.github.io/illustrated-bert/). + +## bert-qa **This example was based off of the Java API one. It shows how to do inference with a pre-trained BERT network that is trained on Questions and Answers using the [SQuAD Dataset](https://rajpurkar.github.io/SQuAD-explorer/)** @@ -33,14 +37,16 @@ Example: :input-question "Along with geothermal and nuclear, what is a notable non-combustion heat source?" :ground-truth-answers ["solar" "solar power" - "solar power, nuclear power or geothermal energysolar"]} + "solar power, nuclear power or geothermal energy solar"]} ``` The prediction in this case would be `solar power` -## Setup Guide +### Setup Guide -### Step 1: Download the model +Note: If you have trouble with your REPL and cider, please comment out the `lein-jupyter` plugin. There are some conflicts with cider. + +#### Step 1: Download the model For this tutorial, you can get the model and vocabulary by running following bash file. This script will use `wget` to download these artifacts from AWS S3. @@ -53,14 +59,14 @@ From the example directory: Some sample questions and answers are provide in the `squad-sample.edn` file. Some are taken directly from the SQuAD dataset and one was just made up. Feel free to edit the file and add your own! -## To run +### To run * `lein install` in the root of the main project directory * cd into this project directory and do `lein run`. This will execute the cpu version. * `lein run` or `lein run :cpu` to run with cpu * `lein run :gpu` to run with gpu -## Background +### Background To learn more about how BERT works in MXNet, please follow this [MXNet Gluon tutorial on NLP using BERT](https://medium.com/apache-mxnet/gluon-nlp-bert-6a489bdd3340). @@ -73,7 +79,7 @@ The original description can be found in the [MXNet GluonNLP model zoo](https:// python static_finetune_squad.py --optimizer adam --accumulate 2 --batch_size 6 --lr 3e-5 --epochs 2 --gpu 0 --export ``` -This script will generate `json` and `param` fles that are the standard MXNet model files. +This script will generate `json` and `param` files that are the standard MXNet model files. By default, this model are using `bert_12_768_12` model with extra layers for QA jobs. After that, to be able to use it in Java, we need to export the dictionary from the script to parse the text @@ -88,3 +94,63 @@ f.close() This would export the token vocabulary in json format. Once you have these three files, you will be able to run this example without problems. +## Fine-tuning Sentence Pair Classification with BERT + +This was based off of the great tutorial for in Gluon-NLP [https://gluon-nlp.mxnet.io/examples/sentence_embedding/bert.html](https://gluon-nlp.mxnet.io/examples/sentence_embedding/bert.html). + +We use the pre-trained BERT model that was exported from GluonNLP via the `scripts/bert/staticbert/static_export_base.py` running `python static_export_base.py --seq_length 128`. For convenience, the model has been downloaded for you by running the get_bert_data.sh file in the root directory of this example. + +It will fine tune the base bert model for use in a classification task for 3 epochs. + + +### Setup Guide + +#### Step 1: Download the model + +For this tutorial, you can get the model and vocabulary by running following bash file. This script will use `wget` to download these artifacts from AWS S3. + +From the example directory: + +```bash +./get_bert_data.sh +``` + +### To run the notebook walkthrough + +There is a Jupyter notebook that uses the `lein jupyter` plugin to be able to execute Clojure code in project setting. The first time that you run it you will need to install the kernel with `lein jupyter install-kernel`. After that you can open the notebook in the project directory with `lein jupyter notebook`. + +There is also an exported copy of the walkthrough to markdown `fine-tune-bert.md`. + + +### To run + +* `lein install` in the root of the main project directory +* cd into this project directory and do `lein run`. This will execute the cpu version. + +`lein run -m bert.bert-sentence-classification :cpu` - to run with cpu +`lein run -m bert.bert-sentence-classification :gpu` - to run with gpu + +By default it will run 3 epochs, you can control the number of epochs with: + +`lein run -m bert.bert-sentence-classification :cpu 1` to run just 1 epoch + + +Sample results from cpu run on OSX +``` +INFO org.apache.mxnet.module.BaseModule: Epoch[1] Train-accuracy=0.65384614 +INFO org.apache.mxnet.module.BaseModule: Epoch[1] Time cost=464187 +INFO org.apache.mxnet.Callback$Speedometer: Epoch[2] Batch [1] Speed: 0.91 samples/sec Train-accuracy=0.656250 +INFO org.apache.mxnet.Callback$Speedometer: Epoch[2] Batch [2] Speed: 0.90 samples/sec Train-accuracy=0.656250 +INFO org.apache.mxnet.Callback$Speedometer: Epoch[2] Batch [3] Speed: 0.91 samples/sec Train-accuracy=0.687500 +INFO org.apache.mxnet.Callback$Speedometer: Epoch[2] Batch [4] Speed: 0.90 samples/sec Train-accuracy=0.693750 +INFO org.apache.mxnet.Callback$Speedometer: Epoch[2] Batch [5] Speed: 0.91 samples/sec Train-accuracy=0.703125 +INFO org.apache.mxnet.Callback$Speedometer: Epoch[2] Batch [6] Speed: 0.92 samples/sec Train-accuracy=0.696429 +INFO org.apache.mxnet.Callback$Speedometer: Epoch[2] Batch [7] Speed: 0.91 samples/sec Train-accuracy=0.699219 +INFO org.apache.mxnet.Callback$Speedometer: Epoch[2] Batch [8] Speed: 0.90 samples/sec Train-accuracy=0.701389 +INFO org.apache.mxnet.Callback$Speedometer: Epoch[2] Batch [9] Speed: 0.90 samples/sec Train-accuracy=0.690625 +INFO org.apache.mxnet.Callback$Speedometer: Epoch[2] Batch [10] Speed: 0.89 samples/sec Train-accuracy=0.690341 +INFO org.apache.mxnet.Callback$Speedometer: Epoch[2] Batch [11] Speed: 0.90 samples/sec Train-accuracy=0.695313 +INFO org.apache.mxnet.Callback$Speedometer: Epoch[2] Batch [12] Speed: 0.91 samples/sec Train-accuracy=0.701923 +INFO org.apache.mxnet.module.BaseModule: Epoch[2] Train-accuracy=0.7019231 +INFO org.apache.mxnet.module.BaseModule: Epoch[2] Time cost=459809 +```` diff --git a/contrib/clojure-package/examples/bert/fine-tune-bert.ipynb b/contrib/clojure-package/examples/bert/fine-tune-bert.ipynb new file mode 100644 index 000000000000..425a9993ad93 --- /dev/null +++ b/contrib/clojure-package/examples/bert/fine-tune-bert.ipynb @@ -0,0 +1,510 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Fine-tuning Sentence Pair Classification with BERT\n", + "\n", + "**This tutorial is based off of the Gluon NLP one here https://gluon-nlp.mxnet.io/examples/sentence_embedding/bert.html**\n", + "\n", + "Pre-trained language representations have been shown to improve many downstream NLP tasks such as question answering, and natural language inference. To apply pre-trained representations to these tasks, there are two strategies:\n", + "\n", + "feature-based approach, which uses the pre-trained representations as additional features to the downstream task.\n", + "fine-tuning based approach, which trains the downstream tasks by fine-tuning pre-trained parameters.\n", + "While feature-based approaches such as ELMo [3] (introduced in the previous tutorial) are effective in improving many downstream tasks, they require task-specific architectures. Devlin, Jacob, et al proposed BERT [1] (Bidirectional Encoder Representations from Transformers), which fine-tunes deep bidirectional representations on a wide range of tasks with minimal task-specific parameters, and obtained state- of-the-art results.\n", + "\n", + "In this tutorial, we will focus on fine-tuning with the pre-trained BERT model to classify semantically equivalent sentence pairs. Specifically, we will:\n", + "\n", + "load the state-of-the-art pre-trained BERT model and attach an additional layer for classification,\n", + "process and transform sentence pair data for the task at hand, and\n", + "fine-tune BERT model for sentence classification.\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Preparation\n", + "\n", + "To run this tutorial locally, in the example directory:\n", + "\n", + "1. Get the model and supporting data by running `get_bert_data.sh`. \n", + "2. This Jupyter Notebook uses the lein-jupyter plugin to be able to execute Clojure code in project setting. The first time that you run it you will need to install the kernel with`lein jupyter install-kernel`. After that you can open the notebook in the project directory with `lein jupyter notebook`." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Load requirements\n", + "\n", + "We need to load up all the namespace requires" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "(ns bert.bert-sentence-classification\n", + " (:require [bert.util :as bert-util]\n", + " [clojure-csv.core :as csv]\n", + " [clojure.java.shell :refer [sh]]\n", + " [clojure.string :as string]\n", + " [org.apache.clojure-mxnet.callback :as callback]\n", + " [org.apache.clojure-mxnet.context :as context]\n", + " [org.apache.clojure-mxnet.dtype :as dtype]\n", + " [org.apache.clojure-mxnet.eval-metric :as eval-metric]\n", + " [org.apache.clojure-mxnet.io :as mx-io]\n", + " [org.apache.clojure-mxnet.layout :as layout]\n", + " [org.apache.clojure-mxnet.module :as m]\n", + " [org.apache.clojure-mxnet.ndarray :as ndarray]\n", + " [org.apache.clojure-mxnet.optimizer :as optimizer]\n", + " [org.apache.clojure-mxnet.symbol :as sym]))\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "collapsed": true + }, + "source": [ + "# Use the Pre-trained BERT Model\n", + "\n", + "In this tutorial we will use the pre-trained BERT model that was exported from GluonNLP via the `scripts/bert/staticbert/static_export_base.py`. For convenience, the model has been downloaded for you by running the `get_bert_data.sh` file in the root directory of this example." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Get BERT\n", + "\n", + "Let’s first take a look at the BERT model architecture for sentence pair classification below:\n", + "\n", + "![bert](https://gluon-nlp.mxnet.io/_images/bert-sentence-pair.png)\n", + "\n", + "where the model takes a pair of sequences and pools the representation of the first token in the sequence. Note that the original BERT model was trained for masked language model and next sentence prediction tasks, which includes layers for language model decoding and classification. These layers will not be used for fine-tuning sentence pair classification.\n", + "\n", + "Let's load the pre-trained BERT using the module API in MXNet." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "collapsed": false + }, + "outputs": [ + { + "data": { + "text/plain": [ + "#'bert.bert-sentence-classification/bert-base" + ] + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "(def model-path-prefix \"data/static_bert_base_net\")\n", + ";; the vocabulary used in the model\n", + "(def vocab (bert-util/get-vocab))\n", + ";; the input question\n", + ";; the maximum length of the sequence\n", + "(def seq-length 128)\n", + "\n", + "(def bert-base (m/load-checkpoint {:prefix model-path-prefix :epoch 0}))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Model Definition for Sentence Pair Classification\n", + "\n", + "Now that we have loaded the BERT model, we only need to attach an additional layer for classification. We can do this by defining a fine tune model from the symbol of the base BERT model." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "collapsed": false + }, + "outputs": [ + { + "data": { + "text/plain": [ + "#'bert.bert-sentence-classification/model-sym" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "(defn fine-tune-model\n", + " \"msymbol: the pretrained network symbol\n", + " num-classes: the number of classes for the fine-tune datasets\n", + " dropout: the dropout rate\"\n", + " [msymbol {:keys [num-classes dropout]}]\n", + " (as-> msymbol data\n", + " (sym/dropout {:data data :p dropout})\n", + " (sym/fully-connected \"fc-finetune\" {:data data :num-hidden num-classes})\n", + " (sym/softmax-output \"softmax\" {:data data})))\n", + "\n", + "(def model-sym (fine-tune-model (m/symbol bert-base) {:num-classes 2 :dropout 0.1}))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Data Preprocessing for BERT\n", + "\n", + "## Dataset\n", + "\n", + "For demonstration purpose, we use the dev set of the Microsoft Research Paraphrase Corpus dataset. The file is named ‘dev.tsv’ and was downloaded as part of the data script. Let’s take a look at the raw dataset." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": { + "collapsed": false + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Quality\t#1 ID\t#2 ID\t#1 String\t#2 String\n", + "1\t1355540\t1355592\tHe said the foodservice pie business doesn 't fit the company 's long-term growth strategy .\t\" The foodservice pie business does not fit our long-term growth strategy .\n", + "0\t2029631\t2029565\tMagnarelli said Racicot hated the Iraqi regime and looked forward to using his long years of training in the war .\tHis wife said he was \" 100 percent behind George Bush \" and looked forward to using his years of training in the war .\n", + "0\t487993\t487952\tThe dollar was at 116.92 yen against the yen , flat on the session , and at 1.2891 against the Swiss franc , also flat .\tThe dollar was at 116.78 yen JPY = , virtually flat on the session , and at 1.2871 against the Swiss franc CHF = , down 0.1 percent .\n", + "1\t1989515\t1989458\tThe AFL-CIO is waiting until October to decide if it will endorse a candidate .\tThe AFL-CIO announced Wednesday that it will decide in October whether to endorse a candidate before the primaries .\n", + "\n" + ] + } + ], + "source": [ + "(-> (sh \"head\" \"-n\" \"5\" \"data/dev.tsv\") \n", + " :out\n", + " println)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The file contains 5 columns, separated by tabs (i.e. ‘\n", + "\n", + "\\t\n", + "‘). The first line of the file explains each of these columns: 0. the label indicating whether the two sentences are semantically equivalent 1. the id of the first sentence in this sample 2. the id of the second sentence in this sample 3. the content of the first sentence 4. the content of the second sentence\n", + "\n", + "For our task, we are interested in the 0th, 3rd and 4th columns. " + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "collapsed": false + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "He said the foodservice pie business doesn 't fit the company 's long-term growth strategy .\n", + " The foodservice pie business does not fit our long-term growth strategy .\n", + "1\n" + ] + } + ], + "source": [ + "(def raw-file \n", + " (csv/parse-csv (string/replace (slurp \"data/dev.tsv\") \"\\\"\" \"\")\n", + " :delimiter \\tab\n", + " :strict true))\n", + "\n", + "(def data-train-raw (->> raw-file\n", + " (mapv #(vals (select-keys % [3 4 0])))\n", + " (rest) ; drop header\n", + " (into [])))\n", + "\n", + "(def sample (first data-train-raw))\n", + "(println (nth sample 0)) ;;;sentence a\n", + "(println (nth sample 1)) ;; sentence b\n", + "(println (nth sample 2)) ;; 1 means equivalent, 0 means not equivalent" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "To use the pre-trained BERT model, we need to preprocess the data in the same way it was trained. The following figure shows the input representation in BERT:\n", + "\n", + "![bert-input](https://gluon-nlp.mxnet.io/_images/bert-embed.png)\n", + "\n", + "We will do pre-processing on the inputs to get them in the right format and to perform the following transformations:\n", + "- tokenize the input sequences\n", + "- insert [CLS] at the beginning\n", + "- insert [SEP] between sentence one and sentence two, and at the end - generate segment ids to indicate whether a token belongs to the first sequence or the second sequence.\n", + "- generate valid length" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": { + "collapsed": false + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Train Count is = 408\n", + "[PAD] token id = 1\n", + "[CLS] token id = 2\n", + "[SEP] token id = 3\n", + "token ids = \n", + " [2 2002 2056 1996 0 11345 2449 2987 0 4906 1996 2194 0 0 3930 5656 0 1012 3 0 1996 0 11345 2449 2515 2025 4906 2256 0 3930 5656 0 1012 3 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1]\n", + "segment ids = \n", + " [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]\n", + "valid length = \n", + " [31]\n", + "label = \n", + " [0]\n" + ] + } + ], + "source": [ + "(defn pre-processing\n", + " \"Preprocesses the sentences in the format that BERT is expecting\"\n", + " [ctx idx->token token->idx train-item]\n", + " (let [[sentence-a sentence-b label] train-item\n", + " ;;; pre-processing tokenize sentence\n", + " token-1 (bert-util/tokenize (string/lower-case sentence-a))\n", + " token-2 (bert-util/tokenize (string/lower-case sentence-b))\n", + " valid-length (+ (count token-1) (count token-2))\n", + " ;;; generate token types [0000...1111...0000]\n", + " qa-embedded (into (bert-util/pad [] 0 (count token-1))\n", + " (bert-util/pad [] 1 (count token-2)))\n", + " token-types (bert-util/pad qa-embedded 0 seq-length)\n", + " ;;; make BERT pre-processing standard\n", + " token-2 (conj token-2 \"[SEP]\")\n", + " token-1 (into [] (concat [\"[CLS]\"] token-1 [\"[SEP]\"] token-2))\n", + " tokens (bert-util/pad token-1 \"[PAD]\" seq-length)\n", + " ;;; pre-processing - token to index translation\n", + " indexes (bert-util/tokens->idxs token->idx tokens)]\n", + " {:input-batch [indexes\n", + " token-types\n", + " [valid-length]]\n", + " :label (if (= \"0\" label)\n", + " [0]\n", + " [1])\n", + " :tokens tokens\n", + " :train-item train-item}))\n", + "\n", + "(def idx->token (:idx->token vocab))\n", + "(def token->idx (:token->idx vocab))\n", + "(def dev (context/default-context))\n", + "(def processed-datas (mapv #(pre-processing dev idx->token token->idx %) data-train-raw))\n", + "(def train-count (count processed-datas))\n", + "(println \"Train Count is = \" train-count)\n", + "(println \"[PAD] token id = \" (get token->idx \"[PAD]\"))\n", + "(println \"[CLS] token id = \" (get token->idx \"[CLS]\"))\n", + "(println \"[SEP] token id = \" (get token->idx \"[SEP]\"))\n", + "(println \"token ids = \\n\"(-> (first processed-datas) :input-batch first)) \n", + "(println \"segment ids = \\n\"(-> (first processed-datas) :input-batch second)) \n", + "(println \"valid length = \\n\" (-> (first processed-datas) :input-batch last)) \n", + "(println \"label = \\n\" (-> (second processed-datas) :label)) \n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now that we have all the input-batches for each row, we are going to slice them up column-wise and create NDArray Iterators that we can use in training" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": { + "collapsed": false + }, + "outputs": [ + { + "data": { + "text/plain": [ + "#object[org.apache.mxnet.io.NDArrayIter 0x2583097d \"non-empty iterator\"]" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "(defn slice-inputs-data\n", + " \"Each sentence pair had to be processed as a row. This breaks all\n", + " the rows up into a column for creating a NDArray\"\n", + " [processed-datas n]\n", + " (->> processed-datas\n", + " (mapv #(nth (:input-batch %) n))\n", + " (flatten)\n", + " (into [])))\n", + "\n", + "(def prepared-data {:data0s (slice-inputs-data processed-datas 0)\n", + " :data1s (slice-inputs-data processed-datas 1)\n", + " :data2s (slice-inputs-data processed-datas 2)\n", + " :labels (->> (mapv :label processed-datas)\n", + " (flatten)\n", + " (into []))\n", + " :train-num (count processed-datas)})\n", + "\n", + "(def batch-size 32)\n", + "\n", + "(def train-data\n", + " (let [{:keys [data0s data1s data2s labels train-num]} prepared-data\n", + " data-desc0 (mx-io/data-desc {:name \"data0\"\n", + " :shape [train-num seq-length]\n", + " :dtype dtype/FLOAT32\n", + " :layout layout/NT})\n", + " data-desc1 (mx-io/data-desc {:name \"data1\"\n", + " :shape [train-num seq-length]\n", + " :dtype dtype/FLOAT32\n", + " :layout layout/NT})\n", + " data-desc2 (mx-io/data-desc {:name \"data2\"\n", + " :shape [train-num]\n", + " :dtype dtype/FLOAT32\n", + " :layout layout/N})\n", + " label-desc (mx-io/data-desc {:name \"softmax_label\"\n", + " :shape [train-num]\n", + " :dtype dtype/FLOAT32\n", + " :layout layout/N})]\n", + " (mx-io/ndarray-iter {data-desc0 (ndarray/array data0s [train-num seq-length]\n", + " {:ctx dev})\n", + " data-desc1 (ndarray/array data1s [train-num seq-length]\n", + " {:ctx dev})\n", + " data-desc2 (ndarray/array data2s [train-num]\n", + " {:ctx dev})}\n", + " {:label {label-desc (ndarray/array labels [train-num]\n", + " {:ctx dev})}\n", + " :data-batch-size batch-size})))\n", + "train-data" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Fine-tune BERT Model\n", + "\n", + "Putting everything together, now we can fine-tune the model with a few epochs. For demonstration, we use a fixed learning rate and skip validation steps." + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": { + "collapsed": false + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Speedometer: epoch 0 count 1 metric [accuracy 0.609375]\n", + "Speedometer: epoch 0 count 2 metric [accuracy 0.6041667]\n", + "Speedometer: epoch 0 count 3 metric [accuracy 0.5703125]\n", + "Speedometer: epoch 0 count 4 metric [accuracy 0.55625]\n", + "Speedometer: epoch 0 count 5 metric [accuracy 0.5625]\n", + "Speedometer: epoch 0 count 6 metric [accuracy 0.55803573]\n", + "Speedometer: epoch 0 count 7 metric [accuracy 0.5625]\n", + "Speedometer: epoch 0 count 8 metric [accuracy 0.5798611]\n", + "Speedometer: epoch 0 count 9 metric [accuracy 0.584375]\n", + "Speedometer: epoch 0 count 10 metric [accuracy 0.57670456]\n", + "Speedometer: epoch 0 count 11 metric [accuracy 0.5807292]\n", + "Speedometer: epoch 0 count 12 metric [accuracy 0.5793269]\n", + "Speedometer: epoch 1 count 1 metric [accuracy 0.5625]\n", + "Speedometer: epoch 1 count 2 metric [accuracy 0.5520833]\n", + "Speedometer: epoch 1 count 3 metric [accuracy 0.5859375]\n", + "Speedometer: epoch 1 count 4 metric [accuracy 0.59375]\n", + "Speedometer: epoch 1 count 5 metric [accuracy 0.6145833]\n", + "Speedometer: epoch 1 count 6 metric [accuracy 0.625]\n", + "Speedometer: epoch 1 count 7 metric [accuracy 0.640625]\n", + "Speedometer: epoch 1 count 8 metric [accuracy 0.6527778]\n", + "Speedometer: epoch 1 count 9 metric [accuracy 0.653125]\n", + "Speedometer: epoch 1 count 10 metric [accuracy 0.6448864]\n", + "Speedometer: epoch 1 count 11 metric [accuracy 0.640625]\n", + "Speedometer: epoch 1 count 12 metric [accuracy 0.6418269]\n", + "Speedometer: epoch 2 count 1 metric [accuracy 0.671875]\n", + "Speedometer: epoch 2 count 2 metric [accuracy 0.7083333]\n", + "Speedometer: epoch 2 count 3 metric [accuracy 0.7109375]\n", + "Speedometer: epoch 2 count 4 metric [accuracy 0.725]\n", + "Speedometer: epoch 2 count 5 metric [accuracy 0.7239583]\n", + "Speedometer: epoch 2 count 6 metric [accuracy 0.71875]\n", + "Speedometer: epoch 2 count 7 metric [accuracy 0.734375]\n", + "Speedometer: epoch 2 count 8 metric [accuracy 0.7361111]\n", + "Speedometer: epoch 2 count 9 metric [accuracy 0.721875]\n", + "Speedometer: epoch 2 count 10 metric [accuracy 0.71022725]\n", + "Speedometer: epoch 2 count 11 metric [accuracy 0.6979167]\n", + "Speedometer: epoch 2 count 12 metric [accuracy 0.7019231]\n" + ] + }, + { + "data": { + "text/plain": [ + "#object[org.apache.mxnet.module.Module 0x73c42ae5 \"org.apache.mxnet.module.Module@73c42ae5\"]" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "(def num-epoch 3)\n", + "\n", + "(def fine-tune-model (m/module model-sym {:contexts [dev]\n", + " :data-names [\"data0\" \"data1\" \"data2\"]}))\n", + "\n", + "(m/fit fine-tune-model {:train-data train-data :num-epoch num-epoch\n", + " :fit-params (m/fit-params {:allow-missing true\n", + " :arg-params (m/arg-params bert-base)\n", + " :aux-params (m/aux-params bert-base)\n", + " :optimizer (optimizer/adam {:learning-rate 5e-6 :episilon 1e-9})\n", + " :batch-end-callback (callback/speedometer batch-size 1)})})\n" + ] + } + ], + "metadata": { + "anaconda-cloud": {}, + "kernelspec": { + "display_name": "Lein-Clojure", + "language": "clojure", + "name": "lein-clojure" + }, + "language_info": { + "file_extension": ".clj", + "mimetype": "text/x-clojure", + "name": "clojure", + "version": "1.9.0" + } + }, + "nbformat": 4, + "nbformat_minor": 1 +} diff --git a/contrib/clojure-package/examples/bert/fine-tune-bert.md b/contrib/clojure-package/examples/bert/fine-tune-bert.md new file mode 100644 index 000000000000..4e6681e7aade --- /dev/null +++ b/contrib/clojure-package/examples/bert/fine-tune-bert.md @@ -0,0 +1,371 @@ + + + + + + + + + + + + + + + + + + +# Fine-tuning Sentence Pair Classification with BERT + +**This tutorial is based off of the Gluon NLP one here https://gluon-nlp.mxnet.io/examples/sentence_embedding/bert.html** + +Pre-trained language representations have been shown to improve many downstream NLP tasks such as question answering, and natural language inference. To apply pre-trained representations to these tasks, there are two strategies: + +feature-based approach, which uses the pre-trained representations as additional features to the downstream task. +fine-tuning based approach, which trains the downstream tasks by fine-tuning pre-trained parameters. +While feature-based approaches such as ELMo [3] (introduced in the previous tutorial) are effective in improving many downstream tasks, they require task-specific architectures. Devlin, Jacob, et al proposed BERT [1] (Bidirectional Encoder Representations from Transformers), which fine-tunes deep bidirectional representations on a wide range of tasks with minimal task-specific parameters, and obtained state- of-the-art results. + +In this tutorial, we will focus on fine-tuning with the pre-trained BERT model to classify semantically equivalent sentence pairs. Specifically, we will: + +load the state-of-the-art pre-trained BERT model and attach an additional layer for classification, +process and transform sentence pair data for the task at hand, and +fine-tune BERT model for sentence classification. + + + +## Preparation + +To run this tutorial locally, in the example directory: + +1. Get the model and supporting data by running `get_bert_data.sh`. +2. This Jupyter Notebook uses the lein-jupyter plugin to be able to execute Clojure code in project setting. The first time that you run it you will need to install the kernel with`lein jupyter install-kernel`. After that you can open the notebook in the project directory with `lein jupyter notebook`. + +## Load requirements + +We need to load up all the namespace requires + + +```clojure +(ns bert.bert-sentence-classification + (:require [bert.util :as bert-util] + [clojure-csv.core :as csv] + [clojure.java.shell :refer [sh]] + [clojure.string :as string] + [org.apache.clojure-mxnet.callback :as callback] + [org.apache.clojure-mxnet.context :as context] + [org.apache.clojure-mxnet.dtype :as dtype] + [org.apache.clojure-mxnet.eval-metric :as eval-metric] + [org.apache.clojure-mxnet.io :as mx-io] + [org.apache.clojure-mxnet.layout :as layout] + [org.apache.clojure-mxnet.module :as m] + [org.apache.clojure-mxnet.ndarray :as ndarray] + [org.apache.clojure-mxnet.optimizer :as optimizer] + [org.apache.clojure-mxnet.symbol :as sym])) + +``` + +# Use the Pre-trained BERT Model + +In this tutorial we will use the pre-trained BERT model that was exported from GluonNLP via the `scripts/bert/staticbert/static_export_base.py`. For convenience, the model has been downloaded for you by running the `get_bert_data.sh` file in the root directory of this example. + +## Get BERT + +Let’s first take a look at the BERT model architecture for sentence pair classification below: + +![bert](https://gluon-nlp.mxnet.io/_images/bert-sentence-pair.png) + +where the model takes a pair of sequences and pools the representation of the first token in the sequence. Note that the original BERT model was trained for masked language model and next sentence prediction tasks, which includes layers for language model decoding and classification. These layers will not be used for fine-tuning sentence pair classification. + +Let's load the pre-trained BERT using the module API in MXNet. + + +```clojure +(def model-path-prefix "data/static_bert_base_net") +;; the vocabulary used in the model +(def vocab (bert-util/get-vocab)) +;; the input question +;; the maximum length of the sequence +(def seq-length 128) + +(def bert-base (m/load-checkpoint {:prefix model-path-prefix :epoch 0})) +``` + + + + + #'bert.bert-sentence-classification/bert-base + + + +## Model Definition for Sentence Pair Classification + +Now that we have loaded the BERT model, we only need to attach an additional layer for classification. We can do this by defining a fine tune model from the symbol of the base BERT model. + + +```clojure +(defn fine-tune-model + "msymbol: the pretrained network symbol + num-classes: the number of classes for the fine-tune datasets + dropout: the dropout rate" + [msymbol {:keys [num-classes dropout]}] + (as-> msymbol data + (sym/dropout {:data data :p dropout}) + (sym/fully-connected "fc-finetune" {:data data :num-hidden num-classes}) + (sym/softmax-output "softmax" {:data data}))) + +(def model-sym (fine-tune-model (m/symbol bert-base) {:num-classes 2 :dropout 0.1})) +``` + + + + + #'bert.bert-sentence-classification/model-sym + + + +# Data Preprocessing for BERT + +## Dataset + +For demonstration purpose, we use the dev set of the Microsoft Research Paraphrase Corpus dataset. The file is named ‘dev.tsv’ and was downloaded as part of the data script. Let’s take a look at the raw dataset. + + +```clojure +(-> (sh "head" "-n" "5" "data/dev.tsv") + :out + println) +``` + + Quality #1 ID #2 ID #1 String #2 String + 1 1355540 1355592 He said the foodservice pie business doesn 't fit the company 's long-term growth strategy . " The foodservice pie business does not fit our long-term growth strategy . + 0 2029631 2029565 Magnarelli said Racicot hated the Iraqi regime and looked forward to using his long years of training in the war . His wife said he was " 100 percent behind George Bush " and looked forward to using his years of training in the war . + 0 487993 487952 The dollar was at 116.92 yen against the yen , flat on the session , and at 1.2891 against the Swiss franc , also flat . The dollar was at 116.78 yen JPY = , virtually flat on the session , and at 1.2871 against the Swiss franc CHF = , down 0.1 percent . + 1 1989515 1989458 The AFL-CIO is waiting until October to decide if it will endorse a candidate . The AFL-CIO announced Wednesday that it will decide in October whether to endorse a candidate before the primaries . + + + +The file contains 5 columns, separated by tabs (i.e. ‘ + +\t +‘). The first line of the file explains each of these columns: 0. the label indicating whether the two sentences are semantically equivalent 1. the id of the first sentence in this sample 2. the id of the second sentence in this sample 3. the content of the first sentence 4. the content of the second sentence + +For our task, we are interested in the 0th, 3rd and 4th columns. + + +```clojure +(def raw-file + (csv/parse-csv (string/replace (slurp "data/dev.tsv") "\"" "") + :delimiter \tab + :strict true)) + +(def data-train-raw (->> raw-file + (mapv #(vals (select-keys % [3 4 0]))) + (rest) ; drop header + (into []))) + +(def sample (first data-train-raw)) +(println (nth sample 0)) ;;;sentence a +(println (nth sample 1)) ;; sentence b +(println (nth sample 2)) ;; 1 means equivalent, 0 means not equivalent +``` + + He said the foodservice pie business doesn 't fit the company 's long-term growth strategy . + The foodservice pie business does not fit our long-term growth strategy . + 1 + + +To use the pre-trained BERT model, we need to preprocess the data in the same way it was trained. The following figure shows the input representation in BERT: + +![bert-input](https://gluon-nlp.mxnet.io/_images/bert-embed.png) + +We will do pre-processing on the inputs to get them in the right format and to perform the following transformations: +- tokenize the input sequences +- insert [CLS] at the beginning +- insert [SEP] between sentence one and sentence two, and at the end - generate segment ids to indicate whether a token belongs to the first sequence or the second sequence. +- generate valid length + + +```clojure +(defn pre-processing + "Preprocesses the sentences in the format that BERT is expecting" + [ctx idx->token token->idx train-item] + (let [[sentence-a sentence-b label] train-item + ;;; pre-processing tokenize sentence + token-1 (bert-util/tokenize (string/lower-case sentence-a)) + token-2 (bert-util/tokenize (string/lower-case sentence-b)) + valid-length (+ (count token-1) (count token-2)) + ;;; generate token types [0000...1111...0000] + qa-embedded (into (bert-util/pad [] 0 (count token-1)) + (bert-util/pad [] 1 (count token-2))) + token-types (bert-util/pad qa-embedded 0 seq-length) + ;;; make BERT pre-processing standard + token-2 (conj token-2 "[SEP]") + token-1 (into [] (concat ["[CLS]"] token-1 ["[SEP]"] token-2)) + tokens (bert-util/pad token-1 "[PAD]" seq-length) + ;;; pre-processing - token to index translation + indexes (bert-util/tokens->idxs token->idx tokens)] + {:input-batch [indexes + token-types + [valid-length]] + :label (if (= "0" label) + [0] + [1]) + :tokens tokens + :train-item train-item})) + +(def idx->token (:idx->token vocab)) +(def token->idx (:token->idx vocab)) +(def dev (context/default-context)) +(def processed-datas (mapv #(pre-processing dev idx->token token->idx %) data-train-raw)) +(def train-count (count processed-datas)) +(println "Train Count is = " train-count) +(println "[PAD] token id = " (get token->idx "[PAD]")) +(println "[CLS] token id = " (get token->idx "[CLS]")) +(println "[SEP] token id = " (get token->idx "[SEP]")) +(println "token ids = \n"(-> (first processed-datas) :input-batch first)) +(println "segment ids = \n"(-> (first processed-datas) :input-batch second)) +(println "valid length = \n" (-> (first processed-datas) :input-batch last)) +(println "label = \n" (-> (second processed-datas) :label)) + + +``` + + Train Count is = 408 + [PAD] token id = 1 + [CLS] token id = 2 + [SEP] token id = 3 + token ids = + [2 2002 2056 1996 0 11345 2449 2987 0 4906 1996 2194 0 0 3930 5656 0 1012 3 0 1996 0 11345 2449 2515 2025 4906 2256 0 3930 5656 0 1012 3 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1] + segment ids = + [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0] + valid length = + [31] + label = + [0] + + +Now that we have all the input-batches for each row, we are going to slice them up column-wise and create NDArray Iterators that we can use in training + + +```clojure +(defn slice-inputs-data + "Each sentence pair had to be processed as a row. This breaks all + the rows up into a column for creating a NDArray" + [processed-datas n] + (->> processed-datas + (mapv #(nth (:input-batch %) n)) + (flatten) + (into []))) + +(def prepared-data {:data0s (slice-inputs-data processed-datas 0) + :data1s (slice-inputs-data processed-datas 1) + :data2s (slice-inputs-data processed-datas 2) + :labels (->> (mapv :label processed-datas) + (flatten) + (into [])) + :train-num (count processed-datas)}) + +(def batch-size 32) + +(def train-data + (let [{:keys [data0s data1s data2s labels train-num]} prepared-data + data-desc0 (mx-io/data-desc {:name "data0" + :shape [train-num seq-length] + :dtype dtype/FLOAT32 + :layout layout/NT}) + data-desc1 (mx-io/data-desc {:name "data1" + :shape [train-num seq-length] + :dtype dtype/FLOAT32 + :layout layout/NT}) + data-desc2 (mx-io/data-desc {:name "data2" + :shape [train-num] + :dtype dtype/FLOAT32 + :layout layout/N}) + label-desc (mx-io/data-desc {:name "softmax_label" + :shape [train-num] + :dtype dtype/FLOAT32 + :layout layout/N})] + (mx-io/ndarray-iter {data-desc0 (ndarray/array data0s [train-num seq-length] + {:ctx dev}) + data-desc1 (ndarray/array data1s [train-num seq-length] + {:ctx dev}) + data-desc2 (ndarray/array data2s [train-num] + {:ctx dev})} + {:label {label-desc (ndarray/array labels [train-num] + {:ctx dev})} + :data-batch-size batch-size}))) +train-data +``` + + + + + #object[org.apache.mxnet.io.NDArrayIter 0x2583097d "non-empty iterator"] + + + +# Fine-tune BERT Model + +Putting everything together, now we can fine-tune the model with a few epochs. For demonstration, we use a fixed learning rate and skip validation steps. + + +```clojure +(def num-epoch 3) + +(def fine-tune-model (m/module model-sym {:contexts [dev] + :data-names ["data0" "data1" "data2"]})) + +(m/fit fine-tune-model {:train-data train-data :num-epoch num-epoch + :fit-params (m/fit-params {:allow-missing true + :arg-params (m/arg-params bert-base) + :aux-params (m/aux-params bert-base) + :optimizer (optimizer/adam {:learning-rate 5e-6 :episilon 1e-9}) + :batch-end-callback (callback/speedometer batch-size 1)})}) + +``` + + Speedometer: epoch 0 count 1 metric [accuracy 0.609375] + Speedometer: epoch 0 count 2 metric [accuracy 0.6041667] + Speedometer: epoch 0 count 3 metric [accuracy 0.5703125] + Speedometer: epoch 0 count 4 metric [accuracy 0.55625] + Speedometer: epoch 0 count 5 metric [accuracy 0.5625] + Speedometer: epoch 0 count 6 metric [accuracy 0.55803573] + Speedometer: epoch 0 count 7 metric [accuracy 0.5625] + Speedometer: epoch 0 count 8 metric [accuracy 0.5798611] + Speedometer: epoch 0 count 9 metric [accuracy 0.584375] + Speedometer: epoch 0 count 10 metric [accuracy 0.57670456] + Speedometer: epoch 0 count 11 metric [accuracy 0.5807292] + Speedometer: epoch 0 count 12 metric [accuracy 0.5793269] + Speedometer: epoch 1 count 1 metric [accuracy 0.5625] + Speedometer: epoch 1 count 2 metric [accuracy 0.5520833] + Speedometer: epoch 1 count 3 metric [accuracy 0.5859375] + Speedometer: epoch 1 count 4 metric [accuracy 0.59375] + Speedometer: epoch 1 count 5 metric [accuracy 0.6145833] + Speedometer: epoch 1 count 6 metric [accuracy 0.625] + Speedometer: epoch 1 count 7 metric [accuracy 0.640625] + Speedometer: epoch 1 count 8 metric [accuracy 0.6527778] + Speedometer: epoch 1 count 9 metric [accuracy 0.653125] + Speedometer: epoch 1 count 10 metric [accuracy 0.6448864] + Speedometer: epoch 1 count 11 metric [accuracy 0.640625] + Speedometer: epoch 1 count 12 metric [accuracy 0.6418269] + Speedometer: epoch 2 count 1 metric [accuracy 0.671875] + Speedometer: epoch 2 count 2 metric [accuracy 0.7083333] + Speedometer: epoch 2 count 3 metric [accuracy 0.7109375] + Speedometer: epoch 2 count 4 metric [accuracy 0.725] + Speedometer: epoch 2 count 5 metric [accuracy 0.7239583] + Speedometer: epoch 2 count 6 metric [accuracy 0.71875] + Speedometer: epoch 2 count 7 metric [accuracy 0.734375] + Speedometer: epoch 2 count 8 metric [accuracy 0.7361111] + Speedometer: epoch 2 count 9 metric [accuracy 0.721875] + Speedometer: epoch 2 count 10 metric [accuracy 0.71022725] + Speedometer: epoch 2 count 11 metric [accuracy 0.6979167] + Speedometer: epoch 2 count 12 metric [accuracy 0.7019231] + + + + + + #object[org.apache.mxnet.module.Module 0x73c42ae5 "org.apache.mxnet.module.Module@73c42ae5"] + + diff --git a/contrib/clojure-package/examples/bert-qa/get_bert_data.sh b/contrib/clojure-package/examples/bert/get_bert_data.sh similarity index 73% rename from contrib/clojure-package/examples/bert-qa/get_bert_data.sh rename to contrib/clojure-package/examples/bert/get_bert_data.sh index 603194a03c05..10ed8e9a1f8e 100755 --- a/contrib/clojure-package/examples/bert-qa/get_bert_data.sh +++ b/contrib/clojure-package/examples/bert/get_bert_data.sh @@ -19,11 +19,14 @@ set -e -data_path=model +data_path=data if [ ! -d "$data_path" ]; then mkdir -p "$data_path" curl https://s3.us-east-2.amazonaws.com/mxnet-scala/scala-example-ci/BertQA/vocab.json -o $data_path/vocab.json curl https://s3.us-east-2.amazonaws.com/mxnet-scala/scala-example-ci/BertQA/static_bert_qa-0002.params -o $data_path/static_bert_qa-0002.params curl https://s3.us-east-2.amazonaws.com/mxnet-scala/scala-example-ci/BertQA/static_bert_qa-symbol.json -o $data_path/static_bert_qa-symbol.json + curl https://s3.us-east-2.amazonaws.com/mxnet-scala/scala-example-ci/BertQA/static_bert_base_net-symbol.json -o $data_path/static_bert_base_net-symbol.json + curl https://s3.us-east-2.amazonaws.com/mxnet-scala/scala-example-ci/BertQA/static_bert_base_net-0000.params -o $data_path/static_bert_base_net-0000.params + curl https://raw.githubusercontent.com/dmlc/gluon-nlp/master/docs/examples/sentence_embedding/dev.tsv -o $data_path/dev.tsv fi diff --git a/contrib/clojure-package/examples/bert-qa/project.clj b/contrib/clojure-package/examples/bert/project.clj similarity index 63% rename from contrib/clojure-package/examples/bert-qa/project.clj rename to contrib/clojure-package/examples/bert/project.clj index d256d44d0798..05061b476241 100644 --- a/contrib/clojure-package/examples/bert-qa/project.clj +++ b/contrib/clojure-package/examples/bert/project.clj @@ -16,13 +16,17 @@ ;; -(defproject bert-qa "0.1.0-SNAPSHOT" - :description "BERT QA Example" - :plugins [[lein-cljfmt "0.5.7"]] +(defproject bert "0.1.0-SNAPSHOT" + :description "BERT Examples" + :plugins [[lein-cljfmt "0.5.7"] + ;;; lein-jupyter seems to have some incompatibilities with dependencies with cider + ;;; so if you run into trouble please delete the `lein-juptyter` plugin + [lein-jupyter "0.1.16" :exclusions [org.clojure/tools.nrepl org.clojure/clojure org.codehaus.plexus/plexus-utils org.clojure/tools.reader]]] :dependencies [[org.clojure/clojure "1.9.0"] [org.apache.mxnet.contrib.clojure/clojure-mxnet "1.5.0-SNAPSHOT"] - [cheshire "5.8.1"]] + [cheshire "5.8.1"] + [clojure-csv/clojure-csv "2.0.1"]] :pedantic? :skip :java-source-paths ["src/java"] - :main bert-qa.infer - :repl-options {:init-ns bert-qa.infer}) + :main bert.infer + :repl-options {:init-ns bert.infer}) diff --git a/contrib/clojure-package/examples/bert-qa/squad-samples.edn b/contrib/clojure-package/examples/bert/squad-samples.edn similarity index 100% rename from contrib/clojure-package/examples/bert-qa/squad-samples.edn rename to contrib/clojure-package/examples/bert/squad-samples.edn diff --git a/contrib/clojure-package/examples/bert/src/bert/bert_sentence_classification.clj b/contrib/clojure-package/examples/bert/src/bert/bert_sentence_classification.clj new file mode 100644 index 000000000000..8c056b719feb --- /dev/null +++ b/contrib/clojure-package/examples/bert/src/bert/bert_sentence_classification.clj @@ -0,0 +1,160 @@ +;; +;; Licensed to the Apache Software Foundation (ASF) under one or more +;; contributor license agreements. See the NOTICE file distributed with +;; this work for additional information regarding copyright ownership. +;; The ASF licenses this file to You under the Apache License, Version 2.0 +;; (the "License"); you may not use this file except in compliance with +;; the License. You may obtain a copy of the License at +;; +;; http://www.apache.org/licenses/LICENSE-2.0 +;; +;; Unless required by applicable law or agreed to in writing, software +;; distributed under the License is distributed on an "AS IS" BASIS, +;; WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +;; See the License for the specific language governing permissions and +;; limitations under the License. +;; + +(ns bert.bert-sentence-classification + (:require [bert.util :as bert-util] + [clojure-csv.core :as csv] + [clojure.string :as string] + [org.apache.clojure-mxnet.callback :as callback] + [org.apache.clojure-mxnet.context :as context] + [org.apache.clojure-mxnet.dtype :as dtype] + [org.apache.clojure-mxnet.io :as mx-io] + [org.apache.clojure-mxnet.layout :as layout] + [org.apache.clojure-mxnet.module :as m] + [org.apache.clojure-mxnet.ndarray :as ndarray] + [org.apache.clojure-mxnet.optimizer :as optimizer] + [org.apache.clojure-mxnet.symbol :as sym])) + +(def model-path-prefix "data/static_bert_base_net") +;; epoch number of the model +;; the maximum length of the sequence +(def seq-length 128) + +(defn pre-processing + "Preprocesses the sentences in the format that BERT is expecting" + [idx->token token->idx train-item] + (let [[sentence-a sentence-b label] train-item + ;;; pre-processing tokenize sentence + token-1 (bert-util/tokenize (string/lower-case sentence-a)) + token-2 (bert-util/tokenize (string/lower-case sentence-b)) + valid-length (+ (count token-1) (count token-2)) + ;;; generate token types [0000...1111...0000] + qa-embedded (into (bert-util/pad [] 0 (count token-1)) + + (bert-util/pad [] 1 (count token-2))) + token-types (bert-util/pad qa-embedded 0 seq-length) + ;;; make BERT pre-processing standard + token-2 (conj token-2 "[SEP]") + token-1 (into [] (concat ["[CLS]"] token-1 ["[SEP]"] token-2)) + tokens (bert-util/pad token-1 "[PAD]" seq-length) + ;;; pre-processing - token to index translation + indexes (bert-util/tokens->idxs token->idx tokens)] + {:input-batch [indexes + token-types + [valid-length]] + :label (if (= "0" label) + [0] + [1]) + :tokens tokens + :train-item train-item})) + +(defn fine-tune-model + "msymbol: the pretrained network symbol + num-classes: the number of classes for the fine-tune datasets + dropout: The dropout rate amount" + [msymbol {:keys [num-classes dropout]}] + (as-> msymbol data + (sym/dropout {:data data :p dropout}) + (sym/fully-connected "fc-finetune" {:data data :num-hidden num-classes}) + (sym/softmax-output "softmax" {:data data}))) + +(defn slice-inputs-data + "Each sentence pair had to be processed as a row. This breaks all + the rows up into a column for creating a NDArray" + [processed-datas n] + (->> processed-datas + (mapv #(nth (:input-batch %) n)) + (flatten) + (into []))) + +(defn get-raw-data [] + (csv/parse-csv (string/replace (slurp "data/dev.tsv") "\"" "") + :delimiter \tab + :strict true)) + +(defn prepare-data + "This prepares the senetence pairs into NDArrays for use in NDArrayIterator" + [] + (let [raw-file (get-raw-data) + vocab (bert-util/get-vocab) + idx->token (:idx->token vocab) + token->idx (:token->idx vocab) + data-train-raw (->> raw-file + (mapv #(vals (select-keys % [3 4 0]))) + (rest) ;;drop header + (into [])) + processed-datas (mapv #(pre-processing idx->token token->idx %) data-train-raw)] + {:data0s (slice-inputs-data processed-datas 0) + :data1s (slice-inputs-data processed-datas 1) + :data2s (slice-inputs-data processed-datas 2) + :labels (->> (mapv :label processed-datas) + (flatten) + (into [])) + :train-num (count processed-datas)})) + +(defn train + "Trains (fine tunes) the sentence pairs for a classification task on the BERT Base model" + [dev num-epoch] + (let [bert-base (m/load-checkpoint {:prefix model-path-prefix :epoch 0}) + model-sym (fine-tune-model (m/symbol bert-base) {:num-classes 2 :dropout 0.1}) + {:keys [data0s data1s data2s labels train-num]} (prepare-data) + batch-size 32 + data-desc0 (mx-io/data-desc {:name "data0" + :shape [train-num seq-length] + :dtype dtype/FLOAT32 + :layout layout/NT}) + data-desc1 (mx-io/data-desc {:name "data1" + :shape [train-num seq-length] + :dtype dtype/FLOAT32 + :layout layout/NT}) + data-desc2 (mx-io/data-desc {:name "data2" + :shape [train-num] + :dtype dtype/FLOAT32 + :layout layout/N}) + label-desc (mx-io/data-desc {:name "softmax_label" + :shape [train-num] + :dtype dtype/FLOAT32 + :layout layout/N}) + train-data (mx-io/ndarray-iter {data-desc0 (ndarray/array data0s [train-num seq-length] + {:ctx dev}) + data-desc1 (ndarray/array data1s [train-num seq-length] + {:ctx dev}) + data-desc2 (ndarray/array data2s [train-num] + {:ctx dev})} + {:label {label-desc (ndarray/array labels [train-num] + {:ctx dev})} + :data-batch-size batch-size}) + model (m/module model-sym {:contexts [dev] + :data-names ["data0" "data1" "data2"]})] + (m/fit model {:train-data train-data :num-epoch num-epoch + :fit-params (m/fit-params {:allow-missing true + :arg-params (m/arg-params bert-base) + :aux-params (m/aux-params bert-base) + :optimizer (optimizer/adam {:learning-rate 5e-6 :episilon 1e-9}) + :batch-end-callback (callback/speedometer batch-size 1)})}))) + +(defn -main [& args] + (let [[dev-arg num-epoch-arg] args + dev (if (= dev-arg ":gpu") (context/gpu) (context/cpu)) + num-epoch (if num-epoch-arg (Integer/parseInt num-epoch-arg) 3)] + (println "Running example with " dev " and " num-epoch " epochs ") + (train dev num-epoch))) + +(comment + + (train (context/cpu 0) 3) + (m/save-checkpoint model {:prefix "fine-tune-sentence-bert" :epoch 3})) diff --git a/contrib/clojure-package/examples/bert-qa/src/bert_qa/infer.clj b/contrib/clojure-package/examples/bert/src/bert/infer.clj similarity index 73% rename from contrib/clojure-package/examples/bert-qa/src/bert_qa/infer.clj rename to contrib/clojure-package/examples/bert/src/bert/infer.clj index 9dcc783ff1ac..2a08dab36f85 100644 --- a/contrib/clojure-package/examples/bert-qa/src/bert_qa/infer.clj +++ b/contrib/clojure-package/examples/bert/src/bert/infer.clj @@ -15,52 +15,27 @@ ;; limitations under the License. ;; -(ns bert-qa.infer - (:require [clojure.string :as string] - [cheshire.core :as json] - [clojure.java.io :as io] - [org.apache.clojure-mxnet.dtype :as dtype] + + +(ns bert.infer + (:require [bert.util :as bert-util] + [clojure.pprint :as pprint] + [clojure.string :as string] [org.apache.clojure-mxnet.context :as context] + [org.apache.clojure-mxnet.dtype :as dtype] + [org.apache.clojure-mxnet.infer :as infer] [org.apache.clojure-mxnet.layout :as layout] [org.apache.clojure-mxnet.ndarray :as ndarray] - [org.apache.clojure-mxnet.infer :as infer] - [clojure.pprint :as pprint])) - -(def model-path-prefix "model/static_bert_qa") + [org.apache.clojure-mxnet.util :as util])) +(def model-path-prefix "data/static_bert_qa") +;; epoch number of the model +(def epoch 2) ;; the maximum length of the sequence (def seq-length 384) ;;; data helpers -(defn break-out-punctuation [s str-match] - (->> (string/split (str s "") (re-pattern (str "\\" str-match))) - (map #(string/replace % "" str-match)))) - -(defn break-out-punctuations [s] - (if-let [target-char (first (re-seq #"[.,?!]" s))] - (break-out-punctuation s target-char) - [s])) - -(defn tokenize [s] - (->> (string/split s #"\s+") - (mapcat break-out-punctuations) - (into []))) - -(defn pad [tokens pad-item num] - (if (>= (count tokens) num) - tokens - (into tokens (repeat (- num (count tokens)) pad-item)))) - -(defn get-vocab [] - (let [vocab (json/parse-stream (io/reader "model/vocab.json"))] - {:idx->token (get vocab "idx_to_token") - :token->idx (get vocab "token_to_idx")})) - -(defn tokens->idxs [token->idx tokens] - (let [unk-idx (get token->idx "[UNK]")] - (mapv #(get token->idx % unk-idx) tokens))) - (defn post-processing [result tokens] (let [output1 (ndarray/slice-axis result 2 0 1) output2 (ndarray/slice-axis result 2 1 2) @@ -96,25 +71,25 @@ (infer/create-predictor factory {:contexts [ctx] - :epoch 2}))) + :epoch epoch}))) (defn pre-processing [ctx idx->token token->idx qa-map] (let [{:keys [input-question input-answer ground-truth-answers]} qa-map ;;; pre-processing tokenize sentence - token-q (tokenize (string/lower-case input-question)) - token-a (tokenize (string/lower-case input-answer)) + token-q (bert-util/tokenize (string/lower-case input-question)) + token-a (bert-util/tokenize (string/lower-case input-answer)) valid-length (+ (count token-q) (count token-a)) ;;; generate token types [0000...1111...0000] - qa-embedded (into (pad [] 0 (count token-q)) - (pad [] 1 (count token-a))) - token-types (pad qa-embedded 0 seq-length) + qa-embedded (into (bert-util/pad [] 0 (count token-q)) + (bert-util/pad [] 1 (count token-a))) + token-types (bert-util/pad qa-embedded 0 seq-length) ;;; make BERT pre-processing standard token-a (conj token-a "[SEP]") token-q (into [] (concat ["[CLS]"] token-q ["[SEP]"] token-a)) - tokens (pad token-q "[PAD]" seq-length) + tokens (bert-util/pad token-q "[PAD]" seq-length) ;;; pre-processing - token to index translation - indexes (tokens->idxs token->idx tokens)] + indexes (bert-util/tokens->idxs token->idx tokens)] {:input-batch [(ndarray/array indexes [1 seq-length] {:context ctx}) (ndarray/array token-types [1 seq-length] {:context ctx}) (ndarray/array [valid-length] [1] {:context ctx})] @@ -125,7 +100,7 @@ ([] (infer (context/default-context))) ([ctx] (let [predictor (make-predictor ctx) - {:keys [idx->token token->idx]} (get-vocab) + {:keys [idx->token token->idx]} (bert-util/get-vocab) ;;; samples taken from https://rajpurkar.github.io/SQuAD-explorer/explore/v2.0/dev/ question-answers (clojure.edn/read-string (slurp "squad-samples.edn"))] (doseq [qa-map question-answers] diff --git a/contrib/clojure-package/examples/bert/src/bert/util.clj b/contrib/clojure-package/examples/bert/src/bert/util.clj new file mode 100644 index 000000000000..061e12b4e8de --- /dev/null +++ b/contrib/clojure-package/examples/bert/src/bert/util.clj @@ -0,0 +1,52 @@ +;; +;; Licensed to the Apache Software Foundation (ASF) under one or more +;; contributor license agreements. See the NOTICE file distributed with +;; this work for additional information regarding copyright ownership. +;; The ASF licenses this file to You under the Apache License, Version 2.0 +;; (the "License"); you may not use this file except in compliance with +;; the License. You may obtain a copy of the License at +;; +;; http://www.apache.org/licenses/LICENSE-2.0 +;; +;; Unless required by applicable law or agreed to in writing, software +;; distributed under the License is distributed on an "AS IS" BASIS, +;; WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +;; See the License for the specific language governing permissions and +;; limitations under the License. +;; + +(ns bert.util + (:require [clojure.java.io :as io] + [clojure.string :as string] + [cheshire.core :as json])) + +(defn break-out-punctuation [s str-match] + (->> (string/split (str s "") (re-pattern (str "\\" str-match))) + (map #(string/replace % "" str-match)))) + +(defn break-out-punctuations [s] + (if-let [target-char (first (re-seq #"[.,?!]" s))] + (break-out-punctuation s target-char) + [s])) + +(defn tokenize [s] + (->> (string/split s #"\s+") + (mapcat break-out-punctuations) + (into []))) + +(defn pad [tokens pad-item num] + (if (>= (count tokens) num) + tokens + (into tokens (repeat (- num (count tokens)) pad-item)))) + +(defn get-vocab [] + (let [vocab (json/parse-stream (io/reader "data/vocab.json"))] + {:idx->token (get vocab "idx_to_token") + :token->idx (get vocab "token_to_idx")})) + +(defn tokens->idxs [token->idx tokens] + (let [unk-idx (get token->idx "[UNK]")] + (mapv #(get token->idx % unk-idx) tokens))) + +(defn idxs->tokens [idx->token idxs] + (mapv #(get idx->token %) idxs)) diff --git a/contrib/clojure-package/examples/bert/test/bert/bert_sentence_classification_test.clj b/contrib/clojure-package/examples/bert/test/bert/bert_sentence_classification_test.clj new file mode 100644 index 000000000000..355f23ea3cfd --- /dev/null +++ b/contrib/clojure-package/examples/bert/test/bert/bert_sentence_classification_test.clj @@ -0,0 +1,86 @@ +;; +;; Licensed to the Apache Software Foundation (ASF) under one or more +;; contributor license agreements. See the NOTICE file distributed with +;; this work for additional information regarding copyright ownership. +;; The ASF licenses this file to You under the Apache License, Version 2.0 +;; (the "License"); you may not use this file except in compliance with +;; the License. You may obtain a copy of the License at +;; +;; http://www.apache.org/licenses/LICENSE-2.0 +;; +;; Unless required by applicable law or agreed to in writing, software +;; distributed under the License is distributed on an "AS IS" BASIS, +;; WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +;; See the License for the specific language governing permissions and +;; limitations under the License. +;; + + +(ns bert.bert-sentence-classification-test + (:require [bert.bert-sentence-classification :refer :all] + [clojure-csv.core :as csv] + [clojure.java.io :as io] + [clojure.java.shell :refer [sh]] + [clojure.test :refer :all] + [org.apache.clojure-mxnet.callback :as callback] + [org.apache.clojure-mxnet.context :as context] + [org.apache.clojure-mxnet.dtype :as dtype] + [org.apache.clojure-mxnet.eval-metric :as eval-metric] + [org.apache.clojure-mxnet.io :as mx-io] + [org.apache.clojure-mxnet.layout :as layout] + [org.apache.clojure-mxnet.ndarray :as ndarray] + [org.apache.clojure-mxnet.optimizer :as optimizer] + [org.apache.clojure-mxnet.module :as m])) + +(def model-dir "data/") + +(when-not (.exists (io/file (str model-dir "static_bert_qa-0002.params"))) + (println "Downloading bert qa data") + (sh "./get_bert_data.sh")) + +(defn get-slim-raw-data [] + (take 32 (csv/parse-csv (slurp "data/dev.tsv") :delimiter \tab))) + +(deftest train-test + (with-redefs [get-raw-data get-slim-raw-data] + (let [dev (context/default-context) + num-epoch 1 + bert-base (m/load-checkpoint {:prefix model-path-prefix :epoch 0}) + model-sym (fine-tune-model (m/symbol bert-base) {:num-classes 2 :dropout 0.1}) + {:keys [data0s data1s data2s labels train-num]} (prepare-data) + batch-size 32 + data-desc0 (mx-io/data-desc {:name "data0" + :shape [train-num seq-length] + :dtype dtype/FLOAT32 + :layout layout/NT}) + data-desc1 (mx-io/data-desc {:name "data1" + :shape [train-num seq-length] + :dtype dtype/FLOAT32 + :layout layout/NT}) + data-desc2 (mx-io/data-desc {:name "data2" + :shape [train-num] + :dtype dtype/FLOAT32 + :layout layout/N}) + label-desc (mx-io/data-desc {:name "softmax_label" + :shape [train-num] + :dtype dtype/FLOAT32 + :layout layout/N}) + train-data (mx-io/ndarray-iter {data-desc0 (ndarray/array data0s [train-num seq-length] + {:ctx dev}) + data-desc1 (ndarray/array data1s [train-num seq-length] + {:ctx dev}) + data-desc2 (ndarray/array data2s [train-num] + {:ctx dev})} + {:label {label-desc (ndarray/array labels [train-num] + {:ctx dev})} + :data-batch-size batch-size}) + model (m/module model-sym {:contexts [dev] + :data-names ["data0" "data1" "data2"]})] + (m/fit model {:train-data train-data :num-epoch num-epoch + :fit-params (m/fit-params {:allow-missing true + :arg-params (m/arg-params bert-base) + :aux-params (m/aux-params bert-base) + :optimizer (optimizer/adam {:learning-rate 5e-6 :episilon 1e-9}) + :batch-end-callback (callback/speedometer batch-size 1)})}) + (is (< 0.5 (-> (m/score model {:eval-data train-data :eval-metric (eval-metric/accuracy) }) + (last))))))) diff --git a/contrib/clojure-package/examples/bert-qa/test/bert_qa/infer_test.clj b/contrib/clojure-package/examples/bert/test/bert/infer_test.clj similarity index 91% rename from contrib/clojure-package/examples/bert-qa/test/bert_qa/infer_test.clj rename to contrib/clojure-package/examples/bert/test/bert/infer_test.clj index 767fb089f284..48ee3a89b177 100644 --- a/contrib/clojure-package/examples/bert-qa/test/bert_qa/infer_test.clj +++ b/contrib/clojure-package/examples/bert/test/bert/infer_test.clj @@ -16,15 +16,16 @@ ;; -(ns bert-qa.infer-test - (:require [bert-qa.infer :refer :all] +(ns bert.infer-test + (:require [bert.infer :refer :all] + [bert.util :as util] [clojure.java.io :as io] [clojure.java.shell :refer [sh]] [clojure.test :refer :all] [org.apache.clojure-mxnet.context :as context] [org.apache.clojure-mxnet.infer :as infer])) -(def model-dir "model/") +(def model-dir "data/") (when-not (.exists (io/file (str model-dir "static_bert_qa-0002.params"))) (println "Downloading bert qa data") @@ -33,7 +34,7 @@ (deftest infer-test (let [ctx (context/default-context) predictor (make-predictor ctx) - {:keys [idx->token token->idx]} (get-vocab) + {:keys [idx->token token->idx]} (util/get-vocab) ;;; samples taken from https://rajpurkar.github.io/SQuAD-explorer/explore/v2.0/dev/ question-answers (clojure.edn/read-string (slurp "squad-samples.edn"))] (let [qa-map (last question-answers) diff --git a/contrib/clojure-package/src/org/apache/clojure_mxnet/callback.clj b/contrib/clojure-package/src/org/apache/clojure_mxnet/callback.clj index c0077208978b..3809c73c5fd1 100644 --- a/contrib/clojure-package/src/org/apache/clojure_mxnet/callback.clj +++ b/contrib/clojure-package/src/org/apache/clojure_mxnet/callback.clj @@ -16,13 +16,20 @@ ;; (ns org.apache.clojure-mxnet.callback + (:require [org.apache.clojure-mxnet.eval-metric :as em]) (:import (org.apache.mxnet Callback$Speedometer))) ;;; used to track status during epoch (defn speedometer ([batch-size frequent] - (new Callback$Speedometer (int batch-size) (int frequent))) + (proxy [Callback$Speedometer] [(int batch-size) (int frequent)] + (invoke [epoch batch-count eval-metric] + (proxy-super invoke epoch batch-count eval-metric) + ;;; so that it prints to repl as well + (when (and (zero? (mod batch-count frequent)) + (pos? batch-count)) + (println "Speedometer: epoch " epoch " count " batch-count " metric " (em/get eval-metric )))))) ([batch-size] (speedometer batch-size 50)))