Skip to content

Commit

Permalink
[AutoPGLE] Explicitly ignore host callback pointers
Browse files Browse the repository at this point in the history
Before this change users had to specify remove_custom_partitioning_ptr_from_cache_key config flag when using AutoPGLE.

PiperOrigin-RevId: 694551120
  • Loading branch information
Google-ML-Automation committed Nov 8, 2024
1 parent 2b55bd5 commit ae03462
Show file tree
Hide file tree
Showing 5 changed files with 91 additions and 42 deletions.
28 changes: 19 additions & 9 deletions jax/_src/cache_key.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,8 @@ def get(module: ir.Module,
devices: np.ndarray,
compile_options: xla_client.CompileOptions,
backend: xla_client.Client,
compression_algorithm: str = "zstandard") -> str:
compression_algorithm: str = "zstandard",
ignore_host_callbacks: bool = False) -> str:
"""Creates a hashed string to use as a key to the compilation cache.
Creates a cache key that is a hex-encoded string of a unique hash based on
Expand All @@ -78,13 +79,17 @@ def get(module: ir.Module,
backend: description of the platform (e.g., TPU version)
compression_algorithm: a string representing the compression algorithm used
for the executable before persisting in the cache
ignore_host_callbacks: whether to remove the host callback pointer from
the computation. This does the same as
jax_remove_custom_partitioning_ptr_from_cache_key, but explicitally.
Typical return value example:
'jit__psum-14ac577cdb2ef6d986078b4054cc9893a9a14a16dbb0d8f37b89167c1f1aacdf'
"""
entries = [
("computation",
lambda hash_obj: _hash_computation(hash_obj, module)),
lambda hash_obj: _hash_computation(hash_obj, module,
ignore_host_callbacks)),
("jax_lib version",
lambda hash_obj: hash_obj.update(
bytes(jaxlib_version_str.encode("utf-8")))),
Expand Down Expand Up @@ -145,30 +150,35 @@ def _update_bc_attribute(op: ir.Operation) -> ir.WalkResult:
return m


def _serialize_ir(m: ir.Module) -> bytes:
def _serialize_ir(m: ir.Module, ignore_host_callbacks: bool) -> bytes:
output = io.BytesIO()
if config.remove_custom_partitioning_ptr_from_cache_key.value:
if (
ignore_host_callbacks
or config.remove_custom_partitioning_ptr_from_cache_key.value
):
m = _remove_custom_partitioning_ptr(type_cast(ir.Module,
m.operation.clone()))
m.operation.write_bytecode(file=output)
return output.getvalue()


def _canonicalize_ir(m_original: ir.Module) -> bytes:
def _canonicalize_ir(
m_original: ir.Module, ignore_host_callbacks: bool
) -> bytes:
with m_original.context:
m = type_cast(ir.Module, m_original.operation.clone())
passes = pm.PassManager.parse(
"builtin.module(strip-debuginfo)"
)
passes.run(m.operation)
return _serialize_ir(m)
return _serialize_ir(m, ignore_host_callbacks)


def _hash_computation(hash_obj, module):
def _hash_computation(hash_obj, module, ignore_host_callbacks):
if config.compilation_cache_include_metadata_in_key.value:
canonical_ir = _serialize_ir(module)
canonical_ir = _serialize_ir(module, ignore_host_callbacks)
else:
canonical_ir = _canonicalize_ir(module)
canonical_ir = _canonicalize_ir(module, ignore_host_callbacks)
hash_obj.update(canonical_ir)


Expand Down
6 changes: 4 additions & 2 deletions jax/_src/compilation_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,9 +268,11 @@ def put_executable_and_time(
def get_cache_key(module: ir.Module,
devices: np.ndarray,
compile_options,
backend) -> str:
backend,
ignore_host_callbacks: bool = False) -> str:
return cache_key.get(module, devices, compile_options, backend,
"zstandard" if zstandard is not None else "zlib")
"zstandard" if zstandard is not None else "zlib",
ignore_host_callbacks)


