Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Examples of Running a JAX function in C++ #5337

Closed
drebain opened this issue Jan 7, 2021 · 14 comments · Fixed by #5383
Closed

Examples of Running a JAX function in C++ #5337

drebain opened this issue Jan 7, 2021 · 14 comments · Fixed by #5383
Assignees

Comments

@drebain
Copy link

drebain commented Jan 7, 2021

Are there any examples available of running jit functions defined in python from C++? I see that there is an interface for generating something usable by XLA but it is a bit unclear how to use the result of this when the function is dependent on variables/weights (e.g. a flax module).

@zhangqiaorjc I am told you have some knowledge of this?

Thanks

@zhangqiaorjc zhangqiaorjc self-assigned this Jan 7, 2021
@zhangqiaorjc
Copy link
Collaborator

zhangqiaorjc commented Jan 8, 2021

I'll show an example for a simple JAX function. Let me know if it suffices for your use case.

There are two steps involved here:

  1. Save a JAX program as an xla::HloModule
  2. Load the HloModule from file, and run it using the C++ Runtime.

Step 1: Use jax/tools/jax_to_hlo.py to save a JAX program.

Suppose we have a dummy JAX program jax/tools/prog.py,

import jax.numpy as jnp

def fn(x, y, z):
  return jnp.dot(x, y) / z

Let's convert it to HLO, with input shapes and constants provided (see usage in jax_to_hlo.py). The commandline options roughly correspond to jax.api.jit options.

$ python3 jax_to_hlo.py \
--fn prog.fn \
--input_shapes '[("x", "f32[2,2]"), ("y", "f32[2,2]")]' \
--constants '{"z": 2.0}' \
--hlo_text_dest /tmp/fn_hlo.txt \
--hlo_proto_dest /tmp/fn_hlo.pb

Pay special attention to the order of parameters specified in --input_shapes. If we had specified --input_shapes '[("y", "f32[2,2]"), ("x", "f32[2,2]")]', x and y would have flipped, even though they are named in the tuple.

Let's see the saved HloModule

$ cat /tmp/fn_hlo.txt
HloModule xla_computation_ordered_wrapper.9

ENTRY xla_computation_ordered_wrapper.9 {
  constant.3 = pred[] constant(false)
  parameter.1 = f32[2,2]{1,0} parameter(0)
  parameter.2 = f32[2,2]{1,0} parameter(1)
  dot.4 = f32[2,2]{1,0} dot(parameter.1, parameter.2), lhs_contracting_dims={1}, rhs_contracting_dims={0}
  constant.5 = f32[] constant(2)
  broadcast.6 = f32[2,2]{1,0} broadcast(constant.5), dimensions={}
  divide.7 = f32[2,2]{1,0} divide(dot.4, broadcast.6)
  ROOT tuple.8 = (f32[2,2]{1,0}) tuple(divide.7)
}

Note the single output with shape f32[2,2]{1,0}.

Step 2: Use PJRT runtime to run the saved HloModule with user-provided input values

Note that there are multiple C++ runtime APIs that we can use to run an HloModule -- HloRunner, LocalClient/LocalExecutable, or PJRT, which are all in the Tensorflow tree. Since I'm most familiar with JAX's runtime API, I'll show an example using PJRT (see tensorflow/compiler/xla/pjrt/pjrt_client.h).

Suppose we have the following BUILD and cc files

/Users/zhangqiaorjc/repo/tensorflow/tensorflow/compiler/xla/examples
└── jax_cpp
    ├── BUILD
    └── main.cc

The BUILD file looks like

load(
    "//tensorflow:tensorflow.bzl",
    "tf_cc_binary",
)

tf_cc_binary(
    name = "main",
    srcs = ["main.cc"],
    deps = [
      "//tensorflow/compiler/xla/tools:hlo_module_loader",
      "//tensorflow/compiler/xla/pjrt:pjrt_client",
      "//tensorflow/compiler/xla/pjrt:cpu_device",
      "//tensorflow/compiler/xla:status",
      "//tensorflow/compiler/xla:statusor",
      "//tensorflow/compiler/xla:literal",
      "//tensorflow/compiler/xla:literal_util",
      "//tensorflow/compiler/xla:shape_util",
      "//tensorflow/compiler/xla/service:hlo_proto_cc",
    ],
)

