Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to use the PJRT C API? #7038

Open
hayden-donnelly opened this issue Nov 15, 2023 · 25 comments
Open

How to use the PJRT C API? #7038

hayden-donnelly opened this issue Nov 15, 2023 · 25 comments
Labels
question Further information is requested

Comments

@hayden-donnelly
Copy link

hayden-donnelly commented Nov 15, 2023

I'm trying to use the PJRT C API so I can execute an HLO module, but I don't understand the instructions. The integration guide says I can either implement the C API or implement the C++ API and use it to wrap the C API. The former seems less complicated so that's what I'm trying to do. So far I've built the pjrt_c_api_cpu library with bazel build --strip=never //xla/pjrt/c:pjrt_c_api_cpu and then linked it with the following cmake file:

cmake_minimum_required(VERSION 3.10)
project(execute_hlo CXX)

set(PJRT_C_API_LIB xla/bazel-bin/xla/pjrt/c/libpjrt_c_api_cpu.a)
set(XLA_INCLUDE xla)

add_executable(execute_hlo execute_hlo.cpp)
target_link_libraries(execute_hlo "${CMAKE_SOURCE_DIR}/${PJRT_C_API_LIB}")
target_include_directories(execute_hlo PRIVATE "${CMAKE_SOURCE_DIR}/${XLA_INCLUDE}")

set_target_properties(execute_hlo PROPERTIES RUNTIME_OUTPUT_DIRECTORY /build/)

set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++17")
set(CMAKE_BUILD_TYPE Release)

add_custom_command(TARGET execute_hlo POST_BUILD COMMAND ${CMAKE_COMMAND} -E copy $<TARGET_FILE:execute_hlo> ${CMAKE_SOURCE_DIR}/build/)

After that I created execute_hlo.cpp to load my HLO module into a PJRT_Program:

#include <memory>
#include <string>
#include <vector>
#include <fstream>
#include <streambuf>
#include <sstream> 
#include <iostream>
#include "xla/pjrt/c/pjrt_c_api.h"

int main(int argc, char** argv) 
{
    std::ifstream t("hlo_comp_proto.txt");
    std::stringstream buffer;
    buffer << t.rdbuf();
    std::string hlo_code = buffer.str();
    std::string format = "hlo";
    
    PJRT_Program pjrt_program;
    pjrt_program.code = (char*)hlo_code.c_str();
    pjrt_program.code_size = (size_t)hlo_code.size();
    pjrt_program.format = format.c_str();
    pjrt_program.format_size = (size_t)format.size();

    std::cout << "HLO Code:\n\n" << pjrt_program.code << "\n\n";
    std::cout << "Code size: " << pjrt_program.code_size << "\n";
    std::cout << "Format: " << pjrt_program.format << "\n";
    std::cout << "Format size: " << pjrt_program.format_size << "\n";
    return 0;
}

This works, but I think I've hit a dead end since I don't see any functions that will let me do anything with this struct. After looking around a bit it seems like I'm supposed to use some implementation of GetPjrtApi() that returns a PJRT_Api pointer, but I only see C++ implementations. Maybe that means the C -> C++ wrapper is the simpler approach? If it is, which targets would I have to build to use it? Or if it's not, what am I doing wrong with my current approach?

@hayden-donnelly
Copy link
Author

So I was able to get something running by following this example and building my whole project with bazel.

My directory looks like this:

BUILD
WORKSPACE
execute_hlo.cpp
hlo_comp_text.txt
xla/

Where BUILD is:

cc_binary(
    name = "execute_hlo",
    srcs = ["execute_hlo.cpp"],
    deps = [
        "@xla//xla:literal",
        "@xla//xla:literal_util",
        "@xla//xla:shape_util",
        "@xla//xla:status",
        "@xla//xla:statusor",
        "@xla//xla/pjrt:pjrt_client",
        "@xla//xla/pjrt:tfrt_cpu_pjrt_client",
        "@xla//xla/service:hlo_proto_cc",
        "@xla//xla/tools:hlo_module_loader",
    ],
)

WORKSPACE is:

local_repository(
    name = "xla", 
    path = "./xla",
)

load("@xla//:workspace4.bzl", "xla_workspace4")
xla_workspace4()

load("@xla//:workspace3.bzl", "xla_workspace3")
xla_workspace3()

load("@xla//:workspace2.bzl", "xla_workspace2")
xla_workspace2()

