diff --git a/xla/service/gpu/ir_emitter_triton_test.cc b/xla/service/gpu/ir_emitter_triton_test.cc index a8fe6df7a8b09..c3e882d7c464f 100644 --- a/xla/service/gpu/ir_emitter_triton_test.cc +++ b/xla/service/gpu/ir_emitter_triton_test.cc @@ -43,6 +43,7 @@ limitations under the License. #include "xla/service/gpu/matmul_utils.h" #include "xla/service/gpu/tests/gpu_codegen_test.h" #include "xla/service/gpu/triton_fusion_analysis.h" +#include "xla/service/gpu/variant_visitor.h" #include "xla/service/pattern_matcher.h" #include "xla/service/pattern_matcher_gmock.h" #include "xla/status_macros.h" @@ -69,6 +70,11 @@ namespace { namespace m = ::xla::match; class TritonTest : public GpuCodegenTest { + const auto& device_desc() { + return backend() + .default_stream_executor() + ->GetDeviceDescription(); + } public: se::CudaComputeCapability GetCudaComputeCapability() { return backend() @@ -76,6 +82,24 @@ class TritonTest : public GpuCodegenTest { ->GetDeviceDescription() .cuda_compute_capability(); } + const se::GpuComputeCapability& GpuComputeComp() { + return device_desc().gpu_compute_capability(); + } + bool SkipBF16Tests() { + if (std::holds_alternative(GpuComputeComp())) { + auto rcc = device_desc().rocm_compute_capability(); + return !rcc.has_bf16_dtype_support(); + } + return false; + } + se::GpuComputeCapability CudaAmpereOrRocm() { + if (std::holds_alternative(GpuComputeComp())) { + return se::GpuComputeCapability{device_desc().rocm_compute_capability()}; + } else { + return se::GpuComputeCapability{se::CudaComputeCapability{ + se::CudaComputeCapability::AMPERE, 0}}; + } + } }; class TritonGemmTest : public TritonTest { @@ -839,6 +863,9 @@ TEST_F(TritonFilecheckTest, NestedReducerFusionGetsCodegenedCorrectly) { se::CudaComputeCapability::AMPERE)) { GTEST_SKIP() << "Doesn't pass on pre-Ampere GPUs."; } + if (SkipBF16Tests()) { + GTEST_SKIP() << "BF16 not supported."; + } const std::string kHloText = R"( HloModule softmax @@ -1268,6 +1295,9 @@ CHECK: mma } TEST_F(TritonGemmTest, FailIfTooMuchShmem) { + if (std::holds_alternative(GpuComputeComp())) { + GTEST_SKIP() << "GEMM padding requirements for ROCM not included yet."; + } const std::string kHloText = R"( HloModule module, is_scheduled=true @@ -1301,8 +1331,7 @@ ENTRY entry { EXPECT_THAT( TritonWrapper(*TritonFusionAnalysis::Execute(*triton_dot_computation), "test_fn", triton_dot_computation, - se::CudaComputeCapability{se::CudaComputeCapability::AMPERE, - /*minor=*/0}, + CudaAmpereOrRocm(), dev_info, config, &llvm_module, &EmitMatMul, mlir_context), tsl::testing::StatusIs( tsl::error::RESOURCE_EXHAUSTED, @@ -1316,8 +1345,7 @@ ENTRY entry { const auto result, TritonWrapper(*TritonFusionAnalysis::Execute(*triton_dot_computation), "test_fn", triton_dot_computation, - se::CudaComputeCapability{se::CudaComputeCapability::AMPERE, - /*minor=*/0}, + CudaAmpereOrRocm(), dev_info, config, &llvm_module, &EmitMatMul, mlir_context)); // Use optin shared memory which is > shared_memory_per_block. EXPECT_GT(result.shmem_bytes, dev_info.shared_memory_per_block()); @@ -1348,7 +1376,7 @@ ENTRY e { ; CHECK-NEXT: parameter ; CHECK-NEXT: fusion( ; CHECK-SAME: kind=kCustom -; CHECK-SAME: "block_m": +; CHECK-PTX-SAME: "block_m": )"); // Not doing a comparison here, because the input matrices are quite big. @@ -1374,7 +1402,7 @@ ENTRY e { ; CHECK-NEXT: parameter ; CHECK-NEXT: fusion( ; CHECK-SAME: kind=kCustom -; CHECK-SAME: "block_m": +; CHECK-PTX-SAME: "block_m": )"); EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); @@ -1399,7 +1427,7 @@ ENTRY e { ; CHECK-NEXT: ROOT ; CHECK-SAME: fusion( ; CHECK-SAME: kind=kCustom -; CHECK-SAME: "block_m": +; CHECK-PTX-SAME: "block_m": ; CHECK-NOT: pad ; CHECK-NOT: slice )"); @@ -1426,7 +1454,7 @@ ENTRY e { ; CHECK-NEXT: parameter ; CHECK-NEXT: fusion( ; CHECK-SAME: kind=kCustom -; CHECK-SAME: "block_m": +; CHECK-PTX-SAME: "block_m": )"); EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{/*aabs=*/0, /*arel=*/0})); @@ -1453,7 +1481,7 @@ ENTRY e { ; CHECK-NEXT: parameter ; CHECK-NEXT: fusion( ; CHECK-SAME: kind=kCustom -; CHECK-SAME: "block_m": +; CHECK-PTX-SAME: "block_m": )"); EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); @@ -1530,7 +1558,7 @@ ENTRY e { ; CHECK: transpose ; CHECK: fusion ; CHECK-SAME: kind=kCustom -; CHECK-SAME: "block_m": +; CHECK-PTX-SAME: "block_m": )"); EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); @@ -1555,7 +1583,7 @@ ENTRY e { ; CHECK: transpose ; CHECK: fusion ; CHECK-SAME: kind=kCustom -; CHECK-SAME: "block_m": +; CHECK-PTX-SAME: "block_m": )"); EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); @@ -1580,7 +1608,7 @@ ENTRY e { ; CHECK-NEXT: parameter ; CHECK-NEXT: fusion ; CHECK-SAME: kind=kCustom -; CHECK-SAME: "block_m": +; CHECK-PTX-SAME: "block_m": )"); EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{/*aabs=*/1e-4, /*arel=*/1e-2})); @@ -1605,7 +1633,7 @@ ENTRY e { ; CHECK-NEXT: parameter ; CHECK-NEXT: fusion ; CHECK-SAME: kind=kCustom -; CHECK-SAME: "block_m": +; CHECK-PTX-SAME: "block_m": )"); EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); @@ -1631,7 +1659,7 @@ ENTRY e { ; CHECK-NEXT: parameter ; CHECK-NEXT: fusion ; CHECK-SAME: kind=kCustom -; CHECK-SAME: "block_m": +; CHECK-PTX-SAME: "block_m": )"); EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{/*aabs=*/1e-4, /*arel=*/1e-2})); @@ -1656,7 +1684,7 @@ ENTRY e { ; CHECK: f32[5,3,4]{2,1,0} bitcast ; CHECK: fusion ; CHECK-SAME: kind=kCustom -; CHECK-SAME: "block_m": +; CHECK-PTX-SAME: "block_m": )"); EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{/*aabs=*/1e-4, /*arel=*/1e-4})); @@ -1734,6 +1762,9 @@ ENTRY e { } TEST_F(TritonGemmTestWithoutTritonGemmAny, SkipU8) { + if (std::holds_alternative(GpuComputeComp())) { + GTEST_SKIP() << "GEMM padding requirements for ROCM not included yet."; + } const std::string hlo_text = R"( HloModule t @@ -1752,6 +1783,9 @@ ENTRY e { } TEST_F(TritonGemmTestWithoutTritonGemmAny, SkipF32F32) { + if (std::holds_alternative(GpuComputeComp())) { + GTEST_SKIP() << "GEMM padding requirements for ROCM not included yet."; + } const std::string hlo_text = R"( HloModule t @@ -1804,8 +1838,7 @@ ENTRY entry { EXPECT_THAT( TritonWrapper(*TritonFusionAnalysis::Execute(*triton_dot_computation), "test_fn", triton_dot_computation, - se::CudaComputeCapability{se::CudaComputeCapability::AMPERE, - /*minor=*/0}, + CudaAmpereOrRocm(), dev_info, config, &llvm_module, &EmitMatMul, mlir_context), tsl::testing::StatusIs( tsl::error::RESOURCE_EXHAUSTED, @@ -1818,8 +1851,7 @@ ENTRY entry { TF_CHECK_OK( TritonWrapper(*TritonFusionAnalysis::Execute(*triton_dot_computation), "test_fn", triton_dot_computation, - se::CudaComputeCapability{se::CudaComputeCapability::AMPERE, - /*minor=*/0}, + CudaAmpereOrRocm(), dev_info, config, &llvm_module, &EmitMatMul, mlir_context) .status()); } @@ -1877,11 +1909,14 @@ ENTRY e { // multiple times and assign block sizes on success. R"( ; CHECK: f16[77,99,111]{2,1,0} transpose -; CHECK: block_m +; CHECK-PTX: block_m )"); } TEST_F(TritonGemmTest, SingleElementTileIsHandled) { + if (std::holds_alternative(GpuComputeComp())) { + GTEST_SKIP() << "Not using autotuner on ROCM yet."; + } MatchOptimizedHlo(R"( t { p0 = f32[2,7,3]{2,1,0} parameter(0) @@ -1935,13 +1970,16 @@ ENTRY e { MatchOptimizedHlo(hlo_text, R"( ; CHECK: fusion( ; CHECK-SAME: kind=kCustom -; CHECK-SAME: block_m +; CHECK-PTX-SAME: block_m )"); EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); } TEST_F(TritonGemmTestAny, DoAddConstantToScalarAndBroadcastThat) { + if (std::holds_alternative(GpuComputeComp())) { + GTEST_SKIP() << "Not using autotuner on ROCM yet."; + } const std::string hlo_text = R"( HloModule t @@ -1978,7 +2016,7 @@ ENTRY e { ; CHECK: ENTRY ; CHECK: %[[p0:.*]] = pred[5,5]{1,0} parameter(0) ; CHECK: fusion(%[[p0]], %[[p0]]), kind=kCustom -; CHECK-SAME: "block_m": +; CHECK-PTX-SAME: "block_m": )"); EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{/*aabs=*/1e-6, /*arel=*/1e-6})); @@ -1986,6 +2024,9 @@ ENTRY e { TEST_F(TritonGemmTestAny, DoNotFuseConcatenationOfSplitNonContractingDimension) { + if (SkipBF16Tests()) { + GTEST_SKIP() << "BF16 not supported."; + } const std::string hlo_text = R"( HloModule m @@ -2208,6 +2249,9 @@ ENTRY e { } TEST_F(TritonGemmLevel2Test, DoubleBroadcastOfScalarConstantIsHandled) { + if (SkipBF16Tests()) { + GTEST_SKIP() << "BF16 not supported."; + } const std::string kHloText = R"( ENTRY e { c = s32[] constant(1) @@ -2253,6 +2297,9 @@ ENTRY e { } TEST_F(TritonGemmLevel2Test, AlwaysFuseScalarConstantAtBroadcastInput) { + if (SkipBF16Tests()) { + GTEST_SKIP() << "BF16 not supported."; + } const std::string kHloText = R"( ENTRY e { p0 = bf16[2,3,3]{2,1,0} parameter(0) @@ -2306,6 +2353,9 @@ ENTRY e { } TEST_F(TritonGemmLevel2Test, FuseConcatenation) { + if (SkipBF16Tests()) { + GTEST_SKIP() << "BF16 not supported."; + } const std::string kHloText = R"( e { p0 = s8[153,1536] parameter(0) @@ -2348,7 +2398,7 @@ ENTRY e { MatchOptimizedHlo(kHloText, R"( ; CHECK: fusion( ; CHECK-SAME: kind=kCustom -; CHECK-SAME: block_m +; CHECK-PTX-SAME: block_m )"); EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); @@ -2371,7 +2421,7 @@ ENTRY e { MatchOptimizedHlo(kHloText, R"( ; CHECK: fusion( ; CHECK-SAME: kind=kCustom -; CHECK-SAME: block_m +; CHECK-PTX-SAME: block_m )"); EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); @@ -2394,7 +2444,7 @@ ENTRY e { MatchOptimizedHlo(kHloText, R"( ; CHECK: fusion( ; CHECK-SAME: kind=kCustom -; CHECK-SAME: block_m +; CHECK-PTX-SAME: block_m )"); EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); @@ -2417,7 +2467,7 @@ ENTRY e { MatchOptimizedHlo(kHloText, R"( ; CHECK: fusion( ; CHECK-SAME: kind=kCustom -; CHECK-SAME: block_m +; CHECK-PTX-SAME: block_m )"); EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); @@ -2441,7 +2491,7 @@ ENTRY e { MatchOptimizedHlo(kHloText, R"( ; CHECK: fusion( ; CHECK-SAME: kind=kCustom -; CHECK-SAME: block_m +; CHECK-PTX-SAME: block_m )"); EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-3, @@ -2466,7 +2516,7 @@ ENTRY e { MatchOptimizedHlo(kHloText, R"( ; CHECK: fusion( ; CHECK-SAME: kind=kCustom -; CHECK-SAME: block_m +; CHECK-PTX-SAME: block_m )"); EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-3, @@ -2491,7 +2541,7 @@ ENTRY e { MatchOptimizedHlo(kHloText, R"( ; CHECK: fusion( ; CHECK-SAME: kind=kCustom -; CHECK-SAME: block_m +; CHECK-PTX-SAME: block_m )"); EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-3, @@ -2516,7 +2566,7 @@ ENTRY e { MatchOptimizedHlo(kHloText, R"( ; CHECK: fusion( ; CHECK-SAME: kind=kCustom -; CHECK-SAME: block_m +; CHECK-PTX-SAME: block_m )"); EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-3, @@ -2684,6 +2734,9 @@ ENTRY e { } TEST_F(TritonGemmLevel2Test, ParameterAfterDotIsFused) { + if (SkipBF16Tests()) { + GTEST_SKIP() << "BF16 not supported."; + } const std::string kHloText = R"( HloModule m @@ -2713,6 +2766,9 @@ ENTRY e { } TEST_F(TritonGemmLevel2Test, OutputFusionExecutesCorrectly) { + if (SkipBF16Tests()) { + GTEST_SKIP() << "BF16 not supported."; + } const std::string kHloText = R"( HloModule m @@ -2746,6 +2802,9 @@ ENTRY e { } TEST_F(TritonGemmLevel2Test, SplitLHSOutputTransposeAloneIsNotFused) { + if (SkipBF16Tests()) { + GTEST_SKIP() << "BF16 not supported."; + } const std::string kHloText = R"( HloModule m @@ -2770,6 +2829,9 @@ ENTRY e { } TEST_F(TritonGemmLevel2Test, SplitLHSInputOutputIsFused) { + if (SkipBF16Tests()) { + GTEST_SKIP() << "BF16 not supported."; + } const std::string kHloText = R"( ENTRY e { p0t = (s8[5,18,20,150]) parameter(0) @@ -3045,6 +3107,9 @@ ENTRY e { } TEST_F(CompareTest, BF16TransposedLHS) { + if (SkipBF16Tests()) { + GTEST_SKIP() << "BF16 not supported."; + } const char* hlo_text_ref = R"( HloModule r @@ -3134,7 +3199,7 @@ ENTRY e { const auto result, TritonWrapper(*TritonFusionAnalysis::Execute(*triton_dot_computation), "test_fn", triton_dot_computation, - GetCudaComputeCapability(), dev_info, triton_gemm_config, + GpuComputeComp(), dev_info, triton_gemm_config, &llvm_module, &EmitMatMul, mlir_context)); // The config is chosen so that the used memory size is slightly above the // 48 kB boundary of standard / optin shared memory so that any GPU that @@ -3250,6 +3315,9 @@ ENTRY e { } TEST_F(CompareTest, S8BF16) { + if (SkipBF16Tests()) { + GTEST_SKIP() << "BF16 not supported."; + } const char* hlo_text_ref = R"( HloModule r @@ -3297,6 +3365,9 @@ ENTRY e { } TEST_F(CompareTest, SplitK) { + if (SkipBF16Tests()) { + GTEST_SKIP() << "BF16 not supported."; + } const std::string hlo_text_ref = R"( HloModule t, is_scheduled=true @@ -3370,6 +3441,9 @@ ENTRY e { } TEST_F(CompareTest, SplitKBatch) { + if (SkipBF16Tests()) { + GTEST_SKIP() << "BF16 not supported."; + } const std::string kHloTextRef = R"( HloModule m, is_scheduled=true @@ -3432,6 +3506,9 @@ ENTRY e { } TEST_F(CompareTest, SplitKNontrivialBitcast) { + if (SkipBF16Tests()) { + GTEST_SKIP() << "BF16 not supported."; + } const std::string kHloTextRef = R"( HloModule module, is_scheduled=true @@ -4006,6 +4083,9 @@ ENTRY e { } TEST_F(CompareTest, PredToBF16ConversionWorks) { + if (SkipBF16Tests()) { + GTEST_SKIP() << "BF16 not supported."; + } const std::string kHloTextTest = R"( HloModule m, is_scheduled=true @@ -4126,6 +4206,9 @@ class TritonGemmContractionDims : public TritonGemmTest { }; TEST_F(TritonGemmContractionDims, TritonDotForceContractionDims_1_0) { + if (SkipBF16Tests()) { + GTEST_SKIP() << "BF16 not supported."; + } const std::string kHloText = R"( HloModule m @@ -4148,6 +4231,9 @@ ENTRY e { } TEST_F(TritonGemmContractionDims, TritonDotForceContractionDims_1_2_1_2) { + if (SkipBF16Tests()) { + GTEST_SKIP() << "BF16 not supported."; + } const std::string kHloText = R"( HloModule m @@ -4170,6 +4256,9 @@ ENTRY e { } TEST_F(TritonGemmContractionDims, TritonDotForceContractionDims_1_2_0_1) { + if (SkipBF16Tests()) { + GTEST_SKIP() << "BF16 not supported."; + } const std::string kHloText = R"( HloModule m @@ -4193,6 +4282,9 @@ ENTRY e { } TEST_F(TritonGemmContractionDims, TritonDotForceContractionDims_1_1) { + if (SkipBF16Tests()) { + GTEST_SKIP() << "BF16 not supported."; + } const std::string kHloText = R"( HloModule m @@ -4232,6 +4324,12 @@ class Triton6xBF16GemmTest : public TritonFilecheckTest { debug_options.set_xla_gpu_enable_split_k_autotuning(false); return debug_options; } + protected: + void SetUp() override { + if (SkipBF16Tests()) { + GTEST_SKIP() << "BF16 not supported."; + } + } }; // In these tests, we depend on debug option flags for selecting the 6XBF16 @@ -4577,6 +4675,12 @@ class Triton3xBF16GemmTestWithFlag : public TritonFilecheckTest { debug_options.set_xla_gpu_enable_bf16_3way_gemm(true); return debug_options; } + protected: + void SetUp() override { + if (SkipBF16Tests()) { + GTEST_SKIP() << "BF16 not supported."; + } + } }; TEST_F(Triton3xBF16GemmTest, Emit3xBF16GemmWhenBothInputsAreF32) {