diff --git a/xla/debug_options_flags.cc b/xla/debug_options_flags.cc index 65a6b8fad9769..c2482733c3a55 100644 --- a/xla/debug_options_flags.cc +++ b/xla/debug_options_flags.cc @@ -337,6 +337,8 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() { opts.set_xla_hlo_pass_fix_detect_cycles(false); opts.set_xla_gpu_experimental_enable_sync_collective_combining(false); opts.set_xla_allow_get_default_platform(true); + opts.set_xla_unsupported_crash_on_hlo_pass_silent_hlo_change(false); + opts.set_xla_unsupported_crash_on_hlo_pass_noop_change(false); return opts; } @@ -2311,7 +2313,22 @@ void MakeDebugOptionsFlags(std::vector* flag_list, "If non empty will interpret this variable as a path for performance " "tables for collectives. Expects `xla.gpu.DeviceHloInstructionProfiles` " "proto.")); -} // NOLINT(readability/fn_size)1 + flag_list->push_back(tsl::Flag( + "xla_unsupported_crash_on_hlo_pass_silent_hlo_change", + bool_setter_for( + &DebugOptions:: + set_xla_unsupported_crash_on_hlo_pass_silent_hlo_change), + debug_options->xla_unsupported_crash_on_hlo_pass_silent_hlo_change(), + "Crash if a pass reports that it did not change the HLO but in fact it " + "did.")); + flag_list->push_back(tsl::Flag( + "xla_unsupported_crash_on_hlo_pass_noop_change", + bool_setter_for( + &DebugOptions::set_xla_unsupported_crash_on_hlo_pass_noop_change), + debug_options->xla_unsupported_crash_on_hlo_pass_noop_change(), + "Crash if a pass reports that it did change the HLO but in fact it " + "did not.")); +} // NOLINT(readability/fn_size) // Allocates flag_values and flag_objects; this function must not be called more // than once - its call done via call_once. diff --git a/xla/hlo/pass/BUILD b/xla/hlo/pass/BUILD index 3e4badd16a98c..cfd6dc662eb08 100644 --- a/xla/hlo/pass/BUILD +++ b/xla/hlo/pass/BUILD @@ -89,6 +89,7 @@ cc_library( "//xla:util", "//xla:xla_proto_cc", "//xla/hlo/ir:hlo", + "//xla/hlo/ir:hlo_module_group", "//xla/service:compilation_stats", "//xla/service:dump", "//xla/service:hlo_graph_dumper", @@ -102,6 +103,7 @@ cc_library( "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/types:optional", "@tsl//tsl/profiler/lib:scoped_annotation", "@tsl//tsl/profiler/lib:traceme", ], diff --git a/xla/hlo/pass/hlo_pass_pipeline.cc b/xla/hlo/pass/hlo_pass_pipeline.cc index 2c15e6721cf4e..09e5d26dc0e99 100644 --- a/xla/hlo/pass/hlo_pass_pipeline.cc +++ b/xla/hlo/pass/hlo_pass_pipeline.cc @@ -15,6 +15,8 @@ limitations under the License. #include "xla/hlo/pass/hlo_pass_pipeline.h" +#include +#include #include #include @@ -24,6 +26,8 @@ limitations under the License. #include "absl/strings/str_format.h" #include "absl/strings/str_join.h" #include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_module_group.h" +#include "xla/hlo/pass/hlo_pass_interface.h" #include "xla/service/dump.h" #include "xla/service/hlo_graph_dumper.h" #include "xla/service/hlo_proto_util.h" @@ -169,6 +173,9 @@ absl::StatusOr HloPassPipeline::RunPassesInternal( /*module_changed=*/false); bool changed = false; + bool verify_pass_changed_report = + debug_options.xla_unsupported_crash_on_hlo_pass_silent_hlo_change() || + debug_options.xla_unsupported_crash_on_hlo_pass_noop_change(); for (int i = 0; i < passes.size(); i++) { HloPassInterface* pass = passes[i]; std::string pass_name = std::string(pass->name()); @@ -178,7 +185,11 @@ absl::StatusOr HloPassPipeline::RunPassesInternal( pass_name, hlo->name(), UniqueId(*hlo)); }}; VLOG(1) << " HLO pass " << pass_name; - VLOG(2) << " Module hash " << absl::HashOf(*hlo); + std::optional hash_before = std::nullopt; + if (verify_pass_changed_report || VLOG_IS_ON(2)) { + hash_before = absl::HashOf(*hlo); + VLOG(2) << " Module hash " << hash_before.value(); + } tsl::profiler::TraceMe traceme(pass->name()); if (!pass->IsPassPipeline()) { compilation_stats_->StartPass(pass_name); @@ -190,6 +201,27 @@ absl::StatusOr HloPassPipeline::RunPassesInternal( pass_name, absl::StatusCodeToString(status.code())); } TF_ASSIGN_OR_RETURN(bool pass_changed, status_or_changed); + if (verify_pass_changed_report) { + size_t hash_after = absl::HashOf(*hlo); + // Fail if pass changed HLO but has reported that it didn't. + if (!pass_changed && hash_after != hash_before && + debug_options.xla_unsupported_crash_on_hlo_pass_silent_hlo_change()) { + LOG(FATAL) << absl::StrFormat( + "Pass '%s' in pipeline '%s' reported that it did not change the " + "HLO but the hash of HLO was changed from %d to %d. HLO text " + "after:\n%s", + pass_name, pipeline_name, hash_before.value(), hash_after, + hlo->ToString()); + } + // Fail if pass did not change HLO but has reported that it did. + if (pass_changed && hash_after == hash_before && + debug_options.xla_unsupported_crash_on_hlo_pass_noop_change()) { + LOG(FATAL) << absl::StrFormat( + "Pass '%s' in pipeline '%s' reported that it changed the HLO but " + "the hash of HLO was not updated. HLO text after:\n%s", + pass_name, pipeline_name, hlo->ToString()); + } + } if (!dump_regex.empty() && (pass_changed || dump_regex != ".*")) { MaybeDumpHloAndSaveFilenames(*hlo, /*after_pass_name=*/pass_name, diff --git a/xla/xla.proto b/xla/xla.proto index 5be375f10dcf1..690175d0146c7 100644 --- a/xla/xla.proto +++ b/xla/xla.proto @@ -125,6 +125,10 @@ message DebugOptions { bool xla_hlo_pass_fix_detect_cycles = 370; // Crash if HloPassFix can not converge after a fixed number of iterations. bool xla_unsupported_crash_on_hlo_pass_fix_max_iterations = 363; + // Crash if a pass reports that it changes the HLO but in fact it did not. + bool xla_unsupported_crash_on_hlo_pass_noop_change = 379; + // Crash if a pass reports that it did not change the HLO but in fact it did. + bool xla_unsupported_crash_on_hlo_pass_silent_hlo_change = 380; // go/keep-sorted end reserved 346; // xla_experimental_exec_time_optimization_effort @@ -1189,7 +1193,7 @@ message DebugOptions { // Note: when adding a new flag, please add it to one of the hardware-specific // or hardware-agnostic sections at the top of this proto message. - // Next id: 379 + // Next id: 381 // Extra options to pass to the compilation backend (e.g. LLVM); specific // interpretation of these values is left to the backend.