From f36f1bf3a154cbe660b6519df17a113e9280b2be Mon Sep 17 00:00:00 2001 From: Qiao Zhang Date: Tue, 12 Jan 2021 14:02:32 -0800 Subject: [PATCH] Add example code to save JAX program and run using C++ runtime. --- examples/jax_cpp/BUILD | 21 +++++++++++++ examples/jax_cpp/main.cc | 67 ++++++++++++++++++++++++++++++++++++++++ examples/jax_cpp/prog.py | 4 +++ 3 files changed, 92 insertions(+) create mode 100644 examples/jax_cpp/BUILD create mode 100644 examples/jax_cpp/main.cc create mode 100644 examples/jax_cpp/prog.py diff --git a/examples/jax_cpp/BUILD b/examples/jax_cpp/BUILD new file mode 100644 index 000000000000..a36a8c5bef8d --- /dev/null +++ b/examples/jax_cpp/BUILD @@ -0,0 +1,21 @@ +load( + "@org_tensorflow//tensorflow:tensorflow.bzl", + "tf_cc_binary", +) + +tf_cc_binary( + name = "main", + srcs = ["main.cc"], + deps = [ + "@org_tensorflow//tensorflow/compiler/xla/tools:hlo_module_loader", + "@org_tensorflow//tensorflow/compiler/xla/pjrt:pjrt_client", + "@org_tensorflow//tensorflow/compiler/xla/pjrt:cpu_device", + "@org_tensorflow//tensorflow/compiler/jit:xla_cpu_jit", + "@org_tensorflow//tensorflow/compiler/xla:status", + "@org_tensorflow//tensorflow/compiler/xla:statusor", + "@org_tensorflow//tensorflow/compiler/xla:literal", + "@org_tensorflow//tensorflow/compiler/xla:literal_util", + "@org_tensorflow//tensorflow/compiler/xla:shape_util", + "@org_tensorflow//tensorflow/compiler/xla/service:hlo_proto_cc", + ], +) diff --git a/examples/jax_cpp/main.cc b/examples/jax_cpp/main.cc new file mode 100644 index 000000000000..5f5528107e21 --- /dev/null +++ b/examples/jax_cpp/main.cc @@ -0,0 +1,67 @@ +// An example for reading a HloModule from a HloProto file and execute the +// module on PJRT CPU client. + +#include +#include +#include + +#include "tensorflow/compiler/xla/literal.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/pjrt/cpu_device.h" +#include "tensorflow/compiler/xla/pjrt/pjrt_client.h" +#include "tensorflow/compiler/xla/status.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/tools/hlo_module_loader.h" +#include "tensorflow/core/platform/init_main.h" +#include "tensorflow/core/platform/logging.h" + +int main(int argc, char** argv) { + tensorflow::port::InitMain("", &argc, &argv); + + // Load HloModule from file. + std::string hlo_filename = "/tmp/fn_hlo.txt"; + std::function config_modifier_hook = + [](xla::HloModuleConfig* config) { config->set_seed(42); }; + std::unique_ptr test_module = + LoadModuleFromFile(hlo_filename, xla::hlo_module_loader_details::Config(), + "txt", config_modifier_hook) + .ValueOrDie(); + const xla::HloModuleProto test_module_proto = test_module->ToProto(); + + // Run it using JAX C++ Runtime (PJRT). + + // Get a CPU client. + std::unique_ptr client = + xla::GetCpuClient(/*asynchronous=*/true).ValueOrDie(); + + // Compile XlaComputation to PjRtExecutable. + xla::XlaComputation xla_computation(test_module_proto); + xla::CompileOptions compile_options; + std::unique_ptr executable = + client->Compile(xla_computation, compile_options).ValueOrDie(); + + // Prepare inputs. + xla::Literal literal_x = + xla::LiteralUtil::CreateR2({{1.0f, 2.0f}, {3.0f, 4.0f}}); + xla::Literal literal_y = + xla::LiteralUtil::CreateR2({{1.0f, 1.0f}, {1.0f, 1.0f}}); + std::unique_ptr param_x = + client->BufferFromHostLiteral(literal_x, client->local_devices()[0]) + .ValueOrDie(); + std::unique_ptr param_y = + client->BufferFromHostLiteral(literal_y, client->local_devices()[0]) + .ValueOrDie(); + + // Execute on CPU. + xla::ExecuteOptions execute_options; + // One vector for each device. + std::vector>> results = + executable->Execute({{param_x.get(), param_y.get()}}, execute_options) + .ValueOrDie(); + + // Get result. + std::shared_ptr result_literal = + results[0][0]->ToLiteral().ValueOrDie(); + LOG(INFO) << "result = " << *result_literal; + return 0; +} diff --git a/examples/jax_cpp/prog.py b/examples/jax_cpp/prog.py new file mode 100644 index 000000000000..a3708c90c75d --- /dev/null +++ b/examples/jax_cpp/prog.py @@ -0,0 +1,4 @@ +import jax.numpy as jnp + +def fn(x, y, z): + return jnp.dot(x, y) / z