Skip to content

Commit

Permalink
[Mosaic GPU] Avoid failing when importing profiler.py even if lib.mos…
Browse files Browse the repository at this point in the history
…aic_gpu is unavailable.

PiperOrigin-RevId: 647626620
  • Loading branch information
Google-ML-Automation authored and jax authors committed Jun 28, 2024
1 parent 3a21c81 commit ad4c9ab
Showing 1 changed file with 12 additions and 6 deletions.
18 changes: 12 additions & 6 deletions jax/experimental/mosaic/gpu/profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@

import jax
from jax._src.interpreters import mlir
from jax._src.lib import mosaic_gpu as mosaic_gpu_lib
from jax._src.lib import xla_client
import jax.numpy as jnp
from jaxlib.mlir import ir
Expand All @@ -33,14 +32,21 @@

from .utils import * # noqa: F403


try:
from jax._src.lib import mosaic_gpu as mosaic_gpu_lib

xla_client.register_custom_call_target(
"mosaic_gpu_record_event",
mosaic_gpu_lib._mosaic_gpu_ext._record_event_capsule(),
platform="CUDA",
)
except ImportError:
pass

# ruff: noqa: F405
# mypy: ignore-errors

xla_client.register_custom_call_target(
"mosaic_gpu_record_event",
mosaic_gpu_lib._mosaic_gpu_ext._record_event_capsule(),
platform="CUDA",
)

record_event_p = jax.core.Primitive("record_event")
record_event_p.multiple_results = True
Expand Down

0 comments on commit ad4c9ab

Please sign in to comment.