def is_initialized() -> bool:
Expand Down
3 changes: 2 additions & 1 deletion jax/_src/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,7 +466,8 @@ def _share_fdo_profiles(
compile_options.executable_build_options.fdo_profile = b""
profile_key = (
compilation_cache.get_cache_key(
computation, devices, compile_options, backend
computation, devices, compile_options, backend,
ignore_host_callbacks=True
)
+ "_fdo_sync"
)
Expand Down
13 changes: 13 additions & 0 deletions tests/cache_key_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,19 @@ def _cp_add(x, y):
for bc in bcs:
self.assertEqual(bc, "REMOVED")

with config.remove_custom_partitioning_ptr_from_cache_key(False):
compile_options = compiler.get_compile_options(
num_replicas=1, num_partitions=1
)
backend = xla_bridge.get_backend()
hash_without_callback_ptrs = cache_key.get(computation, devices,
compile_options,
backend,
ignore_host_callbacks=True)
expected_hash = cache_key.get(updated_module, devices, compile_options,
backend)
self.assertEqual(expected_hash, hash_without_callback_ptrs)

def test_different_device_assignment(self):
computation = jax.jit(lambda x, y: x + y).lower(1, 1).compiler_ir()
devices = np.array([[jax.local_devices()[0]]])
Expand Down
83 changes: 53 additions & 30 deletions tests/pgle_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from contextlib import ExitStack
from functools import partial
import glob
import logging
Expand All @@ -22,40 +23,69 @@

from absl.testing import absltest
import jax
from jax._src import api
from jax._src import compilation_cache as cc
from jax._src import config
from jax._src import profiler
from jax._src import pjit
from jax._src import monitoring
from jax._src import pjit
from jax._src import profiler
from jax._src import test_util as jtu
from jax._src import api
from jax.experimental import profiler as exp_profiler
import jax.numpy as jnp
from jax.sharding import NamedSharding, PartitionSpec
from jax._src import compilation_cache as cc
import numpy as np

from jax.experimental.serialize_executable import (
deserialize_and_load,
serialize,
)
import jax.numpy as jnp
from jax.sharding import NamedSharding, PartitionSpec
import numpy as np

jax.config.parse_flags_with_absl()

dump_dir = tempfile.TemporaryDirectory().name
os.environ['XLA_FLAGS'] = (
f'--xla_dump_to={dump_dir}'
' --xla_gpu_experimental_dump_fdo_profiles=true'
' --xla_gpu_enable_latency_hiding_scheduler=true'
)

@jtu.pytest_mark_if_available('multiaccelerator')
class PgleTest(jtu.JaxTestCase):
_dump_exit_stack: ExitStack | None = None

@classmethod
def setUpClass(cls):
super().setUpClass()
cls._dump_exit_stack = ExitStack()

cls.dump_dir = cls._dump_exit_stack.enter_context(tempfile.TemporaryDirectory())
if 'XLA_FLAGS' in os.environ:
cls.old_xla_flags = os.environ['XLA_FLAGS']
else:
cls.old_xla_flags = None

os.environ['XLA_FLAGS'] = (
f'--xla_dump_to={cls.dump_dir}'
' --xla_gpu_experimental_dump_fdo_profiles=true'
' --xla_gpu_enable_latency_hiding_scheduler=true'
# TODO(patrios): Remove this flag once b/376647494 is fixed.
' --xla_gpu_graph_level=0'
)
if cls.old_xla_flags:
os.environ['XLA_FLAGS'] += ' ' + cls.old_xla_flags

@classmethod
def tearDownClass(cls):
if cls.old_xla_flags:
os.environ['XLA_FLAGS'] = cls.old_xla_flags
cls._dump_exit_stack.close()
super().tearDownClass()

def setUp(self):
super().setUp()
cc.set_cache_dir(None)
cc.reset_cache()

def tearDown(self):
# Cleanup dump directory
for file in os.listdir(self.dump_dir):
file_path = os.path.join(self.dump_dir, file)
if os.path.isfile(file_path):
os.remove(file_path)

cc.set_cache_dir(None)
super().tearDown()

Expand Down Expand Up @@ -87,7 +117,6 @@ def f(x, y):
self.assertIsNotNone(fdo_profile)
self.assertIn(b'custom', fdo_profile)

@unittest.skip("Test failing in CI")
def testPGLEProfilerGetFDOProfileLarge(self):
mesh = jtu.create_mesh((2,), ('x',))
its = 500
Expand All @@ -106,14 +135,10 @@ def f(x):
shape = (16, 16)
x = jnp.arange(math.prod(shape)).reshape(shape).astype(np.float32)

with config.pgle_profiling_runs(0):
f_lowered = f.lower(x)
f_compiled = f_lowered.compile()

pgle_profiler = profiler.PGLEProfiler(1, 90)
with config.enable_pgle(False):
with profiler.PGLEProfiler.trace(pgle_profiler):
f_compiled(x)
f(x)
fdo_profile = pgle_profiler.consume_fdo_profile()
self.assertEqual(fdo_profile.count(b'custom'), its)

Expand Down Expand Up @@ -177,7 +202,6 @@ def f(x):
self.assertArraysEqual(compiled(x), expected)
self.assertEqual(cache_miss_count[0], 0)

@unittest.skip("Test failing in CI")
def testAutoPgleWithPersistentCache(self):
its = 50
mesh = jtu.create_mesh((2,), ('x',))
Expand Down Expand Up @@ -206,11 +230,12 @@ def f(x):
config.persistent_cache_min_compile_time_secs(0),
config.pgle_profiling_runs(2),
tempfile.TemporaryDirectory() as cache_dir):
cc.reset_cache()
cc.set_cache_dir(cache_dir)
# Run 1: Module should be compiled without FDO
with jtu.count_cached_compilation_cache_miss() as cache_miss_count:
f(x)
self.assertEqual(cache_miss_count[0], 1)
self.assertGreater(cache_miss_count[0], 0)

