-
Notifications
You must be signed in to change notification settings - Fork 433
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
How to use the PJRT C API? #7038
Comments
So I was able to get something running by following this example and building my whole project with bazel. My directory looks like this:
Where cc_binary(
name = "execute_hlo",
srcs = ["execute_hlo.cpp"],
deps = [
"@xla//xla:literal",
"@xla//xla:literal_util",
"@xla//xla:shape_util",
"@xla//xla:status",
"@xla//xla:statusor",
"@xla//xla/pjrt:pjrt_client",
"@xla//xla/pjrt:tfrt_cpu_pjrt_client",
"@xla//xla/service:hlo_proto_cc",
"@xla//xla/tools:hlo_module_loader",
],
)
local_repository(
name = "xla",
path = "./xla",
)
load("@xla//:workspace4.bzl", "xla_workspace4")
xla_workspace4()
load("@xla//:workspace3.bzl", "xla_workspace3")
xla_workspace3()
load("@xla//:workspace2.bzl", "xla_workspace2")
xla_workspace2()
load("@xla//:workspace1.bzl", "xla_workspace1")
xla_workspace1()
load("@xla//:workspace0.bzl", "xla_workspace0")
xla_workspace0()
#include <iostream>
#include "xla/literal.h"
#include "xla/literal_util.h"
#include "xla/pjrt/pjrt_client.h"
#include "xla/pjrt/tfrt_cpu_pjrt_client.h"
#include "xla/status.h"
#include "xla/statusor.h"
#include "xla/tools/hlo_module_loader.h"
int main(int argc, char** argv)
{
// Load HloModule from file.
std::string hlo_filename = "hlo_comp_text.txt";
std::function<void(xla::HloModuleConfig*)> config_modifier_hook =
[](xla::HloModuleConfig* config) { config->set_seed(42); };
std::unique_ptr<xla::HloModule> test_module =
LoadModuleFromFile(
hlo_filename, xla::hlo_module_loader_details::Config(),
"txt", config_modifier_hook
).value();
const xla::HloModuleProto test_module_proto = test_module->ToProto();
// Run it using JAX C++ Runtime (PJRT).
// Get a CPU client.
std::unique_ptr<xla::PjRtClient> client = xla::GetTfrtCpuClient(true).value();
// Compile XlaComputation to PjRtExecutable.
xla::XlaComputation xla_computation(test_module_proto);
xla::CompileOptions compile_options;
std::unique_ptr<xla::PjRtLoadedExecutable> executable =
client->Compile(xla_computation, compile_options).value();
// Input.
xla::Literal literal_x;
std::unique_ptr<xla::PjRtBuffer> param_x;
// Output.
std::vector<std::vector<std::unique_ptr<xla::PjRtBuffer>>> results;
std::shared_ptr<xla::Literal> result_literal;
xla::ExecuteOptions execute_options;
for(int i = 0; i < 1000; i++)
{
literal_x = xla::LiteralUtil::CreateR2<float>({{0.1f * (float)i}});
param_x = client->BufferFromHostLiteral(
literal_x, client->addressable_devices()[0]
).value();
results = executable->Execute({{param_x.get()}}, execute_options).value();
result_literal = results[0][0]->ToLiteralSync().value();
std::cout << "Result " << i << " = " << *result_literal << "\n";
}
return 0;
}
And Now I'm trying to figure out how I can accomplish the same thing, but with cmake instead of bazel. Ideally I'd be able to build all the dependency targets with bazel, then link them to my build of |
I don't think we support cmake build system yet sadly |
The integration guide is for people implementing the PJRT C API, e.g. if you're a hardware vendor and want to create a plugin that can be used by JAX or PyTorch:XLA. It sounds like you would like to load + use an existing implementation? It also sounds like you're trying to run on CPU (vs. GPU, TPU, or some other platform). We unfortunately do not have a prepackaged CPU PJRT plugin. We do have the beginnings of a bazel rule for building one, as you've found: Lines 143 to 152 in 875b014
execute_hlo.cpp binary can dynamically link or dlopen the plugin and call GetPjRtApi . Then execute_hlo.cpp can be built separately however you want, e.g. with cmake. Like @cheshire says, we're not currently planning to support a cmake build of the CPU plugin implementation itself.
We are working on making a fully-packaged GPU plugin, including building the shared library. Lines 194 to 215 in 875b014
You should be able to use this as a starting point for how to build a CPU shared library. If you end up doing this, considering upstreaming a PR :) Otherwise we'll eventually create a packaged CPU plugin you can use (via dynamic linking), but we don't have a timeline for this yet. Alternatively, you could build your entire application with bazel, and let it statically link everything together instead of using a separate plugin shared library. Hope this helps! I'm out for the next week, but let me know if you have more questions on any of this and I can try to help when I'm back. |
Ah that makes sense. Yeah I started off on CPU to reduce complexity, but I plan to move to GPU eventually. I'll take a look at the GPU shared library target, and when I have some time I'll look at building a CPU shared library. Thanks for the pointers :) Also for anyone who comes across this and is in a similar position that I was, I put my work into this repo. |
@jyingl3 do you have a bazel command handy for building the CUDA PJRT .so file? I'm guessing something like: |
same here. i want to create the package for pjrt library and its headers to use in c++ project. do you have any updates? or can you please give me any help to have the library/headers package? Too difficult to implement to my cmake based c++ project... |
To use C API, you only need the pjrt_c_api.h header, and the For C++, as far as I understand, it's more complicated. It's possible to create PjRtCApiClient from the You can see an example how the PjRT API is used in the LeelaChessZero project. There, they implemented their own minimal C++ wrappers to the C API. There's also instructions how to build the PjRT plugin. |
@mooskagh Thank you so much for your guide. I built pjrt_c_api_gpu_plugin.so and can access pjrt api through pjrt_c_api.h . Then, i have one more question in use of HLO module. LeelaChessZero project converted onnx model to hlo module however i used flax to generate the model and hlo text has been obtained through xla_computation. Differently with lc0 project, i tried to load HLO module from my hlo file but no any compatible apis are given from pjrt_c_api.h. I may need hlo_module_loader? to use LoadModuleFromFile? and other functions as shown in
Please give me some guides for this. Thank you. |
How do you export the module from flax? |
@moonskagh Thank you for your help. I have produced proto object with jax.xla_computation(fn).as_serialized_hlo_module_proto(), then saved it as 'module.pb'. In addition, i tried to call the saved module with 'PJRT_Error* error = api_->PJRT_Client_Compile(&args);' but i got 'failed to deserialize HloModuleProto'. My module is structured such as
How can i solve deserialization failure? or how can i load my hlo to pjrt? Thank you. |
Everything looks correct, so it should work. Maybe you saved the module in text mode (although it shouldn't matter in Linux), or didn't fill it's size in the You can check whether the protobuf file has the correct format by trying to convert it to a text dump by running:
UPD: My another guess is that you pass a filename to the PjRT function rather than its contents. You have to load the file yourself into the buffer and then pass the buffer to the PjRT. |
@mooskagh Thank you so much. I fixed to pass file contents from filename now and checked protoc decoding that the text dump is shown. In addition, i can call apis and checked version info and so on. Then, if i parse the hlo_comp data by using In addition, would you please give me small example or procedure for setting the hlo computation and input-output handlings? I attach the hlo_comp.pb file just in case. Thank you. |
@mooskagh With your comments above, in some trials, i could run my hlo module by execute function and got return shaped of std::vector<std::vector<std::unique_ptr>>. If i get the result pjrtbuffer with tuple type such as 'f32[1] {-1}' , how can i decompose the value -1 from the buffer? above result from execute function is extracted by below as swhon in 'se_gpu_pjrt_client_test.cc'
Thank you always. |
@Rian-Jo did you end up using the PJRT C++ or C API? If C++, then you can use If you're using the C API directly, use PJRT_Buffer_ToHostBuffer. Let me know which one you're doing and I can provide more detail if needed. |
@skye Thank you for your reply. In the beginning, i have had a trouble in using of xla functions and libraries, so i created pjrt c api with some changes from below.
Then, i have tried to use c api based on LeelaChessZero project but using input/output buffers is still problem. Currently, i have checked executing hlo with combination of pjrt library in tensorflow/xla but i still need and want to use only c api because of hardware implementation issues. Now, I have checked and been trying to use PJRT_Client_BufferFromHostBuffer and PJRT_Buffer_ToHostBuffer to input and output. But if you provide the details, I appreciate your advice. Thank you. |
@skye in lc0 project below, they mentioned "Only one input is kinda supported". Is it true? Thank you. |
No, PjRT support multiple arguments, it's just that Lc0 only needs to pass one, so their wrapper doesn't support more. The Lc0 XLA wrapper is quite specific to the project, so I'm not sure it makes sense to look there much. The Actually Lc0 does call PjRT with many arguments, it's just the most of them are weights that don't change from call to call, and one tensor is a real input that is different for every call (i.e. a single element of the |
I was trying to reproduce @hayden-donnelly example, by creating the same file structure (WORKSPACE, BUILD, execute_hlo.cpp and hlo_comp_text.txt files, and xla repository cloned under After doing
What do you use to compile bazel project that depends on XLA ? Maybe something missing I have to add to a bazelrc ? thanks!! ps.: If I create a much longer WORKSPACE file, using |
I'm running into the same issue as @janpfeifer , even loading |
hi @joaospinto I'm just finishing gopjrt, Go wrapper for PJRT and XlaBuilder For XlaBuilder it took me some effort, but I think I managed to build the minimal BUILD file needed to link XlaBuilder (and related libraries) to be able to build an HLO program to use with PJRT. See WORSPACE and minimal BUILD files under the For PRJT I just copied over pjrt_c_api.h file and didn't directly statically link anything, I just used ( (edit) I don't recall exactly, but I think the issue went away once I updated my WORSPACE with more of the XLA internal rules. In particular the following (taken from
Bazel doesn't make these things easy 😞 |
Thanks @janpfeifer , adding all that did work. I ended up using essentially what the JAX repository is doing to depend on XLA, which is very similar to what you wrote. |
Created a simple repository going over this, to save other folks from fighting with build systems in the future: https://github.com/joaospinto/call_jax_from_cpp/blob/main/cpp/main_hlo.cpp Planning on also adding an example where an AOT-compiled binary is used instead. |
Created a GitHub Issue on the XLA repo about this: #15024 |
hi @joaospinto , btw, let us know if you manage to get AOT compilation to work (for CPU I'm assuming ?), I'd love to add support to that as well. |
I'm trying to use the PJRT C API so I can execute an HLO module, but I don't understand the instructions. The integration guide says I can either implement the C API or implement the C++ API and use it to wrap the C API. The former seems less complicated so that's what I'm trying to do. So far I've built the
pjrt_c_api_cpu
library withbazel build --strip=never //xla/pjrt/c:pjrt_c_api_cpu
and then linked it with the following cmake file:After that I created
execute_hlo.cpp
to load my HLO module into aPJRT_Program
:This works, but I think I've hit a dead end since I don't see any functions that will let me do anything with this struct. After looking around a bit it seems like I'm supposed to use some implementation of
GetPjrtApi()
that returns aPJRT_Api
pointer, but I only see C++ implementations. Maybe that means the C -> C++ wrapper is the simpler approach? If it is, which targets would I have to build to use it? Or if it's not, what am I doing wrong with my current approach?The text was updated successfully, but these errors were encountered: