Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Fix Clojure BERT example's context argument #14843

Merged
merged 3 commits into from
Apr 30, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions contrib/clojure-package/examples/bert-qa/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,8 @@ Some sample questions and answers are provide in the `squad-sample.edn` file. So

* `lein install` in the root of the main project directory
* cd into this project directory and do `lein run`. This will execute the cpu version.

`lein run :cpu` - to run with cpu
`lein run :gpu` - to run with gpu
* `lein run` or `lein run :cpu` to run with cpu
* `lein run :gpu` to run with gpu

## Background

Expand Down
53 changes: 24 additions & 29 deletions contrib/clojure-package/examples/bert-qa/src/bert_qa/infer.clj
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,10 @@
;; limitations under the License.
;;


(ns bert-qa.infer
(:require [clojure.string :as string]
[clojure.reflect :as r]
[cheshire.core :as json]
[clojure.java.io :as io]
[clojure.set :as set]
[org.apache.clojure-mxnet.dtype :as dtype]
[org.apache.clojure-mxnet.context :as context]
[org.apache.clojure-mxnet.layout :as layout]
Expand All @@ -30,11 +27,7 @@
[clojure.pprint :as pprint]))

(def model-path-prefix "model/static_bert_qa")
;; epoch number of the model
(def epoch 2)
;; the vocabulary used in the model
(def model-vocab "model/vocab.json")
;; the input question

;; the maximum length of the sequence
(def seq-length 384)

Expand All @@ -60,16 +53,13 @@
(into tokens (repeat (- num (count tokens)) pad-item))))

(defn get-vocab []
(let [vocab (json/parse-stream (clojure.java.io/reader "model/vocab.json"))]
(let [vocab (json/parse-stream (io/reader "model/vocab.json"))]
{:idx->token (get vocab "idx_to_token")
:token->idx (get vocab "token_to_idx")}))

(defn tokens->idxs [token->idx tokens]
(let [unk-idx (get token->idx "[UNK]")]
(mapv #(get token->idx % unk-idx) tokens)))

(defn idxs->tokens [idx->token idxs]
(mapv #(get idx->token %) idxs))
(mapv #(get token->idx % unk-idx) tokens)))

(defn post-processing [result tokens]
(let [output1 (ndarray/slice-axis result 2 0 1)
Expand Down Expand Up @@ -131,22 +121,23 @@
:tokens tokens
:qa-map qa-map}))

(defn infer [ctx]
(let [ctx (context/default-context)
predictor (make-predictor ctx)
{:keys [idx->token token->idx]} (get-vocab)
(defn infer
([] (infer (context/default-context)))
([ctx]
(let [predictor (make-predictor ctx)
{:keys [idx->token token->idx]} (get-vocab)
;;; samples taken from https://rajpurkar.github.io/SQuAD-explorer/explore/v2.0/dev/
question-answers (clojure.edn/read-string (slurp "squad-samples.edn"))]
(doseq [qa-map question-answers]
(let [{:keys [input-batch tokens qa-map]} (pre-processing ctx idx->token token->idx qa-map)
result (first (infer/predict-with-ndarray predictor input-batch))
answer (post-processing result tokens)]
(println "===============================")
(println " Question Answer Data")
(pprint/pprint qa-map)
(println)
(println " Predicted Answer: " answer)
(println "===============================")))))
question-answers (clojure.edn/read-string (slurp "squad-samples.edn"))]
(doseq [qa-map question-answers]
(let [{:keys [input-batch tokens qa-map]} (pre-processing ctx idx->token token->idx qa-map)
result (first (infer/predict-with-ndarray predictor input-batch))
answer (post-processing result tokens)]
(println "===============================")
(println " Question Answer Data")
(pprint/pprint qa-map)
(println)
(println " Predicted Answer: " answer)
(println "==============================="))))))

(defn -main [& args]
(let [[dev] args]
Expand All @@ -156,4 +147,8 @@

(comment

(infer :cpu))
(infer)

(infer (context/gpu))

)