From e3276a9c97600179cc70c560fb82e655969a4dce Mon Sep 17 00:00:00 2001 From: Whitney Tsang Date: Thu, 24 Apr 2025 23:02:09 +0000 Subject: [PATCH] Allow contract by default Signed-off-by: Whitney Tsang --- third_party/intel/backend/compiler.py | 3 +-- third_party/intel/triton_xpu.cc | 14 +++++++++++++- 2 files changed, 14 insertions(+), 3 deletions(-) diff --git a/third_party/intel/backend/compiler.py b/third_party/intel/backend/compiler.py index 4ef2f7cbde..9331d9eb9b 100644 --- a/third_party/intel/backend/compiler.py +++ b/third_party/intel/backend/compiler.py @@ -358,8 +358,7 @@ def make_llir(src, metadata, options): context = llvm.context() llvm_mod = llvm.to_module(mod, context) intel.set_spv_target_triple(llvm_mod) - if os.getenv("TRITON_INTEL_FAST_MATH", "0") == "1": - intel.set_fast_math(llvm_mod) + intel.set_fast_math(llvm_mod) if options.extern_libs: paths = [path for (name, path) in options.extern_libs] llvm.link_extern_libs(llvm_mod, paths) diff --git a/third_party/intel/triton_xpu.cc b/third_party/intel/triton_xpu.cc index 6ce8477337..e80c59e27c 100644 --- a/third_party/intel/triton_xpu.cc +++ b/third_party/intel/triton_xpu.cc @@ -273,12 +273,24 @@ void init_triton_intel(py::module &&m) { // fast math semantics on all arithmetic operations. // https://github.com/intel/intel-xpu-backend-for-triton/issues/3862 m.def("set_fast_math", [](llvm::Module *mod) { + std::optional fastMath = mlir::triton::tools::isEnvValueBool( + mlir::triton::tools::getStrEnv("TRITON_INTEL_FAST_MATH")); + std::optional enableFpFusion = mlir::triton::tools::isEnvValueBool( + mlir::triton::tools::getStrEnv("TRITON_DEFAULT_FP_FUSION")); + if (fastMath.has_value() && !fastMath.value()) + return; + using namespace llvm; for (Function &func : *mod) { for (Instruction &inst : instructions(func)) { if (auto *op = dyn_cast(&inst)) { FastMathFlags FMF; - FMF.setFast(true); + // Default to allow contract when default fp fusion is not disabled. + if ((!enableFpFusion.has_value() || enableFpFusion.value()) && + !fastMath.has_value()) + FMF.setAllowContract(true); + else if (fastMath.has_value() && fastMath.value()) + FMF.setFast(true); inst.setFastMathFlags(FMF); } }