-
Notifications
You must be signed in to change notification settings - Fork 6.8k
Extend Clojure BERT example #15023
Extend Clojure BERT example #15023
Changes from 8 commits
82cdb08
bf0ce38
146114d
06e8e82
3714452
3597e8a
0efa7b5
6aaa7cc
96775fb
a51264e
dde8b6f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,42 +16,67 @@ | |
;; | ||
|
||
(ns bert.bert-sentence-classification | ||
"Fine-tuning Sentence Pair Classification with BERT | ||
daveliepmann marked this conversation as resolved.
Show resolved
Hide resolved
|
||
This tutorial focuses on fine-tuning with the pre-trained BERT model to classify semantically equivalent sentence pairs. | ||
|
||
Specifically, we will: | ||
1. load the state-of-the-art pre-trained BERT model | ||
2. attach an additional layer for classification | ||
3. process and transform sentence pair data for the task at hand | ||
4. fine-tune BERT model for sentence classification" | ||
(:require [bert.util :as bert-util] | ||
[clojure-csv.core :as csv] | ||
[clojure.string :as string] | ||
[org.apache.clojure-mxnet.callback :as callback] | ||
[org.apache.clojure-mxnet.context :as context] | ||
[org.apache.clojure-mxnet.dtype :as dtype] | ||
[org.apache.clojure-mxnet.infer :as infer] | ||
[org.apache.clojure-mxnet.io :as mx-io] | ||
[org.apache.clojure-mxnet.layout :as layout] | ||
[org.apache.clojure-mxnet.module :as m] | ||
[org.apache.clojure-mxnet.ndarray :as ndarray] | ||
[org.apache.clojure-mxnet.optimizer :as optimizer] | ||
[org.apache.clojure-mxnet.symbol :as sym])) | ||
|
||
;; Pre-trained language representations have been shown to improve | ||
;; many downstream NLP tasks such as question answering, and natural | ||
;; language inference. To apply pre-trained representations to these | ||
;; tasks, there are two strategies: | ||
|
||
;; * feature-based approach, which uses the pre-trained representations as additional features to the downstream task. | ||
;; * fine-tuning based approach, which trains the downstream tasks by fine-tuning pre-trained parameters. | ||
|
||
;; While feature-based approaches such as ELMo are effective in | ||
;; improving many downstream tasks, they require task-specific | ||
;; architectures. Devlin, Jacob, et al proposed BERT (Bidirectional | ||
;; Encoder Representations from Transformers), which fine-tunes deep | ||
;; bidirectional representations on a wide range of tasks with minimal | ||
;; task-specific parameters, and obtained state-of-the-art results. | ||
|
||
(def model-path-prefix "data/static_bert_base_net") | ||
;; epoch number of the model | ||
|
||
(def fine-tuned-prefix "fine-tune-sentence-bert") | ||
|
||
;; the maximum length of the sequence | ||
(def seq-length 128) | ||
|
||
(defn pre-processing | ||
"Preprocesses the sentences in the format that BERT is expecting" | ||
[idx->token token->idx train-item] | ||
(let [[sentence-a sentence-b label] train-item | ||
;;; pre-processing tokenize sentence | ||
;; pre-processing tokenize sentence | ||
token-1 (bert-util/tokenize (string/lower-case sentence-a)) | ||
token-2 (bert-util/tokenize (string/lower-case sentence-b)) | ||
valid-length (+ (count token-1) (count token-2)) | ||
;;; generate token types [0000...1111...0000] | ||
;; generate token types [0000...1111...0000] | ||
qa-embedded (into (bert-util/pad [] 0 (count token-1)) | ||
|
||
(bert-util/pad [] 1 (count token-2))) | ||
token-types (bert-util/pad qa-embedded 0 seq-length) | ||
;;; make BERT pre-processing standard | ||
;; make BERT pre-processing standard | ||
token-2 (conj token-2 "[SEP]") | ||
token-1 (into [] (concat ["[CLS]"] token-1 ["[SEP]"] token-2)) | ||
tokens (bert-util/pad token-1 "[PAD]" seq-length) | ||
;;; pre-processing - token to index translation | ||
;; pre-processing - token to index translation | ||
indexes (bert-util/tokens->idxs token->idx tokens)] | ||
{:input-batch [indexes | ||
token-types | ||
|
@@ -83,19 +108,18 @@ | |
|
||
(defn get-raw-data [] | ||
(csv/parse-csv (string/replace (slurp "data/dev.tsv") "\"" "") | ||
:delimiter \tab | ||
:strict true)) | ||
:delimiter \tab | ||
:strict true)) | ||
|
||
(defn prepare-data | ||
"This prepares the senetence pairs into NDArrays for use in NDArrayIterator" | ||
[] | ||
(let [raw-file (get-raw-data) | ||
vocab (bert-util/get-vocab) | ||
"This prepares the sentence pairs into NDArrays for use in NDArrayIterator" | ||
[raw-data] | ||
(let [vocab (bert-util/get-vocab) | ||
idx->token (:idx->token vocab) | ||
token->idx (:token->idx vocab) | ||
data-train-raw (->> raw-file | ||
data-train-raw (->> raw-data | ||
(mapv #(vals (select-keys % [3 4 0]))) | ||
(rest) ;;drop header | ||
(rest) ; drop header | ||
(into [])) | ||
processed-datas (mapv #(pre-processing idx->token token->idx %) data-train-raw)] | ||
{:data0s (slice-inputs-data processed-datas 0) | ||
|
@@ -111,7 +135,7 @@ | |
[dev num-epoch] | ||
(let [bert-base (m/load-checkpoint {:prefix model-path-prefix :epoch 0}) | ||
model-sym (fine-tune-model (m/symbol bert-base) {:num-classes 2 :dropout 0.1}) | ||
{:keys [data0s data1s data2s labels train-num]} (prepare-data) | ||
{:keys [data0s data1s data2s labels train-num]} (prepare-data (get-raw-data)) | ||
batch-size 32 | ||
data-desc0 (mx-io/data-desc {:name "data0" | ||
:shape [train-num seq-length] | ||
|
@@ -138,14 +162,16 @@ | |
{:label {label-desc (ndarray/array labels [train-num] | ||
{:ctx dev})} | ||
:data-batch-size batch-size}) | ||
model (m/module model-sym {:contexts [dev] | ||
:data-names ["data0" "data1" "data2"]})] | ||
(m/fit model {:train-data train-data :num-epoch num-epoch | ||
:fit-params (m/fit-params {:allow-missing true | ||
:arg-params (m/arg-params bert-base) | ||
:aux-params (m/aux-params bert-base) | ||
:optimizer (optimizer/adam {:learning-rate 5e-6 :episilon 1e-9}) | ||
:batch-end-callback (callback/speedometer batch-size 1)})}))) | ||
fitted-model (m/fit (m/module model-sym {:contexts [dev] | ||
:data-names ["data0" "data1" "data2"]}) | ||
{:train-data train-data :num-epoch num-epoch | ||
:fit-params (m/fit-params {:allow-missing true | ||
:arg-params (m/arg-params bert-base) | ||
:aux-params (m/aux-params bert-base) | ||
:optimizer (optimizer/adam {:learning-rate 5e-6 :epsilon 1e-9}) | ||
daveliepmann marked this conversation as resolved.
Show resolved
Hide resolved
|
||
:batch-end-callback (callback/speedometer batch-size 1)})})] | ||
(m/save-checkpoint fitted-model {:prefix fine-tuned-prefix :epoch num-epoch}) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why do we save the model to disk now? Could we pass in a parameter to the function to to it? This function seems to do too many things? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We have to save the model to disk because there's no other way but saving to disk and loading it back in to get a prediction out of it. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. With the |
||
fitted-model)) | ||
|
||
(defn -main [& args] | ||
(let [[dev-arg num-epoch-arg] args | ||
|
@@ -157,4 +183,42 @@ | |
(comment | ||
|
||
(train (context/cpu 0) 3) | ||
(m/save-checkpoint model {:prefix "fine-tune-sentence-bert" :epoch 3})) | ||
|
||
(m/save-checkpoint model {:prefix fine-tuned-prefix :epoch 3}) | ||
|
||
|
||
;;;; Explore results from the fine-tuned model | ||
|
||
;; We need a predictor with a batch size of 1, so we can feed the | ||
;; model a single sentence pair. | ||
(def fine-tuned-predictor | ||
daveliepmann marked this conversation as resolved.
Show resolved
Hide resolved
|
||
(infer/create-predictor (infer/model-factory fine-tuned-prefix | ||
[{:name "data0" :shape [1 seq-length] :dtype dtype/FLOAT32 :layout layout/NT} | ||
{:name "data1" :shape [1 seq-length] :dtype dtype/FLOAT32 :layout layout/NT} | ||
{:name "data2" :shape [1] :dtype dtype/FLOAT32 :layout layout/N}]) | ||
{:epoch 3})) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why do we need this hardcoded epoch number here? Can't we just use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As I recall, we have to hard-code the epoch because otherwise we don't know which saved model to load from disk. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Responding to edit: Another reason not to |
||
|
||
;; Get the fine-tuned model's opinion on whether two sentences are equivalent: | ||
(defn predict-equivalence | ||
[predictor sentence1 sentence2] | ||
(let [vocab (bert.util/get-vocab) | ||
processed-test-data (mapv #(pre-processing (:idx->token vocab) | ||
(:token->idx vocab) %) | ||
[[sentence1 sentence2]]) | ||
prediction (infer/predict-with-ndarray predictor | ||
[(ndarray/array (slice-inputs-data processed-test-data 0) [1 seq-length]) | ||
(ndarray/array (slice-inputs-data processed-test-data 1) [1 seq-length]) | ||
(ndarray/array (slice-inputs-data processed-test-data 2) [1])])] | ||
(ndarray/->vec (first prediction)))) | ||
|
||
;; Modify an existing sentence pair to test: | ||
;; ["1" | ||
;; "69773" | ||
;; "69792" | ||
;; "Cisco pared spending to compensate for sluggish sales ." | ||
;; "In response to sluggish sales , Cisco pared spending ."] | ||
(predict-equivalence fine-tuned-predictor | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. did you want to add a test for this? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just pushed one. Thanks for the idea—my PR broke the existing test, but I guess tests for examples aren't part of the CI checks. |
||
"The company cut spending to compensate for weak sales ." | ||
"In response to poor sales results, the company cut spending .") | ||
|
||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what's this the output of?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[0.2633881 0.7366119]
is the output of our sample sentence pair equivalence prediction:I'm not sure why the result appears before its expression in the .ipynb file, but on my machine it displays this pair correctly as "In [22]" followed by "Out [22]".