diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/Layout.scala b/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/Layout.scala new file mode 100644 index 000000000000..cfe290c1aff7 --- /dev/null +++ b/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/Layout.scala @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.mxnet.javaapi + +/** + * Layout definition of DataDesc + * N Batch size + * C channels + * H Height + * W Weight + * T sequence length + * __undefined__ default value of Layout + */ +object Layout { + val UNDEFINED: String = org.apache.mxnet.Layout.UNDEFINED + val NCHW: String = org.apache.mxnet.Layout.NCHW + val NTC: String = org.apache.mxnet.Layout.NTC + val NT: String = org.apache.mxnet.Layout.NT + val N: String = org.apache.mxnet.Layout.N +} diff --git a/scala-package/examples/scripts/infer/bert/get_bert_data.sh b/scala-package/examples/scripts/infer/bert/get_bert_data.sh index 14cc78c4dce7..609aae27cc66 100755 --- a/scala-package/examples/scripts/infer/bert/get_bert_data.sh +++ b/scala-package/examples/scripts/infer/bert/get_bert_data.sh @@ -25,9 +25,6 @@ data_path=$MXNET_ROOT/scripts/infer/models/static-bert-qa/ if [ ! -d "$data_path" ]; then mkdir -p "$data_path" -fi - -if [ ! -f "$data_path" ]; then curl https://s3.us-east-2.amazonaws.com/mxnet-scala/scala-example-ci/BertQA/vocab.json -o $data_path/vocab.json curl https://s3.us-east-2.amazonaws.com/mxnet-scala/scala-example-ci/BertQA/static_bert_qa-0002.params -o $data_path/static_bert_qa-0002.params curl https://s3.us-east-2.amazonaws.com/mxnet-scala/scala-example-ci/BertQA/static_bert_qa-symbol.json -o $data_path/static_bert_qa-symbol.json diff --git a/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/bert/BertUtil.java b/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/bert/BertDataParser.java similarity index 99% rename from scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/bert/BertUtil.java rename to scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/bert/BertDataParser.java index 820d14b03cb7..86354477fcbb 100644 --- a/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/bert/BertUtil.java +++ b/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/bert/BertDataParser.java @@ -25,7 +25,7 @@ import com.google.gson.JsonElement; import com.google.gson.JsonObject; -public class BertUtil { +public class BertDataParser { private Map token2idx; private List idx2token; diff --git a/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/bert/BertQA.java b/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/bert/BertQA.java index a6c84865e1b5..3254faeb08d1 100644 --- a/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/bert/BertQA.java +++ b/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/bert/BertQA.java @@ -65,10 +65,8 @@ static List postProcessing(NDArray result, List tokens) { NDArray[] output = NDArray.split( NDArray.new splitParam(result, 2).setAxis(2)); // Get the formatted logits result - NDArray startLogits = NDArray.reshape( - NDArray.new reshapeParam(output[0]).setShape(new Shape(new int[]{0, -3})))[0]; - NDArray endLogits = NDArray.reshape( - NDArray.new reshapeParam(output[1]).setShape(new Shape(new int[]{0, -3})))[0]; + NDArray startLogits = output[0].reshape(new int[]{0, -3}); + NDArray endLogits = output[1].reshape(new int[]{0, -3}); // Get Probability distribution float[] startProb = NDArray.softmax( NDArray.new softmaxParam(startLogits))[0].toArray(); @@ -83,7 +81,7 @@ public static void main(String[] args) throws Exception{ BertQA inst = new BertQA(); CmdLineParser parser = new CmdLineParser(inst); parser.parseArgument(args); - BertUtil util = new BertUtil(); + BertDataParser util = new BertDataParser(); Context context = Context.cpu(); if (System.getenv().containsKey("SCALA_TEST_ON_GPU") && Integer.valueOf(System.getenv("SCALA_TEST_ON_GPU")) == 1) { @@ -115,26 +113,25 @@ public static void main(String[] args) throws Exception{ indexesFloat.add((float) integer); } // Preparing the input data - NDArray inputs = new NDArray(indexesFloat, - new Shape(new int[]{1, inst.seqLength}), context); - NDArray tokenTypesND = new NDArray(tokenTypes, - new Shape(new int[]{1, inst.seqLength}), context); - NDArray validLengthND = new NDArray(new float[] {(float) validLength}, - new Shape(new int[]{1}), context); - List inputBatch = new ArrayList<>(); - inputBatch.add(inputs); - inputBatch.add(tokenTypesND); - inputBatch.add(validLengthND); + List inputBatch = Arrays.asList( + new NDArray(indexesFloat, + new Shape(new int[]{1, inst.seqLength}), context), + new NDArray(tokenTypes, + new Shape(new int[]{1, inst.seqLength}), context), + new NDArray(new float[] { validLength }, + new Shape(new int[]{1}), context) + ); // Build the model - List inputDescs = new ArrayList<>(); List contexts = new ArrayList<>(); contexts.add(context); - inputDescs.add(new DataDesc("data0", - new Shape(new int[]{1, inst.seqLength}), DType.Float32(), "NT")); - inputDescs.add(new DataDesc("data1", - new Shape(new int[]{1, inst.seqLength}), DType.Float32(), "NT")); - inputDescs.add(new DataDesc("data2", - new Shape(new int[]{1}), DType.Float32(), "N")); + List inputDescs = Arrays.asList( + new DataDesc("data0", + new Shape(new int[]{1, inst.seqLength}), DType.Float32(), Layout.NT()), + new DataDesc("data1", + new Shape(new int[]{1, inst.seqLength}), DType.Float32(), Layout.NT()), + new DataDesc("data2", + new Shape(new int[]{1}), DType.Float32(), Layout.N()) + ); Predictor bertQA = new Predictor(inst.modelPathPrefix, inputDescs, contexts, inst.epoch); // Start prediction NDArray result = bertQA.predictWithNDArray(inputBatch).get(0); diff --git a/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/bert/README.md b/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/bert/README.md index 49e1aa6d3b7d..4e65406e8541 100644 --- a/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/bert/README.md +++ b/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/bert/README.md @@ -46,8 +46,6 @@ From the `scala-package/examples/scripts/infer/bert/` folder run: ./get_bert_data.sh ``` -**Note**: You may need to run `chmod +x get_bert_data.sh` before running this script. - ### Step 2: Setup data path of the model ### Setup Datapath and Parameters @@ -80,7 +78,7 @@ To learn more about how BERT works in MXNet, please follow this [tutorial](https The model was extracted from the GluonNLP with static length settings. -[Download link for the scrtipt](https://gluon-nlp.mxnet.io/_downloads/bert.zip) +[Download link for the script](https://gluon-nlp.mxnet.io/_downloads/bert.zip) The original description can be found in [here](https://gluon-nlp.mxnet.io/model_zoo/bert/index.html#bert-base-on-squad-1-1). ```bash