# Non-pgle profiled version of module should be saved
non_pgle_profiled_files = os.listdir(cache_dir)
Expand All @@ -221,26 +246,24 @@ def f(x):
f(x)
self.assertEqual(cache_miss_count[0], 0)

module_before_pgle = os.listdir(dump_dir)
print(module_before_pgle)
module_before_pgle = os.listdir(self.dump_dir)
self.assertNotEmpty(module_before_pgle)
# Run 3: Module should be compiled with FDO and stored to persistent cache
with jtu.count_cached_compilation_cache_miss() as cache_miss_count:
# Add xla_dump_to to env flags
f(x)
self.assertEqual(cache_miss_count[0], 1)
self.assertGreater(cache_miss_count[0], 0)

# Check if FDO profile file of the biggest module is not empty
module_after_pgle = [
x
for x in os.listdir(dump_dir)
for x in os.listdir(self.dump_dir)
if x not in module_before_pgle
]
self.assertNotEmpty(module_after_pgle)
biggest_module_after_pgle = max(
module_after_pgle,
key=lambda x: os.path.getsize(
os.path.join(dump_dir, x)
os.path.join(self.dump_dir, x)
),
)
base_module_name = '.'.join(biggest_module_after_pgle.split('.')[0:1])
Expand All @@ -251,7 +274,7 @@ def f(x):
'.fdo_profile'
):
self.assertGreater(
os.path.getsize(os.path.join(dump_dir, module)), 0
os.path.getsize(os.path.join(self.dump_dir, module)), 0
)

for pgle_profiler in profilers_dict.values():
Expand Down Expand Up @@ -283,7 +306,7 @@ def check_if_cache_hit(event):
f(x)
monitoring._unregister_event_listener_by_callback(check_if_cache_hit)

self.assertEqual(cache_hit, 1)
self.assertGreater(cache_hit, 0)

def testPassingFDOProfile(self):
mesh = jtu.create_mesh((2,), ('x',))
Expand Down

0 comments on commit ae03462

Please sign in to comment.