diff --git a/contrib/clojure-package/examples/bert-qa/README.md b/contrib/clojure-package/examples/bert-qa/README.md index 9a21bcdfd66b..55f13e671c00 100644 --- a/contrib/clojure-package/examples/bert-qa/README.md +++ b/contrib/clojure-package/examples/bert-qa/README.md @@ -57,9 +57,8 @@ Some sample questions and answers are provide in the `squad-sample.edn` file. So * `lein install` in the root of the main project directory * cd into this project directory and do `lein run`. This will execute the cpu version. - -`lein run :cpu` - to run with cpu -`lein run :gpu` - to run with gpu + * `lein run` or `lein run :cpu` to run with cpu + * `lein run :gpu` to run with gpu ## Background diff --git a/contrib/clojure-package/examples/bert-qa/src/bert_qa/infer.clj b/contrib/clojure-package/examples/bert-qa/src/bert_qa/infer.clj index 836684e04977..9dcc783ff1ac 100644 --- a/contrib/clojure-package/examples/bert-qa/src/bert_qa/infer.clj +++ b/contrib/clojure-package/examples/bert-qa/src/bert_qa/infer.clj @@ -15,13 +15,10 @@ ;; limitations under the License. ;; - (ns bert-qa.infer (:require [clojure.string :as string] - [clojure.reflect :as r] [cheshire.core :as json] [clojure.java.io :as io] - [clojure.set :as set] [org.apache.clojure-mxnet.dtype :as dtype] [org.apache.clojure-mxnet.context :as context] [org.apache.clojure-mxnet.layout :as layout] @@ -30,11 +27,7 @@ [clojure.pprint :as pprint])) (def model-path-prefix "model/static_bert_qa") -;; epoch number of the model -(def epoch 2) -;; the vocabulary used in the model -(def model-vocab "model/vocab.json") -;; the input question + ;; the maximum length of the sequence (def seq-length 384) @@ -60,16 +53,13 @@ (into tokens (repeat (- num (count tokens)) pad-item)))) (defn get-vocab [] - (let [vocab (json/parse-stream (clojure.java.io/reader "model/vocab.json"))] + (let [vocab (json/parse-stream (io/reader "model/vocab.json"))] {:idx->token (get vocab "idx_to_token") :token->idx (get vocab "token_to_idx")})) (defn tokens->idxs [token->idx tokens] (let [unk-idx (get token->idx "[UNK]")] - (mapv #(get token->idx % unk-idx) tokens))) - -(defn idxs->tokens [idx->token idxs] - (mapv #(get idx->token %) idxs)) + (mapv #(get token->idx % unk-idx) tokens))) (defn post-processing [result tokens] (let [output1 (ndarray/slice-axis result 2 0 1) @@ -131,22 +121,23 @@ :tokens tokens :qa-map qa-map})) -(defn infer [ctx] - (let [ctx (context/default-context) - predictor (make-predictor ctx) - {:keys [idx->token token->idx]} (get-vocab) +(defn infer + ([] (infer (context/default-context))) + ([ctx] + (let [predictor (make-predictor ctx) + {:keys [idx->token token->idx]} (get-vocab) ;;; samples taken from https://rajpurkar.github.io/SQuAD-explorer/explore/v2.0/dev/ - question-answers (clojure.edn/read-string (slurp "squad-samples.edn"))] - (doseq [qa-map question-answers] - (let [{:keys [input-batch tokens qa-map]} (pre-processing ctx idx->token token->idx qa-map) - result (first (infer/predict-with-ndarray predictor input-batch)) - answer (post-processing result tokens)] - (println "===============================") - (println " Question Answer Data") - (pprint/pprint qa-map) - (println) - (println " Predicted Answer: " answer) - (println "==============================="))))) + question-answers (clojure.edn/read-string (slurp "squad-samples.edn"))] + (doseq [qa-map question-answers] + (let [{:keys [input-batch tokens qa-map]} (pre-processing ctx idx->token token->idx qa-map) + result (first (infer/predict-with-ndarray predictor input-batch)) + answer (post-processing result tokens)] + (println "===============================") + (println " Question Answer Data") + (pprint/pprint qa-map) + (println) + (println " Predicted Answer: " answer) + (println "===============================")))))) (defn -main [& args] (let [[dev] args] @@ -156,4 +147,8 @@ (comment - (infer :cpu)) + (infer) + + (infer (context/gpu)) + + )