From 9454b168b41755231bc61ca9baa62b0b28170968 Mon Sep 17 00:00:00 2001 From: Qing Date: Mon, 1 Apr 2019 15:18:43 -0700 Subject: [PATCH 1/5] add BertQA major code piece --- scala-package/examples/pom.xml | 5 + .../javaapi/infer/bert/BertQA.java | 134 ++++++++++++++++++ .../javaapi/infer/bert/BertUtil.java | 75 ++++++++++ .../javaapi/infer/bert/README.md | 18 +++ 4 files changed, 232 insertions(+) create mode 100644 scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/bert/BertQA.java create mode 100644 scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/bert/BertUtil.java create mode 100644 scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/bert/README.md diff --git a/scala-package/examples/pom.xml b/scala-package/examples/pom.xml index 07e430175c6f..d60782ffd06b 100644 --- a/scala-package/examples/pom.xml +++ b/scala-package/examples/pom.xml @@ -145,5 +145,10 @@ slf4j-simple 1.7.5 + + com.google.code.gson + gson + 2.8.5 + diff --git a/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/bert/BertQA.java b/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/bert/BertQA.java new file mode 100644 index 000000000000..aee134af1fb7 --- /dev/null +++ b/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/bert/BertQA.java @@ -0,0 +1,134 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mxnetexamples.javaapi.infer.bert; + +import org.apache.mxnet.infer.javaapi.Predictor; +import org.apache.mxnet.javaapi.*; +import org.kohsuke.args4j.CmdLineParser; +import org.kohsuke.args4j.Option; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.*; + +public class BertQA { + @Option(name = "--model-path-prefix", usage = "input model directory and prefix of the model") + private String modelPathPrefix = "/model/static_bert_qa"; + @Option(name = "--model-epoch", usage = "Epoch number of the model") + private int epoch = 2; + @Option(name = "--model-vocab", usage = "the vocabulary used in the model") + private String modelVocab = "/model/vocab.json"; + @Option(name = "--input-question", usage = "the input question") + private String inputQ = "When did BBC Japan start broadcasting?"; + @Option(name = "--input-answer", usage = "the input answer") + private String inputA = + "BBC Japan was a general entertainment Channel.\n" + + " Which operated between December 2004 and April 2006.\n" + + "It ceased operations after its Japanese distributor folded."; + @Option(name = "--seq-length", usage = "the maximum length of the sequence") + private int seqLength = 384; + + + final static Logger logger = LoggerFactory.getLogger(BertQA.class); + private static NDArray$ NDArray = NDArray$.MODULE$; + + private static int argmax(float[] prob) { + int maxIdx = 0; + for (int i = 0; i < prob.length; i++) { + if (prob[maxIdx] < prob[i]) maxIdx = i; + } + return maxIdx; + } + + static void postProcessing(NDArray result, List tokens) { + NDArray output = NDArray.split( + NDArray.new splitParam(result, 2).setAxis(2))[0]; + // Get the formatted logits result + NDArray startLogits = output.at(0).reshape(new int[]{0, -3}); + NDArray endLogits = output.at(1).reshape(new int[]{0, -3}); + // Get Probability distribution + float[] startProb = NDArray.softmax( + NDArray.new softmaxParam(startLogits))[0].toArray(); + float[] endProb = NDArray.softmax( + NDArray.new softmaxParam(endLogits))[0].toArray(); + int startIdx = argmax(startProb); + int endIdx = argmax(endProb); + String[] answer = (String[]) tokens.subList(startIdx, endIdx + 1).toArray(); + logger.info("Answer: ", Arrays.toString(answer)); + } + + public static void main(String[] args) throws Exception{ + BertQA inst = new BertQA(); + CmdLineParser parser = new CmdLineParser(inst); + parser.parseArgument(args); + BertUtil util = new BertUtil(); + Context context = Context.cpu(); + logger.info("Question: ", inst.inputQ); + logger.info("Answer paragraph: ", inst.inputA); + // pre-processing - tokenize sentence + List tokenQ = util.tokenizer(inst.inputQ.toLowerCase()); + List tokenA = util.tokenizer(inst.inputA.toLowerCase()); + int validLength = tokenQ.size() + tokenA.size(); + logger.info("Valid length: ", validLength); + // generate token types [0000...1111....0000] + List QAEmbedded = new ArrayList<>(); + util.pad(QAEmbedded, 0f, tokenQ.size()).addAll( + util.pad(new ArrayList(), 1f, tokenA.size()) + ); + List tokenTypes = util.pad(QAEmbedded, 0f, inst.seqLength); + // make BERT pre-processing standard + tokenQ.add("[SEP]"); + tokenQ.add(0, "[CLS]"); + tokenA.add("[SEP]"); + tokenQ.addAll(tokenA); + List tokens = util.pad(tokenQ, "[PAD]", inst.seqLength); + logger.info("Pre-processed tokens: ", Arrays.toString(tokens.toArray())); + // pre-processing - token to index translation + util.parseJSON(inst.modelVocab); + List indexes = util.token2idx(tokens); + List indexesFloat = new ArrayList<>(); + for (int integer : indexes) { + indexesFloat.add((float) integer); + } + // Preparing the input data + NDArray inputs = new NDArray(indexesFloat, + new Shape(new int[]{1, inst.seqLength}), context); + NDArray tokenTypesND = new NDArray(tokenTypes, + new Shape(new int[]{1, inst.seqLength}), context); + NDArray validLengthND = new NDArray(new float[] {(float) validLength}, + new Shape(new int[]{1}), context); + List inputBatch = new ArrayList<>(); + inputBatch.add(inputs); + inputBatch.add(tokenTypesND); + inputBatch.add(validLengthND); + // Build the model + List inputDescs = new ArrayList<>(); + List contexts = new ArrayList<>(); + contexts.add(context); + inputDescs.add(new DataDesc("data0", + new Shape(new int[]{1, inst.seqLength}), DType.Float32(), "NT")); + inputDescs.add(new DataDesc("data1", + new Shape(new int[]{1, inst.seqLength}), DType.Float32(), "NT")); + inputDescs.add(new DataDesc("data2", + new Shape(new int[]{1}), DType.Float32(), "N")); + Predictor bertQA = new Predictor(inst.modelPathPrefix, inputDescs, contexts, inst.epoch); + // Start prediction + NDArray result = bertQA.predictWithNDArray(inputBatch).get(0); + postProcessing(result, tokens); + } +} diff --git a/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/bert/BertUtil.java b/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/bert/BertUtil.java new file mode 100644 index 000000000000..4196401086e1 --- /dev/null +++ b/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/bert/BertUtil.java @@ -0,0 +1,75 @@ +package org.apache.mxnetexamples.javaapi.infer.bert; + +import java.io.FileReader; +import java.util.*; + +import com.google.gson.Gson; +import com.google.gson.JsonArray; +import com.google.gson.JsonElement; +import com.google.gson.JsonObject; + +public class BertUtil { + + private Map token2idx; + private List idx2token; + + void parseJSON(String jsonFile) throws Exception { + Gson gson = new Gson(); + token2idx = new HashMap<>(); + idx2token = new LinkedList<>(); + JsonObject jsonObject = gson.fromJson(new FileReader(jsonFile), JsonObject.class); + JsonArray arr = jsonObject.getAsJsonArray("token_to_idx"); + for (JsonElement element : arr) { + idx2token.add(element.getAsString()); + } + JsonObject preMap = jsonObject.getAsJsonObject("idx_to_token"); + for (String key : preMap.keySet()) { + token2idx.put(key, jsonObject.get(key).getAsInt()); + } + } + + List tokenizer(String input) { + String[] step1 = input.split("[\n\r\t ]+"); + List finalResult = new LinkedList<>(); + for (String item : step1) { + if (item.length() != 0) { + if (item.split("[.,?!]+").length > 1) { + finalResult.add(item.substring(0, item.length() - 1)); + finalResult.add(item.substring(item.length() -1, item.length())); + } else { + finalResult.add(item); + } + } + } + return finalResult; + } + + List pad(List tokens, E padItem, int num) { + if (tokens.size() >= num) return tokens; + List padded = new LinkedList<>(tokens); + for (int i = 0; i < num - tokens.size(); i++) { + tokens.add(padItem); + } + return padded; + } + + List token2idx(List tokens) { + List indexes = new ArrayList<>(); + for (String token : tokens) { + if (token2idx.containsKey(token)) { + indexes.add(token2idx.get(token)); + } else { + indexes.add(token2idx.get("[UNK]")); + } + } + return indexes; + } + + List idx2token(List indexes) { + List tokens = new ArrayList<>(); + for (E index : indexes) { + tokens.add(idx2token.get((int) index)); + } + return tokens; + } +} diff --git a/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/bert/README.md b/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/bert/README.md new file mode 100644 index 000000000000..4fe4dfc9224d --- /dev/null +++ b/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/bert/README.md @@ -0,0 +1,18 @@ + + + + + + + + + + + + + + + + + +# BERT QA model using Java Inference API From 826db2986769062da76e8fce9de640bfe5d95bac Mon Sep 17 00:00:00 2001 From: Qing Date: Mon, 1 Apr 2019 21:26:36 -0700 Subject: [PATCH 2/5] add scripts and bug fixes --- .../scripts/infer/bert/get_bert_data.sh | 34 ++++++++ .../scripts/infer/bert/run_bert_qa_example.sh | 27 ++++++ .../javaapi/infer/bert/BertQA.java | 40 ++++++--- .../javaapi/infer/bert/BertUtil.java | 64 ++++++++++++-- .../javaapi/infer/bert/README.md | 87 ++++++++++++++++++- 5 files changed, 228 insertions(+), 24 deletions(-) create mode 100755 scala-package/examples/scripts/infer/bert/get_bert_data.sh create mode 100755 scala-package/examples/scripts/infer/bert/run_bert_qa_example.sh diff --git a/scala-package/examples/scripts/infer/bert/get_bert_data.sh b/scala-package/examples/scripts/infer/bert/get_bert_data.sh new file mode 100755 index 000000000000..14cc78c4dce7 --- /dev/null +++ b/scala-package/examples/scripts/infer/bert/get_bert_data.sh @@ -0,0 +1,34 @@ +#!/bin/bash + +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +set -e + +MXNET_ROOT=$(cd "$(dirname $0)/../../.."; pwd) + +data_path=$MXNET_ROOT/scripts/infer/models/static-bert-qa/ + +if [ ! -d "$data_path" ]; then + mkdir -p "$data_path" +fi + +if [ ! -f "$data_path" ]; then + curl https://s3.us-east-2.amazonaws.com/mxnet-scala/scala-example-ci/BertQA/vocab.json -o $data_path/vocab.json + curl https://s3.us-east-2.amazonaws.com/mxnet-scala/scala-example-ci/BertQA/static_bert_qa-0002.params -o $data_path/static_bert_qa-0002.params + curl https://s3.us-east-2.amazonaws.com/mxnet-scala/scala-example-ci/BertQA/static_bert_qa-symbol.json -o $data_path/static_bert_qa-symbol.json +fi diff --git a/scala-package/examples/scripts/infer/bert/run_bert_qa_example.sh b/scala-package/examples/scripts/infer/bert/run_bert_qa_example.sh new file mode 100755 index 000000000000..d8ba092c5c1b --- /dev/null +++ b/scala-package/examples/scripts/infer/bert/run_bert_qa_example.sh @@ -0,0 +1,27 @@ +#!/bin/bash + +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +set -e + +MXNET_ROOT=$(cd "$(dirname $0)/../../../../.."; pwd) + +CLASS_PATH=$MXNET_ROOT/scala-package/assembly/target/*:$MXNET_ROOT/scala-package/examples/target/*:$MXNET_ROOT/scala-package/examples/target/classes/lib/* + +java -Xmx8G -Dmxnet.traceLeakedObjects=true -cp $CLASS_PATH \ + org.apache.mxnetexamples.javaapi.infer.bert.BertQA $@ diff --git a/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/bert/BertQA.java b/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/bert/BertQA.java index aee134af1fb7..a6c84865e1b5 100644 --- a/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/bert/BertQA.java +++ b/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/bert/BertQA.java @@ -43,8 +43,7 @@ public class BertQA { @Option(name = "--seq-length", usage = "the maximum length of the sequence") private int seqLength = 384; - - final static Logger logger = LoggerFactory.getLogger(BertQA.class); + private final static Logger logger = LoggerFactory.getLogger(BertQA.class); private static NDArray$ NDArray = NDArray$.MODULE$; private static int argmax(float[] prob) { @@ -55,12 +54,21 @@ private static int argmax(float[] prob) { return maxIdx; } - static void postProcessing(NDArray result, List tokens) { - NDArray output = NDArray.split( - NDArray.new splitParam(result, 2).setAxis(2))[0]; + /** + * Do the post processing on the output, apply softmax to get the probabilities + * reshape and get the most probable index + * @param result prediction result + * @param tokens word tokens + * @return Answers clipped from the original paragraph + */ + static List postProcessing(NDArray result, List tokens) { + NDArray[] output = NDArray.split( + NDArray.new splitParam(result, 2).setAxis(2)); // Get the formatted logits result - NDArray startLogits = output.at(0).reshape(new int[]{0, -3}); - NDArray endLogits = output.at(1).reshape(new int[]{0, -3}); + NDArray startLogits = NDArray.reshape( + NDArray.new reshapeParam(output[0]).setShape(new Shape(new int[]{0, -3})))[0]; + NDArray endLogits = NDArray.reshape( + NDArray.new reshapeParam(output[1]).setShape(new Shape(new int[]{0, -3})))[0]; // Get Probability distribution float[] startProb = NDArray.softmax( NDArray.new softmaxParam(startLogits))[0].toArray(); @@ -68,8 +76,7 @@ static void postProcessing(NDArray result, List tokens) { NDArray.new softmaxParam(endLogits))[0].toArray(); int startIdx = argmax(startProb); int endIdx = argmax(endProb); - String[] answer = (String[]) tokens.subList(startIdx, endIdx + 1).toArray(); - logger.info("Answer: ", Arrays.toString(answer)); + return tokens.subList(startIdx, endIdx + 1); } public static void main(String[] args) throws Exception{ @@ -78,13 +85,15 @@ public static void main(String[] args) throws Exception{ parser.parseArgument(args); BertUtil util = new BertUtil(); Context context = Context.cpu(); - logger.info("Question: ", inst.inputQ); - logger.info("Answer paragraph: ", inst.inputA); + if (System.getenv().containsKey("SCALA_TEST_ON_GPU") && + Integer.valueOf(System.getenv("SCALA_TEST_ON_GPU")) == 1) { + context = Context.gpu(); + } // pre-processing - tokenize sentence List tokenQ = util.tokenizer(inst.inputQ.toLowerCase()); List tokenA = util.tokenizer(inst.inputA.toLowerCase()); int validLength = tokenQ.size() + tokenA.size(); - logger.info("Valid length: ", validLength); + logger.info("Valid length: " + validLength); // generate token types [0000...1111....0000] List QAEmbedded = new ArrayList<>(); util.pad(QAEmbedded, 0f, tokenQ.size()).addAll( @@ -97,7 +106,7 @@ public static void main(String[] args) throws Exception{ tokenA.add("[SEP]"); tokenQ.addAll(tokenA); List tokens = util.pad(tokenQ, "[PAD]", inst.seqLength); - logger.info("Pre-processed tokens: ", Arrays.toString(tokens.toArray())); + logger.info("Pre-processed tokens: " + Arrays.toString(tokenQ.toArray())); // pre-processing - token to index translation util.parseJSON(inst.modelVocab); List indexes = util.token2idx(tokens); @@ -129,6 +138,9 @@ public static void main(String[] args) throws Exception{ Predictor bertQA = new Predictor(inst.modelPathPrefix, inputDescs, contexts, inst.epoch); // Start prediction NDArray result = bertQA.predictWithNDArray(inputBatch).get(0); - postProcessing(result, tokens); + List answer = postProcessing(result, tokens); + logger.info("Question: " + inst.inputQ); + logger.info("Answer paragraph: " + inst.inputA); + logger.info("Answer: " + Arrays.toString(answer.toArray())); } } diff --git a/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/bert/BertUtil.java b/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/bert/BertUtil.java index 4196401086e1..820d14b03cb7 100644 --- a/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/bert/BertUtil.java +++ b/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/bert/BertUtil.java @@ -1,3 +1,20 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.apache.mxnetexamples.javaapi.infer.bert; import java.io.FileReader; @@ -13,29 +30,41 @@ public class BertUtil { private Map token2idx; private List idx2token; + /** + * Parse the Vocabulary to JSON files + * [PAD], [CLS], [SEP], [MASK], [UNK] are reserved token + * @param jsonFile the filePath of the vocab.json + * @throws Exception + */ void parseJSON(String jsonFile) throws Exception { Gson gson = new Gson(); token2idx = new HashMap<>(); idx2token = new LinkedList<>(); JsonObject jsonObject = gson.fromJson(new FileReader(jsonFile), JsonObject.class); - JsonArray arr = jsonObject.getAsJsonArray("token_to_idx"); + JsonArray arr = jsonObject.getAsJsonArray("idx_to_token"); for (JsonElement element : arr) { idx2token.add(element.getAsString()); } - JsonObject preMap = jsonObject.getAsJsonObject("idx_to_token"); + JsonObject preMap = jsonObject.getAsJsonObject("token_to_idx"); for (String key : preMap.keySet()) { - token2idx.put(key, jsonObject.get(key).getAsInt()); + token2idx.put(key, preMap.get(key).getAsInt()); } } + /** + * Tokenize the input, split all kinds of spaces and + * saparate the end of sentence symbol: . , ? ! + * @param input The input String + * @return List of tokens + */ List tokenizer(String input) { String[] step1 = input.split("[\n\r\t ]+"); List finalResult = new LinkedList<>(); for (String item : step1) { if (item.length() != 0) { - if (item.split("[.,?!]+").length > 1) { + if ((item + "a").split("[.,?!]+").length > 1) { finalResult.add(item.substring(0, item.length() - 1)); - finalResult.add(item.substring(item.length() -1, item.length())); + finalResult.add(item.substring(item.length() -1)); } else { finalResult.add(item); } @@ -44,15 +73,27 @@ List tokenizer(String input) { return finalResult; } + /** + * Pad the tokens to the required length + * @param tokens input tokens + * @param padItem things to pad at the end + * @param num total length after padding + * @return List of padded tokens + */ List pad(List tokens, E padItem, int num) { if (tokens.size() >= num) return tokens; List padded = new LinkedList<>(tokens); for (int i = 0; i < num - tokens.size(); i++) { - tokens.add(padItem); + padded.add(padItem); } return padded; } + /** + * Convert tokens to indexes + * @param tokens input tokens + * @return List of indexes + */ List token2idx(List tokens) { List indexes = new ArrayList<>(); for (String token : tokens) { @@ -65,10 +106,15 @@ List token2idx(List tokens) { return indexes; } - List idx2token(List indexes) { + /** + * Convert indexes to tokens + * @param indexes List of indexes + * @return List of tokens + */ + List idx2token(List indexes) { List tokens = new ArrayList<>(); - for (E index : indexes) { - tokens.add(idx2token.get((int) index)); + for (int index : indexes) { + tokens.add(idx2token.get(index)); } return tokens; } diff --git a/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/bert/README.md b/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/bert/README.md index 4fe4dfc9224d..49e1aa6d3b7d 100644 --- a/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/bert/README.md +++ b/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/bert/README.md @@ -15,4 +15,89 @@ -# BERT QA model using Java Inference API +# Run BERT QA model using Java Inference API + +In this tutorial, we will walk through the BERT QA model trained by MXNet. +You will be able to run inference with general Q & A task: + +```text +Q: When did BBC Japan start broadcasting? +``` + +The model are expected to find the right answer in the corresponding text: +```text +BBC Japan was a general entertainment Channel. Which operated between December 2004 and April 2006. +It ceased operations after its Japanese distributor folded. +``` +And it picked up the right one: +```text +A: December 2004 +``` + +## Setup Guide + +### Step 1: Download the model + +For this tutorial, you can get the model and vocabulary by running following bash file. This script will use `wget` to download these artifacts from AWS S3. + +From the `scala-package/examples/scripts/infer/bert/` folder run: + +```bash +./get_bert_data.sh +``` + +**Note**: You may need to run `chmod +x get_bert_data.sh` before running this script. + +### Step 2: Setup data path of the model + +### Setup Datapath and Parameters + +The available arguments are as follows: + +| Argument | Comments | +| ----------------------------- | ---------------------------------------- | +| `--model-path-prefix`           | Folder path with prefix to the model (including json, params). | +| `--model-vocab` | Vocabulary path | +| `--model-epoch` | Epoch number of the model | +| `--input-question` | Question that asked to the model | +| `--input-answer` | Paragraph that contains the answer | +| `--seq-length` | Sequence Length of the model (384 by default) | + +### Step 3: Run Inference +After the previous steps, you should be able to run the code using the following script that will pass all of the required parameters to the Infer API. + +From the `scala-package/examples/scripts/infer/bert/` folder run: + +```bash +./run_bert_qa_example.sh --model-path-prefix ../models/static-bert-qa/static_bert_qa \ + --model-vocab ../models/static-bert-qa/vocab.json \ + --model-epoch 2 +``` + +## Background + +To learn more about how BERT works in MXNet, please follow this [tutorial](https://medium.com/apache-mxnet/gluon-nlp-bert-6a489bdd3340). + +The model was extracted from the GluonNLP with static length settings. + +[Download link for the scrtipt](https://gluon-nlp.mxnet.io/_downloads/bert.zip) + +The original description can be found in [here](https://gluon-nlp.mxnet.io/model_zoo/bert/index.html#bert-base-on-squad-1-1). +```bash +python static_finetune_squad.py --optimizer adam --accumulate 2 --batch_size 6 --lr 3e-5 --epochs 2 --gpu 0 --export + +``` +This script would generate a `json` and `param` which are the standard MXNet model files. +By default, this model are using `bert_12_768_12` model with extra layers for QA jobs. + +After that, to be able to use it in Java, we need to export the dictionary from the script to parse the text +to actual indexes. Please add the following lines after [this line](https://github.com/dmlc/gluon-nlp/blob/master/scripts/bert/staticbert/static_finetune_squad.py#L262). +```python +import json +json_str = vocab.to_json() +f = open("vocab.json", "w") +f.write(json_str) +f.close() +``` +This would export a json file for you to deal with the vocabulary. +Once you have these three files, you will be able to run this example without problems. From 6e2559258e87e1a14093857dbbabcd8941738560 Mon Sep 17 00:00:00 2001 From: Qing Date: Mon, 1 Apr 2019 22:07:52 -0700 Subject: [PATCH 3/5] add integration test --- .../infer/predictor/BertExampleTest.java | 68 +++++++++++++++++++ 1 file changed, 68 insertions(+) create mode 100644 scala-package/examples/src/test/java/org/apache/mxnetexamples/javaapi/infer/predictor/BertExampleTest.java diff --git a/scala-package/examples/src/test/java/org/apache/mxnetexamples/javaapi/infer/predictor/BertExampleTest.java b/scala-package/examples/src/test/java/org/apache/mxnetexamples/javaapi/infer/predictor/BertExampleTest.java new file mode 100644 index 000000000000..cb46ea9cc0fd --- /dev/null +++ b/scala-package/examples/src/test/java/org/apache/mxnetexamples/javaapi/infer/predictor/BertExampleTest.java @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mxnetexamples.javaapi.infer.predictor; + +import org.apache.mxnetexamples.Util; +import org.apache.mxnetexamples.javaapi.infer.bert.BertQA; +import org.junit.BeforeClass; +import org.junit.Test; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.File; + +public class BertExampleTest { + final static Logger logger = LoggerFactory.getLogger(BertExampleTest.class); + private static String modelPathPrefix = ""; + private static String vocabPath = ""; + + @BeforeClass + public static void downloadFile() { + logger.info("Downloading Bert QA Model"); + String tempDirPath = System.getProperty("java.io.tmpdir"); + logger.info("tempDirPath: %s".format(tempDirPath)); + + String baseUrl = "https://s3.us-east-2.amazonaws.com/mxnet-scala/scala-example-ci/BertQA"; + Util.downloadUrl(baseUrl + "/static_bert_qa-symbol.json", + tempDirPath + "/static_bert_qa/static_bert_qa-symbol.json", 3); + Util.downloadUrl(baseUrl + "/static_bert_qa-0002.params", + tempDirPath + "/static_bert_qa/static_bert_qa-0002.params", 3); + Util.downloadUrl(baseUrl + "/vocab.json", + tempDirPath + "/static_bert_qa/vocab.json", 3); + modelPathPrefix = tempDirPath + File.separator + "static_bert_qa/static_bert_qa"; + vocabPath = tempDirPath + File.separator + "static_bert_qa/vocab.json"; + } + + @Test + public void testBertQA() throws Exception{ + BertQA bert = new BertQA(); + String Q = "When did BBC Japan start broadcasting?"; + String A = "BBC Japan was a general entertainment Channel.\n" + + " Which operated between December 2004 and April 2006.\n" + + "It ceased operations after its Japanese distributor folded."; + String[] args = new String[] { + "--model-path-prefix", modelPathPrefix, + "--model-vocab", vocabPath, + "--model-epoch", "2", + "--input-question", Q, + "--input-answer", A, + "--seq-length", "384" + }; + bert.main(args); + } +} From d053102a6548a9f834f4403881b166430c31c779 Mon Sep 17 00:00:00 2001 From: Qing Date: Wed, 3 Apr 2019 11:27:38 -0700 Subject: [PATCH 4/5] address comments --- .../org/apache/mxnet/javaapi/Layout.scala | 34 +++++++++++++++ .../scripts/infer/bert/get_bert_data.sh | 3 -- .../{BertUtil.java => BertDataParser.java} | 2 +- .../javaapi/infer/bert/BertQA.java | 41 +++++++++---------- .../javaapi/infer/bert/README.md | 4 +- 5 files changed, 55 insertions(+), 29 deletions(-) create mode 100644 scala-package/core/src/main/scala/org/apache/mxnet/javaapi/Layout.scala rename scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/bert/{BertUtil.java => BertDataParser.java} (99%) diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/Layout.scala b/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/Layout.scala new file mode 100644 index 000000000000..cfe290c1aff7 --- /dev/null +++ b/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/Layout.scala @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.mxnet.javaapi + +/** + * Layout definition of DataDesc + * N Batch size + * C channels + * H Height + * W Weight + * T sequence length + * __undefined__ default value of Layout + */ +object Layout { + val UNDEFINED: String = org.apache.mxnet.Layout.UNDEFINED + val NCHW: String = org.apache.mxnet.Layout.NCHW + val NTC: String = org.apache.mxnet.Layout.NTC + val NT: String = org.apache.mxnet.Layout.NT + val N: String = org.apache.mxnet.Layout.N +} diff --git a/scala-package/examples/scripts/infer/bert/get_bert_data.sh b/scala-package/examples/scripts/infer/bert/get_bert_data.sh index 14cc78c4dce7..609aae27cc66 100755 --- a/scala-package/examples/scripts/infer/bert/get_bert_data.sh +++ b/scala-package/examples/scripts/infer/bert/get_bert_data.sh @@ -25,9 +25,6 @@ data_path=$MXNET_ROOT/scripts/infer/models/static-bert-qa/ if [ ! -d "$data_path" ]; then mkdir -p "$data_path" -fi - -if [ ! -f "$data_path" ]; then curl https://s3.us-east-2.amazonaws.com/mxnet-scala/scala-example-ci/BertQA/vocab.json -o $data_path/vocab.json curl https://s3.us-east-2.amazonaws.com/mxnet-scala/scala-example-ci/BertQA/static_bert_qa-0002.params -o $data_path/static_bert_qa-0002.params curl https://s3.us-east-2.amazonaws.com/mxnet-scala/scala-example-ci/BertQA/static_bert_qa-symbol.json -o $data_path/static_bert_qa-symbol.json diff --git a/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/bert/BertUtil.java b/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/bert/BertDataParser.java similarity index 99% rename from scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/bert/BertUtil.java rename to scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/bert/BertDataParser.java index 820d14b03cb7..86354477fcbb 100644 --- a/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/bert/BertUtil.java +++ b/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/bert/BertDataParser.java @@ -25,7 +25,7 @@ import com.google.gson.JsonElement; import com.google.gson.JsonObject; -public class BertUtil { +public class BertDataParser { private Map token2idx; private List idx2token; diff --git a/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/bert/BertQA.java b/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/bert/BertQA.java index a6c84865e1b5..3254faeb08d1 100644 --- a/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/bert/BertQA.java +++ b/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/bert/BertQA.java @@ -65,10 +65,8 @@ static List postProcessing(NDArray result, List tokens) { NDArray[] output = NDArray.split( NDArray.new splitParam(result, 2).setAxis(2)); // Get the formatted logits result - NDArray startLogits = NDArray.reshape( - NDArray.new reshapeParam(output[0]).setShape(new Shape(new int[]{0, -3})))[0]; - NDArray endLogits = NDArray.reshape( - NDArray.new reshapeParam(output[1]).setShape(new Shape(new int[]{0, -3})))[0]; + NDArray startLogits = output[0].reshape(new int[]{0, -3}); + NDArray endLogits = output[1].reshape(new int[]{0, -3}); // Get Probability distribution float[] startProb = NDArray.softmax( NDArray.new softmaxParam(startLogits))[0].toArray(); @@ -83,7 +81,7 @@ public static void main(String[] args) throws Exception{ BertQA inst = new BertQA(); CmdLineParser parser = new CmdLineParser(inst); parser.parseArgument(args); - BertUtil util = new BertUtil(); + BertDataParser util = new BertDataParser(); Context context = Context.cpu(); if (System.getenv().containsKey("SCALA_TEST_ON_GPU") && Integer.valueOf(System.getenv("SCALA_TEST_ON_GPU")) == 1) { @@ -115,26 +113,25 @@ public static void main(String[] args) throws Exception{ indexesFloat.add((float) integer); } // Preparing the input data - NDArray inputs = new NDArray(indexesFloat, - new Shape(new int[]{1, inst.seqLength}), context); - NDArray tokenTypesND = new NDArray(tokenTypes, - new Shape(new int[]{1, inst.seqLength}), context); - NDArray validLengthND = new NDArray(new float[] {(float) validLength}, - new Shape(new int[]{1}), context); - List inputBatch = new ArrayList<>(); - inputBatch.add(inputs); - inputBatch.add(tokenTypesND); - inputBatch.add(validLengthND); + List inputBatch = Arrays.asList( + new NDArray(indexesFloat, + new Shape(new int[]{1, inst.seqLength}), context), + new NDArray(tokenTypes, + new Shape(new int[]{1, inst.seqLength}), context), + new NDArray(new float[] { validLength }, + new Shape(new int[]{1}), context) + ); // Build the model - List inputDescs = new ArrayList<>(); List contexts = new ArrayList<>(); contexts.add(context); - inputDescs.add(new DataDesc("data0", - new Shape(new int[]{1, inst.seqLength}), DType.Float32(), "NT")); - inputDescs.add(new DataDesc("data1", - new Shape(new int[]{1, inst.seqLength}), DType.Float32(), "NT")); - inputDescs.add(new DataDesc("data2", - new Shape(new int[]{1}), DType.Float32(), "N")); + List inputDescs = Arrays.asList( + new DataDesc("data0", + new Shape(new int[]{1, inst.seqLength}), DType.Float32(), Layout.NT()), + new DataDesc("data1", + new Shape(new int[]{1, inst.seqLength}), DType.Float32(), Layout.NT()), + new DataDesc("data2", + new Shape(new int[]{1}), DType.Float32(), Layout.N()) + ); Predictor bertQA = new Predictor(inst.modelPathPrefix, inputDescs, contexts, inst.epoch); // Start prediction NDArray result = bertQA.predictWithNDArray(inputBatch).get(0); diff --git a/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/bert/README.md b/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/bert/README.md index 49e1aa6d3b7d..4e65406e8541 100644 --- a/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/bert/README.md +++ b/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/bert/README.md @@ -46,8 +46,6 @@ From the `scala-package/examples/scripts/infer/bert/` folder run: ./get_bert_data.sh ``` -**Note**: You may need to run `chmod +x get_bert_data.sh` before running this script. - ### Step 2: Setup data path of the model ### Setup Datapath and Parameters @@ -80,7 +78,7 @@ To learn more about how BERT works in MXNet, please follow this [tutorial](https The model was extracted from the GluonNLP with static length settings. -[Download link for the scrtipt](https://gluon-nlp.mxnet.io/_downloads/bert.zip) +[Download link for the script](https://gluon-nlp.mxnet.io/_downloads/bert.zip) The original description can be found in [here](https://gluon-nlp.mxnet.io/model_zoo/bert/index.html#bert-base-on-squad-1-1). ```bash From 23797c1fe62c1d54ddcfc59bdc093dd9f8c8e019 Mon Sep 17 00:00:00 2001 From: Qing Date: Wed, 3 Apr 2019 14:24:42 -0700 Subject: [PATCH 5/5] address doc comments --- .../javaapi/infer/bert/BertDataParser.java | 15 +++++++++----- .../javaapi/infer/bert/BertQA.java | 5 +++++ .../javaapi/infer/bert/README.md | 20 ++++++++++--------- .../infer/predictor/BertExampleTest.java | 3 +++ 4 files changed, 29 insertions(+), 14 deletions(-) diff --git a/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/bert/BertDataParser.java b/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/bert/BertDataParser.java index 86354477fcbb..440670afc098 100644 --- a/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/bert/BertDataParser.java +++ b/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/bert/BertDataParser.java @@ -25,6 +25,11 @@ import com.google.gson.JsonElement; import com.google.gson.JsonObject; +/** + * This is the Utility for pre-processing the data for Bert Model + * You can use this utility to parse Vocabulary JSON into Java Array and Dictionary, + * clean and tokenize sentences and pad the text + */ public class BertDataParser { private Map token2idx; @@ -32,7 +37,7 @@ public class BertDataParser { /** * Parse the Vocabulary to JSON files - * [PAD], [CLS], [SEP], [MASK], [UNK] are reserved token + * [PAD], [CLS], [SEP], [MASK], [UNK] are reserved tokens * @param jsonFile the filePath of the vocab.json * @throws Exception */ @@ -52,13 +57,13 @@ void parseJSON(String jsonFile) throws Exception { } /** - * Tokenize the input, split all kinds of spaces and - * saparate the end of sentence symbol: . , ? ! - * @param input The input String + * Tokenize the input, split all kinds of whitespace and + * Separate the end of sentence symbol: . , ? ! + * @param input The input string * @return List of tokens */ List tokenizer(String input) { - String[] step1 = input.split("[\n\r\t ]+"); + String[] step1 = input.split("\\s+"); List finalResult = new LinkedList<>(); for (String item : step1) { if (item.length() != 0) { diff --git a/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/bert/BertQA.java b/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/bert/BertQA.java index 3254faeb08d1..b40a4e94afbd 100644 --- a/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/bert/BertQA.java +++ b/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/bert/BertQA.java @@ -26,6 +26,11 @@ import java.util.*; +/** + * This is an example of using BERT to do the general Question and Answer inference jobs + * Users can provide a question with a paragraph contains answer to the model and + * the model will be able to find the best answer from the answer paragraph + */ public class BertQA { @Option(name = "--model-path-prefix", usage = "input model directory and prefix of the model") private String modelPathPrefix = "/model/static_bert_qa"; diff --git a/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/bert/README.md b/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/bert/README.md index 4e65406e8541..7925a259f48f 100644 --- a/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/bert/README.md +++ b/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/bert/README.md @@ -17,16 +17,18 @@ # Run BERT QA model using Java Inference API -In this tutorial, we will walk through the BERT QA model trained by MXNet. -You will be able to run inference with general Q & A task: +In this tutorial, we will walk through the BERT QA model trained by MXNet. +Users can provide a question with a paragraph contains answer to the model and +the model will be able to find the best answer from the answer paragraph. +Example: ```text Q: When did BBC Japan start broadcasting? ``` -The model are expected to find the right answer in the corresponding text: +Answer paragraph ```text -BBC Japan was a general entertainment Channel. Which operated between December 2004 and April 2006. +BBC Japan was a general entertainment channel, which operated between December 2004 and April 2006. It ceased operations after its Japanese distributor folded. ``` And it picked up the right one: @@ -74,18 +76,18 @@ From the `scala-package/examples/scripts/infer/bert/` folder run: ## Background -To learn more about how BERT works in MXNet, please follow this [tutorial](https://medium.com/apache-mxnet/gluon-nlp-bert-6a489bdd3340). +To learn more about how BERT works in MXNet, please follow this [MXNet Gluon tutorial on NLP using BERT](https://medium.com/apache-mxnet/gluon-nlp-bert-6a489bdd3340). -The model was extracted from the GluonNLP with static length settings. +The model was extracted from MXNet GluonNLP with static length settings. [Download link for the script](https://gluon-nlp.mxnet.io/_downloads/bert.zip) -The original description can be found in [here](https://gluon-nlp.mxnet.io/model_zoo/bert/index.html#bert-base-on-squad-1-1). +The original description can be found in the [MXNet GluonNLP model zoo](https://gluon-nlp.mxnet.io/model_zoo/bert/index.html#bert-base-on-squad-1-1). ```bash python static_finetune_squad.py --optimizer adam --accumulate 2 --batch_size 6 --lr 3e-5 --epochs 2 --gpu 0 --export ``` -This script would generate a `json` and `param` which are the standard MXNet model files. +This script will generate `json` and `param` fles that are the standard MXNet model files. By default, this model are using `bert_12_768_12` model with extra layers for QA jobs. After that, to be able to use it in Java, we need to export the dictionary from the script to parse the text @@ -97,5 +99,5 @@ f = open("vocab.json", "w") f.write(json_str) f.close() ``` -This would export a json file for you to deal with the vocabulary. +This would export the token vocabulary in json format. Once you have these three files, you will be able to run this example without problems. diff --git a/scala-package/examples/src/test/java/org/apache/mxnetexamples/javaapi/infer/predictor/BertExampleTest.java b/scala-package/examples/src/test/java/org/apache/mxnetexamples/javaapi/infer/predictor/BertExampleTest.java index cb46ea9cc0fd..0518254c297d 100644 --- a/scala-package/examples/src/test/java/org/apache/mxnetexamples/javaapi/infer/predictor/BertExampleTest.java +++ b/scala-package/examples/src/test/java/org/apache/mxnetexamples/javaapi/infer/predictor/BertExampleTest.java @@ -26,6 +26,9 @@ import java.io.File; +/** + * Test on BERT QA model + */ public class BertExampleTest { final static Logger logger = LoggerFactory.getLogger(BertExampleTest.class); private static String modelPathPrefix = "";