Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Add BERT QA Scala/Java example #14592

Merged
merged 5 commits into from
Apr 5, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.mxnet.javaapi

/**
* Layout definition of DataDesc
* N Batch size
* C channels
* H Height
* W Weight
* T sequence length
* __undefined__ default value of Layout
*/
object Layout {
val UNDEFINED: String = org.apache.mxnet.Layout.UNDEFINED
val NCHW: String = org.apache.mxnet.Layout.NCHW
val NTC: String = org.apache.mxnet.Layout.NTC
val NT: String = org.apache.mxnet.Layout.NT
val N: String = org.apache.mxnet.Layout.N
}
5 changes: 5 additions & 0 deletions scala-package/examples/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -145,5 +145,10 @@
<artifactId>slf4j-simple</artifactId>
<version>1.7.5</version>
</dependency>
<dependency>
<groupId>com.google.code.gson</groupId>
<artifactId>gson</artifactId>
<version>2.8.5</version>
</dependency>
</dependencies>
</project>
31 changes: 31 additions & 0 deletions scala-package/examples/scripts/infer/bert/get_bert_data.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
#!/bin/bash

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

set -e

MXNET_ROOT=$(cd "$(dirname $0)/../../.."; pwd)

data_path=$MXNET_ROOT/scripts/infer/models/static-bert-qa/

if [ ! -d "$data_path" ]; then
mkdir -p "$data_path"
lanking520 marked this conversation as resolved.
Show resolved Hide resolved
curl https://s3.us-east-2.amazonaws.com/mxnet-scala/scala-example-ci/BertQA/vocab.json -o $data_path/vocab.json
curl https://s3.us-east-2.amazonaws.com/mxnet-scala/scala-example-ci/BertQA/static_bert_qa-0002.params -o $data_path/static_bert_qa-0002.params
curl https://s3.us-east-2.amazonaws.com/mxnet-scala/scala-example-ci/BertQA/static_bert_qa-symbol.json -o $data_path/static_bert_qa-symbol.json
fi
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
#!/bin/bash

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

set -e

MXNET_ROOT=$(cd "$(dirname $0)/../../../../.."; pwd)