The cc file looks like

// An example for reading a HloModule from a HloProto file and execute the
// module on PJRT CPU client.

#include <memory>
#include <string>
#include <vector>

#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/pjrt/cpu_device.h"
#include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
#include "tensorflow/compiler/xla/status.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/tools/hlo_module_loader.h"
#include "tensorflow/core/platform/init_main.h"
#include "tensorflow/core/platform/logging.h"

int main(int argc, char** argv) {
  tensorflow::port::InitMain("", &argc, &argv);

  // Load HloModule from file.
  std::string hlo_filename = "/tmp/fn_hlo.txt";
  std::function<void(xla::HloModuleConfig*)> config_modifier_hook =
      [](xla::HloModuleConfig* config) { config->set_seed(42); };
  std::unique_ptr<xla::HloModule> test_module =
      LoadModuleFromFile(hlo_filename, xla::hlo_module_loader_details::Config(),
                         "txt", config_modifier_hook)
          .ValueOrDie();
  const xla::HloModuleProto test_module_proto = test_module->ToProto();

  // Run it using JAX C++ Runtime (PJRT).

  // Get a CPU client.
  std::unique_ptr<xla::PjRtClient> client =
      xla::GetCpuClient(/*asynchronous=*/true).ValueOrDie();

  // Compile XlaComputation to PjRtExecutable.
  xla::XlaComputation xla_computation(test_module_proto);
  xla::CompileOptions compile_options;
  std::unique_ptr<xla::PjRtExecutable> executable =
      client->Compile(xla_computation, compile_options).ValueOrDie();

  // Prepare inputs.
  xla::Literal literal_x =
      xla::LiteralUtil::CreateR2<float>({{1.0f, 2.0f}, {3.0f, 4.0f}});
  xla::Literal literal_y =
      xla::LiteralUtil::CreateR2<float>({{1.0f, 1.0f}, {1.0f, 1.0f}});
  std::unique_ptr<xla::PjRtBuffer> param_x =
      client->BufferFromHostLiteral(literal_x, client->local_devices()[0])
          .ValueOrDie();
  std::unique_ptr<xla::PjRtBuffer> param_y =
      client->BufferFromHostLiteral(literal_y, client->local_devices()[0])
          .ValueOrDie();

  // Execute on CPU.
  xla::ExecuteOptions execute_options;
  // One vector<buffer> for each device.
  std::vector<std::vector<std::unique_ptr<xla::PjRtBuffer>>> results =
      executable->Execute({{param_x.get(), param_y.get()}}, execute_options)
          .ValueOrDie();

  // Get result.
  std::shared_ptr<xla::Literal> result_literal =
      results[0][0]->ToLiteral().ValueOrDie();
  LOG(INFO) << "result = " << *result_literal;
  return 0;
}

To run it

$ bazel run -c opt :main

2021-01-07 17:30:23.472798: I tensorflow/compiler/xla/examples/jax_cpp/main.cc:69] result = (
f32[2,2] {
  { 1.5, 1.5 },
  { 3.5, 3.5 }
}
)

@gnecula
Copy link
Collaborator

gnecula commented Jan 8, 2021

Do you think it makes sense to link this from the FAQ?

@drebain
Copy link
Author

drebain commented Jan 8, 2021

Awesome! This will be very helpful, thank you.

I only have one question: if I understand correctly, the arguments to PjRtExecutable::Execute correspond to a list of arrays/tensors, so to run a neural network with weights you either need to pass the weights as individual tensor arguments or compile them into the HLO.

Is one of these approaches preferred, or would they be more or less equivalent in terms of performance?

Thanks again!

@zhangqiaorjc
Copy link
Collaborator

@gnecula we could turn this into a developer doc to make it more discoverable?

@drebain For the question on NN weights as constants or parameters, I believe making them constants will improve dispatch performance. There may be issues with increased memory consumption. I would recommend try both and measure performance and memory consumption.

Do let me know if you find any part of the example unclear, so I can make a developer doc that others can use (running JAX program via C++ API seems to have come up a few times)

