From 098019337b01a7427d8393554c42751af9814123 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Wed, 4 Sep 2024 21:58:20 +0200 Subject: [PATCH] Automatically fetch and register PjRT Metal plugin --- deps/ReactantExtra/API.cpp | 26 +++++++++++++++++++++++++ src/XLA.jl | 39 +++++++++++++++++++++++++++++++++++++- 2 files changed, 64 insertions(+), 1 deletion(-) diff --git a/deps/ReactantExtra/API.cpp b/deps/ReactantExtra/API.cpp index 31915c71..221e4493 100644 --- a/deps/ReactantExtra/API.cpp +++ b/deps/ReactantExtra/API.cpp @@ -183,6 +183,32 @@ extern "C" PjRtClient* MakeTPUClient(const char* tpu_path , const char** error) return GetCApiClient("TPU"); } +const char* const kEnvMetalLibraryPath = "METAL_LIBRARY_PATH"; + +extern "C" PjRtClient* MakeMetalClient(const char* libpath, const char** error) { + // Prefer $METAL_LIBRARY_PATH if set + std::string metal_library_path; + if (auto path = llvm::sys::Process::GetEnv(kEnvMetalLibraryPath)) { + metal_library_path = *path; + } else if (libpath) { + metal_library_path = std::string(libpath); + } else { + *error = "Could not find Metal path"; + return nullptr; + } + + const PJRT_Api* pluginLoad = LoadPjrtPlugin("metal", metal_library_path.c_str(), error); + if (pluginLoad == nullptr) + return nullptr; + + + auto metal_status = InitializePjrtPlugin("metal", error); + if (metal_status) + return nullptr; + + return GetCApiClient("METAL"); +} + extern "C" int ClientNumDevices(PjRtClient* client) { return client->device_count(); } diff --git a/src/XLA.jl b/src/XLA.jl index 684511e6..02f08446 100644 --- a/src/XLA.jl +++ b/src/XLA.jl @@ -85,6 +85,16 @@ function TPUClient(tpu_path::String) return Client(client) end +function MetalClient(libpath::String) + f = Libdl.dlsym(Reactant_jll.libReactantExtra_handle, "MakeMetalClient") + refstr = Ref{Cstring}() + client = ccall(f, Ptr{Cvoid}, (Cstring, Ptr{Cstring}), libpath, refstr) + if client == C_NULL + throw(AssertionError(unsafe_string(refstr[]))) + end + return Client(client) +end + const backends = Dict{String,Client}() const default_backend = Ref{Client}() const default_device_idx = Ref{Int}(0) @@ -100,7 +110,34 @@ function __init__() backends["cpu"] = cpu default_backend[] = cpu - @static if !Sys.isapple() + @static if Sys.isapple() + metaldir = @get_scratch!("pjrt-plugin-metal") + if !isfile(metaldir * "/pjrt_plugin_metal_14.dylib") + Downloads.download( + if Sys.ARCH === :aarch64 + "https://files.pythonhosted.org/packages/80/af/ed482a421a868726e7ca3f51ac19b0c9a8e37f33f54413312c37e9056acc/jax_metal-0.1.0-py3-none-macosx_11_0_arm64.whl" + else + "https://files.pythonhosted.org/packages/51/6a/1c0e2d07d92c6583e874ef2bbf4382662a3469bbb661d885eeaaddca426f/jax_metal-0.1.0-py3-none-macosx_10_14_x86_64.whl" + end, + joinpath(metaldir, "pjrt-plugin-metal.zip"), + ) + run(`unzip -qq $(metaldir*"/pjrt-plugin-metal.zip") -d $(metaldir)/tmp`) + run( + `mv $(metaldir)/tmp/jax_plugins/metal_plugin/pjrt_plugin_metal_14.dylib $(metaldir)/pjrt_plugin_metal_14.dylib`, + ) + rm(metaldir * "/tmp"; recursive=true) + rm(metaldir * "/pjrt-plugin-metal.zip"; recursive=true) + end + + try + metal = MetalClient(metaldir * "/pjrt_plugin_metal_14.dylib") + backends["metal"] = metal + # NOTE Float64, ComplexF64, ComplexF128 not yet supported, so don't make it default + # default_backend[] = metal + catch e + println(stdout, e) + end + else if isfile("/usr/lib/libtpu.so") dataset_dir = @get_scratch!("libtpu") if !isfile(dataset_dir * "/libtpu.so")