Skip to content

Commit

Permalink
[AutoPGLE] Fix test after pjrt cache refactoring
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 696156229
  • Loading branch information
Google-ML-Automation committed Nov 13, 2024
1 parent be3c8be commit 93a1f9d
Showing 1 changed file with 2 additions and 4 deletions.
6 changes: 2 additions & 4 deletions tests/pgle_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,8 +219,6 @@ def f(x):
shape = (16, 16)
x = jnp.arange(math.prod(shape)).reshape(shape).astype(np.float32)

profilers_dict = (
pjit._most_recent_pjit_call_executable.weak_pgle_profiler_dict)
with (config.enable_compilation_cache(True),
config.enable_pgle(True),
config.raise_persistent_cache_errors(True),
Expand Down Expand Up @@ -276,7 +274,7 @@ def f(x):
os.path.getsize(os.path.join(self.dump_dir, module)), 0
)

for pgle_profiler in profilers_dict.values():
for pgle_profiler in pjit._pgle_profiler_dict.values():
self.assertTrue(pgle_profiler.is_enabled())
self.assertTrue(pgle_profiler.is_fdo_consumed())

Expand All @@ -291,7 +289,7 @@ def f(x):
os.remove(os.path.join(cache_dir, non_pgle_file))

api.clear_caches()
profilers_dict.clear()
pjit._pgle_profiler_dict.clear()

# Run 4: Persistent compilation cache should be hit PGLE profiler should
# be disabled
Expand Down

0 comments on commit 93a1f9d

Please sign in to comment.