-
Notifications
You must be signed in to change notification settings - Fork 2.8k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add example code to save JAX program and run using C++ runtime.
- Loading branch information
1 parent
3ac809e
commit d8faa44
Showing
3 changed files
with
158 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
# Copyright 2018 Google LLC | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
licenses(["notice"]) # Apache 2 | ||
|
||
load( | ||
"@org_tensorflow//tensorflow:tensorflow.bzl", | ||
"tf_cc_binary", | ||
) | ||
|
||
tf_cc_binary( | ||
name = "main", | ||
srcs = ["main.cc"], | ||
deps = [ | ||
"@org_tensorflow//tensorflow/compiler/jit:xla_cpu_jit", | ||
"@org_tensorflow//tensorflow/compiler/xla:literal", | ||
"@org_tensorflow//tensorflow/compiler/xla:literal_util", | ||
"@org_tensorflow//tensorflow/compiler/xla:shape_util", | ||
"@org_tensorflow//tensorflow/compiler/xla:status", | ||
"@org_tensorflow//tensorflow/compiler/xla:statusor", | ||
"@org_tensorflow//tensorflow/compiler/xla/pjrt:cpu_device", | ||
"@org_tensorflow//tensorflow/compiler/xla/pjrt:pjrt_client", | ||
"@org_tensorflow//tensorflow/compiler/xla/service:hlo_proto_cc", | ||
"@org_tensorflow//tensorflow/compiler/xla/tools:hlo_module_loader", | ||
], | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,102 @@ | ||
/* Copyright 2019 Google LLC | ||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
http://www.apache.org/licenses/LICENSE-2.0 | ||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
==============================================================================*/ | ||
|
||
// An example for reading a HloModule from a HloProto file and execute the | ||
// module on PJRT CPU client. | ||
// | ||
// To build a HloModule, | ||
// | ||
// $ python3 jax/tools/jax_to_hlo.py \ | ||
// --fn examples.jax_cpp.prog.fn \ | ||
// --input_shapes '[("x", "f32[2,2]"), ("y", "f32[2,2]")]' \ | ||
// --constants '{"z": 2.0}' \ | ||
// --hlo_text_dest /tmp/fn_hlo.txt \ | ||
// --hlo_proto_dest /tmp/fn_hlo.pb | ||
// | ||
// To load and run the HloModule, | ||
// | ||
// $ bazel build examples/jax_cpp:main --experimental_repo_remote_exec --check_visibility=false | ||
// $ bazel-bin/examples/jax_cpp/main | ||
// 2021-01-12 15:35:28.316880: I examples/jax_cpp/main.cc:65] result = ( | ||
// f32[2,2] { | ||
// { 1.5, 1.5 }, | ||
// { 3.5, 3.5 } | ||
// } | ||
// ) | ||
|
||
#include <memory> | ||
#include <string> | ||
#include <vector> | ||
|
||
#include "tensorflow/compiler/xla/literal.h" | ||
#include "tensorflow/compiler/xla/literal_util.h" | ||
#include "tensorflow/compiler/xla/pjrt/cpu_device.h" | ||
#include "tensorflow/compiler/xla/pjrt/pjrt_client.h" | ||
#include "tensorflow/compiler/xla/status.h" | ||
#include "tensorflow/compiler/xla/statusor.h" | ||
#include "tensorflow/compiler/xla/tools/hlo_module_loader.h" | ||
#include "tensorflow/core/platform/init_main.h" | ||
#include "tensorflow/core/platform/logging.h" | ||
|
||
int main(int argc, char** argv) { | ||
tensorflow::port::InitMain("", &argc, &argv); | ||
|
||
// Load HloModule from file. | ||
std::string hlo_filename = "/tmp/fn_hlo.txt"; | ||
std::function<void(xla::HloModuleConfig*)> config_modifier_hook = | ||
[](xla::HloModuleConfig* config) { config->set_seed(42); }; | ||
std::unique_ptr<xla::HloModule> test_module = | ||
LoadModuleFromFile(hlo_filename, xla::hlo_module_loader_details::Config(), | ||
"txt", config_modifier_hook) | ||
.ValueOrDie(); | ||
const xla::HloModuleProto test_module_proto = test_module->ToProto(); | ||
|
||
// Run it using JAX C++ Runtime (PJRT). | ||
|
||
// Get a CPU client. | ||
std::unique_ptr<xla::PjRtClient> client = | ||
xla::GetCpuClient(/*asynchronous=*/true).ValueOrDie(); | ||
|
||
// Compile XlaComputation to PjRtExecutable. | ||
xla::XlaComputation xla_computation(test_module_proto); | ||
xla::CompileOptions compile_options; | ||
std::unique_ptr<xla::PjRtExecutable> executable = | ||
client->Compile(xla_computation, compile_options).ValueOrDie(); | ||
|
||
// Prepare inputs. | ||
xla::Literal literal_x = | ||
xla::LiteralUtil::CreateR2<float>({{1.0f, 2.0f}, {3.0f, 4.0f}}); | ||
xla::Literal literal_y = | ||
xla::LiteralUtil::CreateR2<float>({{1.0f, 1.0f}, {1.0f, 1.0f}}); | ||
std::unique_ptr<xla::PjRtBuffer> param_x = | ||
client->BufferFromHostLiteral(literal_x, client->local_devices()[0]) | ||
.ValueOrDie(); | ||
std::unique_ptr<xla::PjRtBuffer> param_y = | ||
client->BufferFromHostLiteral(literal_y, client->local_devices()[0]) | ||
.ValueOrDie(); | ||
|
||
// Execute on CPU. | ||
xla::ExecuteOptions execute_options; | ||
// One vector<buffer> for each device. | ||
std::vector<std::vector<std::unique_ptr<xla::PjRtBuffer>>> results = | ||
executable->Execute({{param_x.get(), param_y.get()}}, execute_options) | ||
.ValueOrDie(); | ||
|
||
// Get result. | ||
std::shared_ptr<xla::Literal> result_literal = | ||
results[0][0]->ToLiteral().ValueOrDie(); | ||
LOG(INFO) << "result = " << *result_literal; | ||
return 0; | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
# Copyright 2018 Google LLC | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
"""Example function to be jitted.""" | ||
import jax.numpy as jnp | ||
|
||
def fn(x, y, z): | ||
return jnp.dot(x, y) / z |