From 8ff94d3cf1e5edc5586a98bd0cdab1601ff6924b Mon Sep 17 00:00:00 2001 From: Benjamin Chetioui Date: Thu, 3 Oct 2024 09:21:00 -0700 Subject: [PATCH] [XLA:GPU][Cleanup] Remove pre-Ampere paths in GEMM fusion autotuner. These paths are dead, given that GEMM fusions are gated on the compute capability being at least Ampere. Fix includes as a side cleanup. PiperOrigin-RevId: 681909633 --- .../xla/xla/service/gpu/autotuning/BUILD | 3 ++ .../gpu/autotuning/gemm_fusion_autotuner.cc | 49 +++++++++---------- 2 files changed, 25 insertions(+), 27 deletions(-) diff --git a/third_party/xla/xla/service/gpu/autotuning/BUILD b/third_party/xla/xla/service/gpu/autotuning/BUILD index 947e8bed38864a..406b6bc628fcd9 100644 --- a/third_party/xla/xla/service/gpu/autotuning/BUILD +++ b/third_party/xla/xla/service/gpu/autotuning/BUILD @@ -60,6 +60,7 @@ cc_library( "//xla/service/gpu:backend_configs_cc", "//xla/service/gpu:buffer_comparator", "//xla/service/gpu:gpu_float_support", + "//xla/service/gpu:hlo_traversal", "//xla/service/gpu:ir_emission_utils", "//xla/service/gpu:matmul_utils", "//xla/service/gpu:split_k_gemm_rewriter", @@ -77,9 +78,11 @@ cc_library( "//xla/stream_executor:device_memory", "//xla/stream_executor:semantic_version", "//xla/stream_executor:stream_executor_memory_allocator", + "//xla/stream_executor/gpu:redzone_allocator", "//xla/tools:hlo_decomposer_lib", "//xla/tsl/lib/core:bits", "//xla/tsl/util/proto:proto_utils", + "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log", diff --git a/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc b/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc index 321b312e4c34e0..37b549a2aee4d7 100644 --- a/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc +++ b/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include #include +#include #include #include #include @@ -26,6 +27,7 @@ limitations under the License. #include #include +#include "absl/algorithm/container.h" #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" #include "absl/log/check.h" @@ -56,12 +58,14 @@ limitations under the License. #include "xla/service/algorithm_util.h" #include "xla/service/call_inliner.h" #include "xla/service/dump.h" +#include "xla/service/executable.h" #include "xla/service/float_normalization.h" #include "xla/service/gpu/autotuning/autotuner_compile_util.h" #include "xla/service/gpu/autotuning/autotuner_util.h" #include "xla/service/gpu/backend_configs.pb.h" #include "xla/service/gpu/buffer_comparator.h" #include "xla/service/gpu/gpu_float_support.h" +#include "xla/service/gpu/hlo_traversal.h" #include "xla/service/gpu/ir_emission_utils.h" #include "xla/service/gpu/kernels/custom_kernel.h" #include "xla/service/gpu/kernels/custom_kernel_fusion.h" @@ -84,6 +88,7 @@ limitations under the License. #include "xla/stream_executor/device_description.h" #include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/device_memory_allocator.h" +#include "xla/stream_executor/gpu/redzone_allocator.h" #include "xla/stream_executor/semantic_version.h" #include "xla/stream_executor/stream.h" #include "xla/stream_executor/stream_executor_memory_allocator.h" @@ -717,7 +722,7 @@ std::vector GenerateCustomKernelFusionConfigs( std::vector match = patterns->Match(device_description, dot_instruction); - // For Cutlass we expect only one match for a gemm fusion. + // For Cutlass we expect only one match for a GEMM fusion. if (match.size() == 1) { CustomKernelFusionRegistry* registry = CustomKernelFusionRegistry::Default(); @@ -1195,10 +1200,6 @@ GemmFusionAutotunerImpl::GetExhaustiveTritonConfigs() const { debug_options_.xla_gpu_exhaustive_tiling_search() && cc.IsAtLeastHopper(); for (int num_stages : kNumStages) { - // Volta doesn't support num_stages > 2. - if (!cc.IsAtLeastAmpere() && num_stages > 2) { - break; - } for (int tile_m : kBlockSizes) { for (int tile_n : kBlockSizes) { for (int tile_k : kBlockSizes) { @@ -1242,28 +1243,22 @@ std::vector GemmFusionAutotunerImpl::GetDefaultTritonConfigs() const { using Config = TritonGemmConfig; std::vector configs = { - Config(32, 32, 256, 1, 1, 4), Config(64, 32, 32, 16, 1, 4), - Config(32, 64, 64, 4, 1, 4), Config(128, 128, 64, 4, 1, 4), - Config(16, 16, 256, 1, 1, 4), Config(16, 128, 32, 16, 1, 4), - Config(16, 64, 128, 1, 1, 4), Config(16, 128, 32, 8, 1, 4), - Config(16, 16, 512, 1, 1, 4), Config(32, 16, 512, 1, 1, 4), - Config(64, 32, 64, 1, 2, 8)}; - if (GetComputeCapability().IsAtLeastAmpere()) { - absl::c_copy( - std::vector{ - Config(128, 256, 32, 1, 3, 8), Config(256, 128, 32, 1, 3, 8), - Config(256, 64, 32, 1, 4, 4), Config(64, 256, 32, 1, 4, 4), - Config(128, 64, 32, 1, 4, 4), Config(64, 128, 32, 1, 4, 4), - Config(256, 128, 128, 1, 3, 8), Config(256, 64, 128, 1, 4, 4), - Config(64, 256, 128, 1, 4, 4), Config(128, 128, 128, 1, 4, 4), - Config(128, 64, 64, 1, 4, 4), Config(64, 128, 64, 1, 4, 4), - Config(128, 32, 64, 1, 4, 4), Config(64, 32, 64, 1, 4, 4), - Config(32, 128, 32, 1, 4, 4), Config(128, 128, 32, 1, 4, 4), - Config(16, 16, 256, 1, 3, 4), Config(128, 128, 64, 2, 1, 8), - Config(64, 64, 64, 1, 2, 4), Config(16, 64, 256, 8, 1, 4), - Config(256, 256, 128, 1, 3, 8)}, - std::back_inserter(configs)); - } + Config(32, 32, 256, 1, 1, 4), Config(64, 32, 32, 16, 1, 4), + Config(32, 64, 64, 4, 1, 4), Config(128, 128, 64, 4, 1, 4), + Config(16, 16, 256, 1, 1, 4), Config(16, 128, 32, 16, 1, 4), + Config(16, 64, 128, 1, 1, 4), Config(16, 128, 32, 8, 1, 4), + Config(16, 16, 512, 1, 1, 4), Config(32, 16, 512, 1, 1, 4), + Config(64, 32, 64, 1, 2, 8), Config(128, 256, 32, 1, 3, 8), + Config(256, 128, 32, 1, 3, 8), Config(256, 64, 32, 1, 4, 4), + Config(64, 256, 32, 1, 4, 4), Config(128, 64, 32, 1, 4, 4), + Config(64, 128, 32, 1, 4, 4), Config(256, 128, 128, 1, 3, 8), + Config(256, 64, 128, 1, 4, 4), Config(64, 256, 128, 1, 4, 4), + Config(128, 128, 128, 1, 4, 4), Config(128, 64, 64, 1, 4, 4), + Config(64, 128, 64, 1, 4, 4), Config(128, 32, 64, 1, 4, 4), + Config(64, 32, 64, 1, 4, 4), Config(32, 128, 32, 1, 4, 4), + Config(128, 128, 32, 1, 4, 4), Config(16, 16, 256, 1, 3, 4), + Config(128, 128, 64, 2, 1, 8), Config(64, 64, 64, 1, 2, 4), + Config(16, 64, 256, 8, 1, 4), Config(256, 256, 128, 1, 3, 8)}; if (GetComputeCapability().IsAtLeastHopper()) { absl::c_copy( std::vector{