CLASS_PATH=$MXNET_ROOT/scala-package/assembly/target/*:$MXNET_ROOT/scala-package/examples/target/*:$MXNET_ROOT/scala-package/examples/target/classes/lib/*

java -Xmx8G -Dmxnet.traceLeakedObjects=true -cp $CLASS_PATH \
org.apache.mxnetexamples.javaapi.infer.bert.BertQA $@
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

lanking520 marked this conversation as resolved.
Show resolved Hide resolved
package org.apache.mxnetexamples.javaapi.infer.bert;

import java.io.FileReader;
import java.util.*;

import com.google.gson.Gson;
import com.google.gson.JsonArray;
import com.google.gson.JsonElement;
import com.google.gson.JsonObject;

/**
* This is the Utility for pre-processing the data for Bert Model
* You can use this utility to parse Vocabulary JSON into Java Array and Dictionary,
* clean and tokenize sentences and pad the text
*/
public class BertDataParser {

private Map<String, Integer> token2idx;
private List<String> idx2token;

/**
* Parse the Vocabulary to JSON files
* [PAD], [CLS], [SEP], [MASK], [UNK] are reserved tokens
* @param jsonFile the filePath of the vocab.json
* @throws Exception
*/
void parseJSON(String jsonFile) throws Exception {
Gson gson = new Gson();
token2idx = new HashMap<>();
idx2token = new LinkedList<>();
JsonObject jsonObject = gson.fromJson(new FileReader(jsonFile), JsonObject.class);
JsonArray arr = jsonObject.getAsJsonArray("idx_to_token");
for (JsonElement element : arr) {
idx2token.add(element.getAsString());
}
JsonObject preMap = jsonObject.getAsJsonObject("token_to_idx");
for (String key : preMap.keySet()) {
token2idx.put(key, preMap.get(key).getAsInt());
}
}

/**
* Tokenize the input, split all kinds of whitespace and
* Separate the end of sentence symbol: . , ? !
* @param input The input string
* @return List of tokens
*/
List<String> tokenizer(String input) {
String[] step1 = input.split("\\s+");
List<String> finalResult = new LinkedList<>();
for (String item : step1) {
if (item.length() != 0) {
if ((item + "a").split("[.,?!]+").length > 1) {
finalResult.add(item.substring(0, item.length() - 1));
finalResult.add(item.substring(item.length() -1));
} else {
finalResult.add(item);
}
}
}
return finalResult;
}

/**
* Pad the tokens to the required length
* @param tokens input tokens
* @param padItem things to pad at the end
* @param num total length after padding
* @return List of padded tokens
*/
<E> List<E> pad(List<E> tokens, E padItem, int num) {
if (tokens.size() >= num) return tokens;
List<E> padded = new LinkedList<>(tokens);
for (int i = 0; i < num - tokens.size(); i++) {
padded.add(padItem);
}
return padded;
}

/**
* Convert tokens to indexes
* @param tokens input tokens
* @return List of indexes
*/
List<Integer> token2idx(List<String> tokens) {
List<Integer> indexes = new ArrayList<>();
for (String token : tokens) {
if (token2idx.containsKey(token)) {
indexes.add(token2idx.get(token));
} else {
indexes.add(token2idx.get("[UNK]"));
}
}
return indexes;
}

/**
* Convert indexes to tokens
* @param indexes List of indexes
* @return List of tokens
*/
List<String> idx2token(List<Integer> indexes) {
List<String> tokens = new ArrayList<>();
for (int index : indexes) {
tokens.add(idx2token.get(index));
}
return tokens;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

lanking520 marked this conversation as resolved.
Show resolved Hide resolved
package org.apache.mxnetexamples.javaapi.infer.bert;

import org.apache.mxnet.infer.javaapi.Predictor;
import org.apache.mxnet.javaapi.*;
import org.kohsuke.args4j.CmdLineParser;
import org.kohsuke.args4j.Option;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.*;

/**
* This is an example of using BERT to do the general Question and Answer inference jobs
* Users can provide a question with a paragraph contains answer to the model and
* the model will be able to find the best answer from the answer paragraph
*/
public class BertQA {
@Option(name = "--model-path-prefix", usage = "input model directory and prefix of the model")
private String modelPathPrefix = "/model/static_bert_qa";
@Option(name = "--model-epoch", usage = "Epoch number of the model")
private int epoch = 2;
@Option(name = "--model-vocab", usage = "the vocabulary used in the model")
private String modelVocab = "/model/vocab.json";
@Option(name = "--input-question", usage = "the input question")
private String inputQ = "When did BBC Japan start broadcasting?";
@Option(name = "--input-answer", usage = "the input answer")
private String inputA =
"BBC Japan was a general entertainment Channel.\n" +
" Which operated between December 2004 and April 2006.\n" +
"It ceased operations after its Japanese distributor folded.";
@Option(name = "--seq-length", usage = "the maximum length of the sequence")
private int seqLength = 384;

private final static Logger logger = LoggerFactory.getLogger(BertQA.class);
private static NDArray$ NDArray = NDArray$.MODULE$;

private static int argmax(float[] prob) {
int maxIdx = 0;
for (int i = 0; i < prob.length; i++) {
if (prob[maxIdx] < prob[i]) maxIdx = i;
}
return maxIdx;
}

/**
* Do the post processing on the output, apply softmax to get the probabilities
* reshape and get the most probable index
* @param result prediction result
* @param tokens word tokens
* @return Answers clipped from the original paragraph
*/
static List<String> postProcessing(NDArray result, List<String> tokens) {
NDArray[] output = NDArray.split(
NDArray.new splitParam(result, 2).setAxis(2));
// Get the formatted logits result
NDArray startLogits = output[0].reshape(new int[]{0, -3});
NDArray endLogits = output[1].reshape(new int[]{0, -3});
// Get Probability distribution
float[] startProb = NDArray.softmax(
NDArray.new softmaxParam(startLogits))[0].toArray();
float[] endProb = NDArray.softmax(
NDArray.new softmaxParam(endLogits))[0].toArray();
int startIdx = argmax(startProb);
int endIdx = argmax(endProb);
return tokens.subList(startIdx, endIdx + 1);
}

public static void main(String[] args) throws Exception{
BertQA inst = new BertQA();
CmdLineParser parser = new CmdLineParser(inst);
parser.parseArgument(args);
BertDataParser util = new BertDataParser();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit : util --> dataparser just a more meaningful variable name :)

Context context = Context.cpu();
if (System.getenv().containsKey("SCALA_TEST_ON_GPU") &&
Integer.valueOf(System.getenv("SCALA_TEST_ON_GPU")) == 1) {
context = Context.gpu();
}
// pre-processing - tokenize sentence
List<String> tokenQ = util.tokenizer(inst.inputQ.toLowerCase());
List<String> tokenA = util.tokenizer(inst.inputA.toLowerCase());
int validLength = tokenQ.size() + tokenA.size();
logger.info("Valid length: " + validLength);
// generate token types [0000...1111....0000]
List<Float> QAEmbedded = new ArrayList<>();
util.pad(QAEmbedded, 0f, tokenQ.size()).addAll(
util.pad(new ArrayList<Float>(), 1f, tokenA.size())
);
List<Float> tokenTypes = util.pad(QAEmbedded, 0f, inst.seqLength);
// make BERT pre-processing standard
tokenQ.add("[SEP]");
tokenQ.add(0, "[CLS]");
tokenA.add("[SEP]");
tokenQ.addAll(tokenA);
List<String> tokens = util.pad(tokenQ, "[PAD]", inst.seqLength);
logger.info("Pre-processed tokens: " + Arrays.toString(tokenQ.toArray()));
// pre-processing - token to index translation
util.parseJSON(inst.modelVocab);
List<Integer> indexes = util.token2idx(tokens);
List<Float> indexesFloat = new ArrayList<>();
for (int integer : indexes) {
indexesFloat.add((float) integer);
}
// Preparing the input data
List<NDArray> inputBatch = Arrays.asList(
new NDArray(indexesFloat,
new Shape(new int[]{1, inst.seqLength}), context),
new NDArray(tokenTypes,
new Shape(new int[]{1, inst.seqLength}), context),
new NDArray(new float[] { validLength },
new Shape(new int[]{1}), context)
);
// Build the model
List<Context> contexts = new ArrayList<>();
contexts.add(context);
List<DataDesc> inputDescs = Arrays.asList(
new DataDesc("data0",
new Shape(new int[]{1, inst.seqLength}), DType.Float32(), Layout.NT()),
new DataDesc("data1",
new Shape(new int[]{1, inst.seqLength}), DType.Float32(), Layout.NT()),
new DataDesc("data2",
new Shape(new int[]{1}), DType.Float32(), Layout.N())
);
Predictor bertQA = new Predictor(inst.modelPathPrefix, inputDescs, contexts, inst.epoch);
// Start prediction
NDArray result = bertQA.predictWithNDArray(inputBatch).get(0);
List<String> answer = postProcessing(result, tokens);
logger.info("Question: " + inst.inputQ);
logger.info("Answer paragraph: " + inst.inputA);
logger.info("Answer: " + Arrays.toString(answer.toArray()));
}
}
Loading