load("@xla//:workspace1.bzl", "xla_workspace1")
xla_workspace1()

load("@xla//:workspace0.bzl", "xla_workspace0")
xla_workspace0()

execute_hlo.cpp is:

#include <iostream>
#include "xla/literal.h"
#include "xla/literal_util.h"
#include "xla/pjrt/pjrt_client.h"
#include "xla/pjrt/tfrt_cpu_pjrt_client.h"
#include "xla/status.h"
#include "xla/statusor.h"
#include "xla/tools/hlo_module_loader.h"

int main(int argc, char** argv) 
{
    // Load HloModule from file.
    std::string hlo_filename = "hlo_comp_text.txt";
    std::function<void(xla::HloModuleConfig*)> config_modifier_hook =
        [](xla::HloModuleConfig* config) { config->set_seed(42); };
    std::unique_ptr<xla::HloModule> test_module = 
        LoadModuleFromFile(
            hlo_filename, xla::hlo_module_loader_details::Config(),
            "txt", config_modifier_hook
        ).value();
    const xla::HloModuleProto test_module_proto = test_module->ToProto();

    // Run it using JAX C++ Runtime (PJRT).

    // Get a CPU client.
    std::unique_ptr<xla::PjRtClient> client = xla::GetTfrtCpuClient(true).value();

    // Compile XlaComputation to PjRtExecutable.
    xla::XlaComputation xla_computation(test_module_proto);
    xla::CompileOptions compile_options;
    std::unique_ptr<xla::PjRtLoadedExecutable> executable =
        client->Compile(xla_computation, compile_options).value();

    // Input.
    xla::Literal literal_x;
    std::unique_ptr<xla::PjRtBuffer> param_x;
    // Output.
    std::vector<std::vector<std::unique_ptr<xla::PjRtBuffer>>> results;
    std::shared_ptr<xla::Literal> result_literal;
    xla::ExecuteOptions execute_options;

    for(int i = 0; i < 1000; i++)
    {
        literal_x = xla::LiteralUtil::CreateR2<float>({{0.1f * (float)i}});
        param_x = client->BufferFromHostLiteral(
            literal_x, client->addressable_devices()[0]
        ).value();

        results = executable->Execute({{param_x.get()}}, execute_options).value();
        result_literal = results[0][0]->ToLiteralSync().value();
        std::cout << "Result " << i << " = " << *result_literal << "\n";
    }
    
    return 0;
}

hlo_comp_text.txt is:

HloModule xla_computation_target_fn, entry_computation_layout={(f32[1]{0})->(f32[1]{0})}

ENTRY main.4 {
  Arg_0.1 = f32[1]{0} parameter(0)
  multiply.2 = f32[1]{0} multiply(Arg_0.1, Arg_0.1)
  ROOT tuple.3 = (f32[1]{0}) tuple(multiply.2)
}

And xla/ is a clone of this repo.

Now I'm trying to figure out how I can accomplish the same thing, but with cmake instead of bazel. Ideally I'd be able to build all the dependency targets with bazel, then link them to my build of execute_hlo with cmake, but the includes were all messed up when I tried that.

@cheshire
Copy link
Contributor

@hawkinsp @skye any pointers?

@cheshire
Copy link
Contributor

but with cmake instead of bazel

I don't think we support cmake build system yet sadly

@skye
Copy link
Contributor

skye commented Nov 17, 2023

The integration guide is for people implementing the PJRT C API, e.g. if you're a hardware vendor and want to create a plugin that can be used by JAX or PyTorch:XLA. It sounds like you would like to load + use an existing implementation? It also sounds like you're trying to run on CPU (vs. GPU, TPU, or some other platform).

We unfortunately do not have a prepackaged CPU PJRT plugin. We do have the beginnings of a bazel rule for building one, as you've found:

xla/xla/pjrt/c/BUILD

Lines 143 to 152 in 875b014

cc_library(
name = "pjrt_c_api_cpu",
srcs = ["pjrt_c_api_cpu.cc"],
hdrs = ["pjrt_c_api_cpu.h"],
visibility = ["//visibility:public"],
deps = [
":pjrt_c_api_cpu_internal",
":pjrt_c_api_hdrs",
],
)
You would have to use bazel to build a .so shared library containing the plugin implementation, and then your execute_hlo.cpp binary can dynamically link or dlopen the plugin and call GetPjRtApi. Then execute_hlo.cpp can be built separately however you want, e.g. with cmake. Like @cheshire says, we're not currently planning to support a cmake build of the CPU plugin implementation itself.

