diff --git a/contrib/clojure-package/examples/bert/fine-tune-bert.ipynb b/contrib/clojure-package/examples/bert/fine-tune-bert.ipynb index 425a9993ad93..5934477ea338 100644 --- a/contrib/clojure-package/examples/bert/fine-tune-bert.ipynb +++ b/contrib/clojure-package/examples/bert/fine-tune-bert.ipynb @@ -10,15 +10,16 @@ "\n", "Pre-trained language representations have been shown to improve many downstream NLP tasks such as question answering, and natural language inference. To apply pre-trained representations to these tasks, there are two strategies:\n", "\n", - "feature-based approach, which uses the pre-trained representations as additional features to the downstream task.\n", - "fine-tuning based approach, which trains the downstream tasks by fine-tuning pre-trained parameters.\n", - "While feature-based approaches such as ELMo [3] (introduced in the previous tutorial) are effective in improving many downstream tasks, they require task-specific architectures. Devlin, Jacob, et al proposed BERT [1] (Bidirectional Encoder Representations from Transformers), which fine-tunes deep bidirectional representations on a wide range of tasks with minimal task-specific parameters, and obtained state- of-the-art results.\n", + " - **feature-based approach**, which uses the pre-trained representations as additional features to the downstream task.\n", + " - **fine-tuning based approach**, which trains the downstream tasks by fine-tuning pre-trained parameters.\n", + " \n", + "While feature-based approaches such as ELMo [1] are effective in improving many downstream tasks, they require task-specific architectures. Devlin, Jacob, et al proposed BERT [2] (Bidirectional Encoder Representations from Transformers), which fine-tunes deep bidirectional representations on a wide range of tasks with minimal task-specific parameters, and obtained state- of-the-art results.\n", "\n", "In this tutorial, we will focus on fine-tuning with the pre-trained BERT model to classify semantically equivalent sentence pairs. Specifically, we will:\n", "\n", - "load the state-of-the-art pre-trained BERT model and attach an additional layer for classification,\n", - "process and transform sentence pair data for the task at hand, and\n", - "fine-tune BERT model for sentence classification.\n", + " 1. load the state-of-the-art pre-trained BERT model and attach an additional layer for classification\n", + " 2. process and transform sentence pair data for the task at hand, and \n", + " 3. fine-tune BERT model for sentence classification.\n", "\n" ] }, @@ -59,6 +60,7 @@ " [org.apache.clojure-mxnet.callback :as callback]\n", " [org.apache.clojure-mxnet.context :as context]\n", " [org.apache.clojure-mxnet.dtype :as dtype]\n", + " [org.apache.clojure-mxnet.infer :as infer]\n", " [org.apache.clojure-mxnet.eval-metric :as eval-metric]\n", " [org.apache.clojure-mxnet.io :as mx-io]\n", " [org.apache.clojure-mxnet.layout :as layout]\n", @@ -89,7 +91,7 @@ "\n", "![bert](https://gluon-nlp.mxnet.io/_images/bert-sentence-pair.png)\n", "\n", - "where the model takes a pair of sequences and pools the representation of the first token in the sequence. Note that the original BERT model was trained for masked language model and next sentence prediction tasks, which includes layers for language model decoding and classification. These layers will not be used for fine-tuning sentence pair classification.\n", + "where the model takes a pair of sequences and *pools* the representation of the first token in the sequence. Note that the original BERT model was trained for masked language model and next sentence prediction tasks, which includes layers for language model decoding and classification. These layers will not be used for fine-tuning sentence pair classification.\n", "\n", "Let's load the pre-trained BERT using the module API in MXNet." ] @@ -114,12 +116,15 @@ ], "source": [ "(def model-path-prefix \"data/static_bert_base_net\")\n", + "\n", ";; the vocabulary used in the model\n", "(def vocab (bert-util/get-vocab))\n", - ";; the input question\n", + "\n", ";; the maximum length of the sequence\n", "(def seq-length 128)\n", "\n", + "(def batch-size 32)\n", + "\n", "(def bert-base (m/load-checkpoint {:prefix model-path-prefix :epoch 0}))" ] }, @@ -291,7 +296,7 @@ "source": [ "(defn pre-processing\n", " \"Preprocesses the sentences in the format that BERT is expecting\"\n", - " [ctx idx->token token->idx train-item]\n", + " [idx->token token->idx train-item]\n", " (let [[sentence-a sentence-b label] train-item\n", " ;;; pre-processing tokenize sentence\n", " token-1 (bert-util/tokenize (string/lower-case sentence-a))\n", @@ -319,7 +324,7 @@ "(def idx->token (:idx->token vocab))\n", "(def token->idx (:token->idx vocab))\n", "(def dev (context/default-context))\n", - "(def processed-datas (mapv #(pre-processing dev idx->token token->idx %) data-train-raw))\n", + "(def processed-datas (mapv #(pre-processing idx->token token->idx %) data-train-raw))\n", "(def train-count (count processed-datas))\n", "(println \"Train Count is = \" train-count)\n", "(println \"[PAD] token id = \" (get token->idx \"[PAD]\"))\n", @@ -375,8 +380,6 @@ " (into []))\n", " :train-num (count processed-datas)})\n", "\n", - "(def batch-size 32)\n", - "\n", "(def train-data\n", " (let [{:keys [data0s data1s data2s labels train-num]} prepared-data\n", " data-desc0 (mx-io/data-desc {:name \"data0\"\n", @@ -480,7 +483,7 @@ "(def num-epoch 3)\n", "\n", "(def fine-tune-model (m/module model-sym {:contexts [dev]\n", - " :data-names [\"data0\" \"data1\" \"data2\"]}))\n", + " :data-names [\"data0\" \"data1\" \"data2\"]}))\n", "\n", "(m/fit fine-tune-model {:train-data train-data :num-epoch num-epoch\n", " :fit-params (m/fit-params {:allow-missing true\n", @@ -489,6 +492,122 @@ " :optimizer (optimizer/adam {:learning-rate 5e-6 :episilon 1e-9})\n", " :batch-end-callback (callback/speedometer batch-size 1)})})\n" ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Explore results from the fine-tuned model\n", + "\n", + "Now that our model is fitted, we can use it to infer semantic equivalence of arbitrary sentence pairs. Note that for demonstration purpose we skipped the warmup learning rate schedule and validation on dev dataset used in the original implementation. This means that our model's performance will be significantly less than optimal. Please visit [here](https://gluon-nlp.mxnet.io/model_zoo/bert/index.html) for the complete fine-tuning scripts (using Python and GluonNLP).\n", + "\n", + "To do inference with our model we need a predictor. It must have a batch size of 1 so we can feed the model a single sentence pair." + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "#'bert.bert-sentence-classification/fine-tuned-predictor" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "(def fine-tuned-prefix \"fine-tune-sentence-bert\")\n", + "\n", + "(m/save-checkpoint fine-tune-model {:prefix fine-tuned-prefix :epoch 3})\n", + "\n", + "(def fine-tuned-predictor\n", + " (infer/create-predictor (infer/model-factory fine-tuned-prefix\n", + " [{:name \"data0\" :shape [1 seq-length] :dtype dtype/FLOAT32 :layout layout/NT}\n", + " {:name \"data1\" :shape [1 seq-length] :dtype dtype/FLOAT32 :layout layout/NT}\n", + " {:name \"data2\" :shape [1] :dtype dtype/FLOAT32 :layout layout/N}])\n", + " {:epoch 3}))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now we can write a function that feeds a sentence pair to the fine-tuned model:" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "#'bert.bert-sentence-classification/predict-equivalence" + ] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "(defn predict-equivalence\n", + " [predictor sentence1 sentence2]\n", + " (let [vocab (bert.util/get-vocab)\n", + " processed-test-data (mapv #(pre-processing (:idx->token vocab)\n", + " (:token->idx vocab) %)\n", + " [[sentence1 sentence2]])\n", + " prediction (infer/predict-with-ndarray predictor\n", + " [(ndarray/array (slice-inputs-data processed-test-data 0) [1 seq-length])\n", + " (ndarray/array (slice-inputs-data processed-test-data 1) [1 seq-length])\n", + " (ndarray/array (slice-inputs-data processed-test-data 2) [1])])]\n", + " (ndarray/->vec (first prediction))))" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[0.2633881 0.7366119]" + ] + }, + "execution_count": 22, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + ";; Modify an existing sentence pair to test:\n", + ";; [\"1\"\n", + ";; \"69773\"\n", + ";; \"69792\"\n", + ";; \"Cisco pared spending to compensate for sluggish sales .\"\n", + ";; \"In response to sluggish sales , Cisco pared spending .\"]\n", + "(predict-equivalence fine-tuned-predictor\n", + " \"The company cut spending to compensate for weak sales .\"\n", + " \"In response to poor sales results, the company cut spending .\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## References\n", + "\n", + "[1] Peters, Matthew E., et al. “Deep contextualized word representations.” arXiv preprint arXiv:1802.05365 (2018).\n", + "\n", + "[2] Devlin, Jacob, et al. “Bert: Pre-training of deep bidirectional transformers for language understanding.” arXiv preprint arXiv:1810.04805 (2018)." + ] } ], "metadata": { diff --git a/contrib/clojure-package/examples/bert/src/bert/bert_sentence_classification.clj b/contrib/clojure-package/examples/bert/src/bert/bert_sentence_classification.clj index 8c056b719feb..6ec4d586ad17 100644 --- a/contrib/clojure-package/examples/bert/src/bert/bert_sentence_classification.clj +++ b/contrib/clojure-package/examples/bert/src/bert/bert_sentence_classification.clj @@ -16,12 +16,21 @@ ;; (ns bert.bert-sentence-classification + "Fine-tuning Sentence Pair Classification with BERT + This tutorial focuses on fine-tuning with the pre-trained BERT model to classify semantically equivalent sentence pairs. + + Specifically, we will: + 1. load the state-of-the-art pre-trained BERT model + 2. attach an additional layer for classification + 3. process and transform sentence pair data for the task at hand + 4. fine-tune BERT model for sentence classification" (:require [bert.util :as bert-util] [clojure-csv.core :as csv] [clojure.string :as string] [org.apache.clojure-mxnet.callback :as callback] [org.apache.clojure-mxnet.context :as context] [org.apache.clojure-mxnet.dtype :as dtype] + [org.apache.clojure-mxnet.infer :as infer] [org.apache.clojure-mxnet.io :as mx-io] [org.apache.clojure-mxnet.layout :as layout] [org.apache.clojure-mxnet.module :as m] @@ -29,8 +38,25 @@ [org.apache.clojure-mxnet.optimizer :as optimizer] [org.apache.clojure-mxnet.symbol :as sym])) +;; Pre-trained language representations have been shown to improve +;; many downstream NLP tasks such as question answering, and natural +;; language inference. To apply pre-trained representations to these +;; tasks, there are two strategies: + +;; * feature-based approach, which uses the pre-trained representations as additional features to the downstream task. +;; * fine-tuning based approach, which trains the downstream tasks by fine-tuning pre-trained parameters. + +;; While feature-based approaches such as ELMo are effective in +;; improving many downstream tasks, they require task-specific +;; architectures. Devlin, Jacob, et al proposed BERT (Bidirectional +;; Encoder Representations from Transformers), which fine-tunes deep +;; bidirectional representations on a wide range of tasks with minimal +;; task-specific parameters, and obtained state-of-the-art results. + (def model-path-prefix "data/static_bert_base_net") -;; epoch number of the model + +(def fine-tuned-prefix "fine-tune-sentence-bert") + ;; the maximum length of the sequence (def seq-length 128) @@ -38,20 +64,19 @@ "Preprocesses the sentences in the format that BERT is expecting" [idx->token token->idx train-item] (let [[sentence-a sentence-b label] train-item - ;;; pre-processing tokenize sentence + ;; pre-processing tokenize sentence token-1 (bert-util/tokenize (string/lower-case sentence-a)) token-2 (bert-util/tokenize (string/lower-case sentence-b)) valid-length (+ (count token-1) (count token-2)) - ;;; generate token types [0000...1111...0000] + ;; generate token types [0000...1111...0000] qa-embedded (into (bert-util/pad [] 0 (count token-1)) - (bert-util/pad [] 1 (count token-2))) token-types (bert-util/pad qa-embedded 0 seq-length) - ;;; make BERT pre-processing standard + ;; make BERT pre-processing standard token-2 (conj token-2 "[SEP]") token-1 (into [] (concat ["[CLS]"] token-1 ["[SEP]"] token-2)) tokens (bert-util/pad token-1 "[PAD]" seq-length) - ;;; pre-processing - token to index translation + ;; pre-processing - token to index translation indexes (bert-util/tokens->idxs token->idx tokens)] {:input-batch [indexes token-types @@ -83,19 +108,18 @@ (defn get-raw-data [] (csv/parse-csv (string/replace (slurp "data/dev.tsv") "\"" "") - :delimiter \tab - :strict true)) + :delimiter \tab + :strict true)) (defn prepare-data - "This prepares the senetence pairs into NDArrays for use in NDArrayIterator" - [] - (let [raw-file (get-raw-data) - vocab (bert-util/get-vocab) + "This prepares the sentence pairs into NDArrays for use in NDArrayIterator" + [raw-data] + (let [vocab (bert-util/get-vocab) idx->token (:idx->token vocab) token->idx (:token->idx vocab) - data-train-raw (->> raw-file + data-train-raw (->> raw-data (mapv #(vals (select-keys % [3 4 0]))) - (rest) ;;drop header + (rest) ; drop header (into [])) processed-datas (mapv #(pre-processing idx->token token->idx %) data-train-raw)] {:data0s (slice-inputs-data processed-datas 0) @@ -111,7 +135,7 @@ [dev num-epoch] (let [bert-base (m/load-checkpoint {:prefix model-path-prefix :epoch 0}) model-sym (fine-tune-model (m/symbol bert-base) {:num-classes 2 :dropout 0.1}) - {:keys [data0s data1s data2s labels train-num]} (prepare-data) + {:keys [data0s data1s data2s labels train-num]} (prepare-data (get-raw-data)) batch-size 32 data-desc0 (mx-io/data-desc {:name "data0" :shape [train-num seq-length] @@ -138,14 +162,16 @@ {:label {label-desc (ndarray/array labels [train-num] {:ctx dev})} :data-batch-size batch-size}) - model (m/module model-sym {:contexts [dev] - :data-names ["data0" "data1" "data2"]})] - (m/fit model {:train-data train-data :num-epoch num-epoch - :fit-params (m/fit-params {:allow-missing true - :arg-params (m/arg-params bert-base) - :aux-params (m/aux-params bert-base) - :optimizer (optimizer/adam {:learning-rate 5e-6 :episilon 1e-9}) - :batch-end-callback (callback/speedometer batch-size 1)})}))) + fitted-model (m/fit (m/module model-sym {:contexts [dev] + :data-names ["data0" "data1" "data2"]}) + {:train-data train-data :num-epoch num-epoch + :fit-params (m/fit-params {:allow-missing true + :arg-params (m/arg-params bert-base) + :aux-params (m/aux-params bert-base) + :optimizer (optimizer/adam {:learning-rate 5e-6 :epsilon 1e-9}) + :batch-end-callback (callback/speedometer batch-size 1)})})] + (m/save-checkpoint fitted-model {:prefix fine-tuned-prefix :epoch num-epoch}) + fitted-model)) (defn -main [& args] (let [[dev-arg num-epoch-arg] args @@ -154,7 +180,46 @@ (println "Running example with " dev " and " num-epoch " epochs ") (train dev num-epoch))) +;; For evaluating the model +(defn predict-equivalence + "Get the fine-tuned model's opinion on whether two sentences are equivalent:" + [predictor sentence1 sentence2] + (let [vocab (bert.util/get-vocab) + processed-test-data (mapv #(pre-processing (:idx->token vocab) + (:token->idx vocab) %) + [[sentence1 sentence2]]) + prediction (infer/predict-with-ndarray predictor + [(ndarray/array (slice-inputs-data processed-test-data 0) [1 seq-length]) + (ndarray/array (slice-inputs-data processed-test-data 1) [1 seq-length]) + (ndarray/array (slice-inputs-data processed-test-data 2) [1])])] + (ndarray/->vec (first prediction)))) + (comment (train (context/cpu 0) 3) - (m/save-checkpoint model {:prefix "fine-tune-sentence-bert" :epoch 3})) + + (m/save-checkpoint model {:prefix fine-tuned-prefix :epoch 3}) + + + ;;;; Explore results from the fine-tuned model + + ;; We need a predictor with a batch size of 1, so we can feed the + ;; model a single sentence pair. + (def fine-tuned-predictor + (infer/create-predictor (infer/model-factory fine-tuned-prefix + [{:name "data0" :shape [1 seq-length] :dtype dtype/FLOAT32 :layout layout/NT} + {:name "data1" :shape [1 seq-length] :dtype dtype/FLOAT32 :layout layout/NT} + {:name "data2" :shape [1] :dtype dtype/FLOAT32 :layout layout/N}]) + {:epoch 3})) + + ;; Modify an existing sentence pair to test: + ;; ["1" + ;; "69773" + ;; "69792" + ;; "Cisco pared spending to compensate for sluggish sales ." + ;; "In response to sluggish sales , Cisco pared spending ."] + (predict-equivalence fine-tuned-predictor + "The company cut spending to compensate for weak sales ." + "In response to poor sales results, the company cut spending .") + + ) diff --git a/contrib/clojure-package/examples/bert/test/bert/bert_sentence_classification_test.clj b/contrib/clojure-package/examples/bert/test/bert/bert_sentence_classification_test.clj index 355f23ea3cfd..c26301e34fe6 100644 --- a/contrib/clojure-package/examples/bert/test/bert/bert_sentence_classification_test.clj +++ b/contrib/clojure-package/examples/bert/test/bert/bert_sentence_classification_test.clj @@ -26,6 +26,7 @@ [org.apache.clojure-mxnet.context :as context] [org.apache.clojure-mxnet.dtype :as dtype] [org.apache.clojure-mxnet.eval-metric :as eval-metric] + [org.apache.clojure-mxnet.infer :as infer] [org.apache.clojure-mxnet.io :as mx-io] [org.apache.clojure-mxnet.layout :as layout] [org.apache.clojure-mxnet.ndarray :as ndarray] @@ -34,6 +35,8 @@ (def model-dir "data/") +(def test-prefix "test-fine-tuning-bert-sentence-pairs") + (when-not (.exists (io/file (str model-dir "static_bert_qa-0002.params"))) (println "Downloading bert qa data") (sh "./get_bert_data.sh")) @@ -47,7 +50,7 @@ num-epoch 1 bert-base (m/load-checkpoint {:prefix model-path-prefix :epoch 0}) model-sym (fine-tune-model (m/symbol bert-base) {:num-classes 2 :dropout 0.1}) - {:keys [data0s data1s data2s labels train-num]} (prepare-data) + {:keys [data0s data1s data2s labels train-num]} (prepare-data (get-raw-data)) batch-size 32 data-desc0 (mx-io/data-desc {:name "data0" :shape [train-num seq-length] @@ -82,5 +85,20 @@ :aux-params (m/aux-params bert-base) :optimizer (optimizer/adam {:learning-rate 5e-6 :episilon 1e-9}) :batch-end-callback (callback/speedometer batch-size 1)})}) - (is (< 0.5 (-> (m/score model {:eval-data train-data :eval-metric (eval-metric/accuracy) }) - (last))))))) + (m/save-checkpoint model {:prefix test-prefix :epoch num-epoch}) + (testing "accuracy" + (is (< 0.5 (last (m/score model {:eval-data train-data :eval-metric (eval-metric/accuracy)}))))) + (testing "prediction" + (let [test-predictor (infer/create-predictor (infer/model-factory test-prefix + [{:name "data0" :shape [1 seq-length] :dtype dtype/FLOAT32 :layout layout/NT} + {:name "data1" :shape [1 seq-length] :dtype dtype/FLOAT32 :layout layout/NT} + {:name "data2" :shape [1] :dtype dtype/FLOAT32 :layout layout/N}]) + {:epoch num-epoch}) + prediction (predict-equivalence test-predictor + "The company cut spending to compensate for weak sales ." + "In response to poor sales results, the company cut spending .")] + ;; We can't say much about how the model will find this prediction, so we test only the prediction's shape. + (is (vector? prediction)) + (is (number? (first prediction))) + (is (number? (second prediction))) + (is (= 2 (count prediction)))))))) diff --git a/contrib/clojure-package/examples/infer/predictor/src/infer/predictor_example.clj b/contrib/clojure-package/examples/infer/predictor/src/infer/predictor_example.clj index 05eb0add3138..41a003a86ce0 100644 --- a/contrib/clojure-package/examples/infer/predictor/src/infer/predictor_example.clj +++ b/contrib/clojure-package/examples/infer/predictor/src/infer/predictor_example.clj @@ -99,3 +99,9 @@ (:help options) (println summary) (some? errors) (println (join "\n" errors)) :else (run-predictor options)))) + +(comment + (run-predictor {:model-path-prefix "models/resnet-18/resnet-18" + :input-image "images/kitten.jpg"}) + + )