From 6b800e4e2aab3ec0af43fc7f8362d7795f903ff2 Mon Sep 17 00:00:00 2001 From: Lanking Date: Fri, 5 Apr 2019 09:26:25 -0700 Subject: [PATCH] Add BERT QA Scala/Java example (#14592) * add BertQA major code piece * add scripts and bug fixes * add integration test * address comments * address doc comments --- .../org/apache/mxnet/javaapi/Layout.scala | 34 ++++ scala-package/examples/pom.xml | 5 + .../scripts/infer/bert/get_bert_data.sh | 31 ++++ .../scripts/infer/bert/run_bert_qa_example.sh | 27 ++++ .../javaapi/infer/bert/BertDataParser.java | 126 +++++++++++++++ .../javaapi/infer/bert/BertQA.java | 148 ++++++++++++++++++ .../javaapi/infer/bert/README.md | 103 ++++++++++++ .../infer/predictor/BertExampleTest.java | 71 +++++++++ 8 files changed, 545 insertions(+) create mode 100644 scala-package/core/src/main/scala/org/apache/mxnet/javaapi/Layout.scala create mode 100755 scala-package/examples/scripts/infer/bert/get_bert_data.sh create mode 100755 scala-package/examples/scripts/infer/bert/run_bert_qa_example.sh create mode 100644 scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/bert/BertDataParser.java create mode 100644 scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/bert/BertQA.java create mode 100644 scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/bert/README.md create mode 100644 scala-package/examples/src/test/java/org/apache/mxnetexamples/javaapi/infer/predictor/BertExampleTest.java diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/Layout.scala b/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/Layout.scala new file mode 100644 index 000000000000..cfe290c1aff7 --- /dev/null +++ b/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/Layout.scala @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.mxnet.javaapi + +/** + * Layout definition of DataDesc + * N Batch size + * C channels + * H Height + * W Weight + * T sequence length + * __undefined__ default value of Layout + */ +object Layout { + val UNDEFINED: String = org.apache.mxnet.Layout.UNDEFINED + val NCHW: String = org.apache.mxnet.Layout.NCHW + val NTC: String = org.apache.mxnet.Layout.NTC + val NT: String = org.apache.mxnet.Layout.NT + val N: String = org.apache.mxnet.Layout.N +} diff --git a/scala-package/examples/pom.xml b/scala-package/examples/pom.xml index 07e430175c6f..d60782ffd06b 100644 --- a/scala-package/examples/pom.xml +++ b/scala-package/examples/pom.xml @@ -145,5 +145,10 @@ slf4j-simple 1.7.5 + + com.google.code.gson + gson + 2.8.5 + diff --git a/scala-package/examples/scripts/infer/bert/get_bert_data.sh b/scala-package/examples/scripts/infer/bert/get_bert_data.sh new file mode 100755 index 000000000000..609aae27cc66 --- /dev/null +++ b/scala-package/examples/scripts/infer/bert/get_bert_data.sh @@ -0,0 +1,31 @@ +#!/bin/bash + +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +set -e + +MXNET_ROOT=$(cd "$(dirname $0)/../../.."; pwd) + +data_path=$MXNET_ROOT/scripts/infer/models/static-bert-qa/ + +if [ ! -d "$data_path" ]; then + mkdir -p "$data_path" + curl https://s3.us-east-2.amazonaws.com/mxnet-scala/scala-example-ci/BertQA/vocab.json -o $data_path/vocab.json + curl https://s3.us-east-2.amazonaws.com/mxnet-scala/scala-example-ci/BertQA/static_bert_qa-0002.params -o $data_path/static_bert_qa-0002.params + curl https://s3.us-east-2.amazonaws.com/mxnet-scala/scala-example-ci/BertQA/static_bert_qa-symbol.json -o $data_path/static_bert_qa-symbol.json +fi diff --git a/scala-package/examples/scripts/infer/bert/run_bert_qa_example.sh b/scala-package/examples/scripts/infer/bert/run_bert_qa_example.sh new file mode 100755 index 000000000000..d8ba092c5c1b --- /dev/null +++ b/scala-package/examples/scripts/infer/bert/run_bert_qa_example.sh @@ -0,0 +1,27 @@ +#!/bin/bash + +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +set -e + +MXNET_ROOT=$(cd "$(dirname $0)/../../../../.."; pwd) + +CLASS_PATH=$MXNET_ROOT/scala-package/assembly/target/*:$MXNET_ROOT/scala-package/examples/target/*:$MXNET_ROOT/scala-package/examples/target/classes/lib/* + +java -Xmx8G -Dmxnet.traceLeakedObjects=true -cp $CLASS_PATH \ + org.apache.mxnetexamples.javaapi.infer.bert.BertQA $@ diff --git a/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/bert/BertDataParser.java b/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/bert/BertDataParser.java new file mode 100644 index 000000000000..440670afc098 --- /dev/null +++ b/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/bert/BertDataParser.java @@ -0,0 +1,126 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mxnetexamples.javaapi.infer.bert; + +import java.io.FileReader; +import java.util.*; + +import com.google.gson.Gson; +import com.google.gson.JsonArray; +import com.google.gson.JsonElement; +import com.google.gson.JsonObject; + +/** + * This is the Utility for pre-processing the data for Bert Model + * You can use this utility to parse Vocabulary JSON into Java Array and Dictionary, + * clean and tokenize sentences and pad the text + */ +public class BertDataParser { + + private Map token2idx; + private List idx2token; + + /** + * Parse the Vocabulary to JSON files + * [PAD], [CLS], [SEP], [MASK], [UNK] are reserved tokens + * @param jsonFile the filePath of the vocab.json + * @throws Exception + */ + void parseJSON(String jsonFile) throws Exception { + Gson gson = new Gson(); + token2idx = new HashMap<>(); + idx2token = new LinkedList<>(); + JsonObject jsonObject = gson.fromJson(new FileReader(jsonFile), JsonObject.class); + JsonArray arr = jsonObject.getAsJsonArray("idx_to_token"); + for (JsonElement element : arr) { + idx2token.add(element.getAsString()); + } + JsonObject preMap = jsonObject.getAsJsonObject("token_to_idx"); + for (String key : preMap.keySet()) { + token2idx.put(key, preMap.get(key).getAsInt()); + } + } + + /** + * Tokenize the input, split all kinds of whitespace and + * Separate the end of sentence symbol: . , ? ! + * @param input The input string + * @return List of tokens + */ + List tokenizer(String input) { + String[] step1 = input.split("\\s+"); + List finalResult = new LinkedList<>(); + for (String item : step1) { + if (item.length() != 0) { + if ((item + "a").split("[.,?!]+").length > 1) { + finalResult.add(item.substring(0, item.length() - 1)); + finalResult.add(item.substring(item.length() -1)); + } else { + finalResult.add(item); + } + } + } + return finalResult; + } + + /** + * Pad the tokens to the required length + * @param tokens input tokens + * @param padItem things to pad at the end + * @param num total length after padding + * @return List of padded tokens + */ + List pad(List tokens, E padItem, int num) { + if (tokens.size() >= num) return tokens; + List padded = new LinkedList<>(tokens); + for (int i = 0; i < num - tokens.size(); i++) { + padded.add(padItem); + } + return padded; + } + + /** + * Convert tokens to indexes + * @param tokens input tokens + * @return List of indexes + */ + List token2idx(List tokens) { + List indexes = new ArrayList<>(); + for (String token : tokens) { + if (token2idx.containsKey(token)) { + indexes.add(token2idx.get(token)); + } else { + indexes.add(token2idx.get("[UNK]")); + } + } + return indexes; + } + + /** + * Convert indexes to tokens + * @param indexes List of indexes + * @return List of tokens + */ + List idx2token(List indexes) { + List tokens = new ArrayList<>(); + for (int index : indexes) { + tokens.add(idx2token.get(index)); + } + return tokens; + } +} diff --git a/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/bert/BertQA.java b/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/bert/BertQA.java new file mode 100644 index 000000000000..b40a4e94afbd --- /dev/null +++ b/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/bert/BertQA.java @@ -0,0 +1,148 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mxnetexamples.javaapi.infer.bert; + +import org.apache.mxnet.infer.javaapi.Predictor; +import org.apache.mxnet.javaapi.*; +import org.kohsuke.args4j.CmdLineParser; +import org.kohsuke.args4j.Option; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.*; + +/** + * This is an example of using BERT to do the general Question and Answer inference jobs + * Users can provide a question with a paragraph contains answer to the model and + * the model will be able to find the best answer from the answer paragraph + */ +public class BertQA { + @Option(name = "--model-path-prefix", usage = "input model directory and prefix of the model") + private String modelPathPrefix = "/model/static_bert_qa"; + @Option(name = "--model-epoch", usage = "Epoch number of the model") + private int epoch = 2; + @Option(name = "--model-vocab", usage = "the vocabulary used in the model") + private String modelVocab = "/model/vocab.json"; + @Option(name = "--input-question", usage = "the input question") + private String inputQ = "When did BBC Japan start broadcasting?"; + @Option(name = "--input-answer", usage = "the input answer") + private String inputA = + "BBC Japan was a general entertainment Channel.\n" + + " Which operated between December 2004 and April 2006.\n" + + "It ceased operations after its Japanese distributor folded."; + @Option(name = "--seq-length", usage = "the maximum length of the sequence") + private int seqLength = 384; + + private final static Logger logger = LoggerFactory.getLogger(BertQA.class); + private static NDArray$ NDArray = NDArray$.MODULE$; + + private static int argmax(float[] prob) { + int maxIdx = 0; + for (int i = 0; i < prob.length; i++) { + if (prob[maxIdx] < prob[i]) maxIdx = i; + } + return maxIdx; + } + + /** + * Do the post processing on the output, apply softmax to get the probabilities + * reshape and get the most probable index + * @param result prediction result + * @param tokens word tokens + * @return Answers clipped from the original paragraph + */ + static List postProcessing(NDArray result, List tokens) { + NDArray[] output = NDArray.split( + NDArray.new splitParam(result, 2).setAxis(2)); + // Get the formatted logits result + NDArray startLogits = output[0].reshape(new int[]{0, -3}); + NDArray endLogits = output[1].reshape(new int[]{0, -3}); + // Get Probability distribution + float[] startProb = NDArray.softmax( + NDArray.new softmaxParam(startLogits))[0].toArray(); + float[] endProb = NDArray.softmax( + NDArray.new softmaxParam(endLogits))[0].toArray(); + int startIdx = argmax(startProb); + int endIdx = argmax(endProb); + return tokens.subList(startIdx, endIdx + 1); + } + + public static void main(String[] args) throws Exception{ + BertQA inst = new BertQA(); + CmdLineParser parser = new CmdLineParser(inst); + parser.parseArgument(args); + BertDataParser util = new BertDataParser(); + Context context = Context.cpu(); + if (System.getenv().containsKey("SCALA_TEST_ON_GPU") && + Integer.valueOf(System.getenv("SCALA_TEST_ON_GPU")) == 1) { + context = Context.gpu(); + } + // pre-processing - tokenize sentence + List tokenQ = util.tokenizer(inst.inputQ.toLowerCase()); + List tokenA = util.tokenizer(inst.inputA.toLowerCase()); + int validLength = tokenQ.size() + tokenA.size(); + logger.info("Valid length: " + validLength); + // generate token types [0000...1111....0000] + List QAEmbedded = new ArrayList<>(); + util.pad(QAEmbedded, 0f, tokenQ.size()).addAll( + util.pad(new ArrayList(), 1f, tokenA.size()) + ); + List tokenTypes = util.pad(QAEmbedded, 0f, inst.seqLength); + // make BERT pre-processing standard + tokenQ.add("[SEP]"); + tokenQ.add(0, "[CLS]"); + tokenA.add("[SEP]"); + tokenQ.addAll(tokenA); + List tokens = util.pad(tokenQ, "[PAD]", inst.seqLength); + logger.info("Pre-processed tokens: " + Arrays.toString(tokenQ.toArray())); + // pre-processing - token to index translation + util.parseJSON(inst.modelVocab); + List indexes = util.token2idx(tokens); + List indexesFloat = new ArrayList<>(); + for (int integer : indexes) { + indexesFloat.add((float) integer); + } + // Preparing the input data + List inputBatch = Arrays.asList( + new NDArray(indexesFloat, + new Shape(new int[]{1, inst.seqLength}), context), + new NDArray(tokenTypes, + new Shape(new int[]{1, inst.seqLength}), context), + new NDArray(new float[] { validLength }, + new Shape(new int[]{1}), context) + ); + // Build the model + List contexts = new ArrayList<>(); + contexts.add(context); + List inputDescs = Arrays.asList( + new DataDesc("data0", + new Shape(new int[]{1, inst.seqLength}), DType.Float32(), Layout.NT()), + new DataDesc("data1", + new Shape(new int[]{1, inst.seqLength}), DType.Float32(), Layout.NT()), + new DataDesc("data2", + new Shape(new int[]{1}), DType.Float32(), Layout.N()) + ); + Predictor bertQA = new Predictor(inst.modelPathPrefix, inputDescs, contexts, inst.epoch); + // Start prediction + NDArray result = bertQA.predictWithNDArray(inputBatch).get(0); + List answer = postProcessing(result, tokens); + logger.info("Question: " + inst.inputQ); + logger.info("Answer paragraph: " + inst.inputA); + logger.info("Answer: " + Arrays.toString(answer.toArray())); + } +} diff --git a/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/bert/README.md b/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/bert/README.md new file mode 100644 index 000000000000..7925a259f48f --- /dev/null +++ b/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/bert/README.md @@ -0,0 +1,103 @@ + + + + + + + + + + + + + + + + + +# Run BERT QA model using Java Inference API + +In this tutorial, we will walk through the BERT QA model trained by MXNet. +Users can provide a question with a paragraph contains answer to the model and +the model will be able to find the best answer from the answer paragraph. + +Example: +```text +Q: When did BBC Japan start broadcasting? +``` + +Answer paragraph +```text +BBC Japan was a general entertainment channel, which operated between December 2004 and April 2006. +It ceased operations after its Japanese distributor folded. +``` +And it picked up the right one: +```text +A: December 2004 +``` + +## Setup Guide + +### Step 1: Download the model + +For this tutorial, you can get the model and vocabulary by running following bash file. This script will use `wget` to download these artifacts from AWS S3. + +From the `scala-package/examples/scripts/infer/bert/` folder run: + +```bash +./get_bert_data.sh +``` + +### Step 2: Setup data path of the model + +### Setup Datapath and Parameters + +The available arguments are as follows: + +| Argument | Comments | +| ----------------------------- | ---------------------------------------- | +| `--model-path-prefix`           | Folder path with prefix to the model (including json, params). | +| `--model-vocab` | Vocabulary path | +| `--model-epoch` | Epoch number of the model | +| `--input-question` | Question that asked to the model | +| `--input-answer` | Paragraph that contains the answer | +| `--seq-length` | Sequence Length of the model (384 by default) | + +### Step 3: Run Inference +After the previous steps, you should be able to run the code using the following script that will pass all of the required parameters to the Infer API. + +From the `scala-package/examples/scripts/infer/bert/` folder run: + +```bash +./run_bert_qa_example.sh --model-path-prefix ../models/static-bert-qa/static_bert_qa \ + --model-vocab ../models/static-bert-qa/vocab.json \ + --model-epoch 2 +``` + +## Background + +To learn more about how BERT works in MXNet, please follow this [MXNet Gluon tutorial on NLP using BERT](https://medium.com/apache-mxnet/gluon-nlp-bert-6a489bdd3340). + +The model was extracted from MXNet GluonNLP with static length settings. + +[Download link for the script](https://gluon-nlp.mxnet.io/_downloads/bert.zip) + +The original description can be found in the [MXNet GluonNLP model zoo](https://gluon-nlp.mxnet.io/model_zoo/bert/index.html#bert-base-on-squad-1-1). +```bash +python static_finetune_squad.py --optimizer adam --accumulate 2 --batch_size 6 --lr 3e-5 --epochs 2 --gpu 0 --export + +``` +This script will generate `json` and `param` fles that are the standard MXNet model files. +By default, this model are using `bert_12_768_12` model with extra layers for QA jobs. + +After that, to be able to use it in Java, we need to export the dictionary from the script to parse the text +to actual indexes. Please add the following lines after [this line](https://github.com/dmlc/gluon-nlp/blob/master/scripts/bert/staticbert/static_finetune_squad.py#L262). +```python +import json +json_str = vocab.to_json() +f = open("vocab.json", "w") +f.write(json_str) +f.close() +``` +This would export the token vocabulary in json format. +Once you have these three files, you will be able to run this example without problems. diff --git a/scala-package/examples/src/test/java/org/apache/mxnetexamples/javaapi/infer/predictor/BertExampleTest.java b/scala-package/examples/src/test/java/org/apache/mxnetexamples/javaapi/infer/predictor/BertExampleTest.java new file mode 100644 index 000000000000..0518254c297d --- /dev/null +++ b/scala-package/examples/src/test/java/org/apache/mxnetexamples/javaapi/infer/predictor/BertExampleTest.java @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mxnetexamples.javaapi.infer.predictor; + +import org.apache.mxnetexamples.Util; +import org.apache.mxnetexamples.javaapi.infer.bert.BertQA; +import org.junit.BeforeClass; +import org.junit.Test; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.File; + +/** + * Test on BERT QA model + */ +public class BertExampleTest { + final static Logger logger = LoggerFactory.getLogger(BertExampleTest.class); + private static String modelPathPrefix = ""; + private static String vocabPath = ""; + + @BeforeClass + public static void downloadFile() { + logger.info("Downloading Bert QA Model"); + String tempDirPath = System.getProperty("java.io.tmpdir"); + logger.info("tempDirPath: %s".format(tempDirPath)); + + String baseUrl = "https://s3.us-east-2.amazonaws.com/mxnet-scala/scala-example-ci/BertQA"; + Util.downloadUrl(baseUrl + "/static_bert_qa-symbol.json", + tempDirPath + "/static_bert_qa/static_bert_qa-symbol.json", 3); + Util.downloadUrl(baseUrl + "/static_bert_qa-0002.params", + tempDirPath + "/static_bert_qa/static_bert_qa-0002.params", 3); + Util.downloadUrl(baseUrl + "/vocab.json", + tempDirPath + "/static_bert_qa/vocab.json", 3); + modelPathPrefix = tempDirPath + File.separator + "static_bert_qa/static_bert_qa"; + vocabPath = tempDirPath + File.separator + "static_bert_qa/vocab.json"; + } + + @Test + public void testBertQA() throws Exception{ + BertQA bert = new BertQA(); + String Q = "When did BBC Japan start broadcasting?"; + String A = "BBC Japan was a general entertainment Channel.\n" + + " Which operated between December 2004 and April 2006.\n" + + "It ceased operations after its Japanese distributor folded."; + String[] args = new String[] { + "--model-path-prefix", modelPathPrefix, + "--model-vocab", vocabPath, + "--model-epoch", "2", + "--input-question", Q, + "--input-answer", A, + "--seq-length", "384" + }; + bert.main(args); + } +}