Just a clarification on the PjRtExecutable::Execute signature since that part may be confusing to other readers.

  // Executes on devices addressable by the client. Requires executable has a
  // device_assignment and all devices in the device_assignment are addressable
  // by the client.
  virtual StatusOr<std::vector<std::vector<std::unique_ptr<PjRtBuffer>>>>
  Execute(absl::Span<const std::vector<PjRtBuffer*>> argument_handles,
          const ExecuteOptions& options) = 0;

The inner vector correspond to the parameters of the HLO computation. It's a Span<vector> because it's one set of parameters per device.

If it's only a single device, then you just pass in {{param_x.get(), param_y.get()}}. Note the double braces.

@zhangqiaorjc
Copy link
Collaborator

@drebain a bit of nuance on constants vs parameters, there seems to be some nontrivial tradeoffs (may depend on platforms too)

some XLA experts recommend turning small or scalars to constants while staying away from large constants

unfortunately, it's not easy to experiment with this, so you will have to play around with it to know

@drebain
Copy link
Author

drebain commented Jan 8, 2021

I see. I guess It makes sense that this would depend on the size of the constants. I will try both on a medium-sized MLP network and report back on which is faster for future reference.

@drebain
Copy link
Author

drebain commented Jan 15, 2021

Ok, as promised I have done some basic benchmarks and it seems that compiling network weights in as constants gets me a roughly 30% performance increase over passing them as arguments. This is with a 6-layer, 256-unit MLP, running on an RTX 3090.

@zhangqiaorjc
Copy link
Collaborator

For gpu, you may need the following changes

  1. header include
    #include "tensorflow/compiler/xla/pjrt/cpu_device.h"
    -->
    #include "tensorflow/compiler/xla/pjrt/gpu_device.h"

  2. xla::GetCpuClient --> xla::GetGpuClient

  3. BUILD dep change

"//tensorflow/compiler/xla/pjrt:cpu_device",
-->
"//tensorflow/compiler/xla/pjrt:gpu_device",

@zhangqiaorjc
Copy link
Collaborator

PJRT is the lowest level runtime API, and only advisable if one wants to avoid direct TF deps.

Otherwise, please use jax2tf and the usual TF C++ server for running a SavedModel.

https://github.com/google/jax/tree/main/jax/experimental/jax2tf

@lucagrementieri
Copy link

Thank you for the script, it's very useful!
Here, at the end of the script the result is printed, instead I would like to convert it to a Eigen::Matrix object to use it.

Do you know how to cast a xla::Literal to Eigen::Matrix?
Reading the header file for xla::Literal I have found no specific function to do it, but maybe I missed the relevant piece of code.

@xueeinstein
Copy link

@zhangqiaorjc Thank you for the example, it's very useful! I tried it using the newest codebase, but the output is

2022-06-18 22:26:10.013679: I jax_cpp/main.cc:68] result = (
f32[2,2] {
  { 2, 3 },
  { 2, 3 }
}
)

It looks like xla::LiteralUtil::CreateR2 creates column-major matrix now. It is different from your previous result.

Did this API change recently? Is there any building option to switch to row-major matrix like it in Python?

@joaospinto
Copy link
Contributor

joaospinto commented Jun 25, 2024

I'm trying to determine to what extent something like what @zhangqiaorjc shared above could be used to deploy JAX code into real-time/low-memory/restricted (e.g. no dynamic memory allocation) environments.

For example, would it be possible to "pre-compute" client->Compile, and auto-generate a Bazel library with functionality equivalent to the executable->Execute method of the xla::PjRtExecutable object?

Does the executable->Execute method perform any dynamic memory allocation?

@joaospinto
Copy link
Contributor

^ This is discussed here: #22184

@kranipa
Copy link

kranipa commented Jul 5, 2024

I am trying to build the JAX CPP example for GPU Backend. It runs into linking error.
I am using se_gpu_pjrt_client for jax cpp example.

Build file is as follows


load("@tsl//tsl/platform:rules_cc.bzl", "cc_library")
load("//xla:xla.bzl", "xla_cc_binary")


