Skip to content

Commit

Permalink
Merge pull request #22703 from Rifur13:plugin-fix
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 657283607
  • Loading branch information
jax authors committed Jul 29, 2024
2 parents 9beb4f1 + 0224235 commit f070c06
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 6 deletions.
11 changes: 11 additions & 0 deletions jax/_src/hardware_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,11 @@
'0x005e',
]

_NVIDIA_GPU_DEVICES = [
'/dev/nvidia0',
'/dev/dxg', # WSL2
]

def num_available_tpu_chips_and_device_id():
"""Returns the device id and number of TPU chips attached through PCI."""
num_chips = 0
Expand All @@ -57,3 +62,9 @@ def tpu_enhanced_barrier_supported() -> bool:
"""Returns if tpu_enhanced_barrier flag is supported on this TPU version."""
_, device_id = num_available_tpu_chips_and_device_id()
return device_id in _TPU_ENHANCED_BARRIER_SUPPORTED


def has_visible_nvidia_gpu() -> bool:
"""True if there's a visible nvidia gpu available on device, False otherwise."""

return any(os.path.exists(d) for d in _NVIDIA_GPU_DEVICES)
10 changes: 4 additions & 6 deletions jax/_src/xla_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -881,6 +881,9 @@ def backends() -> dict[str, xla_client.Client]:
default_priority = -1000
for platform, priority, fail_quietly in platform_registrations:
try:
if platform == "cuda" and not hardware_utils.has_visible_nvidia_gpu():
continue

backend = _init_backend(platform)
_backends[platform] = backend

Expand Down Expand Up @@ -918,12 +921,7 @@ def _suggest_missing_backends():

assert _default_backend is not None
default_platform = _default_backend.platform
nvidia_gpu_devices = [
"/dev/nvidia0",
"/dev/dxg", # WSL2
]
if ("cuda" not in _backends and
any(os.path.exists(d) for d in nvidia_gpu_devices)):
if "cuda" not in _backends and hardware_utils.has_visible_nvidia_gpu():
if hasattr(xla_extension, "GpuAllocatorConfig") and "cuda" in _backend_errors:
err = _backend_errors["cuda"]
warning_msg = f"CUDA backend failed to initialize: {err}."
Expand Down

0 comments on commit f070c06

Please sign in to comment.