diff --git a/xla/service/gpu/gemm_fusion.cc b/xla/service/gpu/gemm_fusion.cc index 999989208d73c..40fca88f55415 100644 --- a/xla/service/gpu/gemm_fusion.cc +++ b/xla/service/gpu/gemm_fusion.cc @@ -703,9 +703,14 @@ class GemmFusionVisitor : public DfsHloRewriteVisitor { // If a GEMM requiring padding for cuBLAS is encountered here this // happened because earlier ShouldTritonHandleGEMM() accepted it and padding // was skipped. Accept it ignoring profitability checks. - if (!CublasRequiresPadding(*Cast(dot), gpu_version_) && - !should_fuse) { - return absl::OkStatus(); + // TODO(rocm): check ROCM padding requirements. + if(std::holds_alternative(gpu_version_)) { + if (!CublasRequiresPadding( + *Cast(dot), + std::get(gpu_version_)) && + !should_fuse) { + return OkStatus(); + } } HloComputation* computation = @@ -753,15 +758,31 @@ absl::StatusOr RunOnComputation( bool IsSupportedByTriton( PrecisionConfig::Algorithm algorithm, - const se::CudaComputeCapability& cuda_compute_capability) { + const se::GpuComputeCapability& gpu_version) { + auto cuda_compute_capability = + std::get_if(&gpu_version); + auto rocm_compute_capability = + std::get_if(&gpu_version); switch (algorithm) { case PrecisionConfig::ALG_DOT_BF16_BF16_F32: + if(rocm_compute_capability) { + return rocm_compute_capability->has_bf16_dtype_support(); + } return true; case PrecisionConfig::ALG_DOT_TF32_TF32_F32: + if(rocm_compute_capability) { + return false; + } case PrecisionConfig::ALG_DOT_BF16_BF16_F32_X3: case PrecisionConfig::ALG_DOT_BF16_BF16_F32_X6: - return cuda_compute_capability.IsAtLeastAmpere(); + if(rocm_compute_capability) { + return rocm_compute_capability->has_bf16_dtype_support(); + } + else if (cuda_compute_capability) { + return cuda_compute_capability->IsAtLeastAmpere(); + } + return false; // TODO(b/326579472): Fix the support of this algorithm and maybe allow it // here. @@ -779,8 +800,12 @@ FusionDecision CanTritonHandleGEMM( const HloDotInstruction& dot, const se::GpuComputeCapability& gpu_version) { auto cuda_compute_capability = std::get_if(&gpu_version); + auto rocm_compute_capability = + std::get_if(&gpu_version); - if (!cuda_compute_capability) return "Non CUDA device."; + if (!cuda_compute_capability && !rocm_compute_capability) { + return "Non CUDA or ROCM device."; + } if (dot.precision_config().algorithm() == PrecisionConfig::ALG_UNSET) { if (!tsl::tensor_float_32_execution_enabled() || @@ -801,8 +826,14 @@ FusionDecision CanTritonHandleGEMM( case F32: return true; case BF16: - return cuda_compute_capability->IsAtLeast( + if(cuda_compute_capability) { + return cuda_compute_capability->IsAtLeast( stream_executor::CudaComputeCapability::AMPERE); + } + else if(rocm_compute_capability) { + return rocm_compute_capability->has_bf16_dtype_support(); + } + return false; default: return false; } diff --git a/xla/service/gpu/triton_tiling_propagation.cc b/xla/service/gpu/triton_tiling_propagation.cc index 123032dd955cf..91424638b37e2 100644 --- a/xla/service/gpu/triton_tiling_propagation.cc +++ b/xla/service/gpu/triton_tiling_propagation.cc @@ -1037,7 +1037,9 @@ GetPropagatedDimOrdersAndRequirementsIfProfitablyFusible( std::move(std::get(result_or_error)); int fusion_level = hlo.GetModule()->config().debug_options().xla_gpu_triton_fusion_level(); - if (!std::get(gpu_version) + //TODO(ROCm) Check fusion level for ROCm. + if (std::holds_alternative(gpu_version) + && !std::get(gpu_version) .IsAtLeast(se::CudaComputeCapability::AMPERE)) { fusion_level = std::min(fusion_level, 1); }