Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion xla/service/gpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -502,7 +502,8 @@ cc_library(
name = "ir_emitter_triton",
srcs = if_cuda_is_configured(["ir_emitter_triton.cc"]) + if_rocm_hipblaslt([
"ir_emitter_triton.cc",
]),
]) + if_cuda_is_configured(["ir_emitter_triton_cuda.cc"])
+ if_rocm_is_configured(["ir_emitter_triton_rocm.cc"]),
hdrs = if_gpu_is_configured(["ir_emitter_triton.h"]),
deps = [
":hlo_traversal",
Expand Down
10 changes: 5 additions & 5 deletions xla/service/gpu/fusions/triton.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ limitations under the License.
#include "tsl/platform/errors.h"
#include "tsl/platform/statusor.h"

#if GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#include "xla/service/gpu/ir_emitter_triton.h"
#else
#include "absl/status/status.h"
Expand Down Expand Up @@ -98,7 +98,7 @@ absl::StatusOr<FusionEmissionResult> TritonFusion::Emit(
IrEmitterContext& ir_emitter_context,
const HloFusionInstruction& fusion) const {
llvm::IRBuilder builder(ir_emitter_context.llvm_module()->getContext());
#if GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
VLOG(3) << fusion.ToString();
std::string suggested_kernel_name = std::string(fusion.name());
TF_ASSIGN_OR_RETURN(
Expand Down Expand Up @@ -137,7 +137,7 @@ absl::StatusOr<FusionEmissionResult> TritonFusion::Emit(
TF_ASSIGN_OR_RETURN(
triton_wrapper_result,
TritonWrapper(analysis, impl_fn_name, hlo_computation,
ir_emitter_context.cuda_compute_capability(),
ir_emitter_context.gpu_compute_capability(),
ir_emitter_context.gpu_device_info(), config,
ir_emitter_context.llvm_module(), &EmitSoftMax,
*ir_emitter_context.mlir_context()));
Expand All @@ -164,7 +164,7 @@ absl::StatusOr<FusionEmissionResult> TritonFusion::Emit(
TF_ASSIGN_OR_RETURN(
triton_wrapper_result,
TritonWrapper(analysis, impl_fn_name, hlo_computation,
ir_emitter_context.cuda_compute_capability(),
ir_emitter_context.gpu_compute_capability(),
ir_emitter_context.gpu_device_info(), config,
ir_emitter_context.llvm_module(), &EmitMatMul,
*ir_emitter_context.mlir_context()));
Expand Down Expand Up @@ -212,7 +212,7 @@ absl::StatusOr<FusionEmissionResult> TritonFusion::Emit(

return result;
#else
return absl::UnimplementedError("Triton support requires CUDA");
return absl::UnimplementedError("Triton support requires CUDA or ROCm");
#endif
}

Expand Down
2 changes: 1 addition & 1 deletion xla/service/gpu/hlo_fusion_analysis.cc
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ HloFusionAnalysis::EmitterFusionKind HloFusionAnalysis::GetEmitterFusionKind()
return EmitterFusionKind::kCustomFusion;
}

#if GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
if (fusion_backend_config_.kind() == kTritonGemmFusionKind ||
fusion_backend_config_.kind() == kTritonSoftmaxFusionKind) {
return EmitterFusionKind::kTriton;
Expand Down
135 changes: 33 additions & 102 deletions xla/service/gpu/ir_emitter_triton.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,6 @@ limitations under the License.
#include <variant>
#include <vector>

#include "nvidia/include/NVGPUToLLVM/NVGPUToLLVMPass.h"
#include "nvidia/include/TritonNVIDIAGPUToLLVM/Passes.h"
#include "absl/algorithm/container.h"
#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
Expand All @@ -54,6 +52,7 @@ limitations under the License.
#include "llvm/TargetParser/Triple.h"
#include "mlir/Conversion/AffineToStandard/AffineToStandard.h" // from @llvm-project
#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h" // from @llvm-project
#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h" // from @llvm-project
#include "mlir/Conversion/IndexToLLVM/IndexToLLVM.h" // from @llvm-project
#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h" // from @llvm-project
#include "mlir/Dialect/Affine/IR/AffineOps.h" // from @llvm-project
Expand Down Expand Up @@ -90,6 +89,7 @@ limitations under the License.
#include "mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h" // from @llvm-project
#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h" // from @llvm-project
#include "mlir/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.h" // from @llvm-project
#include "mlir/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.h" // from @llvm-project
#include "mlir/Target/LLVMIR/Export.h" // from @llvm-project
#include "mlir/Transforms/Passes.h" // from @llvm-project
#include "xla/autotuning.pb.h"
Expand Down Expand Up @@ -145,6 +145,7 @@ limitations under the License.
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
#include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h"


namespace xla {
namespace gpu {

Expand Down Expand Up @@ -425,12 +426,16 @@ absl::StatusOr<Value> EmitElementwise(ImplicitLocOpBuilder& b,
mlir::getElementTypeOrSelf(inputs[0]).isF64()) {
auto dev_fn_id = GetTargetDeviceFunctionID(hlo.opcode());
if (dev_fn_id.ok()) {
return b.create<mt::ExternElementwiseOp>(
inputs[0].getType(), inputs, "libdevice", libdevice_path,
ObtainDeviceFunctionName(dev_fn_id.value(),
hlo.shape().element_type(),
llvm::Triple("nvptx64-unknown-unknown")),
/*pure=*/true);
llvm::Triple triple("nvptx64-unknown-unknown");
if (std::holds_alternative<se::RocmComputeCapability>
(device_info.gpu_compute_capability())) {
triple.setTriple("amdgcn-unknown-unknown");
}
return b.create<mt::ExternElementwiseOp>(
inputs[0].getType(), inputs, "libdevice", libdevice_path,
ObtainDeviceFunctionName(dev_fn_id.value(),
hlo.shape().element_type(), triple),
/*pure=*/true);
}
}
const bool is_integer =
Expand Down Expand Up @@ -934,81 +939,6 @@ absl::StatusOr<Value> EmitScope(
return values[instructions.back()];
}

// Create Triton pipeline.
//
// `out_cluster_info` must be kept alive at least until pm.run() is called.
// It should be read after that. We have to pass the cluster dims to
// LaunchDimensions. Triton currently uses this as an out-parameter to return
// the cluster dims determined based on `config.num_ctas` and a heuristic. There
// are some signs that show that this was intended to be used as an in-out
// parameter which would give a hint to Triton which cluster dims we prefer to
// use, but that's not the case currently.
absl::Status CreateTritonPipeline(
mlir::OpPassManager& pm, const se::CudaComputeCapability& cc,
const TritonGemmConfig& config,
mt::nvidia_gpu::ClusterInfo& out_cluster_info) {
const int ccAsInt = cc.major * 10 + cc.minor;
const int threadsPerWarp = 32;

// Based on make_ttir() in
// @triton//:third_party/nvidia/backend/compiler.py
pm.addPass(mlir::createInlinerPass());
pm.addPass(mt::createRewriteTensorPointerPass());
pm.addPass(mt::createCombineOpsPass());
pm.addPass(mlir::createCanonicalizerPass());
pm.addPass(mt::createReorderBroadcastPass());
pm.addPass(mlir::createCSEPass());
pm.addPass(mlir::createLoopInvariantCodeMotionPass());
pm.addPass(mlir::createSymbolDCEPass());

// Based on make_ttgir() in
// @triton//:third_party/nvidia/backend/compiler.py
pm.addPass(mt::createConvertTritonToTritonGPUPass(
config.num_warps, threadsPerWarp, config.num_ctas, ccAsInt));
pm.addPass(mt::gpu::createCoalescePass());
pm.addPass(mlir::createTritonNvidiaGPUPlanCTAPass(&out_cluster_info));
pm.addPass(mt::gpu::createRemoveLayoutConversionsPass());
pm.addPass(mt::gpu::createOptimizeThreadLocalityPass());
pm.addPass(mt::gpu::createAccelerateMatmulPass(ccAsInt));
pm.addPass(mt::gpu::createRemoveLayoutConversionsPass());
pm.addPass(mt::gpu::createOptimizeDotOperandsPass());
pm.addPass(mlir::createCSEPass());

pm.addPass(mt::gpu::createPipelinePass(config.num_stages, config.num_warps,
config.num_ctas, ccAsInt));

if (!cc.IsAtLeastHopper()) {
pm.addPass(mt::gpu::createPrefetchPass());
}

pm.addPass(mt::gpu::createOptimizeDotOperandsPass());
pm.addPass(mt::gpu::createRemoveLayoutConversionsPass());
pm.addPass(mt::gpu::createReduceDataDuplicationPass());
pm.addPass(mt::gpu::createReorderInstructionsPass());
pm.addPass(mlir::createCSEPass());
pm.addPass(mlir::createSymbolDCEPass());
if (cc.IsAtLeastHopper()) {
pm.addPass(mlir::createTritonNvidiaGPUFenceInsertionPass(ccAsInt));
}
pm.addPass(mlir::createCanonicalizerPass());

// Based on make_llir() in
// @triton//:third_party/nvidia/backend/compiler.py
pm.addPass(mt::gpu::createDecomposeUnsupportedConversionsPass());
pm.addPass(mlir::createConvertSCFToCFPass());
pm.addPass(mlir::createConvertIndexToLLVMPass());
pm.addPass(mt::gpu::createAllocateSharedMemoryPass());
pm.addPass(mt::createConvertTritonGPUToLLVMPass(ccAsInt));
pm.addPass(mt::createConvertNVGPUToLLVMPass());
pm.addPass(mlir::createArithToLLVMConversionPass());
pm.addPass(mlir::createCanonicalizerPass());
pm.addPass(mlir::createCSEPass());
pm.addPass(mlir::createSymbolDCEPass());
// Note: translateTritonGPUToLLVMIR adds line info with LLVMDIScopePass.

return absl::OkStatus();
}

// Extract additional attributes from an LLVM function that are not passed
// to the builder directly.
SmallVector<mlir::NamedAttribute> GetExtraAttrs(ml::LLVMFuncOp func) {
Expand Down Expand Up @@ -2666,6 +2596,7 @@ absl::StatusOr<std::unique_ptr<llvm::Module>> TranslateLLVMToLLVMIR(
mlir::registerBuiltinDialectTranslation(registry);
mlir::registerLLVMDialectTranslation(registry);
mlir::registerNVVMDialectTranslation(registry);
mlir::registerROCDLDialectTranslation(registry);
module->getContext()->appendDialectRegistry(registry);

std::unique_ptr<llvm::Module> llvmModule =
Expand All @@ -2690,14 +2621,6 @@ absl::StatusOr<std::unique_ptr<llvm::Module>> TranslateLLVMToLLVMIR(
return llvmModule;
}

namespace {

std::string GetLibdevicePath(const HloModuleConfig& hlo_config) {
return nvptx::LibDevicePath(
hlo_config.debug_options().xla_gpu_cuda_data_dir());
}

} // namespace

absl::StatusOr<mlir::OwningOpRef<mlir::ModuleOp>> CreateTritonModule(
const TritonFusionAnalysis& analysis, absl::string_view fn_name,
Expand Down Expand Up @@ -2737,7 +2660,8 @@ absl::StatusOr<mlir::OwningOpRef<mlir::ModuleOp>> CreateTritonModule(
b.setInsertionPointToStart(&fn.front());

TF_RETURN_IF_ERROR(
ir_emitter(b, GetLibdevicePath(hlo_computation->parent()->config()),
ir_emitter(b, GetLibdevicePath(hlo_computation->parent()->config(),
device_info),
device_info, analysis, hlo_computation, fn, config));

b.create<mt::ReturnOp>(loc);
Expand All @@ -2759,13 +2683,16 @@ absl::StatusOr<mlir::OwningOpRef<mlir::ModuleOp>> CreateTritonModule(

absl::StatusOr<TritonWrapperResult> TritonWrapper(
const TritonFusionAnalysis& analysis, absl::string_view fn_name,
const HloComputation* hlo_computation, const se::CudaComputeCapability& cc,
const HloComputation* hlo_computation, const se::GpuComputeCapability& cc,
const se::DeviceDescription& device_info, const TritonGemmConfig& config,
llvm::Module* llvm_module, TritonIrEmitter ir_emitter,
mlir::MLIRContext& mlir_context) {
if (!cc.IsAtLeastAmpere()) {
return absl::FailedPreconditionError(
"Triton support is only enabled for Ampere GPUs and up.");
if (std::holds_alternative<se::CudaComputeCapability>(cc)) {
auto ccCuda = std::get<se::CudaComputeCapability>(cc);
if (!ccCuda.IsAtLeastAmpere()) {
return absl::FailedPreconditionError(
"Triton support is only enabled for Ampere GPUs and up.");
}
}

auto debug_options = GetDebugOptionsFromFlags();
Expand Down Expand Up @@ -2793,13 +2720,16 @@ absl::StatusOr<TritonWrapperResult> TritonWrapper(
// TODO(b/325220878): Replace TritonGemmConfig with a more generic abstraction.
absl::StatusOr<TritonWrapperResult> CompileTritonToLLVM(
const HloModuleConfig& hlo_config, absl::string_view hlo_module_name,
const se::CudaComputeCapability& cc,
const se::GpuComputeCapability& cc,
const se::DeviceDescription& device_info, const TritonGemmConfig& config,
mlir::ModuleOp triton_module, llvm::Module* llvm_module,
mlir::MLIRContext& mlir_context) {
if (!cc.IsAtLeastAmpere()) {
return absl::FailedPreconditionError(
"Triton support is only enabled for Ampere GPUs and up.");
if (std::holds_alternative<se::CudaComputeCapability>(cc)) {
auto ccCuda = std::get<se::CudaComputeCapability>(cc);
if (!ccCuda.IsAtLeastAmpere()) {
return absl::FailedPreconditionError(
"Triton support is only enabled for Ampere GPUs and up.");
}
}

bool should_verify =
Expand Down Expand Up @@ -2873,7 +2803,8 @@ absl::StatusOr<TritonWrapperResult> CompileTritonToLLVM(
triton_module->getAttrOfType<mlir::IntegerAttr>("triton_gpu.shared")
.getInt();
VLOG(2) << "Shared memory usage: " << shared_mem_bytes << " B";
if (shared_mem_bytes > device_info.shared_memory_per_block_optin()) {
if (std::holds_alternative<se::CudaComputeCapability>(cc)
&& shared_mem_bytes > device_info.shared_memory_per_block_optin()) {
return absl::ResourceExhaustedError(absl::StrFormat(
"Shared memory size limit exceeded: requested %d, available: %d",
shared_mem_bytes, device_info.shared_memory_per_block_optin()));
Expand All @@ -2882,7 +2813,7 @@ absl::StatusOr<TritonWrapperResult> CompileTritonToLLVM(
TF_ASSIGN_OR_RETURN(
std::unique_ptr<llvm::Module> ll_triton_module,
TranslateLLVMToLLVMIR(&llvm_module->getContext(), triton_module,
GetLibdevicePath(hlo_config)));
GetLibdevicePath(hlo_config, device_info)));
VLogModule(5, *ll_triton_module);
if (should_verify) {
VerifyModule(*ll_triton_module);
Expand Down
25 changes: 23 additions & 2 deletions xla/service/gpu/ir_emitter_triton.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ limitations under the License.
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/IR/OwningOpRef.h" // from @llvm-project
#include "mlir/Pass/PassManager.h" // from @llvm-project
#include "xla/autotuning.pb.h"
#include "xla/hlo/ir/hlo_computation.h"
#include "xla/service/gpu/hlo_traversal.h"
Expand All @@ -40,10 +41,13 @@ limitations under the License.
#include "xla/stream_executor/device_description.h"
#include "xla/stream_executor/launch_dim.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h"

namespace xla {
namespace gpu {

namespace mt = ::mlir::triton;

struct TritonWrapperResult {
int64_t shmem_bytes = 0;
std::optional<se::ClusterDim> cluster_dim;
Expand Down Expand Up @@ -89,7 +93,7 @@ using TritonIrEmitter = std::function<Status(
// MatMul and SoftMax above are some such IR generators.
absl::StatusOr<TritonWrapperResult> TritonWrapper(
const TritonFusionAnalysis& analysis, absl::string_view fn_name,
const HloComputation* hlo_computation, const se::CudaComputeCapability& cc,
const HloComputation* hlo_computation, const se::GpuComputeCapability& cc,
const se::DeviceDescription& device_info, const TritonGemmConfig& config,
llvm::Module* llvm_module, TritonIrEmitter ir_emitter,
mlir::MLIRContext& mlir_context);
Expand All @@ -105,11 +109,28 @@ absl::StatusOr<mlir::OwningOpRef<mlir::ModuleOp>> CreateTritonModule(
// Compiles a given Triton module to LLVM IR.
absl::StatusOr<TritonWrapperResult> CompileTritonToLLVM(
const HloModuleConfig& hlo_config, absl::string_view hlo_module_name,
const se::CudaComputeCapability& cc,
const se::GpuComputeCapability& cc,
const se::DeviceDescription& device_info, const TritonGemmConfig& config,
mlir::ModuleOp triton_module, llvm::Module* llvm_module,
mlir::MLIRContext& mlir_context);

// Create Triton pipeline.
//
// `out_cluster_info` must be kept alive at least until pm.run() is called.
// It should be read after that. We have to pass the cluster dims to
// LaunchDimensions. Triton currently uses this as an out-parameter to return
// the cluster dims determined based on `config.num_ctas` and a heuristic. There
// are some signs that show that this was intended to be used as an in-out
// parameter which would give a hint to Triton which cluster dims we prefer to
// use, but that's not the case currently.
absl::Status CreateTritonPipeline(
mlir::OpPassManager& pm, const se::GpuComputeCapability& cc,
const TritonGemmConfig& config,
mt::nvidia_gpu::ClusterInfo& out_cluster_info);

std::string GetLibdevicePath(const HloModuleConfig& hlo_config,
const se::DeviceDescription& device_info);

} // namespace gpu
} // namespace xla

Expand Down
Loading