Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds a Java RunInference example #23619

Merged
merged 6 commits into from
Oct 17, 2022
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions examples/java/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ dependencies {
implementation library.java.kafka_clients
implementation project(path: ":sdks:java:core", configuration: "shadow")
implementation project(":sdks:java:extensions:google-cloud-platform-core")
implementation project(":sdks:java:extensions:python")
implementation project(":sdks:java:io:google-cloud-platform")
implementation project(":sdks:java:io:kafka")
implementation project(":sdks:java:extensions:ml")
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.beam.examples.multilanguage;

import java.util.ArrayList;
import java.util.List;
import org.apache.beam.sdk.Pipeline;
import org.apache.beam.sdk.coders.VarLongCoder;
import org.apache.beam.sdk.extensions.python.transforms.RunInference;
import org.apache.beam.sdk.io.TextIO;
import org.apache.beam.sdk.options.Default;
import org.apache.beam.sdk.options.Description;
import org.apache.beam.sdk.options.PipelineOptions;
import org.apache.beam.sdk.options.PipelineOptionsFactory;
import org.apache.beam.sdk.options.Validation.Required;
import org.apache.beam.sdk.schemas.Schema;
import org.apache.beam.sdk.schemas.Schema.FieldType;
import org.apache.beam.sdk.transforms.Filter;
import org.apache.beam.sdk.transforms.MapElements;
import org.apache.beam.sdk.transforms.SerializableFunction;
import org.apache.beam.sdk.transforms.SimpleFunction;
import org.apache.beam.sdk.values.KV;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.Row;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Splitter;

/**
* An example Java MUlti-language pipeline that Performs image classification on handwritten digits
chamikaramj marked this conversation as resolved.
Show resolved Hide resolved
* from the <a href="https://en.wikipedia.org/wiki/MNIST_database">MNIST</a> database.
*
* <p>For more details and instructions for running this please see <a
* href="https://github.com/apache/beam/tree/master/examples/multi-language">here</a>.
*/
public class SklearnMnistClassification {

private String getModelLoaderScript() {
chamikaramj marked this conversation as resolved.
Show resolved Hide resolved
String s = "from apache_beam.ml.inference.sklearn_inference import SklearnModelHandlerNumpy\n";
s = s + "from apache_beam.ml.inference.base import KeyedModelHandler\n";
s = s + "def get_model_handler(model_uri):\n";
s = s + " return KeyedModelHandler(SklearnModelHandlerNumpy(model_uri))\n";

return s;
}

static class FilterFn implements SerializableFunction<String, Boolean> {
chamikaramj marked this conversation as resolved.
Show resolved Hide resolved

@Override
public Boolean apply(String input) {
return !input.startsWith("label");
}
}

static class KVFn extends SimpleFunction<String, KV<Long, Iterable<Long>>> {
chamikaramj marked this conversation as resolved.
Show resolved Hide resolved

@Override
public KV<Long, Iterable<Long>> apply(String input) {
chamikaramj marked this conversation as resolved.
Show resolved Hide resolved
String[] data = Splitter.on(',').splitToList(input).toArray(new String[] {});
Long label = Long.valueOf(data[0]);
List<Long> pixels = new ArrayList<Long>();
for (int i = 1; i < data.length; i++) {
pixels.add(Long.valueOf(data[i]));
}

return KV.of(label, pixels);
}
}

static class FormatOutput extends SimpleFunction<KV<Long, Row>, String> {

@Override
public String apply(KV<Long, Row> input) {
return input.getKey() + " was mapped to " + input.getValue().getString("inference");
}
}

void runExample(SklearnMnistClassificationOptions options, String expansionService) {
Schema schema =
chamikaramj marked this conversation as resolved.
Show resolved Hide resolved
Schema.of(
Schema.Field.of("example", Schema.FieldType.array(Schema.FieldType.INT64)),
Schema.Field.of("inference", FieldType.STRING));

Pipeline pipeline = Pipeline.create(options);
PCollection<KV<Long, Iterable<Long>>> col =
pipeline
.apply(TextIO.read().from(options.getInput()))
.apply(Filter.by(new FilterFn()))
.apply(MapElements.via(new KVFn()));
col.apply(
RunInference.ofKVs(getModelLoaderScript(), schema, VarLongCoder.of())
.withKwarg("model_uri", options.getModelPath())
.withExpansionService(expansionService))
.apply(MapElements.via(new FormatOutput()))
.apply(TextIO.write().to(options.getOutput()));

pipeline.run().waitUntilFinish();
}

public interface SklearnMnistClassificationOptions extends PipelineOptions {

@Description("Path to an input file that contains labels and pixels to feed into the model")
@Default.String("gs://apache-beam-samples/multi-language/mnist/example_input.csv")
String getInput();

void setInput(String value);

@Description("Path for storing the output")
@Required
String getOutput();

void setOutput(String value);

@Description(
"Path to a model file that contains the pickled file of a scikit-learn model trained on MNIST data")
@Default.String("gs://apache-beam-samples/multi-language/mnist/example_model")
String getModelPath();

void setModelPath(String value);

/** Set this option to specify Python expansion service URL. */
@Description("URL of Python expansion service")
@Default.String("")
damccorm marked this conversation as resolved.
Show resolved Hide resolved
String getExpansionService();

void setExpansionService(String value);
}

public static void main(String[] args) {
SklearnMnistClassificationOptions options =
PipelineOptionsFactory.fromArgs(args).as(SklearnMnistClassificationOptions.class);
SklearnMnistClassification example = new SklearnMnistClassification();
example.runExample(options, options.getExpansionService());
}
}
126 changes: 122 additions & 4 deletions examples/multi-language/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,29 +22,147 @@
This project provides examples of Apache Beam
[multi-language pipelines](https://beam.apache.org/documentation/programming-guide/#multi-language-pipelines):

## Using Java transforms from Python

* **python/addprefix** - A Python pipeline that reads a text file and attaches a prefix on the Java side to each input.
* **python/javacount** - A Python pipeline that counts words using the Java `Count.perElement()` transform.
* **python/javadatagenerator** - A Python pipeline that produces a set of strings generated from Java.
This example demonstrates the `JavaExternalTransform` API.

## Instructions for running the pipelines
### Instructions for running the pipelines

### 1) Start the expansion service
#### 1) Start the expansion service

1. Download the latest 'beam-examples-multi-language' JAR. Starting with Apache Beam 2.36.0,
you can find it in [the Maven Central Repository](https://search.maven.org/search?q=g:org.apache.beam).
2. Run the following command, replacing `<version>` and `<port>` with valid values:
`java -jar beam-examples-multi-language-<version>.jar <port> --javaClassLookupAllowlistFile='*'`

### 2) Set up a Python virtual environment for Beam
#### 2) Set up a Python virtual environment for Beam

1. See [the Python quickstart](https://beam.apache.org/get-started/quickstart-py/)
for more information.

### 3) Execute the Python pipeline
#### 3) Execute the Python pipeline

1. In a new shell, run a pipeline in the **python** directory using a Beam runner that supports
multi-language pipelines.

The Python files contain details about the actual commands to run.

## Using Python transforms from Java

### Sklearn Mnist Classification

Performs image classification on handwritten digits from the [MNIST](https://en.wikipedia.org/wiki/MNIST_database)
database.

Please see [here](https://github.com/apache/beam/tree/master/sdks/python/apache_beam/examples/inference) for
context and information regarding the corresponding Python pipeline.

Please note that the Java pipeline is
[availalble in the Beam Java examples module](https://github.com/apache/beam/tree/master/examples/java/src/main/java/org/apache/beam/examples/multilanguage/SklearnMnistClassification.java).

#### Setup

* Obtain/generate a csv input file that contains labels and pixels to feed into the model and store it in
GCS. An example input is available
[here](https://storage.googleapis.com/apache-beam-samples/multi-language/mnist/example_input.csv).

* Create a model file that contains the pickled file of a scikit-learn model
trained on MNIST data and store it in GCS. An example model file is available
[here](https://storage.googleapis.com/apache-beam-samples/multi-language/mnist/example_model).
This model was generated by by running the program given
[here](https://python-course.eu/machine-learning/training-and-testing-with-mnist.php)
on the
[example input dataset](https://storage.googleapis.com/apache-beam-samples/multi-language/mnist/example_input.csv).

* Perform Beam runner specific setup according to instructions
[here](https://beam.apache.org/get-started/quickstart-java/#run-a-pipeline).

Following instructions are for running the pipeline with the Dataflow runner. For other portable runners,
please modify the instructions according to the guidelines
[here](https://beam.apache.org/documentation/sdks/java-multi-language-pipelines/#run-with-directrunner)

#### Instructions for running the Java pipeline on released Beam (Beam 2.43.0 and later).

* Checkout the Beam examples Maven archetype for the relevant Beam version.

```
export BEAM_VERSION=<Beam version>

mvn archetype:generate \
-DarchetypeGroupId=org.apache.beam \
-DarchetypeArtifactId=beam-sdks-java-maven-archetypes-examples \
-DarchetypeVersion=$BEAM_VERSION \
-DgroupId=org.example \
-DartifactId=multi-language-beam \
-Dversion="0.1" \
-Dpackage=org.apache.beam.examples \
-DinteractiveMode=false
```

* Run the pipeline.

```
export GCP_PROJECT=<GCP project>
export GCP_BUCKET=<GCP bucket>
export GCP_REGION=<GCP region>

mvn compile exec:java -Dexec.mainClass=org.apache.beam.examples.multilanguage.SklearnMnistClassification \
-Dexec.args="--runner=DataflowRunner --project=$GCP_PROJECT \
--region=us-central1 \
--gcpTempLocation=gs://$GCP_BUCKET/multi-language-beam/tmp \
--output=gs://$GCP_BUCKET/multi-language-beam/output" \
-Pdataflow-runner
```

* Inspect the output. Each line has data separated by a comma ",". The first item is the actual label of
the digit. The second item is the predicted label of the digit.

```
gsutil cat gs://$GCP_BUCKET/multi-language-beam/output*
```

#### Instructions for running the Java pipeline at HEAD (Beam 2.41.0 and 2.42.0).

* Make sure that Docker is installed and available on your system.

* Build and push Python and Java Docker containers.

```
export DOCKER_ROOT=<Docker root>

./gradlew :sdks:python:container:py38:docker -Pdocker-repository-root=$DOCKER_ROOT -Pdocker-tag=latest

docker push $DOCKER_ROOT/beam_python3.8_sdk:latest

./gradlew :sdks:java:container:java11:docker -Pdocker-repository-root=$DOCKER_ROOT -Pdocker-tag=latest

docker push $DOCKER_ROOT/beam_java11_sdk:latest
```

* Run the pipeline using the following Gradle command (this guide assumes Dataflow runner).
Note that we override both the Java and Python SDK harness containers here.

```
export GCP_PROJECT=<GCP project>
export GCP_BUCKET=<GCP bucket>
export GCP_REGION=<GCP region>

./gradlew :examples:multi-language:sklearnMinstClassification --args=" \
--runner=DataflowRunner \
--project=$GCP_PROJECT \
--gcpTempLocation=gs://$GCP_BUCKET/multi-language-beam/tmp \
--output=gs://$GCP_BUCKET/multi-language-beam/output \
--sdkContainerImage=$DOCKER_ROOT/beam_java11_sdk:latest \
--sdkHarnessContainerImageOverrides=.*python.*,$DOCKER_ROOT/beam_python3.8_sdk:latest \
--region=${GCP_REGION}"
```

* Inspect the output. Each line has data separated by a comma ",". The first item is the actual label
of the digit. The second item is the predicted label of the digit.

```
gsutil cat gs://$GCP_BUCKET/multi-language-beam/output*
```
10 changes: 9 additions & 1 deletion examples/multi-language/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ ext.summary = "Java Classes for Multi-language Examples"
dependencies {
implementation library.java.vendored_guava_26_0_jre
implementation project(path: ":sdks:java:core", configuration: "shadow")
runtimeOnly project(path: ":examples:java")
runtimeOnly project(path: ":runners:direct-java", configuration: "shadow")
runtimeOnly project(path: ":runners:google-cloud-dataflow-java")
runtimeOnly project(path: ":runners:portability:java")
Expand All @@ -47,4 +48,11 @@ task pythonDataframeWordCount(type: JavaExec) {
description "Run the Java word count example using external Python DataframeTransform"
mainClass = "org.apache.beam.examples.multilanguage.PythonDataframeWordCount"
classpath = sourceSets.main.runtimeClasspath
}
}

task sklearnMinstClassification(type: JavaExec) {
description "Run the Java pipeline that performns image classification on handwritten digits from the MNIST database"
mainClass = "org.apache.beam.examples.multilanguage.SklearnMnistClassification"
classpath = sourceSets.main.runtimeClasspath
}

Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@
* ./gradlew :examples:multi-language:pythonDataframeWordCount --args=" \
* --runner=DataflowRunner \
* --output=gs://{$OUTPUT_BUCKET}/count \
* --experiments=use_runner_v2 \
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you mean to remove this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, we don't need this anymore.

* --sdkHarnessContainerImageOverrides=.*python.*,gcr.io/apache-beam-testing/beam-sdk/beam_python{$PYTHON_VERSION}_sdk:latest"
* }</pre>
*/
Expand Down
10 changes: 10 additions & 0 deletions sdks/java/maven-archetypes/examples/generate-sources.sh
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,16 @@ rsync -a \
"${EXAMPLES_ROOT}"/src/test/java/org/apache/beam/examples/complete/game/ \
"${ARCHETYPE_ROOT}/src/test/java/complete/game"

#
# Copy the Java multi-language examples.
#

mkdir -p "${ARCHETYPE_ROOT}/src/test/java/multilanguage/"

rsync -a \
"${EXAMPLES_ROOT}"/src/main/java/org/apache/beam/examples/multilanguage/ \
"${ARCHETYPE_ROOT}/src/main/java/multilanguage"

#
# Replace 'package org.apache.beam.examples' with 'package ${package}' in all Java code
#
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,13 @@
<version>${beam.version}</version>
</dependency>

<!-- Adds a dependency on the Python Multi-language pipelines API module. -->
<dependency>
<groupId>org.apache.beam</groupId>
<artifactId>beam-sdks-java-extensions-python</artifactId>
<version>${beam.version}</version>
</dependency>

<!-- Dependencies below this line are specific dependencies needed by the examples code. -->
<dependency>
<groupId>com.google.api-client</groupId>
Expand Down