Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Automatically fetch and register PjRT Metal plugin #99

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions deps/ReactantExtra/API.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,32 @@ extern "C" PjRtClient* MakeTPUClient(const char* tpu_path , const char** error)
return GetCApiClient("TPU");
}

const char* const kEnvMetalLibraryPath = "METAL_LIBRARY_PATH";

extern "C" PjRtClient* MakeMetalClient(const char* libpath, const char** error) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we avoid making a separate api change atm since the oher loadpjrt plugins have an api presently?

Just for the sake of setting it up without a jll bump

// Prefer $METAL_LIBRARY_PATH if set
std::string metal_library_path;
if (auto path = llvm::sys::Process::GetEnv(kEnvMetalLibraryPath)) {
metal_library_path = *path;
} else if (libpath) {
metal_library_path = std::string(libpath);
} else {
*error = "Could not find Metal path";
return nullptr;
}

const PJRT_Api* pluginLoad = LoadPjrtPlugin("metal", metal_library_path.c_str(), error);
if (pluginLoad == nullptr)
return nullptr;


auto metal_status = InitializePjrtPlugin("metal", error);
if (metal_status)
return nullptr;

return GetCApiClient("METAL");
}

extern "C" int ClientNumDevices(PjRtClient* client) {
return client->device_count();
}
Expand Down
39 changes: 38 additions & 1 deletion src/XLA.jl
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,16 @@ function TPUClient(tpu_path::String)
return Client(client)
end

function MetalClient(libpath::String)
f = Libdl.dlsym(Reactant_jll.libReactantExtra_handle, "MakeMetalClient")
refstr = Ref{Cstring}()
client = ccall(f, Ptr{Cvoid}, (Cstring, Ptr{Cstring}), libpath, refstr)
if client == C_NULL
throw(AssertionError(unsafe_string(refstr[])))
end
return Client(client)
end

const backends = Dict{String,Client}()
const default_backend = Ref{Client}()
const default_device_idx = Ref{Int}(0)
Expand All @@ -100,7 +110,34 @@ function __init__()
backends["cpu"] = cpu
default_backend[] = cpu

@static if !Sys.isapple()
@static if Sys.isapple()
metaldir = @get_scratch!("pjrt-plugin-metal")
if !isfile(metaldir * "/pjrt_plugin_metal_14.dylib")
Downloads.download(
if Sys.ARCH === :aarch64
"https://files.pythonhosted.org/packages/80/af/ed482a421a868726e7ca3f51ac19b0c9a8e37f33f54413312c37e9056acc/jax_metal-0.1.0-py3-none-macosx_11_0_arm64.whl"
else
"https://files.pythonhosted.org/packages/51/6a/1c0e2d07d92c6583e874ef2bbf4382662a3469bbb661d885eeaaddca426f/jax_metal-0.1.0-py3-none-macosx_10_14_x86_64.whl"
end,
joinpath(metaldir, "pjrt-plugin-metal.zip"),
)
run(`unzip -qq $(metaldir*"/pjrt-plugin-metal.zip") -d $(metaldir)/tmp`)
run(
`mv $(metaldir)/tmp/jax_plugins/metal_plugin/pjrt_plugin_metal_14.dylib $(metaldir)/pjrt_plugin_metal_14.dylib`,
)
rm(metaldir * "/tmp"; recursive=true)
rm(metaldir * "/pjrt-plugin-metal.zip"; recursive=true)
end

try
metal = MetalClient(metaldir * "/pjrt_plugin_metal_14.dylib")
backends["metal"] = metal
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We might as well just label this as gpu

# NOTE Float64, ComplexF64, ComplexF128 not yet supported, so don't make it default
# default_backend[] = metal
catch e
println(stdout, e)
end
else
if isfile("/usr/lib/libtpu.so")
dataset_dir = @get_scratch!("libtpu")
if !isfile(dataset_dir * "/libtpu.so")
Expand Down
Loading