We are working on making a fully-packaged GPU plugin, including building the shared library.

xla/xla/pjrt/c/BUILD

Lines 194 to 215 in 875b014

# PJRT GPU plugin. Can be configured to be built for CUDA or ROCM.
xla_cc_binary(
name = "pjrt_c_api_gpu_plugin.so",
linkopts = [
"-Wl,--version-script,$(location :pjrt_c_api_gpu_version_script.lds)",
"-Wl,--no-undefined",
],
linkshared = True,
tags = [
"noasan",
"nomac", # No GPU on mac.
"nomsan",
"notsan",
],
deps = [
":pjrt_c_api_gpu",
":pjrt_c_api_gpu_version_script.lds",
"//xla/service:gpu_plugin",
] + if_cuda_is_configured([
"//xla/stream_executor:cuda_platform",
]),
)

You should be able to use this as a starting point for how to build a CPU shared library. If you end up doing this, considering upstreaming a PR :) Otherwise we'll eventually create a packaged CPU plugin you can use (via dynamic linking), but we don't have a timeline for this yet.

Alternatively, you could build your entire application with bazel, and let it statically link everything together instead of using a separate plugin shared library.

Hope this helps! I'm out for the next week, but let me know if you have more questions on any of this and I can try to help when I'm back.

@hayden-donnelly
Copy link
Author

The integration guide is for people implementing the PJRT C API, e.g. if you're a hardware vendor and want to create a plugin that can be used by JAX or PyTorch:XLA.

Ah that makes sense.

Yeah I started off on CPU to reduce complexity, but I plan to move to GPU eventually. I'll take a look at the GPU shared library target, and when I have some time I'll look at building a CPU shared library. Thanks for the pointers :)

Also for anyone who comes across this and is in a similar position that I was, I put my work into this repo.

@skye
Copy link
Contributor

skye commented Nov 28, 2023

@jyingl3 do you have a bazel command handy for building the CUDA PJRT .so file? I'm guessing something like:
bazel build -c opt --config=cuda //xla/pjrt/c:pjrt_c_api_gpu_plugin.so, and then the resulting .so file should be in bazel-bin/xla/pjrt/c

@penpornk penpornk added the question Further information is requested label Feb 29, 2024
@Rian-Jo
Copy link

Rian-Jo commented Apr 2, 2024

same here. i want to create the package for pjrt library and its headers to use in c++ project. do you have any updates? or can you please give me any help to have the library/headers package? Too difficult to implement to my cmake based c++ project...

@mooskagh
Copy link
Member

mooskagh commented Apr 2, 2024

To use C API, you only need the pjrt_c_api.h header, and the pjrt_c_api_gpu_plugin.so file, it's possible to use in a project that doesn't use bazel.

For C++, as far as I understand, it's more complicated. It's possible to create PjRtCApiClient from the PJRT_Api received from the PjRT plugin, but that header has lots of dependencies inside XLA, so without including entire XLA as bazel subproject, it's not very easy. It's doable though: the elixir-nx project managed to extract all the required headers and has pre-build libraries.

You can see an example how the PjRT API is used in the LeelaChessZero project. There, they implemented their own minimal C++ wrappers to the C API. There's also instructions how to build the PjRT plugin.

@Rian-Jo
Copy link

Rian-Jo commented Apr 4, 2024

@mooskagh Thank you so much for your guide. I built pjrt_c_api_gpu_plugin.so and can access pjrt api through pjrt_c_api.h .

Then, i have one more question in use of HLO module. LeelaChessZero project converted onnx model to hlo module however i used flax to generate the model and hlo text has been obtained through xla_computation. Differently with lc0 project, i tried to load HLO module from my hlo file but no any compatible apis are given from pjrt_c_api.h. I may need hlo_module_loader? to use LoadModuleFromFile? and other functions as shown in

// Load HloModule from file.
    std::string hlo_filename = "hlo_comp_text.txt";
    std::function<void(xla::HloModuleConfig*)> config_modifier_hook =
        [](xla::HloModuleConfig* config) { config->set_seed(42); };
    std::unique_ptr<xla::HloModule> test_module = 
        LoadModuleFromFile(
            hlo_filename, xla::hlo_module_loader_details::Config(),
            "txt", config_modifier_hook
        ).value();
    const xla::HloModuleProto test_module_proto = test_module->ToProto();

