-
Notifications
You must be signed in to change notification settings - Fork 2.8k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add example code to save JAX program and run using C++ runtime.
- Loading branch information
1 parent
3ac809e
commit f36f1bf
Showing
3 changed files
with
92 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
load( | ||
"@org_tensorflow//tensorflow:tensorflow.bzl", | ||
"tf_cc_binary", | ||
) | ||
|
||
tf_cc_binary( | ||
name = "main", | ||
srcs = ["main.cc"], | ||
deps = [ | ||
"@org_tensorflow//tensorflow/compiler/xla/tools:hlo_module_loader", | ||
"@org_tensorflow//tensorflow/compiler/xla/pjrt:pjrt_client", | ||
"@org_tensorflow//tensorflow/compiler/xla/pjrt:cpu_device", | ||
"@org_tensorflow//tensorflow/compiler/jit:xla_cpu_jit", | ||
"@org_tensorflow//tensorflow/compiler/xla:status", | ||
"@org_tensorflow//tensorflow/compiler/xla:statusor", | ||
"@org_tensorflow//tensorflow/compiler/xla:literal", | ||
"@org_tensorflow//tensorflow/compiler/xla:literal_util", | ||
"@org_tensorflow//tensorflow/compiler/xla:shape_util", | ||
"@org_tensorflow//tensorflow/compiler/xla/service:hlo_proto_cc", | ||
], | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
// An example for reading a HloModule from a HloProto file and execute the | ||
// module on PJRT CPU client. | ||
|
||
#include <memory> | ||
#include <string> | ||
#include <vector> | ||
|
||
#include "tensorflow/compiler/xla/literal.h" | ||
#include "tensorflow/compiler/xla/literal_util.h" | ||
#include "tensorflow/compiler/xla/pjrt/cpu_device.h" | ||
#include "tensorflow/compiler/xla/pjrt/pjrt_client.h" | ||
#include "tensorflow/compiler/xla/status.h" | ||
#include "tensorflow/compiler/xla/statusor.h" | ||
#include "tensorflow/compiler/xla/tools/hlo_module_loader.h" | ||
#include "tensorflow/core/platform/init_main.h" | ||
#include "tensorflow/core/platform/logging.h" | ||
|
||
int main(int argc, char** argv) { | ||
tensorflow::port::InitMain("", &argc, &argv); | ||
|
||
// Load HloModule from file. | ||
std::string hlo_filename = "/tmp/fn_hlo.txt"; | ||
std::function<void(xla::HloModuleConfig*)> config_modifier_hook = | ||
[](xla::HloModuleConfig* config) { config->set_seed(42); }; | ||
std::unique_ptr<xla::HloModule> test_module = | ||
LoadModuleFromFile(hlo_filename, xla::hlo_module_loader_details::Config(), | ||
"txt", config_modifier_hook) | ||
.ValueOrDie(); | ||
const xla::HloModuleProto test_module_proto = test_module->ToProto(); | ||
|
||
// Run it using JAX C++ Runtime (PJRT). | ||
|
||
// Get a CPU client. | ||
std::unique_ptr<xla::PjRtClient> client = | ||
xla::GetCpuClient(/*asynchronous=*/true).ValueOrDie(); | ||
|
||
// Compile XlaComputation to PjRtExecutable. | ||
xla::XlaComputation xla_computation(test_module_proto); | ||
xla::CompileOptions compile_options; | ||
std::unique_ptr<xla::PjRtExecutable> executable = | ||
client->Compile(xla_computation, compile_options).ValueOrDie(); | ||
|
||
// Prepare inputs. | ||
xla::Literal literal_x = | ||
xla::LiteralUtil::CreateR2<float>({{1.0f, 2.0f}, {3.0f, 4.0f}}); | ||
xla::Literal literal_y = | ||
xla::LiteralUtil::CreateR2<float>({{1.0f, 1.0f}, {1.0f, 1.0f}}); | ||
std::unique_ptr<xla::PjRtBuffer> param_x = | ||
client->BufferFromHostLiteral(literal_x, client->local_devices()[0]) | ||
.ValueOrDie(); | ||
std::unique_ptr<xla::PjRtBuffer> param_y = | ||
client->BufferFromHostLiteral(literal_y, client->local_devices()[0]) | ||
.ValueOrDie(); | ||
|
||
// Execute on CPU. | ||
xla::ExecuteOptions execute_options; | ||
// One vector<buffer> for each device. | ||
std::vector<std::vector<std::unique_ptr<xla::PjRtBuffer>>> results = | ||
executable->Execute({{param_x.get(), param_y.get()}}, execute_options) | ||
.ValueOrDie(); | ||
|
||
// Get result. | ||
std::shared_ptr<xla::Literal> result_literal = | ||
results[0][0]->ToLiteral().ValueOrDie(); | ||
LOG(INFO) << "result = " << *result_literal; | ||
return 0; | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
import jax.numpy as jnp | ||
|
||
def fn(x, y, z): | ||
return jnp.dot(x, y) / z |