Skip to content

Commit

Permalink
[XLA:GPU][Cleanup] Remove pre-Ampere paths in GEMM fusion autotuner.
Browse files Browse the repository at this point in the history
These paths are dead, given that GEMM fusions are gated on the compute
capability being at least Ampere.

Fix includes as a side cleanup.

PiperOrigin-RevId: 681909633
  • Loading branch information
bchetioui authored and tensorflower-gardener committed Oct 3, 2024
1 parent a22c6d4 commit 8ff94d3
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 27 deletions.
3 changes: 3 additions & 0 deletions third_party/xla/xla/service/gpu/autotuning/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ cc_library(
"//xla/service/gpu:backend_configs_cc",
"//xla/service/gpu:buffer_comparator",
"//xla/service/gpu:gpu_float_support",
"//xla/service/gpu:hlo_traversal",
"//xla/service/gpu:ir_emission_utils",
"//xla/service/gpu:matmul_utils",
"//xla/service/gpu:split_k_gemm_rewriter",
Expand All @@ -77,9 +78,11 @@ cc_library(
"//xla/stream_executor:device_memory",
"//xla/stream_executor:semantic_version",
"//xla/stream_executor:stream_executor_memory_allocator",
"//xla/stream_executor/gpu:redzone_allocator",
"//xla/tools:hlo_decomposer_lib",
"//xla/tsl/lib/core:bits",
"//xla/tsl/util/proto:proto_utils",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/log",
Expand Down
49 changes: 22 additions & 27 deletions third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,15 @@ limitations under the License.
#include <array>
#include <atomic>
#include <cstdint>
#include <iterator>
#include <memory>
#include <optional>
#include <string>
#include <utility>
#include <variant>
#include <vector>

#include "absl/algorithm/container.h"
#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
#include "absl/log/check.h"
Expand Down Expand Up @@ -56,12 +58,14 @@ limitations under the License.
#include "xla/service/algorithm_util.h"
#include "xla/service/call_inliner.h"
#include "xla/service/dump.h"
#include "xla/service/executable.h"
#include "xla/service/float_normalization.h"
#include "xla/service/gpu/autotuning/autotuner_compile_util.h"
#include "xla/service/gpu/autotuning/autotuner_util.h"
#include "xla/service/gpu/backend_configs.pb.h"
#include "xla/service/gpu/buffer_comparator.h"
#include "xla/service/gpu/gpu_float_support.h"
#include "xla/service/gpu/hlo_traversal.h"
#include "xla/service/gpu/ir_emission_utils.h"
#include "xla/service/gpu/kernels/custom_kernel.h"
#include "xla/service/gpu/kernels/custom_kernel_fusion.h"
Expand All @@ -84,6 +88,7 @@ limitations under the License.
#include "xla/stream_executor/device_description.h"
#include "xla/stream_executor/device_memory.h"
#include "xla/stream_executor/device_memory_allocator.h"
#include "xla/stream_executor/gpu/redzone_allocator.h"
#include "xla/stream_executor/semantic_version.h"
#include "xla/stream_executor/stream.h"
#include "xla/stream_executor/stream_executor_memory_allocator.h"
Expand Down Expand Up @@ -717,7 +722,7 @@ std::vector<BackendConfig> GenerateCustomKernelFusionConfigs(
std::vector<CustomKernelFusionPattern::Match> match =
patterns->Match(device_description, dot_instruction);

// For Cutlass we expect only one match for a gemm fusion.
// For Cutlass we expect only one match for a GEMM fusion.
if (match.size() == 1) {
CustomKernelFusionRegistry* registry =
CustomKernelFusionRegistry::Default();
Expand Down Expand Up @@ -1195,10 +1200,6 @@ GemmFusionAutotunerImpl::GetExhaustiveTritonConfigs() const {
debug_options_.xla_gpu_exhaustive_tiling_search() && cc.IsAtLeastHopper();

for (int num_stages : kNumStages) {
// Volta doesn't support num_stages > 2.
if (!cc.IsAtLeastAmpere() && num_stages > 2) {
break;
}
for (int tile_m : kBlockSizes) {
for (int tile_n : kBlockSizes) {
for (int tile_k : kBlockSizes) {
Expand Down Expand Up @@ -1242,28 +1243,22 @@ std::vector<TritonGemmConfig> GemmFusionAutotunerImpl::GetDefaultTritonConfigs()
const {
using Config = TritonGemmConfig;
std::vector<Config> configs = {
Config(32, 32, 256, 1, 1, 4), Config(64, 32, 32, 16, 1, 4),
Config(32, 64, 64, 4, 1, 4), Config(128, 128, 64, 4, 1, 4),
Config(16, 16, 256, 1, 1, 4), Config(16, 128, 32, 16, 1, 4),
Config(16, 64, 128, 1, 1, 4), Config(16, 128, 32, 8, 1, 4),
Config(16, 16, 512, 1, 1, 4), Config(32, 16, 512, 1, 1, 4),
Config(64, 32, 64, 1, 2, 8)};
if (GetComputeCapability().IsAtLeastAmpere()) {
absl::c_copy(
std::vector<Config>{
Config(128, 256, 32, 1, 3, 8), Config(256, 128, 32, 1, 3, 8),
Config(256, 64, 32, 1, 4, 4), Config(64, 256, 32, 1, 4, 4),
Config(128, 64, 32, 1, 4, 4), Config(64, 128, 32, 1, 4, 4),
Config(256, 128, 128, 1, 3, 8), Config(256, 64, 128, 1, 4, 4),
Config(64, 256, 128, 1, 4, 4), Config(128, 128, 128, 1, 4, 4),
Config(128, 64, 64, 1, 4, 4), Config(64, 128, 64, 1, 4, 4),
Config(128, 32, 64, 1, 4, 4), Config(64, 32, 64, 1, 4, 4),
Config(32, 128, 32, 1, 4, 4), Config(128, 128, 32, 1, 4, 4),
Config(16, 16, 256, 1, 3, 4), Config(128, 128, 64, 2, 1, 8),
Config(64, 64, 64, 1, 2, 4), Config(16, 64, 256, 8, 1, 4),
Config(256, 256, 128, 1, 3, 8)},
std::back_inserter(configs));
}
Config(32, 32, 256, 1, 1, 4), Config(64, 32, 32, 16, 1, 4),
Config(32, 64, 64, 4, 1, 4), Config(128, 128, 64, 4, 1, 4),
Config(16, 16, 256, 1, 1, 4), Config(16, 128, 32, 16, 1, 4),
Config(16, 64, 128, 1, 1, 4), Config(16, 128, 32, 8, 1, 4),
Config(16, 16, 512, 1, 1, 4), Config(32, 16, 512, 1, 1, 4),
Config(64, 32, 64, 1, 2, 8), Config(128, 256, 32, 1, 3, 8),
Config(256, 128, 32, 1, 3, 8), Config(256, 64, 32, 1, 4, 4),
Config(64, 256, 32, 1, 4, 4), Config(128, 64, 32, 1, 4, 4),
Config(64, 128, 32, 1, 4, 4), Config(256, 128, 128, 1, 3, 8),
Config(256, 64, 128, 1, 4, 4), Config(64, 256, 128, 1, 4, 4),
Config(128, 128, 128, 1, 4, 4), Config(128, 64, 64, 1, 4, 4),
Config(64, 128, 64, 1, 4, 4), Config(128, 32, 64, 1, 4, 4),
Config(64, 32, 64, 1, 4, 4), Config(32, 128, 32, 1, 4, 4),
Config(128, 128, 32, 1, 4, 4), Config(16, 16, 256, 1, 3, 4),
Config(128, 128, 64, 2, 1, 8), Config(64, 64, 64, 1, 2, 4),
Config(16, 64, 256, 8, 1, 4), Config(256, 256, 128, 1, 3, 8)};
if (GetComputeCapability().IsAtLeastHopper()) {
absl::c_copy(
std::vector<Config>{
Expand Down

0 comments on commit 8ff94d3

Please sign in to comment.