Please give me some guides for this.

Thank you.

@mooskagh
Copy link
Member

mooskagh commented Apr 4, 2024

@mooskagh Thank you so much for your guide. I built pjrt_c_api_gpu_plugin.so and can access pjrt api through pjrt_c_api.h .

Then, i have one more question in use of HLO module. LeelaChessZero project converted onnx model to hlo module however i used flax to generate the model and hlo text has been obtained through xla_computation.

How do you export the module from flax?
If you have jax.xla_computation, just call as_serialized_hlo_module_proto() instead of as_hlo_text(), and it will be in the format that you can pass to PjRT.

@Rian-Jo
Copy link

Rian-Jo commented Apr 5, 2024

@moonskagh Thank you for your help. I have produced proto object with jax.xla_computation(fn).as_serialized_hlo_module_proto(), then saved it as 'module.pb'. In addition, i tried to call the saved module with 'PJRT_Error* error = api_->PJRT_Client_Compile(&args);' but i got 'failed to deserialize HloModuleProto'. My module is structured such as

HloModule xla_computation_policy, entry_computation_layout={(f32[36]{0}, f32[2]{0})->(f32[1]{0})}

silu.3 {
  Arg_0.4 = f32[32]{0} parameter(0)
  constant.5 = f32[] constant(1)
  broadcast.6 = f32[32]{0} broadcast(constant.5), dimensions={}
  negate.7 = f32[32]{0} negate(Arg_0.4)
  exponential.8 = f32[32]{0} exponential(negate.7)
  add.9 = f32[32]{0} add(exponential.8, broadcast.6)
  divide.10 = f32[32]{0} divide(broadcast.6, add.9)
  ROOT multiply.11 = f32[32]{0} multiply(Arg_0.4, divide.10)
}
...

ENTRY main.97 {
  Arg_0.1 = f32[36]{0} parameter(0)
  Arg_1.2 = f32[2]{0} parameter(1)
  call.95 = f32[1]{0} call(Arg_0.1, Arg_1.2), to_apply=policy.60
  ROOT tuple.96 = (f32[1]{0}) tuple(call.95)
}

How can i solve deserialization failure? or how can i load my hlo to pjrt?

Thank you.

@mooskagh
Copy link
Member

mooskagh commented Apr 5, 2024

@moonskagh Thank you for your help. I have produced proto object with jax.xla_computation(fn).as_serialized_hlo_module_proto(), then saved it as 'module.pb'. In addition, i tried to call the saved module with 'PJRT_Error* error = api_->PJRT_Client_Compile(&args);' but i got 'failed to deserialize HloModuleProto'. My module is structured such as