xla_cc_binary(
    name = "main_gpu",
    srcs = ["main_gpu.cc"],
  
    tags = [
        "manual",
    ],
    deps = [
        "//xla/pjrt/gpu:gpu_topology",
        "//xla/pjrt/gpu:se_gpu_pjrt_client",
        "//xla:literal",
        "//xla:literal_util",
        "//xla:shape_util",
        "//xla:status_macros",
        "//xla:xla_data_proto_cc",
        "//xla/client:xla_computation",
        "//xla/ffi",
        "//xla/ffi:ffi_api",
        "//xla/pjrt:host_memory_spaces",
        "//xla/pjrt:pjrt_client",
        "//xla/pjrt:pjrt_executable",
        "//xla/pjrt:pjrt_future",
        "//xla/pjrt:pjrt_stream_executor_client",
        "//xla/pjrt/distributed:in_memory_key_value_store",
        "//xla/service:gpu_plugin",
        "//xla/service:hlo_parser",
        "//xla/service:platform_util",
        "//xla/stream_executor",
        "//xla/stream_executor:executor_cache",
        "@com_google_absl//absl/log:check",
        "@com_google_absl//absl/status",
        "@com_google_absl//absl/status:statusor",
        "@com_google_absl//absl/strings",
        "@com_google_absl//absl/synchronization",
        "@com_google_absl//absl/time",
        "@tsl//tsl/platform:env",
        "@tsl//tsl/platform:errors",
        "@tsl//tsl/platform:status",
        "@tsl//tsl/platform:statusor",
        "//xla/tools:hlo_module_loader",
        "@tsl//tsl/platform:logging",
        "@tsl//tsl/platform:platform_port",
    ],
)

Build Commands

./configure.py --backend=CUDA --host_compiler=gcc --nccl

bazel build --verbose_failures xla/jax_pjrt_demo:main_gpu

Runs into following linking error

  external/local_config_cuda/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc @bazel-out/k8-opt/bin/xla/jax_pjrt_demo/main_gpu-2.params)
# Configuration: 2e0f9574b87d96e39f67c831c7dc6a1546625e87b8c3e3ca3e42b8605ed3a5c8
# Execution platform: @local_execution_config_platform//:platform
/usr/bin/ld: bazel-out/k8-opt/bin/xla/stream_executor/cuda/libcuda_platform_cuda_only.lo(cuda_platform.o): in function `stream_executor::gpu::CudaPlatform::~CudaPlatform()':
cuda_platform.cc:(.text._ZN15stream_executor3gpu12CudaPlatformD2Ev+0x21): undefined reference to `stream_executor::ExecutorCache::~ExecutorCache()'
/usr/bin/ld: bazel-out/k8-opt/bin/xla/stream_executor/cuda/libcuda_platform_cuda_only.lo(cuda_platform.o): in function `stream_executor::gpu::CudaPlatform::GetExecutor(stream_executor::StreamExecutorConfig const&)':
cuda_platform.cc:(.text._ZN15stream_executor3gpu12CudaPlatform11GetExecutorERKNS_20StreamExecutorConfigE+0x30): undefined reference to `stream_executor::ExecutorCache::Get(stream_executor::StreamExecutorConfig const&)'
/usr/bin/ld: cuda_platform.cc:(.text._ZN15stream_executor3gpu12CudaPlatform11GetExecutorERKNS_20StreamExecutorConfigE+0x79): undefined reference to `stream_executor::ExecutorCache::GetOrCreate(stream_executor::StreamExecutorConfig const&, std::function<absl::lts_20230802::StatusOr<std::unique_ptr<stream_executor::StreamExecutor, std::default_delete<stream_executor::StreamExecutor> > > ()> const&)'
/usr/bin/ld: bazel-out/k8-opt/bin/xla/stream_executor/cuda/libcuda_platform_cuda_only.lo(cuda_platform.o): in function `stream_executor::gpu::CudaPlatform::CudaPlatform()':
cuda_platform.cc:(.text._ZN15stream_executor3gpu12CudaPlatformC2Ev+0x43): undefined reference to `stream_executor::ExecutorCache::ExecutorCache()'
collect2: error: ld returned 1 exit status
Target //xla/jax_pjrt_demo:main_gpu failed to build

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

7 participants