Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
run formatter
Browse files Browse the repository at this point in the history
  • Loading branch information
gigasquid committed Apr 26, 2019
1 parent 7de5a26 commit a3293f3
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 30 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -38,25 +38,24 @@
;; the maximum length of the sequence
(def seq-length 128)


(defn pre-processing
"Preprocesses the sentences in the format that BERT is expecting"
[ctx idx->token token->idx train-item]
(let [[sentence-a sentence-b label] train-item
(let [[sentence-a sentence-b label] train-item
;;; pre-processing tokenize sentence
token-1 (bert-util/tokenize (string/lower-case sentence-a))
token-2 (bert-util/tokenize (string/lower-case sentence-b))
valid-length (+ (count token-1) (count token-2))
token-1 (bert-util/tokenize (string/lower-case sentence-a))
token-2 (bert-util/tokenize (string/lower-case sentence-b))
valid-length (+ (count token-1) (count token-2))
;;; generate token types [0000...1111...0000]
qa-embedded (into (bert-util/pad [] 0 (count token-1))
(bert-util/pad [] 1 (count token-2)))
token-types (bert-util/pad qa-embedded 0 seq-length)
qa-embedded (into (bert-util/pad [] 0 (count token-1))
(bert-util/pad [] 1 (count token-2)))
token-types (bert-util/pad qa-embedded 0 seq-length)
;;; make BERT pre-processing standard
token-2 (conj token-2 "[SEP]")
token-1 (into [] (concat ["[CLS]"] token-1 ["[SEP]"] token-2))
tokens (bert-util/pad token-1 "[PAD]" seq-length)
token-2 (conj token-2 "[SEP]")
token-1 (into [] (concat ["[CLS]"] token-1 ["[SEP]"] token-2))
tokens (bert-util/pad token-1 "[PAD]" seq-length)
;;; pre-processing - token to index translation
indexes (bert-util/tokens->idxs token->idx tokens)]
indexes (bert-util/tokens->idxs token->idx tokens)]
{:input-batch [indexes
token-types
[valid-length]]
Expand All @@ -76,7 +75,6 @@
(sym/fully-connected "fc-finetune" {:data data :num-hidden num-classes})
(sym/softmax-output "softmax" {:data data})))


(defn slice-inputs-data
"Each sentence pair had to be processed as a row. This breaks all
the rows up into a column for creating a NDArray"
Expand Down Expand Up @@ -106,7 +104,6 @@
(into []))
:train-num (count processed-datas)}))


(defn train
"Trains (fine tunes) the sentence pairs for a classification task on the BERT Base model"
[dev num-epoch]
Expand All @@ -115,30 +112,30 @@
{:keys [data0s data1s data2s labels train-num]} (prepare-data dev)
batch-size 32
data-desc0 (mx-io/data-desc {:name "data0"
:shape [train-num seq-length]
:dtype dtype/FLOAT32
:shape [train-num seq-length]
:dtype dtype/FLOAT32
:layout layout/NT})
data-desc1 (mx-io/data-desc {:name "data1"
:shape [train-num seq-length]
:dtype dtype/FLOAT32
:shape [train-num seq-length]
:dtype dtype/FLOAT32
:layout layout/NT})
data-desc2 (mx-io/data-desc {:name "data2"
:shape [train-num]
:dtype dtype/FLOAT32
:layout layout/N})
label-desc (mx-io/data-desc {:name "softmax_label"
:shape [train-num]
:dtype dtype/FLOAT32
:dtype dtype/FLOAT32
:layout layout/N})
train-data (mx-io/ndarray-iter {data-desc0 (ndarray/array data0s [train-num seq-length]
{:ctx dev})
data-desc1 (ndarray/array data1s [train-num seq-length]
{:ctx dev})
data-desc2 (ndarray/array data2s [train-num]
{:ctx dev})}
{:label {label-desc (ndarray/array labels [train-num]
{:ctx dev})
data-desc1 (ndarray/array data1s [train-num seq-length]
{:ctx dev})
data-desc2 (ndarray/array data2s [train-num]
{:ctx dev})}
{:label {label-desc (ndarray/array labels [train-num]
{:ctx dev})}
:data-batch-size batch-size})
:data-batch-size batch-size})
model (m/module model-sym {:contexts [dev]
:data-names ["data0" "data1" "data2"]})]
(m/fit model {:train-data train-data :num-epoch num-epoch
Expand All @@ -157,6 +154,4 @@
(comment

(train (context/cpu 0) 3)
(m/save-checkpoint model {:prefix "fine-tune-sentence-bert" :epoch 3})

)
(m/save-checkpoint model {:prefix "fine-tune-sentence-bert" :epoch 3}))
2 changes: 1 addition & 1 deletion contrib/clojure-package/examples/bert/src/bert/util.clj
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@

(defn tokens->idxs [token->idx tokens]
(let [unk-idx (get token->idx "[UNK]")]
(mapv #(get token->idx % unk-idx) tokens)))
(mapv #(get token->idx % unk-idx) tokens)))

(defn idxs->tokens [idx->token idxs]
(mapv #(get idx->token %) idxs))

0 comments on commit a3293f3

Please sign in to comment.