Everything looks correct, so it should work. Maybe you saved the module in text mode (although it shouldn't matter in Linux), or didn't fill it's size in the PJRT_Program structure?

You can check whether the protobuf file has the correct format by trying to convert it to a text dump by running:

cat <yourfile.pb> | protoc -I <path_to_xla_repository> <path_to_xla_repository>/xla/service/hlo.proto --decode=xla.HloModuleProto

UPD: My another guess is that you pass a filename to the PjRT function rather than its contents. You have to load the file yourself into the buffer and then pass the buffer to the PjRT.

@Rian-Jo
Copy link

Rian-Jo commented Apr 8, 2024

@mooskagh Thank you so much. I fixed to pass file contents from filename now and checked protoc decoding that the text dump is shown. In addition, i can call apis and checked version info and so on.

Then, if i parse the hlo_comp data by using ParseFromString function as below, PJRT_Program struct and PJRT_Client_Compile_Args are defined automatically? or do i need to set by myself from reading text? (it is difficult to know the shapes, ids, and layouts..)

In addition, would you please give me small example or procedure for setting the hlo computation and input-output handlings? I attach the hlo_comp.pb file just in case.
hlo_comp.zip

Thank you.

@Rian-Jo
Copy link

Rian-Jo commented Apr 19, 2024

@mooskagh With your comments above, in some trials, i could run my hlo module by execute function and got return shaped of std::vector<std::vector<std::unique_ptr>>.

If i get the result pjrtbuffer with tuple type such as 'f32[1] {-1}' , how can i decompose the value -1 from the buffer?

above result from execute function is extracted by below as swhon in 'se_gpu_pjrt_client_test.cc'

// Given the result of a PjrtExecutable::Execute call (TF-status of vectors of
// vectors), extract the zeroth result from the zeroth device.
absl::StatusOr<std::shared_ptr<xla::Literal>> ExtractSingleResult(
    absl::StatusOr<std::vector<std::vector<std::unique_ptr<xla::PjRtBuffer>>>> &
        result)
{
    TF_RETURN_IF_ERROR(result.status());
    TF_RET_CHECK(result->size() == 1);
    std::vector<std::unique_ptr<xla::PjRtBuffer>> &result_buffers = (*result)[0];
    TF_RET_CHECK(result_buffers.size() == 1);
    auto literal_or = result_buffers[0]->ToLiteralSync();
    if (!literal_or.status().ok()) return literal_or.status();
    return *literal_or;
}

Thank you always.

@skye
Copy link
Contributor

skye commented Apr 21, 2024

@Rian-Jo did you end up using the PJRT C++ or C API?

If C++, then you can use PjRtBuffer::ToLiteralSync like in the test file.

If you're using the C API directly, use PJRT_Buffer_ToHostBuffer.

Let me know which one you're doing and I can provide more detail if needed.

@Rian-Jo
Copy link

Rian-Jo commented Apr 22, 2024

@skye Thank you for your reply.

In the beginning, i have had a trouble in using of xla functions and libraries, so i created pjrt c api with some changes from below.

xla/xla/pjrt/c/BUILD

Lines 194 to 215 in 875b014

# PJRT GPU plugin. Can be configured to be built for CUDA or ROCM.
xla_cc_binary(
name = "pjrt_c_api_gpu_plugin.so",
linkopts = [
"-Wl,--version-script,$(location :pjrt_c_api_gpu_version_script.lds)",
"-Wl,--no-undefined",
],
linkshared = True,
tags = [
"noasan",
"nomac", # No GPU on mac.
"nomsan",
"notsan",
],
deps = [
":pjrt_c_api_gpu",
":pjrt_c_api_gpu_version_script.lds",
"//xla/service:gpu_plugin",
] + if_cuda_is_configured([
"//xla/stream_executor:cuda_platform",
]),
)

Then, i have tried to use c api based on LeelaChessZero project but using input/output buffers is still problem.

Currently, i have checked executing hlo with combination of pjrt library in tensorflow/xla but i still need and want to use only c api because of hardware implementation issues.

Now, I have checked and been trying to use PJRT_Client_BufferFromHostBuffer and PJRT_Buffer_ToHostBuffer to input and output. But if you provide the details, I appreciate your advice.

Thank you.

@Rian-Jo
Copy link

Rian-Jo commented Apr 25, 2024

@skye in lc0 project below,

https://github.com/LeelaChessZero/lc0/blob/7fce117e4ef8f1c7f4e94838cd11db25769e3fb9/src/neural/xla/xla_runner.cc#L164-L168

they mentioned "Only one input is kinda supported". Is it true?

Thank you.

@mooskagh
Copy link
Member

@skye in lc0 project below,

LeelaChessZero/lc0@7fce117/src/neural/xla/xla_runner.cc#L164-L168

they mentioned "Only one input is kinda supported". Is it true?

No, PjRT support multiple arguments, it's just that Lc0 only needs to pass one, so their wrapper doesn't support more. The Lc0 XLA wrapper is quite specific to the project, so I'm not sure it makes sense to look there much. The pjrt.h wrapper is quite isolated though and may be a good example of the PjRT C usage.

Actually Lc0 does call PjRT with many arguments, it's just the most of them are weights that don't change from call to call, and one tensor is a real input that is different for every call (i.e. a single element of the input_buffers vector is updated here before the call)

@janpfeifer
Copy link
Contributor

I was trying to reproduce @hayden-donnelly example, by creating the same file structure (WORKSPACE, BUILD, execute_hlo.cpp and hlo_comp_text.txt files, and xla repository cloned under xla/), but it consistently fails to build with random Bazel issues.

After doing bazel build --experimental_repo_remote_exec :execute_hlo, I'm getting:

ERROR: /home/janpf/Projects/gopjrt/c/BUILD:1:10: While resolving toolchains for target //:execute_hlo (5c86453): invalid registered toolchain '@local_config_python//:py_toolchain': error loading package '@@local_config_python//': Unable to find package for @@python//:defs.bzl: The repository '@@python' could not be resolved: Repository '@@python' is not defined.
ERROR: Analysis of target '//:execute_hlo' failed; build aborted

What do you use to compile bazel project that depends on XLA ? Maybe something missing I have to add to a bazelrc ?

thanks!!

ps.: If I create a much longer WORKSPACE file, using http_archive to a fixed version (as opposed to manually cloning the repo locally and using local_repository) it works.

@joaospinto
Copy link

I'm running into the same issue as @janpfeifer , even loading xla as an http_archive.

@janpfeifer
Copy link
Contributor

janpfeifer commented Jul 2, 2024

hi @joaospinto I'm just finishing gopjrt, Go wrapper for PJRT and XlaBuilder

For XlaBuilder it took me some effort, but I think I managed to build the minimal BUILD file needed to link XlaBuilder (and related libraries) to be able to build an HLO program to use with PJRT. See WORSPACE and minimal BUILD files under the c/ and c/gomlx/xlabuilder sub-directories. For my project I create a minimal C-to-C++ wrapper, needed for Go. But you can use it for linking some other C/C++ program.

For PRJT I just copied over pjrt_c_api.h file and didn't directly statically link anything, I just used (dlopen, etc.) the plugins directly. Took some work figuring out how it work, but let me know if you have any questions on that.

(edit)

I don't recall exactly, but I think the issue went away once I updated my WORSPACE with more of the XLA internal rules. In particular the following (taken from gopjrt's WORKSPACE file):

