Skip to content

Commit

Permalink
disable remat hlo pass by default
Browse files Browse the repository at this point in the history
  • Loading branch information
keshavb96 committed Sep 18, 2024
1 parent 3f2c58b commit 1b74cfd
Showing 1 changed file with 12 additions and 0 deletions.
12 changes: 12 additions & 0 deletions jax/_src/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,15 @@
),
)

_DISABLE_COMPILER_REMAT_OPTIMIZATION_PASS = config.bool_flag(
"jax_compiler_enable_remat_pass",
config.bool_env('JAX_COMPILER_ENABLE_REMAT_PASS', True),
help=(
'Disable the rematerialization HLO pass by default '
'Avoids having to pass --xla_disable_hlo_passes=rematerialization. '
)
)

# The special XLA-AutoFDO profile version that indicates that a profile is not
# available and retrieval should not be attempted.
_NO_PROFILE_DONT_RETRIEVE = -1
Expand Down Expand Up @@ -199,6 +208,9 @@ def get_compile_options(
debug_options.xla_backend_optimization_level = 0
debug_options.xla_llvm_disable_expensive_passes = True
debug_options.xla_test_all_input_layouts = False

if _DISABLE_COMPILER_REMAT_OPTIMIZATION_PASS.value:
debug_options.xla_disable_hlo_passes = "rematerialization"

# XLA-AutoFDO profile version: precedence order is:
# 1. Whatever --jax_xla_profile_version is set to.
Expand Down

0 comments on commit 1b74cfd

Please sign in to comment.