Skip to content

Commit

Permalink
Add example code to save JAX program and run using C++ runtime.
Browse files Browse the repository at this point in the history
  • Loading branch information
zhangqiaorjc committed Jan 12, 2021
1 parent 3ac809e commit f36f1bf
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 0 deletions.
21 changes: 21 additions & 0 deletions examples/jax_cpp/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
load(
"@org_tensorflow//tensorflow:tensorflow.bzl",
"tf_cc_binary",
)

tf_cc_binary(
name = "main",
srcs = ["main.cc"],
deps = [
"@org_tensorflow//tensorflow/compiler/xla/tools:hlo_module_loader",
"@org_tensorflow//tensorflow/compiler/xla/pjrt:pjrt_client",
"@org_tensorflow//tensorflow/compiler/xla/pjrt:cpu_device",
"@org_tensorflow//tensorflow/compiler/jit:xla_cpu_jit",
"@org_tensorflow//tensorflow/compiler/xla:status",
"@org_tensorflow//tensorflow/compiler/xla:statusor",
"@org_tensorflow//tensorflow/compiler/xla:literal",
"@org_tensorflow//tensorflow/compiler/xla:literal_util",
"@org_tensorflow//tensorflow/compiler/xla:shape_util",
"@org_tensorflow//tensorflow/compiler/xla/service:hlo_proto_cc",
],
)
67 changes: 67 additions & 0 deletions examples/jax_cpp/main.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
// An example for reading a HloModule from a HloProto file and execute the
// module on PJRT CPU client.

#include <memory>
#include <string>
#include <vector>

#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/pjrt/cpu_device.h"
#include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
#include "tensorflow/compiler/xla/status.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/tools/hlo_module_loader.h"
#include "tensorflow/core/platform/init_main.h"
#include "tensorflow/core/platform/logging.h"

int main(int argc, char** argv) {
tensorflow::port::InitMain("", &argc, &argv);

// Load HloModule from file.
std::string hlo_filename = "/tmp/fn_hlo.txt";
std::function<void(xla::HloModuleConfig*)> config_modifier_hook =
[](xla::HloModuleConfig* config) { config->set_seed(42); };
std::unique_ptr<xla::HloModule> test_module =
LoadModuleFromFile(hlo_filename, xla::hlo_module_loader_details::Config(),
"txt", config_modifier_hook)
.ValueOrDie();
const xla::HloModuleProto test_module_proto = test_module->ToProto();

// Run it using JAX C++ Runtime (PJRT).

// Get a CPU client.
std::unique_ptr<xla::PjRtClient> client =
xla::GetCpuClient(/*asynchronous=*/true).ValueOrDie();

// Compile XlaComputation to PjRtExecutable.
xla::XlaComputation xla_computation(test_module_proto);
xla::CompileOptions compile_options;
std::unique_ptr<xla::PjRtExecutable> executable =
client->Compile(xla_computation, compile_options).ValueOrDie();

// Prepare inputs.
xla::Literal literal_x =
xla::LiteralUtil::CreateR2<float>({{1.0f, 2.0f}, {3.0f, 4.0f}});
xla::Literal literal_y =
xla::LiteralUtil::CreateR2<float>({{1.0f, 1.0f}, {1.0f, 1.0f}});
std::unique_ptr<xla::PjRtBuffer> param_x =
client->BufferFromHostLiteral(literal_x, client->local_devices()[0])
.ValueOrDie();
std::unique_ptr<xla::PjRtBuffer> param_y =
client->BufferFromHostLiteral(literal_y, client->local_devices()[0])
.ValueOrDie();

// Execute on CPU.
xla::ExecuteOptions execute_options;
// One vector<buffer> for each device.
std::vector<std::vector<std::unique_ptr<xla::PjRtBuffer>>> results =
executable->Execute({{param_x.get(), param_y.get()}}, execute_options)
.ValueOrDie();

// Get result.
std::shared_ptr<xla::Literal> result_literal =
results[0][0]->ToLiteral().ValueOrDie();
LOG(INFO) << "result = " << *result_literal;
return 0;
}
4 changes: 4 additions & 0 deletions examples/jax_cpp/prog.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
import jax.numpy as jnp

def fn(x, y, z):
return jnp.dot(x, y) / z

0 comments on commit f36f1bf

Please sign in to comment.