...
# Initialize hermetic Python
load("@xla//third_party/py:python_init_rules.bzl", "python_init_rules")

python_init_rules()

load("@xla//third_party/py:python_init_repositories.bzl", "python_init_repositories")

python_init_repositories(
    requirements = {
        "3.11": "//:requirements_lock_3_11.txt",
    },
)

load("@xla//third_party/py:python_init_toolchains.bzl", "python_init_toolchains")

python_init_toolchains()

load("@xla//third_party/py:python_init_pip.bzl", "python_init_pip")

python_init_pip()

load("@pypi//:requirements.bzl", "install_deps")

install_deps()
...

Bazel doesn't make these things easy 😞

@joaospinto
Copy link

Thanks @janpfeifer , adding all that did work. I ended up using essentially what the JAX repository is doing to depend on XLA, which is very similar to what you wrote.

@joaospinto
Copy link

Created a simple repository going over this, to save other folks from fighting with build systems in the future: https://github.com/joaospinto/call_jax_from_cpp/blob/main/cpp/main_hlo.cpp

Planning on also adding an example where an AOT-compiled binary is used instead.

@joaospinto
Copy link

hi @joaospinto I'm just finishing gopjrt, Go wrapper for PJRT and XlaBuilder

For XlaBuilder it took me some effort, but I think I managed to build the minimal BUILD file needed to link XlaBuilder (and related libraries) to be able to build an HLO program to use with PJRT. See WORSPACE and minimal BUILD files under the c/ and c/gomlx/xlabuilder sub-directories. For my project I create a minimal C-to-C++ wrapper, needed for Go. But you can use it for linking some other C/C++ program.

For PRJT I just copied over pjrt_c_api.h file and didn't directly statically link anything, I just used (dlopen, etc.) the plugins directly. Took some work figuring out how it work, but let me know if you have any questions on that.

(edit)

I don't recall exactly, but I think the issue went away once I updated my WORSPACE with more of the XLA internal rules. In particular the following (taken from gopjrt's WORKSPACE file):

...
# Initialize hermetic Python
load("@xla//third_party/py:python_init_rules.bzl", "python_init_rules")

python_init_rules()

load("@xla//third_party/py:python_init_repositories.bzl", "python_init_repositories")

python_init_repositories(
    requirements = {
        "3.11": "//:requirements_lock_3_11.txt",
    },
)

load("@xla//third_party/py:python_init_toolchains.bzl", "python_init_toolchains")

python_init_toolchains()

load("@xla//third_party/py:python_init_pip.bzl", "python_init_pip")

python_init_pip()

load("@pypi//:requirements.bzl", "install_deps")

install_deps()
...

Bazel doesn't make these things easy 😞

Created a GitHub Issue on the XLA repo about this: #15024

@janpfeifer
Copy link
Contributor

hi @joaospinto , btw, let us know if you manage to get AOT compilation to work (for CPU I'm assuming ?), I'd love to add support to that as well.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

No branches or pull requests

8 participants