Skip to content

Commit

Permalink
[Clojure] Add Fine Tuning Sentence Pair Classification BERT Example (a…
Browse files Browse the repository at this point in the history
…pache#14769)

* Rework Bert examples to include QA infer and finetuning

* update notebook example and exported markdown

* add integration test for the classification

* fix tests

* add RAT

* add another RAT

* fix all the typos

* Clojure BERT finetuning example: fix CSV parsing

* update readme and gitignore

* add fix from @davliepmann’s notebook parsing

* feedback from @daveliepmann

* fix running of example and don’t show very first batch on callback speedometer

* rerun the notebook and save results

* remove bert stuff from main .gitignore

* re-putting the license back on after regen

* fix integration test
  • Loading branch information
gigasquid authored and haohuw committed Jun 23, 2019
1 parent ad12334 commit 95a6894
Show file tree
Hide file tree
Showing 15 changed files with 1,310 additions and 67 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -161,4 +161,4 @@ tests/mxnet_unit_tests
coverage.xml

# Local CMake build config
cmake_options.yml
cmake_options.yml
2 changes: 2 additions & 0 deletions contrib/clojure-package/.gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,11 @@ src/.DS_Store
src/org/.DS_Store
test/test-ndarray.clj
test/test-ndarray-random.clj
test/test-ndarray-random-api.clj
test/test-ndarray-api.clj
test/test-symbol.clj
test/test-symbol-random.clj
test/test-symbol-random-api.clj
test/test-symbol-api.clj
src/org/apache/clojure_mxnet/gen/*

Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
/target
/classes
/checkouts
profiles.clj
pom.xml
pom.xml.asc
*.jar
Expand All @@ -10,3 +9,10 @@ pom.xml.asc
/.nrepl-port
.hgignore
.hg/
data/*
model/*
*~
*.params
*.states
*.json

Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,11 @@
<!--- under the License. -->


# bert-qa
# BERT

There are two examples showcasing the power of BERT. One is BERT-QA for inference and the other is BERT Sentence Pair Classification which uses fine tuning of the BERT base model. For more information about BERT please read [http://jalammar.github.io/illustrated-bert/](http://jalammar.github.io/illustrated-bert/).

## bert-qa

**This example was based off of the Java API one. It shows how to do inference with a pre-trained BERT network that is trained on Questions and Answers using the [SQuAD Dataset](https://rajpurkar.github.io/SQuAD-explorer/)**

Expand All @@ -33,14 +37,16 @@ Example:
:input-question "Along with geothermal and nuclear, what is a notable non-combustion heat source?"
:ground-truth-answers ["solar"
"solar power"
"solar power, nuclear power or geothermal energysolar"]}
"solar power, nuclear power or geothermal energy solar"]}
```

The prediction in this case would be `solar power`

## Setup Guide
### Setup Guide

### Step 1: Download the model
Note: If you have trouble with your REPL and cider, please comment out the `lein-jupyter` plugin. There are some conflicts with cider.

#### Step 1: Download the model

For this tutorial, you can get the model and vocabulary by running following bash file. This script will use `wget` to download these artifacts from AWS S3.

Expand All @@ -53,14 +59,14 @@ From the example directory:
Some sample questions and answers are provide in the `squad-sample.edn` file. Some are taken directly from the SQuAD dataset and one was just made up. Feel free to edit the file and add your own!


## To run
### To run

* `lein install` in the root of the main project directory
* cd into this project directory and do `lein run`. This will execute the cpu version.
* `lein run` or `lein run :cpu` to run with cpu
* `lein run :gpu` to run with gpu

## Background
### Background

To learn more about how BERT works in MXNet, please follow this [MXNet Gluon tutorial on NLP using BERT](https://medium.com/apache-mxnet/gluon-nlp-bert-6a489bdd3340).

Expand All @@ -73,7 +79,7 @@ The original description can be found in the [MXNet GluonNLP model zoo](https://
python static_finetune_squad.py --optimizer adam --accumulate 2 --batch_size 6 --lr 3e-5 --epochs 2 --gpu 0 --export

```
This script will generate `json` and `param` fles that are the standard MXNet model files.
This script will generate `json` and `param` files that are the standard MXNet model files.
By default, this model are using `bert_12_768_12` model with extra layers for QA jobs.

After that, to be able to use it in Java, we need to export the dictionary from the script to parse the text
Expand All @@ -88,3 +94,63 @@ f.close()
This would export the token vocabulary in json format.
Once you have these three files, you will be able to run this example without problems.

## Fine-tuning Sentence Pair Classification with BERT

This was based off of the great tutorial for in Gluon-NLP [https://gluon-nlp.mxnet.io/examples/sentence_embedding/bert.html](https://gluon-nlp.mxnet.io/examples/sentence_embedding/bert.html).

We use the pre-trained BERT model that was exported from GluonNLP via the `scripts/bert/staticbert/static_export_base.py` running `python static_export_base.py --seq_length 128`. For convenience, the model has been downloaded for you by running the get_bert_data.sh file in the root directory of this example.

It will fine tune the base bert model for use in a classification task for 3 epochs.


### Setup Guide

#### Step 1: Download the model

For this tutorial, you can get the model and vocabulary by running following bash file. This script will use `wget` to download these artifacts from AWS S3.

From the example directory:

```bash
./get_bert_data.sh
```

### To run the notebook walkthrough

There is a Jupyter notebook that uses the `lein jupyter` plugin to be able to execute Clojure code in project setting. The first time that you run it you will need to install the kernel with `lein jupyter install-kernel`. After that you can open the notebook in the project directory with `lein jupyter notebook`.

There is also an exported copy of the walkthrough to markdown `fine-tune-bert.md`.


### To run

* `lein install` in the root of the main project directory
* cd into this project directory and do `lein run`. This will execute the cpu version.

`lein run -m bert.bert-sentence-classification :cpu` - to run with cpu
`lein run -m bert.bert-sentence-classification :gpu` - to run with gpu

By default it will run 3 epochs, you can control the number of epochs with:

`lein run -m bert.bert-sentence-classification :cpu 1` to run just 1 epoch


Sample results from cpu run on OSX
```
INFO org.apache.mxnet.module.BaseModule: Epoch[1] Train-accuracy=0.65384614
INFO org.apache.mxnet.module.BaseModule: Epoch[1] Time cost=464187
INFO org.apache.mxnet.Callback$Speedometer: Epoch[2] Batch [1] Speed: 0.91 samples/sec Train-accuracy=0.656250
INFO org.apache.mxnet.Callback$Speedometer: Epoch[2] Batch [2] Speed: 0.90 samples/sec Train-accuracy=0.656250
INFO org.apache.mxnet.Callback$Speedometer: Epoch[2] Batch [3] Speed: 0.91 samples/sec Train-accuracy=0.687500
INFO org.apache.mxnet.Callback$Speedometer: Epoch[2] Batch [4] Speed: 0.90 samples/sec Train-accuracy=0.693750
INFO org.apache.mxnet.Callback$Speedometer: Epoch[2] Batch [5] Speed: 0.91 samples/sec Train-accuracy=0.703125
INFO org.apache.mxnet.Callback$Speedometer: Epoch[2] Batch [6] Speed: 0.92 samples/sec Train-accuracy=0.696429
INFO org.apache.mxnet.Callback$Speedometer: Epoch[2] Batch [7] Speed: 0.91 samples/sec Train-accuracy=0.699219
INFO org.apache.mxnet.Callback$Speedometer: Epoch[2] Batch [8] Speed: 0.90 samples/sec Train-accuracy=0.701389
INFO org.apache.mxnet.Callback$Speedometer: Epoch[2] Batch [9] Speed: 0.90 samples/sec Train-accuracy=0.690625
INFO org.apache.mxnet.Callback$Speedometer: Epoch[2] Batch [10] Speed: 0.89 samples/sec Train-accuracy=0.690341
INFO org.apache.mxnet.Callback$Speedometer: Epoch[2] Batch [11] Speed: 0.90 samples/sec Train-accuracy=0.695313
INFO org.apache.mxnet.Callback$Speedometer: Epoch[2] Batch [12] Speed: 0.91 samples/sec Train-accuracy=0.701923
INFO org.apache.mxnet.module.BaseModule: Epoch[2] Train-accuracy=0.7019231
INFO org.apache.mxnet.module.BaseModule: Epoch[2] Time cost=459809
````
Loading

0 comments on commit 95a6894

Please sign in to comment.