From 67e1fe8dcae5a74c0e32b786a16ca6baaea98048 Mon Sep 17 00:00:00 2001 From: Chamikara Jayalath Date: Wed, 12 Oct 2022 17:17:23 -0700 Subject: [PATCH 1/6] Adds a Java RunInference example --- examples/java/build.gradle | 1 + .../SklearnMnistClassification.java | 155 ++++++++++++++++++ examples/multi-language/README.md | 126 +++++++++++++- examples/multi-language/build.gradle | 10 +- .../PythonDataframeWordCount.java | 1 - .../examples/generate-sources.sh | 10 ++ .../resources/archetype-resources/pom.xml | 7 + 7 files changed, 304 insertions(+), 6 deletions(-) create mode 100644 examples/java/src/main/java/org/apache/beam/examples/multilanguage/SklearnMnistClassification.java diff --git a/examples/java/build.gradle b/examples/java/build.gradle index 2a56e27b6d34..13b2518bf382 100644 --- a/examples/java/build.gradle +++ b/examples/java/build.gradle @@ -57,6 +57,7 @@ dependencies { implementation library.java.kafka_clients implementation project(path: ":sdks:java:core", configuration: "shadow") implementation project(":sdks:java:extensions:google-cloud-platform-core") + implementation project(":sdks:java:extensions:python") implementation project(":sdks:java:io:google-cloud-platform") implementation project(":sdks:java:io:kafka") implementation project(":sdks:java:extensions:ml") diff --git a/examples/java/src/main/java/org/apache/beam/examples/multilanguage/SklearnMnistClassification.java b/examples/java/src/main/java/org/apache/beam/examples/multilanguage/SklearnMnistClassification.java new file mode 100644 index 000000000000..c2dddd8cf515 --- /dev/null +++ b/examples/java/src/main/java/org/apache/beam/examples/multilanguage/SklearnMnistClassification.java @@ -0,0 +1,155 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.examples.multilanguage; + +import java.util.ArrayList; +import java.util.List; +import org.apache.beam.sdk.Pipeline; +import org.apache.beam.sdk.coders.VarLongCoder; +import org.apache.beam.sdk.extensions.python.transforms.RunInference; +import org.apache.beam.sdk.io.FileIO; +import org.apache.beam.sdk.io.TextIO; +import org.apache.beam.sdk.options.Default; +import org.apache.beam.sdk.options.Description; +import org.apache.beam.sdk.options.PipelineOptions; +import org.apache.beam.sdk.options.PipelineOptionsFactory; +import org.apache.beam.sdk.options.Validation.Required; +import org.apache.beam.sdk.schemas.JavaFieldSchema; +import org.apache.beam.sdk.schemas.Schema; +import org.apache.beam.sdk.schemas.Schema.FieldType; +import org.apache.beam.sdk.schemas.annotations.DefaultSchema; +import org.apache.beam.sdk.schemas.transforms.Convert; +import org.apache.beam.sdk.transforms.DoFn.Element; +import org.apache.beam.sdk.transforms.Filter; +import org.apache.beam.sdk.transforms.MapElements; +import org.apache.beam.sdk.transforms.SerializableFunction; +import org.apache.beam.sdk.transforms.SimpleFunction; +import org.apache.beam.sdk.transforms.Values; +import org.apache.beam.sdk.values.KV; +import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.Row; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Splitter; + +/** + * An example Java MUlti-language pipeline that Performs image classification on handwritten digits + * from the MNIST database. + * + *

For more details and instructions for running this please see here. + */ +public class SklearnMnistClassification { + + private String getModelLoaderScript() { + String s = "from apache_beam.ml.inference.sklearn_inference import SklearnModelHandlerNumpy\n"; + s = s + "from apache_beam.ml.inference.base import KeyedModelHandler\n"; + s = s + "def get_model_handler(model_uri):\n"; + s = s + " return KeyedModelHandler(SklearnModelHandlerNumpy(model_uri))\n"; + + return s; + } + + static class FilterFn implements SerializableFunction { + + @Override + public Boolean apply(String input) { + return !input.startsWith("label"); + } + } + + static class KVFn extends SimpleFunction>> { + + @Override + public KV> apply(String input) { + String[] data = Splitter.on(',').splitToList(input).toArray(new String[]{}); + Long label = Long.valueOf(data[0]); + List pixels = new ArrayList(); + for (int i = 1; i < data.length; i++) { + pixels.add(Long.valueOf(data[i])); + } + + return KV.of(label, pixels); + } + } + + static class FormatOutput extends SimpleFunction, String> { + + @Override + public String apply(KV input) { + return input.getKey() + " was mapped to " + + input.getValue().getString("inference"); + } + } + + + void runExample(SklearnMnistClassificationOptions options, String expansionService) { + Schema schema = + Schema.of( + Schema.Field.of("example", Schema.FieldType.array(Schema.FieldType.INT64)), + Schema.Field.of("inference", FieldType.STRING)); + + Pipeline pipeline = Pipeline.create(options); + PCollection>> col = + pipeline + .apply(TextIO.read().from(options.getInput())) + .apply(Filter.by(new FilterFn())) + .apply(MapElements.via(new KVFn())); + col.apply( + RunInference.ofKVs(getModelLoaderScript(), schema, VarLongCoder.of()) + .withKwarg("model_uri", options.getModelPath()) + .withExpansionService(expansionService)) + .apply(MapElements.via(new FormatOutput())).apply(TextIO.write().to(options.getOutput())); + + pipeline.run().waitUntilFinish(); + } + + public interface SklearnMnistClassificationOptions extends PipelineOptions { + + @Description("Path to an input file that contains labels and pixels to feed into the model") + @Default.String("gs://apache-beam-samples/multi-language/mnist/example_input.csv") + String getInput(); + + void setInput(String value); + + @Description("Path for storing the output") + @Required + String getOutput(); + + void setOutput(String value); + + @Description( + "Path to a model file that contains the pickled file of a scikit-learn model trained on MNIST data") + @Default.String("gs://apache-beam-samples/multi-language/mnist/example_model") + String getModelPath(); + + void setModelPath(String value); + + /** Set this option to specify Python expansion service URL. */ + @Description("URL of Python expansion service") + @Default.String("") + String getExpansionService(); + + void setExpansionService(String value); + } + + public static void main(String[] args) { + SklearnMnistClassificationOptions options = + PipelineOptionsFactory.fromArgs(args).as(SklearnMnistClassificationOptions.class); + SklearnMnistClassification example = new SklearnMnistClassification(); + example.runExample(options, options.getExpansionService()); + } +} diff --git a/examples/multi-language/README.md b/examples/multi-language/README.md index e181c602e342..54e40a2099c3 100644 --- a/examples/multi-language/README.md +++ b/examples/multi-language/README.md @@ -22,29 +22,147 @@ This project provides examples of Apache Beam [multi-language pipelines](https://beam.apache.org/documentation/programming-guide/#multi-language-pipelines): +## Using Java transforms from Python + * **python/addprefix** - A Python pipeline that reads a text file and attaches a prefix on the Java side to each input. * **python/javacount** - A Python pipeline that counts words using the Java `Count.perElement()` transform. * **python/javadatagenerator** - A Python pipeline that produces a set of strings generated from Java. This example demonstrates the `JavaExternalTransform` API. -## Instructions for running the pipelines +### Instructions for running the pipelines -### 1) Start the expansion service +#### 1) Start the expansion service 1. Download the latest 'beam-examples-multi-language' JAR. Starting with Apache Beam 2.36.0, you can find it in [the Maven Central Repository](https://search.maven.org/search?q=g:org.apache.beam). 2. Run the following command, replacing `` and `` with valid values: `java -jar beam-examples-multi-language-.jar --javaClassLookupAllowlistFile='*'` -### 2) Set up a Python virtual environment for Beam +#### 2) Set up a Python virtual environment for Beam 1. See [the Python quickstart](https://beam.apache.org/get-started/quickstart-py/) for more information. -### 3) Execute the Python pipeline +#### 3) Execute the Python pipeline 1. In a new shell, run a pipeline in the **python** directory using a Beam runner that supports multi-language pipelines. The Python files contain details about the actual commands to run. +## Using Python transforms from Java + +### Sklearn Mnist Classification + +Performs image classification on handwritten digits from the [MNIST](https://en.wikipedia.org/wiki/MNIST_database) +database. + +Please see [here](https://github.com/apache/beam/tree/master/sdks/python/apache_beam/examples/inference) for +context and information regarding the corresponding Python pipeline. + +Please note that the Java pipeline is +[availalble in the Beam Java examples module](https://github.com/apache/beam/tree/master/examples/java/src/main/java/org/apache/beam/examples/multilanguage/SklearnMnistClassification.java). + +#### Setup + +* Obtain/generate a csv input file that contains labels and pixels to feed into the model and store it in +GCS. And example input is available [here](TODO). + +* Create a model file that contains the pickled file of a scikit-learn model +trained on MNIST data and store it in GCS. An example model file is available [here](TODO). + +* Perform Beam runner specific setup according to instructions +[here](https://beam.apache.org/get-started/quickstart-java/#run-a-pipeline). + +Following instructions are for running the pipeline with the Dataflow runner. For other portable runners, +please modify the instructions according to the guidelines +[here](https://beam.apache.org/documentation/sdks/java-multi-language-pipelines/#run-with-directrunner) + +#### Instructions for running the Java pipeline on released Beam (Beam 2.43.0 and later). + +* Checkout the Beam examples Maven archetype for the relevant Beam version. + +``` +export BEAM_VERSION= + +mvn archetype:generate \ + -DarchetypeGroupId=org.apache.beam \ + -DarchetypeArtifactId=beam-sdks-java-maven-archetypes-examples \ + -DarchetypeVersion=$BEAM_VERSION \ + -DgroupId=org.example \ + -DartifactId=multi-language-beam \ + -Dversion="0.1" \ + -Dpackage=org.apache.beam.examples \ + -DinteractiveMode=false +``` + +* Run the pipeline. + +``` +export GCP_PROJECT= +export GCP_BUCKET= +export GCP_REGION= + +mvn compile exec:java -Dexec.mainClass=org.apache.beam.examples.multilanguage.SklearnMnistClassification \ + -Dexec.args="--runner=DataflowRunner --project=$GCP_PROJECT \ + --region=us-central1 \ + --gcpTempLocation=gs://$GCP_BUCKET/multi-language-beam/tmp \ + --output=gs://$GCP_BUCKET/multi-language-beam/output" \ + -Pdataflow-runner +``` + +* Inspect the output. Each line has data separated by a comma ",". The first item is the actual label of +the digit. The second item is the predicted label of the digit. + +``` +gsutil cat gs://$GCP_BUCKET/multi-language-beam/output* +``` + +#### Instructions for running the Java pipeline at HEAD (Beam 2.41.0 and 2.42.0). + +* Make sure that Docker is installed and available on your system. + +* Build and push Python and Java Docker containers. + +``` +export DOCKER_ROOT= + +./gradlew :sdks:python:container:py38:docker -Pdocker-repository-root=$DOCKER_ROOT -Pdocker-tag=latest + +docker push $DOCKER_ROOT/beam_python3.8_sdk:latest + +./gradlew :sdks:java:container:java11:docker -Pdocker-repository-root=$DOCKER_ROOT -Pdocker-tag=latest + +docker push $DOCKER_ROOT/beam_java11_sdk:latest +``` + +* Run the pipeline using the following Gradle command (this guide assumes Dataflow runner). +Note that we override both the Java and Python SDK harness containers here. + +``` +export GCP_PROJECT= +export GCP_BUCKET= +export GCP_REGION= + +./gradlew :examples:multi-language:sklearnMinstClassification --args=" \ +--runner=DataflowRunner \ +--project=$GCP_PROJECT \ +--gcpTempLocation=gs://$GCP_BUCKET/multi-language-beam/tmp \ +--output=gs://$GCP_BUCKET/multi-language-beam/output \ +--sdkContainerImage=$DOCKER_ROOT/beam_java11_sdk:latest \ +--sdkHarnessContainerImageOverrides=.*python.*,$DOCKER_ROOT/beam_python3.8_sdk:latest \ +--region=${GCP_REGION}" +``` + +* Inspect the output. Each line has data separated by a comma ",". The first item is the actual label +of the digit. The second item is the predicted label of the digit. + +``` +gsutil cat gs://$GCP_BUCKET/multi-language-beam/output* +``` + + + + + + diff --git a/examples/multi-language/build.gradle b/examples/multi-language/build.gradle index e53505b1ea70..03a662038307 100644 --- a/examples/multi-language/build.gradle +++ b/examples/multi-language/build.gradle @@ -33,6 +33,7 @@ ext.summary = "Java Classes for Multi-language Examples" dependencies { implementation library.java.vendored_guava_26_0_jre + implementation project(":examples:java") implementation project(path: ":sdks:java:core", configuration: "shadow") runtimeOnly project(path: ":runners:direct-java", configuration: "shadow") runtimeOnly project(path: ":runners:google-cloud-dataflow-java") @@ -47,4 +48,11 @@ task pythonDataframeWordCount(type: JavaExec) { description "Run the Java word count example using external Python DataframeTransform" mainClass = "org.apache.beam.examples.multilanguage.PythonDataframeWordCount" classpath = sourceSets.main.runtimeClasspath -} \ No newline at end of file +} + +task sklearnMinstClassification(type: JavaExec) { + description "Run the Java pipeline that performns image classification on handwritten digits from the MNIST database" + mainClass = "org.apache.beam.examples.multilanguage.SklearnMnistClassification" + classpath = sourceSets.main.runtimeClasspath +} + diff --git a/examples/multi-language/src/main/java/org/apache/beam/examples/multilanguage/PythonDataframeWordCount.java b/examples/multi-language/src/main/java/org/apache/beam/examples/multilanguage/PythonDataframeWordCount.java index 6f037e7d8f82..dc90c5063dee 100644 --- a/examples/multi-language/src/main/java/org/apache/beam/examples/multilanguage/PythonDataframeWordCount.java +++ b/examples/multi-language/src/main/java/org/apache/beam/examples/multilanguage/PythonDataframeWordCount.java @@ -49,7 +49,6 @@ * ./gradlew :examples:multi-language:pythonDataframeWordCount --args=" \ * --runner=DataflowRunner \ * --output=gs://{$OUTPUT_BUCKET}/count \ - * --experiments=use_runner_v2 \ * --sdkHarnessContainerImageOverrides=.*python.*,gcr.io/apache-beam-testing/beam-sdk/beam_python{$PYTHON_VERSION}_sdk:latest" * } */ diff --git a/sdks/java/maven-archetypes/examples/generate-sources.sh b/sdks/java/maven-archetypes/examples/generate-sources.sh index 5df4c1f77616..f22f0f23890f 100755 --- a/sdks/java/maven-archetypes/examples/generate-sources.sh +++ b/sdks/java/maven-archetypes/examples/generate-sources.sh @@ -70,6 +70,16 @@ rsync -a \ "${EXAMPLES_ROOT}"/src/test/java/org/apache/beam/examples/complete/game/ \ "${ARCHETYPE_ROOT}/src/test/java/complete/game" +# +# Copy the Java multi-language examples. +# + +mkdir -p "${ARCHETYPE_ROOT}/src/test/java/multilanguage/" + +rsync -a \ + "${EXAMPLES_ROOT}"/src/main/java/org/apache/beam/examples/multilanguage/ \ + "${ARCHETYPE_ROOT}/src/main/java/multilanguage" + # # Replace 'package org.apache.beam.examples' with 'package ${package}' in all Java code # diff --git a/sdks/java/maven-archetypes/examples/src/main/resources/archetype-resources/pom.xml b/sdks/java/maven-archetypes/examples/src/main/resources/archetype-resources/pom.xml index 6517f56a95d3..50515b812078 100644 --- a/sdks/java/maven-archetypes/examples/src/main/resources/archetype-resources/pom.xml +++ b/sdks/java/maven-archetypes/examples/src/main/resources/archetype-resources/pom.xml @@ -357,6 +357,13 @@ ${beam.version} + + + org.apache.beam + beam-sdks-java-extensions-python + ${beam.version} + + com.google.api-client From f88c65285f667820319578506cca8ad586a4a01d Mon Sep 17 00:00:00 2001 From: Chamikara Jayalath Date: Wed, 12 Oct 2022 18:05:44 -0700 Subject: [PATCH 2/6] Fixes lint --- .../examples/multilanguage/SklearnMnistClassification.java | 6 ------ examples/multi-language/README.md | 2 +- 2 files changed, 1 insertion(+), 7 deletions(-) diff --git a/examples/java/src/main/java/org/apache/beam/examples/multilanguage/SklearnMnistClassification.java b/examples/java/src/main/java/org/apache/beam/examples/multilanguage/SklearnMnistClassification.java index c2dddd8cf515..37557a852308 100644 --- a/examples/java/src/main/java/org/apache/beam/examples/multilanguage/SklearnMnistClassification.java +++ b/examples/java/src/main/java/org/apache/beam/examples/multilanguage/SklearnMnistClassification.java @@ -22,24 +22,18 @@ import org.apache.beam.sdk.Pipeline; import org.apache.beam.sdk.coders.VarLongCoder; import org.apache.beam.sdk.extensions.python.transforms.RunInference; -import org.apache.beam.sdk.io.FileIO; import org.apache.beam.sdk.io.TextIO; import org.apache.beam.sdk.options.Default; import org.apache.beam.sdk.options.Description; import org.apache.beam.sdk.options.PipelineOptions; import org.apache.beam.sdk.options.PipelineOptionsFactory; import org.apache.beam.sdk.options.Validation.Required; -import org.apache.beam.sdk.schemas.JavaFieldSchema; import org.apache.beam.sdk.schemas.Schema; import org.apache.beam.sdk.schemas.Schema.FieldType; -import org.apache.beam.sdk.schemas.annotations.DefaultSchema; -import org.apache.beam.sdk.schemas.transforms.Convert; -import org.apache.beam.sdk.transforms.DoFn.Element; import org.apache.beam.sdk.transforms.Filter; import org.apache.beam.sdk.transforms.MapElements; import org.apache.beam.sdk.transforms.SerializableFunction; import org.apache.beam.sdk.transforms.SimpleFunction; -import org.apache.beam.sdk.transforms.Values; import org.apache.beam.sdk.values.KV; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.Row; diff --git a/examples/multi-language/README.md b/examples/multi-language/README.md index 54e40a2099c3..4b41594f9c4b 100644 --- a/examples/multi-language/README.md +++ b/examples/multi-language/README.md @@ -65,7 +65,7 @@ Please note that the Java pipeline is #### Setup -* Obtain/generate a csv input file that contains labels and pixels to feed into the model and store it in +* Obtain/generate a csv input file that contains labels and pixels to feed into the model and store it in GCS. And example input is available [here](TODO). * Create a model file that contains the pickled file of a scikit-learn model From 9a078e79a5b4b9aa6bd97c583945740c39d3fe56 Mon Sep 17 00:00:00 2001 From: Chamikara Jayalath Date: Wed, 12 Oct 2022 18:24:38 -0700 Subject: [PATCH 3/6] Fix spotless --- .../multilanguage/SklearnMnistClassification.java | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/examples/java/src/main/java/org/apache/beam/examples/multilanguage/SklearnMnistClassification.java b/examples/java/src/main/java/org/apache/beam/examples/multilanguage/SklearnMnistClassification.java index 37557a852308..cef6848ca8be 100644 --- a/examples/java/src/main/java/org/apache/beam/examples/multilanguage/SklearnMnistClassification.java +++ b/examples/java/src/main/java/org/apache/beam/examples/multilanguage/SklearnMnistClassification.java @@ -69,7 +69,7 @@ static class KVFn extends SimpleFunction>> { @Override public KV> apply(String input) { - String[] data = Splitter.on(',').splitToList(input).toArray(new String[]{}); + String[] data = Splitter.on(',').splitToList(input).toArray(new String[] {}); Long label = Long.valueOf(data[0]); List pixels = new ArrayList(); for (int i = 1; i < data.length; i++) { @@ -84,12 +84,10 @@ static class FormatOutput extends SimpleFunction, String> { @Override public String apply(KV input) { - return input.getKey() + " was mapped to " + - input.getValue().getString("inference"); + return input.getKey() + " was mapped to " + input.getValue().getString("inference"); } } - void runExample(SklearnMnistClassificationOptions options, String expansionService) { Schema schema = Schema.of( @@ -106,7 +104,8 @@ void runExample(SklearnMnistClassificationOptions options, String expansionServi RunInference.ofKVs(getModelLoaderScript(), schema, VarLongCoder.of()) .withKwarg("model_uri", options.getModelPath()) .withExpansionService(expansionService)) - .apply(MapElements.via(new FormatOutput())).apply(TextIO.write().to(options.getOutput())); + .apply(MapElements.via(new FormatOutput())) + .apply(TextIO.write().to(options.getOutput())); pipeline.run().waitUntilFinish(); } From 6e72c8700340248b86a293cc8e2ccb25b1e126d8 Mon Sep 17 00:00:00 2001 From: Chamikara Jayalath Date: Thu, 13 Oct 2022 10:25:01 -0700 Subject: [PATCH 4/6] Fix Java PreCommit --- examples/multi-language/build.gradle | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/multi-language/build.gradle b/examples/multi-language/build.gradle index 03a662038307..61fdb686f4eb 100644 --- a/examples/multi-language/build.gradle +++ b/examples/multi-language/build.gradle @@ -33,8 +33,8 @@ ext.summary = "Java Classes for Multi-language Examples" dependencies { implementation library.java.vendored_guava_26_0_jre - implementation project(":examples:java") implementation project(path: ":sdks:java:core", configuration: "shadow") + runtimeOnly project(path: ":examples:java") runtimeOnly project(path: ":runners:direct-java", configuration: "shadow") runtimeOnly project(path: ":runners:google-cloud-dataflow-java") runtimeOnly project(path: ":runners:portability:java") From f088110cc6254c5dd4cb4e070a4a0ec8a4d8cd09 Mon Sep 17 00:00:00 2001 From: Chamikara Jayalath Date: Sun, 16 Oct 2022 21:21:56 -0700 Subject: [PATCH 5/6] Address reviewer comments --- examples/multi-language/README.md | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/examples/multi-language/README.md b/examples/multi-language/README.md index 4b41594f9c4b..127ab8c30eb2 100644 --- a/examples/multi-language/README.md +++ b/examples/multi-language/README.md @@ -66,10 +66,16 @@ Please note that the Java pipeline is #### Setup * Obtain/generate a csv input file that contains labels and pixels to feed into the model and store it in -GCS. And example input is available [here](TODO). +GCS. An example input is available +[here](https://storage.googleapis.com/apache-beam-samples/multi-language/mnist/example_input.csv). * Create a model file that contains the pickled file of a scikit-learn model -trained on MNIST data and store it in GCS. An example model file is available [here](TODO). +trained on MNIST data and store it in GCS. An example model file is available +[here](https://storage.googleapis.com/apache-beam-samples/multi-language/mnist/example_model). +This model was generated by by running the program given +[here](https://python-course.eu/machine-learning/training-and-testing-with-mnist.php) +on the +[example input dataset](https://storage.googleapis.com/apache-beam-samples/multi-language/mnist/example_input.csv). * Perform Beam runner specific setup according to instructions [here](https://beam.apache.org/get-started/quickstart-java/#run-a-pipeline). @@ -160,9 +166,3 @@ of the digit. The second item is the predicted label of the digit. ``` gsutil cat gs://$GCP_BUCKET/multi-language-beam/output* ``` - - - - - - From 76d6c192b31297a7c8f44350e86a6a5f6f7fa048 Mon Sep 17 00:00:00 2001 From: Chamikara Jayalath Date: Mon, 17 Oct 2022 08:25:52 -0700 Subject: [PATCH 6/6] Addresses reviewer comments --- .../SklearnMnistClassification.java | 25 ++++++++++++++----- 1 file changed, 19 insertions(+), 6 deletions(-) diff --git a/examples/java/src/main/java/org/apache/beam/examples/multilanguage/SklearnMnistClassification.java b/examples/java/src/main/java/org/apache/beam/examples/multilanguage/SklearnMnistClassification.java index cef6848ca8be..4668ec1b41ef 100644 --- a/examples/java/src/main/java/org/apache/beam/examples/multilanguage/SklearnMnistClassification.java +++ b/examples/java/src/main/java/org/apache/beam/examples/multilanguage/SklearnMnistClassification.java @@ -40,7 +40,7 @@ import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Splitter; /** - * An example Java MUlti-language pipeline that Performs image classification on handwritten digits + * An example Java Multi-language pipeline that Performs image classification on handwritten digits * from the MNIST database. * *

For more details and instructions for running this please see { + /** Filters out the header of the dataset that should not be used for the computation. */ + static class FilterNonRecordsFn implements SerializableFunction { @Override public Boolean apply(String input) { @@ -65,7 +71,12 @@ public Boolean apply(String input) { } } - static class KVFn extends SimpleFunction>> { + /** + * Seperates our input records to label and data. Each input record is a set of comma separated + * string digits where first digit is the label and rest are data (pixels that represent the + * digit). + */ + static class RecordsToLabeledPixelsFn extends SimpleFunction>> { @Override public KV> apply(String input) { @@ -80,15 +91,17 @@ public KV> apply(String input) { } } + /** Formats the output to a mapping from the expected digit to the inferred digit. */ static class FormatOutput extends SimpleFunction, String> { @Override public String apply(KV input) { - return input.getKey() + " was mapped to " + input.getValue().getString("inference"); + return input.getKey() + "," + input.getValue().getString("inference"); } } void runExample(SklearnMnistClassificationOptions options, String expansionService) { + // Schema of the output PCollection Row type to be provided to the RunInference transform. Schema schema = Schema.of( Schema.Field.of("example", Schema.FieldType.array(Schema.FieldType.INT64)), @@ -98,8 +111,8 @@ void runExample(SklearnMnistClassificationOptions options, String expansionServi PCollection>> col = pipeline .apply(TextIO.read().from(options.getInput())) - .apply(Filter.by(new FilterFn())) - .apply(MapElements.via(new KVFn())); + .apply(Filter.by(new FilterNonRecordsFn())) + .apply(MapElements.via(new RecordsToLabeledPixelsFn())); col.apply( RunInference.ofKVs(getModelLoaderScript(), schema, VarLongCoder.of()) .withKwarg("model_uri", options.getModelPath())