Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Extend Clojure BERT example #15023

Merged
merged 11 commits into from
Jun 22, 2019
Merged
145 changes: 132 additions & 13 deletions contrib/clojure-package/examples/bert/fine-tune-bert.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,16 @@
"\n",
"Pre-trained language representations have been shown to improve many downstream NLP tasks such as question answering, and natural language inference. To apply pre-trained representations to these tasks, there are two strategies:\n",
"\n",
"feature-based approach, which uses the pre-trained representations as additional features to the downstream task.\n",
"fine-tuning based approach, which trains the downstream tasks by fine-tuning pre-trained parameters.\n",
"While feature-based approaches such as ELMo [3] (introduced in the previous tutorial) are effective in improving many downstream tasks, they require task-specific architectures. Devlin, Jacob, et al proposed BERT [1] (Bidirectional Encoder Representations from Transformers), which fine-tunes deep bidirectional representations on a wide range of tasks with minimal task-specific parameters, and obtained state- of-the-art results.\n",
" - **feature-based approach**, which uses the pre-trained representations as additional features to the downstream task.\n",
" - **fine-tuning based approach**, which trains the downstream tasks by fine-tuning pre-trained parameters.\n",
" \n",
"While feature-based approaches such as ELMo [1] are effective in improving many downstream tasks, they require task-specific architectures. Devlin, Jacob, et al proposed BERT [2] (Bidirectional Encoder Representations from Transformers), which fine-tunes deep bidirectional representations on a wide range of tasks with minimal task-specific parameters, and obtained state- of-the-art results.\n",
"\n",
"In this tutorial, we will focus on fine-tuning with the pre-trained BERT model to classify semantically equivalent sentence pairs. Specifically, we will:\n",
"\n",
"load the state-of-the-art pre-trained BERT model and attach an additional layer for classification,\n",
"process and transform sentence pair data for the task at hand, and\n",
"fine-tune BERT model for sentence classification.\n",
" 1. load the state-of-the-art pre-trained BERT model and attach an additional layer for classification\n",
" 2. process and transform sentence pair data for the task at hand, and \n",
" 3. fine-tune BERT model for sentence classification.\n",
"\n"
]
},
Expand Down Expand Up @@ -59,6 +60,7 @@
" [org.apache.clojure-mxnet.callback :as callback]\n",
" [org.apache.clojure-mxnet.context :as context]\n",
" [org.apache.clojure-mxnet.dtype :as dtype]\n",
" [org.apache.clojure-mxnet.infer :as infer]\n",
" [org.apache.clojure-mxnet.eval-metric :as eval-metric]\n",
" [org.apache.clojure-mxnet.io :as mx-io]\n",
" [org.apache.clojure-mxnet.layout :as layout]\n",
Expand Down Expand Up @@ -89,7 +91,7 @@
"\n",
"![bert](https://gluon-nlp.mxnet.io/_images/bert-sentence-pair.png)\n",
"\n",
"where the model takes a pair of sequences and pools the representation of the first token in the sequence. Note that the original BERT model was trained for masked language model and next sentence prediction tasks, which includes layers for language model decoding and classification. These layers will not be used for fine-tuning sentence pair classification.\n",
"where the model takes a pair of sequences and *pools* the representation of the first token in the sequence. Note that the original BERT model was trained for masked language model and next sentence prediction tasks, which includes layers for language model decoding and classification. These layers will not be used for fine-tuning sentence pair classification.\n",
"\n",
"Let's load the pre-trained BERT using the module API in MXNet."
]
Expand All @@ -114,12 +116,15 @@
],
"source": [
"(def model-path-prefix \"data/static_bert_base_net\")\n",
"\n",
";; the vocabulary used in the model\n",
"(def vocab (bert-util/get-vocab))\n",
";; the input question\n",
"\n",
";; the maximum length of the sequence\n",
"(def seq-length 128)\n",
"\n",
"(def batch-size 32)\n",
"\n",
"(def bert-base (m/load-checkpoint {:prefix model-path-prefix :epoch 0}))"
]
},
Expand Down Expand Up @@ -291,7 +296,7 @@
"source": [
"(defn pre-processing\n",
" \"Preprocesses the sentences in the format that BERT is expecting\"\n",
" [ctx idx->token token->idx train-item]\n",
" [idx->token token->idx train-item]\n",
" (let [[sentence-a sentence-b label] train-item\n",
" ;;; pre-processing tokenize sentence\n",
" token-1 (bert-util/tokenize (string/lower-case sentence-a))\n",
Expand Down Expand Up @@ -319,7 +324,7 @@
"(def idx->token (:idx->token vocab))\n",
"(def token->idx (:token->idx vocab))\n",
"(def dev (context/default-context))\n",
"(def processed-datas (mapv #(pre-processing dev idx->token token->idx %) data-train-raw))\n",
"(def processed-datas (mapv #(pre-processing idx->token token->idx %) data-train-raw))\n",
"(def train-count (count processed-datas))\n",
"(println \"Train Count is = \" train-count)\n",
"(println \"[PAD] token id = \" (get token->idx \"[PAD]\"))\n",
Expand Down Expand Up @@ -375,8 +380,6 @@
" (into []))\n",
" :train-num (count processed-datas)})\n",
"\n",
"(def batch-size 32)\n",
"\n",
"(def train-data\n",
" (let [{:keys [data0s data1s data2s labels train-num]} prepared-data\n",
" data-desc0 (mx-io/data-desc {:name \"data0\"\n",
Expand Down Expand Up @@ -480,7 +483,7 @@
"(def num-epoch 3)\n",
"\n",
"(def fine-tune-model (m/module model-sym {:contexts [dev]\n",
" :data-names [\"data0\" \"data1\" \"data2\"]}))\n",
" :data-names [\"data0\" \"data1\" \"data2\"]}))\n",
"\n",
"(m/fit fine-tune-model {:train-data train-data :num-epoch num-epoch\n",
" :fit-params (m/fit-params {:allow-missing true\n",
Expand All @@ -489,6 +492,122 @@
" :optimizer (optimizer/adam {:learning-rate 5e-6 :episilon 1e-9})\n",
" :batch-end-callback (callback/speedometer batch-size 1)})})\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Explore results from the fine-tuned model\n",
"\n",
"Now that our model is fitted, we can use it to infer semantic equivalence of arbitrary sentence pairs. Note that for demonstration purpose we skipped the warmup learning rate schedule and validation on dev dataset used in the original implementation. This means that our model's performance will be significantly less than optimal. Please visit [here](https://gluon-nlp.mxnet.io/model_zoo/bert/index.html) for the complete fine-tuning scripts (using Python and GluonNLP).\n",
"\n",
"To do inference with our model we need a predictor. It must have a batch size of 1 so we can feed the model a single sentence pair."
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"#'bert.bert-sentence-classification/fine-tuned-predictor"
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"(def fine-tuned-prefix \"fine-tune-sentence-bert\")\n",
"\n",
"(m/save-checkpoint fine-tune-model {:prefix fine-tuned-prefix :epoch 3})\n",
"\n",
"(def fine-tuned-predictor\n",
" (infer/create-predictor (infer/model-factory fine-tuned-prefix\n",
" [{:name \"data0\" :shape [1 seq-length] :dtype dtype/FLOAT32 :layout layout/NT}\n",
" {:name \"data1\" :shape [1 seq-length] :dtype dtype/FLOAT32 :layout layout/NT}\n",
" {:name \"data2\" :shape [1] :dtype dtype/FLOAT32 :layout layout/N}])\n",
" {:epoch 3}))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now we can write a function that feeds a sentence pair to the fine-tuned model:"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"#'bert.bert-sentence-classification/predict-equivalence"
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"(defn predict-equivalence\n",
" [predictor sentence1 sentence2]\n",
" (let [vocab (bert.util/get-vocab)\n",
" processed-test-data (mapv #(pre-processing (:idx->token vocab)\n",
" (:token->idx vocab) %)\n",
" [[sentence1 sentence2]])\n",
" prediction (infer/predict-with-ndarray predictor\n",
" [(ndarray/array (slice-inputs-data processed-test-data 0) [1 seq-length])\n",
" (ndarray/array (slice-inputs-data processed-test-data 1) [1 seq-length])\n",
" (ndarray/array (slice-inputs-data processed-test-data 2) [1])])]\n",
" (ndarray/->vec (first prediction))))"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[0.2633881 0.7366119]"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's this the output of?

Copy link
Contributor Author

@daveliepmann daveliepmann Jun 12, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[0.2633881 0.7366119] is the output of our sample sentence pair equivalence prediction:

(predict-equivalence fine-tuned-predictor
                     "The company cut spending to compensate for weak sales ."
                     "In response to poor sales results, the company cut spending .")

I'm not sure why the result appears before its expression in the .ipynb file, but on my machine it displays this pair correctly as "In [22]" followed by "Out [22]".

]
},
"execution_count": 22,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
";; Modify an existing sentence pair to test:\n",
";; [\"1\"\n",
";; \"69773\"\n",
";; \"69792\"\n",
";; \"Cisco pared spending to compensate for sluggish sales .\"\n",
";; \"In response to sluggish sales , Cisco pared spending .\"]\n",
"(predict-equivalence fine-tuned-predictor\n",
" \"The company cut spending to compensate for weak sales .\"\n",
" \"In response to poor sales results, the company cut spending .\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## References\n",
"\n",
"[1] Peters, Matthew E., et al. “Deep contextualized word representations.” arXiv preprint arXiv:1802.05365 (2018).\n",
"\n",
"[2] Devlin, Jacob, et al. “Bert: Pre-training of deep bidirectional transformers for language understanding.” arXiv preprint arXiv:1810.04805 (2018)."
]
}
],
"metadata": {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,42 +16,67 @@
;;

(ns bert.bert-sentence-classification
"Fine-tuning Sentence Pair Classification with BERT
daveliepmann marked this conversation as resolved.
Show resolved Hide resolved
This tutorial focuses on fine-tuning with the pre-trained BERT model to classify semantically equivalent sentence pairs.

Specifically, we will:
1. load the state-of-the-art pre-trained BERT model
2. attach an additional layer for classification
3. process and transform sentence pair data for the task at hand
4. fine-tune BERT model for sentence classification"
(:require [bert.util :as bert-util]
[clojure-csv.core :as csv]
[clojure.string :as string]
[org.apache.clojure-mxnet.callback :as callback]
[org.apache.clojure-mxnet.context :as context]
[org.apache.clojure-mxnet.dtype :as dtype]
[org.apache.clojure-mxnet.infer :as infer]
[org.apache.clojure-mxnet.io :as mx-io]
[org.apache.clojure-mxnet.layout :as layout]
[org.apache.clojure-mxnet.module :as m]
[org.apache.clojure-mxnet.ndarray :as ndarray]
[org.apache.clojure-mxnet.optimizer :as optimizer]
[org.apache.clojure-mxnet.symbol :as sym]))

;; Pre-trained language representations have been shown to improve
;; many downstream NLP tasks such as question answering, and natural
;; language inference. To apply pre-trained representations to these
;; tasks, there are two strategies:

;; * feature-based approach, which uses the pre-trained representations as additional features to the downstream task.
;; * fine-tuning based approach, which trains the downstream tasks by fine-tuning pre-trained parameters.

;; While feature-based approaches such as ELMo are effective in
;; improving many downstream tasks, they require task-specific
;; architectures. Devlin, Jacob, et al proposed BERT (Bidirectional
;; Encoder Representations from Transformers), which fine-tunes deep
;; bidirectional representations on a wide range of tasks with minimal
;; task-specific parameters, and obtained state-of-the-art results.

(def model-path-prefix "data/static_bert_base_net")
;; epoch number of the model

(def fine-tuned-prefix "fine-tune-sentence-bert")

;; the maximum length of the sequence
(def seq-length 128)

(defn pre-processing
"Preprocesses the sentences in the format that BERT is expecting"
[idx->token token->idx train-item]
(let [[sentence-a sentence-b label] train-item
;;; pre-processing tokenize sentence
;; pre-processing tokenize sentence
token-1 (bert-util/tokenize (string/lower-case sentence-a))
token-2 (bert-util/tokenize (string/lower-case sentence-b))
valid-length (+ (count token-1) (count token-2))
;;; generate token types [0000...1111...0000]
;; generate token types [0000...1111...0000]
qa-embedded (into (bert-util/pad [] 0 (count token-1))

(bert-util/pad [] 1 (count token-2)))
token-types (bert-util/pad qa-embedded 0 seq-length)
;;; make BERT pre-processing standard
;; make BERT pre-processing standard
token-2 (conj token-2 "[SEP]")
token-1 (into [] (concat ["[CLS]"] token-1 ["[SEP]"] token-2))
tokens (bert-util/pad token-1 "[PAD]" seq-length)
;;; pre-processing - token to index translation
;; pre-processing - token to index translation
indexes (bert-util/tokens->idxs token->idx tokens)]
{:input-batch [indexes
token-types
Expand Down Expand Up @@ -83,19 +108,18 @@

(defn get-raw-data []
(csv/parse-csv (string/replace (slurp "data/dev.tsv") "\"" "")
:delimiter \tab
:strict true))
:delimiter \tab
:strict true))

(defn prepare-data
"This prepares the senetence pairs into NDArrays for use in NDArrayIterator"
[]
(let [raw-file (get-raw-data)
vocab (bert-util/get-vocab)
"This prepares the sentence pairs into NDArrays for use in NDArrayIterator"
[raw-data]
(let [vocab (bert-util/get-vocab)
idx->token (:idx->token vocab)
token->idx (:token->idx vocab)
data-train-raw (->> raw-file
data-train-raw (->> raw-data
(mapv #(vals (select-keys % [3 4 0])))
(rest) ;;drop header
(rest) ; drop header
(into []))
processed-datas (mapv #(pre-processing idx->token token->idx %) data-train-raw)]
{:data0s (slice-inputs-data processed-datas 0)
Expand All @@ -111,7 +135,7 @@
[dev num-epoch]
(let [bert-base (m/load-checkpoint {:prefix model-path-prefix :epoch 0})
model-sym (fine-tune-model (m/symbol bert-base) {:num-classes 2 :dropout 0.1})
{:keys [data0s data1s data2s labels train-num]} (prepare-data)
{:keys [data0s data1s data2s labels train-num]} (prepare-data (get-raw-data))
batch-size 32
data-desc0 (mx-io/data-desc {:name "data0"
:shape [train-num seq-length]
Expand All @@ -138,14 +162,16 @@
{:label {label-desc (ndarray/array labels [train-num]
{:ctx dev})}
:data-batch-size batch-size})
model (m/module model-sym {:contexts [dev]
:data-names ["data0" "data1" "data2"]})]
(m/fit model {:train-data train-data :num-epoch num-epoch
:fit-params (m/fit-params {:allow-missing true
:arg-params (m/arg-params bert-base)
:aux-params (m/aux-params bert-base)
:optimizer (optimizer/adam {:learning-rate 5e-6 :episilon 1e-9})
:batch-end-callback (callback/speedometer batch-size 1)})})))
fitted-model (m/fit (m/module model-sym {:contexts [dev]
:data-names ["data0" "data1" "data2"]})
{:train-data train-data :num-epoch num-epoch
:fit-params (m/fit-params {:allow-missing true
:arg-params (m/arg-params bert-base)
:aux-params (m/aux-params bert-base)
:optimizer (optimizer/adam {:learning-rate 5e-6 :epsilon 1e-9})
daveliepmann marked this conversation as resolved.
Show resolved Hide resolved
:batch-end-callback (callback/speedometer batch-size 1)})})]
(m/save-checkpoint fitted-model {:prefix fine-tuned-prefix :epoch num-epoch})
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we save the model to disk now? Could we pass in a parameter to the function to to it? This function seems to do too many things?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have to save the model to disk because there's no other way but saving to disk and loading it back in to get a prediction out of it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With the infer API I suppose? Maybe we could change this at some point @kedarbellare?

fitted-model))

(defn -main [& args]
(let [[dev-arg num-epoch-arg] args
Expand All @@ -157,4 +183,42 @@
(comment

(train (context/cpu 0) 3)
(m/save-checkpoint model {:prefix "fine-tune-sentence-bert" :epoch 3}))

(m/save-checkpoint model {:prefix fine-tuned-prefix :epoch 3})


;;;; Explore results from the fine-tuned model

;; We need a predictor with a batch size of 1, so we can feed the
;; model a single sentence pair.
(def fine-tuned-predictor
daveliepmann marked this conversation as resolved.
Show resolved Hide resolved
(infer/create-predictor (infer/model-factory fine-tuned-prefix
[{:name "data0" :shape [1 seq-length] :dtype dtype/FLOAT32 :layout layout/NT}
{:name "data1" :shape [1 seq-length] :dtype dtype/FLOAT32 :layout layout/NT}
{:name "data2" :shape [1] :dtype dtype/FLOAT32 :layout layout/N}])
{:epoch 3}))
Copy link
Contributor

@Chouffe Chouffe Jun 7, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need this hardcoded epoch number here? Can't we just use num-epoch?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As I recall, we have to hard-code the epoch because otherwise we don't know which saved model to load from disk.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Responding to edit: num-epoch isn't in scope in the comment. I decided against defining it globally in order to parameterize a short REPL exploration.

Another reason not to def the 3 here is that num-epoch is a value meant to be passed in from the command line, and the rich comment code is a parallel of invoking from the command line with that argument. So at a minimum we would need a new name.


;; Get the fine-tuned model's opinion on whether two sentences are equivalent:
(defn predict-equivalence
[predictor sentence1 sentence2]
(let [vocab (bert.util/get-vocab)
processed-test-data (mapv #(pre-processing (:idx->token vocab)
(:token->idx vocab) %)
[[sentence1 sentence2]])
prediction (infer/predict-with-ndarray predictor
[(ndarray/array (slice-inputs-data processed-test-data 0) [1 seq-length])
(ndarray/array (slice-inputs-data processed-test-data 1) [1 seq-length])
(ndarray/array (slice-inputs-data processed-test-data 2) [1])])]
(ndarray/->vec (first prediction))))

;; Modify an existing sentence pair to test:
;; ["1"
;; "69773"
;; "69792"
;; "Cisco pared spending to compensate for sluggish sales ."
;; "In response to sluggish sales , Cisco pared spending ."]
(predict-equivalence fine-tuned-predictor
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

did you want to add a test for this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just pushed one. Thanks for the idea—my PR broke the existing test, but I guess tests for examples aren't part of the CI checks.

"The company cut spending to compensate for weak sales ."
"In response to poor sales results, the company cut spending .")

)
Loading