Skip to content

Commit

Permalink
Fix Clojure BERT example's context argument (apache#14843)
Browse files Browse the repository at this point in the history
* Clojure BERT example: minor code cleanup

* Remove unused requires
* Remove unused vars & function
* Use `io` alias

* Clojure BERT example: whitespace fix

* Clojure BERT example: allow running with GPU

The `infer` function accepts a CPU/GPU context, which the command line
version of this example exposes as a `:cpu`/`:gpu`
keyword. Previously, these options were ignored and the context was
overridden to the default context (CPU). This commit allows
users (both REPL and shell) to pass in a GPU context.
  • Loading branch information
daveliepmann authored and haohuw committed Jun 23, 2019
1 parent 81c2425 commit 45d2127
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 32 deletions.
5 changes: 2 additions & 3 deletions contrib/clojure-package/examples/bert-qa/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,8 @@ Some sample questions and answers are provide in the `squad-sample.edn` file. So

* `lein install` in the root of the main project directory
* cd into this project directory and do `lein run`. This will execute the cpu version.

`lein run :cpu` - to run with cpu
`lein run :gpu` - to run with gpu
* `lein run` or `lein run :cpu` to run with cpu
* `lein run :gpu` to run with gpu

## Background

Expand Down
53 changes: 24 additions & 29 deletions contrib/clojure-package/examples/bert-qa/src/bert_qa/infer.clj
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,10 @@
;; limitations under the License.
;;


(ns bert-qa.infer
(:require [clojure.string :as string]
[clojure.reflect :as r]
[cheshire.core :as json]
[clojure.java.io :as io]
[clojure.set :as set]
[org.apache.clojure-mxnet.dtype :as dtype]
[org.apache.clojure-mxnet.context :as context]
[org.apache.clojure-mxnet.layout :as layout]
Expand All @@ -30,11 +27,7 @@
[clojure.pprint :as pprint]))

(def model-path-prefix "model/static_bert_qa")
;; epoch number of the model
(def epoch 2)
;; the vocabulary used in the model
(def model-vocab "model/vocab.json")
;; the input question

;; the maximum length of the sequence
(def seq-length 384)

Expand All @@ -60,16 +53,13 @@
(into tokens (repeat (- num (count tokens)) pad-item))))

(defn get-vocab []
(let [vocab (json/parse-stream (clojure.java.io/reader "model/vocab.json"))]
(let [vocab (json/parse-stream (io/reader "model/vocab.json"))]
{:idx->token (get vocab "idx_to_token")
:token->idx (get vocab "token_to_idx")}))

(defn tokens->idxs [token->idx tokens]
(let [unk-idx (get token->idx "[UNK]")]
(mapv #(get token->idx % unk-idx) tokens)))

(defn idxs->tokens [idx->token idxs]
(mapv #(get idx->token %) idxs))
(mapv #(get token->idx % unk-idx) tokens)))

(defn post-processing [result tokens]
(let [output1 (ndarray/slice-axis result 2 0 1)
Expand Down Expand Up @@ -131,22 +121,23 @@
:tokens tokens
:qa-map qa-map}))

(defn infer [ctx]
(let [ctx (context/default-context)
predictor (make-predictor ctx)
{:keys [idx->token token->idx]} (get-vocab)
(defn infer
([] (infer (context/default-context)))
([ctx]
(let [predictor (make-predictor ctx)
{:keys [idx->token token->idx]} (get-vocab)
;;; samples taken from https://rajpurkar.github.io/SQuAD-explorer/explore/v2.0/dev/
question-answers (clojure.edn/read-string (slurp "squad-samples.edn"))]
(doseq [qa-map question-answers]
(let [{:keys [input-batch tokens qa-map]} (pre-processing ctx idx->token token->idx qa-map)
result (first (infer/predict-with-ndarray predictor input-batch))
answer (post-processing result tokens)]
(println "===============================")
(println " Question Answer Data")
(pprint/pprint qa-map)
(println)
(println " Predicted Answer: " answer)
(println "===============================")))))
question-answers (clojure.edn/read-string (slurp "squad-samples.edn"))]
(doseq [qa-map question-answers]
(let [{:keys [input-batch tokens qa-map]} (pre-processing ctx idx->token token->idx qa-map)
result (first (infer/predict-with-ndarray predictor input-batch))
answer (post-processing result tokens)]
(println "===============================")
(println " Question Answer Data")
(pprint/pprint qa-map)
(println)
(println " Predicted Answer: " answer)
(println "==============================="))))))

(defn -main [& args]
(let [[dev] args]
Expand All @@ -156,4 +147,8 @@

(comment

(infer :cpu))
(infer)

(infer (context/gpu))

)

0 comments on commit 45d2127

Please sign in to comment.