diff --git a/BUILD b/BUILD new file mode 100644 index 000000000000..8321be1e6ba2 --- /dev/null +++ b/BUILD @@ -0,0 +1,974 @@ +# This package imports OpenAI's Triton (https://github.com/openai/triton). +# +# There are two versions of Triton in google3 at the moment. The older version +# can be found at //third_party/py/triton. This is the MLIR-based version close +# to head. We expect to transition users to this version in the following +# weeks. +# +# There is no SLA associated with this package and it may get broken by LLVM +# imports at any time. + +load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "td_library") +# copybara:uncomment load("//tools/build_defs/license:license.bzl", "license") + +package( + # copybara:uncomment_begin + # default_applicable_licenses = [":license"], + # default_compatible_with = ["//buildenv/target:gce"], + # default_visibility = [ + # # Add your project here if you need to depend on Triton's C++ sources. + # # Add a point of contact we can reach out to when needed in the comment. + # # + # # If you need to use the Python fronted, add your project to + # # google3/third_party/py/triton/BUILD instead. + # # + # # By adding your project here, you agree to the Triton SLA: go/triton-google3-sla + # "//third_party/py/jax:__subpackages__", # cjfj@ + # "//third_party/tensorflow/compiler/xla:__subpackages__", # bchetioui@ + # # Triton-internal visibility + # "//:__subpackages__", + # ], + # copybara:uncomment_end_and_comment_begin + default_visibility = ["//visibility:public"], + # copybara:comment_end + # TODO(csigg): fix and remove + features = [ + "-parse_headers", + "-use_header_modules", + ], +) + +# copybara:uncomment_begin +# license(name = "license") +# +# licenses(["notice"]) +# +# exports_files(["LICENSE"]) +# copybara:uncomment_end + +config_setting( + name = "compiler_is_msvc", + flag_values = { + # copybara:comment_begin + "@bazel_tools" + + # copybara:comment_end + "//tools/cpp:compiler": "msvc-cl", + }, +) + +# TODO(csigg): fix, enable error upstream, remove. +_no_unused_variable = select({ + ":compiler_is_msvc": [], + "//conditions:default": ["-Wno-unused-variable"], +}) + +td_library( + name = "td_files", + srcs = glob(["include/triton/**/*.td"]), + includes = ["include"], + deps = [ + "@llvm-project//mlir:ArithOpsTdFiles", + "@llvm-project//mlir:CastInterfacesTdFiles", + "@llvm-project//mlir:ControlFlowInterfacesTdFiles", + "@llvm-project//mlir:DestinationStyleOpInterfaceTdFiles", + "@llvm-project//mlir:FunctionInterfacesTdFiles", + "@llvm-project//mlir:InferTypeOpInterfaceTdFiles", + "@llvm-project//mlir:LLVMOpsTdFiles", + "@llvm-project//mlir:OpBaseTdFiles", + "@llvm-project//mlir:PassBaseTdFiles", + "@llvm-project//mlir:SideEffectInterfacesTdFiles", + "@llvm-project//mlir:ViewLikeInterfaceTdFiles", + ], +) + +gentbl_cc_library( + name = "triton_attr_inc_gen", + tbl_outs = [ + ( + ["--gen-attrdef-decls"], + "include/triton/Dialect/Triton/IR/TritonAttrDefs.h.inc", + ), + ( + ["--gen-attrdef-defs"], + "include/triton/Dialect/Triton/IR/TritonAttrDefs.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/triton/Dialect/Triton/IR/TritonAttrDefs.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "triton_dialect_inc_gen", + tbl_outs = [ + ( + ["--gen-dialect-decls"], + "include/triton/Dialect/Triton/IR/Dialect.h.inc", + ), + ( + ["--gen-dialect-defs"], + "include/triton/Dialect/Triton/IR/Dialect.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/triton/Dialect/Triton/IR/TritonDialect.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "triton_interfaces_inc_gen", + tbl_outs = [ + ( + ["--gen-attr-interface-decls"], + "include/triton/Dialect/Triton/IR/AttrInterfaces.h.inc", + ), + ( + ["--gen-attr-interface-defs"], + "include/triton/Dialect/Triton/IR/AttrInterfaces.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/triton/Dialect/Triton/IR/TritonInterfaces.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "triton_ops_inc_gen", + tbl_outs = [ + ( + ["--gen-enum-decls"], + "include/triton/Dialect/Triton/IR/OpsEnums.h.inc", + ), + ( + ["--gen-enum-defs"], + "include/triton/Dialect/Triton/IR/OpsEnums.cpp.inc", + ), + ( + ["--gen-op-decls"], + "include/triton/Dialect/Triton/IR/Ops.h.inc", + ), + ( + ["--gen-op-defs"], + "include/triton/Dialect/Triton/IR/Ops.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/triton/Dialect/Triton/IR/TritonOps.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "triton_types_inc_gen", + tbl_outs = [ + ( + ["--gen-typedef-decls"], + "include/triton/Dialect/Triton/IR/Types.h.inc", + ), + ( + ["--gen-typedef-defs"], + "include/triton/Dialect/Triton/IR/Types.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/triton/Dialect/Triton/IR/TritonTypes.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "triton_transforms_inc_gen", + tbl_outs = [ + ( + [ + "--gen-pass-decls", + "--name=Triton", + ], + "include/triton/Dialect/Triton/Transforms/Passes.h.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/triton/Dialect/Triton/Transforms/Passes.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "triton_combine_inc_gen", + # The generated file is #included without relative path. + strip_include_prefix = "lib/Dialect/Triton/Transforms", + tbl_outs = [ + ( + ["--gen-rewriters"], + "lib/Dialect/Triton/Transforms/TritonCombine.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "lib/Dialect/Triton/Transforms/Combine.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "triton_gpu_attr_inc_gen", + tbl_outs = [ + ( + ["--gen-attrdef-decls"], + "include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.h.inc", + ), + ( + ["--gen-attrdef-defs"], + "include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.cpp.inc", + ), + ( + ["--gen-enum-decls"], + "include/triton/Dialect/TritonGPU/IR/OpsEnums.h.inc", + ), + ( + ["--gen-enum-defs"], + "include/triton/Dialect/TritonGPU/IR/OpsEnums.cpp.inc", + ), + ( + ["--gen-attr-interface-decls"], + "include/triton/Dialect/TritonGPU/IR/TritonGPUAttrInterfaces.h.inc", + ), + ( + ["--gen-attr-interface-defs"], + "include/triton/Dialect/TritonGPU/IR/TritonGPUAttrInterfaces.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "triton_gpu_dialect_inc_gen", + tbl_outs = [ + ( + ["--gen-dialect-decls"], + "include/triton/Dialect/TritonGPU/IR/Dialect.h.inc", + ), + ( + ["--gen-dialect-defs"], + "include/triton/Dialect/TritonGPU/IR/Dialect.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/triton/Dialect/TritonGPU/IR/TritonGPUDialect.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "triton_gpu_ops_inc_gen", + tbl_outs = [ + ( + ["--gen-op-decls"], + "include/triton/Dialect/TritonGPU/IR/Ops.h.inc", + ), + ( + ["--gen-op-defs"], + "include/triton/Dialect/TritonGPU/IR/Ops.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "triton_gpu_types_inc_gen", + tbl_outs = [ + ( + ["--gen-typedef-decls"], + "include/triton/Dialect/TritonGPU/IR/Types.h.inc", + ), + ( + ["--gen-typedef-defs"], + "include/triton/Dialect/TritonGPU/IR/Types.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/triton/Dialect/TritonGPU/IR/TritonGPUTypes.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "triton_gpu_transforms_inc_gen", + tbl_outs = [ + ( + [ + "--gen-pass-decls", + "--name=TritonGPU", + ], + "include/triton/Dialect/TritonGPU/Transforms/Passes.h.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/triton/Dialect/TritonGPU/Transforms/Passes.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "triton_nvgpu_attr_inc_gen", + tbl_outs = [ + ( + ["--gen-attrdef-decls"], + "include/triton/Dialect/NVGPU/IR/NVGPUAttrDefs.h.inc", + ), + ( + ["--gen-attrdef-defs"], + "include/triton/Dialect/NVGPU/IR/NVGPUAttrDefs.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/triton/Dialect/NVGPU/IR/NVGPUAttrDefs.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "triton_nvgpu_dialect_inc_gen", + tbl_outs = [ + ( + ["--gen-dialect-decls"], + "include/triton/Dialect/NVGPU/IR/Dialect.h.inc", + ), + ( + ["--gen-dialect-defs"], + "include/triton/Dialect/NVGPU/IR/Dialect.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/triton/Dialect/NVGPU/IR/NVGPUDialect.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "triton_nvgpu_ops_inc_gen", + tbl_outs = [ + ( + ["--gen-llvmir-conversions"], + "include/triton/Dialect/NVGPU/IR/OpsConversions.inc", + ), + ( + ["--gen-op-decls"], + "include/triton/Dialect/NVGPU/IR/Ops.h.inc", + ), + ( + ["--gen-op-defs"], + "include/triton/Dialect/NVGPU/IR/Ops.cpp.inc", + ), + ( + ["--gen-enum-decls"], + "include/triton/Dialect/NVGPU/IR/OpsEnums.h.inc", + ), + ( + ["--gen-enum-defs"], + "include/triton/Dialect/NVGPU/IR/OpsEnums.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/triton/Dialect/NVGPU/IR/NVGPUOps.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "triton_nvidia_gpu_attr_inc_gen", + tbl_outs = [ + ( + ["--gen-attrdef-decls"], + "include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUAttrDefs.h.inc", + ), + ( + ["--gen-attrdef-defs"], + "include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUAttrDefs.cpp.inc", + ), + ( + ["--gen-enum-decls"], + "include/triton/Dialect/TritonNvidiaGPU/IR/OpsEnums.h.inc", + ), + ( + ["--gen-enum-defs"], + "include/triton/Dialect/TritonNvidiaGPU/IR/OpsEnums.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUAttrDefs.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "triton_nvidia_gpu_dialect_inc_gen", + tbl_outs = [ + ( + ["--gen-dialect-decls"], + "include/triton/Dialect/TritonNvidiaGPU/IR/Dialect.h.inc", + ), + ( + ["--gen-dialect-defs"], + "include/triton/Dialect/TritonNvidiaGPU/IR/Dialect.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUDialect.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "triton_nvidia_gpu_ops_inc_gen", + tbl_outs = [ + ( + ["--gen-op-decls"], + "include/triton/Dialect/TritonNvidiaGPU/IR/Ops.h.inc", + ), + ( + ["--gen-op-defs"], + "include/triton/Dialect/TritonNvidiaGPU/IR/Ops.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "triton_nvidia_gpu_types_inc_gen", + tbl_outs = [ + ( + ["--gen-typedef-decls"], + "include/triton/Dialect/TritonNvidiaGPU/IR/Types.h.inc", + ), + ( + ["--gen-typedef-defs"], + "include/triton/Dialect/TritonNvidiaGPU/IR/Types.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUTypes.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "triton_nvidia_gpu_transforms_inc_gen", + tbl_outs = [ + ( + [ + "--gen-pass-decls", + "--name=TritonNvidiaGPU", + ], + "include/triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/triton/Dialect/TritonNvidiaGPU/Transforms/Passes.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "triton_conversion_triton_to_triton_gpu_passes_inc_gen", + tbl_outs = [ + ( + [ + "--gen-pass-decls", + "--name=TritonToTritonGPU", + ], + "include/triton/Conversion/TritonToTritonGPU/Passes.h.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/triton/Conversion/TritonToTritonGPU/Passes.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "triton_target_llvmir_passes_inc_gen", + tbl_outs = [ + ( + [ + "--gen-pass-decls", + "--name=TritonLLVMIR", + ], + "include/triton/Target/LLVMIR/Passes.h.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/triton/Target/LLVMIR/Passes.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "triton_conversion_triton_gpu_to_llvm_pass_inc_gen", + tbl_outs = [ + ( + [ + "--gen-pass-decls", + "--name=TritonGPUToLLVM", + ], + "include/triton/Conversion/TritonGPUToLLVM/Passes.h.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/triton/Conversion/TritonGPUToLLVM/Passes.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "triton_type_interfaces_inc_gen", + tbl_outs = [ + ( + ["--gen-type-interface-decls"], + "include/triton/Dialect/Triton/IR/TritonTypeInterfaces.h.inc", + ), + ( + ["--gen-type-interface-defs"], + "include/triton/Dialect/Triton/IR/TritonTypeInterfaces.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/triton/Dialect/Triton/IR/TritonTypeInterfaces.td", + deps = ["td_files"], +) + +cc_library( + name = "TritonAnalysis", + srcs = [ + "lib/Analysis/Alias.cpp", + "lib/Analysis/Allocation.cpp", + "lib/Analysis/Membar.cpp", + # Part of TritonDialects compilation unit to avoid circular dependencies. + # "lib/Analysis/Utility.cpp", + # "lib/Analysis/AxisInfo.cpp", + ], + hdrs = [ + "include/triton/Analysis/Alias.h", + "include/triton/Analysis/Allocation.h", + "include/triton/Analysis/Membar.h", + # Part of TritonDialects compilation unit to avoid circular dependencies. + # "include/triton/Analysis/AxisInfo.h", + # "include/triton/Analysis/Utility.h", + "include/triton/Conversion/MLIRTypes.h", + "include/triton/Conversion/TritonGPUToLLVM/AsmFormat.h", + "include/triton/Conversion/TritonGPUToLLVM/Utility.h", + "include/triton/Dialect/TritonGPU/Transforms/Utility.h", + ], + copts = _no_unused_variable, + includes = ["include"], + deps = [ + ":TritonDialects", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:Analysis", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:GPUDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:LLVMCommonConversion", + "@llvm-project//mlir:LLVMDialect", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TensorDialect", + "@llvm-project//mlir:Transforms", + ], +) + +cc_library( + name = "TritonDialects", + srcs = glob([ + "lib/Dialect/NVGPU/IR/*.cpp", + "lib/Dialect/Triton/IR/*.cpp", + "lib/Dialect/TritonGPU/IR/*.cpp", + "lib/Dialect/TritonNvidiaGPU/IR/*.cpp", + ]) + [ + "lib/Analysis/AxisInfo.cpp", # Avoid circular dependency. + "lib/Analysis/Utility.cpp", # Avoid circular dependency. + "lib/Dialect/TritonGPU/Transforms/Utility.cpp", # Avoid circular dependency. + ], + hdrs = glob([ + "include/triton/Dialect/NVGPU/IR/*.h", + "include/triton/Dialect/Triton/IR/*.h", + "include/triton/Dialect/TritonGPU/IR/*.h", + "include/triton/Dialect/TritonNvidiaGPU/IR/*.h", + ]) + [ + "include/triton/Analysis/AxisInfo.h", # Avoid circular dependency. + "include/triton/Analysis/Utility.h", # Avoid circular dependency. + "include/triton/Dialect/TritonGPU/Transforms/Utility.h", # Avoid circular dependency. + ], + copts = select({ + ":compiler_is_msvc": [], + "//conditions:default": [ + "-Wno-unused-variable", + "-Wno-logical-op-parentheses", + ], + }), + includes = ["include"], + deps = [ + ":triton_dialect_inc_gen", + ":triton_gpu_attr_inc_gen", + ":triton_gpu_dialect_inc_gen", + ":triton_gpu_ops_inc_gen", + ":triton_gpu_types_inc_gen", + ":triton_interfaces_inc_gen", + ":triton_nvgpu_attr_inc_gen", + ":triton_nvgpu_dialect_inc_gen", + ":triton_nvgpu_ops_inc_gen", + ":triton_nvidia_gpu_attr_inc_gen", + ":triton_nvidia_gpu_dialect_inc_gen", + ":triton_nvidia_gpu_ops_inc_gen", + ":triton_nvidia_gpu_types_inc_gen", + ":triton_ops_inc_gen", + ":triton_types_inc_gen", + ":triton_type_interfaces_inc_gen", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:Analysis", + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:ControlFlowDialect", + "@llvm-project//mlir:ControlFlowInterfaces", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:FunctionInterfaces", + "@llvm-project//mlir:GPUDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:InliningUtils", + "@llvm-project//mlir:LLVMDialect", + "@llvm-project//mlir:MathDialect", + "@llvm-project//mlir:SCFDialect", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TensorDialect", + # The following is added to make Utility compile + ":TritonTools", + "@llvm-project//mlir:LLVMCommonConversion", + "@llvm-project//mlir:TransformUtils", + "@llvm-project//mlir:Transforms", + ], +) + +cc_library( + name = "TritonTransforms", + srcs = glob(["lib/Dialect/Triton/Transforms/*.cpp"]), + hdrs = glob(["include/triton/Dialect/Triton/Transforms/*.h"]), + copts = _no_unused_variable, + includes = ["include"], + deps = [ + ":TritonDialects", + ":triton_combine_inc_gen", + ":triton_transforms_inc_gen", + "@llvm-project//mlir:ControlFlowDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TransformUtils", + "@llvm-project//mlir:Transforms", + ], + alwayslink = True, # TritonDialect uses getCanonicalizationPatterns(). +) + +cc_library( + name = "TritonGPUTransforms", + srcs = glob( + [ + "lib/Dialect/TritonGPU/Transforms/*.cpp", + "lib/Dialect/TritonGPU/Transforms/*.h", + "lib/Dialect/TritonGPU/Transforms/Pipeliner/*.cpp", + "lib/Dialect/TritonGPU/Transforms/Pipeliner/*.h", + ], + exclude = ["lib/Dialect/TritonGPU/Transforms/Utility.cpp"], + ), + hdrs = glob( + [ + "include/triton/Dialect/TritonGPU/Transforms/*.h", + ], + exclude = ["include/triton/Dialect/TritonGPU/Transforms/Utility.h"], + ) + [ + "include/triton/Tools/Sys/GetEnv.hpp", + ], + copts = select({ + ":compiler_is_msvc": [], + "//conditions:default": [ + "-Wno-reorder-ctor", + "-Wno-return-type", + "-Wno-unused-variable", + ], + }), + includes = ["include"], + deps = [ + ":TritonAnalysis", + ":TritonDialects", + ":triton_gpu_transforms_inc_gen", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:Analysis", + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:InferTypeOpInterface", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:SCFDialect", + "@llvm-project//mlir:SCFTransforms", + "@llvm-project//mlir:SCFUtils", + "@llvm-project//mlir:SideEffectInterfaces", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TensorDialect", + "@llvm-project//mlir:TransformUtils", + "@llvm-project//mlir:Transforms", + ], +) + +cc_library( + name = "TritonGPUToLLVM", + srcs = glob([ + "lib/Conversion/TritonGPUToLLVM/*.h", + "lib/Conversion/TritonGPUToLLVM/**/*.cpp", + ]), + hdrs = glob([ + "include/triton/Tools/Sys/*.hpp", + "include/triton/Conversion/TritonGPUToLLVM/*.h", + ]), + copts = select({ + "//conditions:default": [ + "-Wno-unused-variable", + ], + }), + includes = ["include"], + deps = [ + ":TritonAnalysis", + ":TritonDialects", + ":triton_conversion_triton_gpu_to_llvm_pass_inc_gen", + ":triton_gpu_attr_inc_gen", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:ControlFlowDialect", + "@llvm-project//mlir:GPUDialect", + "@llvm-project//mlir:GPUToROCDLTransforms", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:LLVMCommonConversion", + "@llvm-project//mlir:LLVMDialect", + "@llvm-project//mlir:NVVMDialect", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:ROCDLDialect", + "@llvm-project//mlir:TransformUtils", + "@llvm-project//mlir:Transforms", + ], +) + +cc_library( + name = "TritonNvidiaGPUTransforms", + srcs = glob([ + "lib/Dialect/TritonNvidiaGPU/Transforms/*.cpp", + ]), + hdrs = glob([ + "include/triton/Dialect/TritonNvidiaGPU/Transforms/*.h", + ]), + copts = select({ + ":compiler_is_msvc": [], + "//conditions:default": [ + "-Wno-ctad-maybe-unsupported", + "-Wno-logical-op-parentheses", + "-Wno-non-virtual-dtor", + "-Wno-return-type", + "-Wno-unused-variable", + ], + }), + includes = ["include"], + deps = [ + ":TritonDialects", + ":TritonGPUTransforms", + ":triton_nvidia_gpu_transforms_inc_gen", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:Pass", + ], +) + +cc_library( + name = "TritonToTritonGPU", + srcs = glob([ + "lib/Conversion/TritonToTritonGPU/*.h", + "lib/Conversion/TritonToTritonGPU/*.cpp", + ]), + hdrs = glob(["include/triton/Conversion/TritonToTritonGPU/*.h"]), + includes = ["include"], + deps = [ + ":TritonDialects", + ":TritonGPUTransforms", + ":triton_conversion_triton_to_triton_gpu_passes_inc_gen", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:ControlFlowDialect", + "@llvm-project//mlir:GPUDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:IndexDialect", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:TransformUtils", + "@llvm-project//mlir:Transforms", + ], +) + +cc_library( + name = "TritonLLVMIR", + srcs = glob([ + "lib/Target/LLVMIR/*.cpp", + "lib/Target/LLVMIR/*.h", + ]), + hdrs = glob(["include/triton/Target/LLVMIR/*.h"]), + copts = _no_unused_variable, + includes = ["include"], + deps = [ + ":TritonTransforms", + ":triton_target_llvmir_passes_inc_gen", + "@llvm-project//llvm:Analysis", + "@llvm-project//llvm:BinaryFormat", + "@llvm-project//llvm:Core", + "@llvm-project//llvm:IPO", + "@llvm-project//llvm:IRReader", + "@llvm-project//llvm:InstCombine", + "@llvm-project//llvm:Linker", + "@llvm-project//llvm:MC", + "@llvm-project//llvm:Passes", + "@llvm-project//llvm:Support", + "@llvm-project//llvm:Target", + "@llvm-project//mlir:ArithToLLVM", + "@llvm-project//mlir:BuiltinToLLVMIRTranslation", + "@llvm-project//mlir:ConversionPasses", + "@llvm-project//mlir:ExecutionEngine", + "@llvm-project//mlir:ExecutionEngineUtils", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:IndexToLLVM", + "@llvm-project//mlir:LLVMDialect", + "@llvm-project//mlir:LLVMIRTransforms", + "@llvm-project//mlir:LLVMToLLVMIRTranslation", + "@llvm-project//mlir:NVVMToLLVMIRTranslation", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:ROCDLToLLVMIRTranslation", + "@llvm-project//mlir:SCFToControlFlow", + "@llvm-project//mlir:ToLLVMIRTranslation", + "@llvm-project//mlir:Transforms", + # copybara:uncomment "//third_party/py/triton/google:find_cuda", + ], +) + +cc_library( + name = "TritonPTX", + srcs = glob([ + "lib/Target/PTX/*.cpp", + ]), + hdrs = glob(["include/triton/Target/PTX/*.h"]), + includes = ["include"], + deps = ["@llvm-project//llvm:Support"], +) + +cc_library( + name = "TritonHSACO", + srcs = glob([ + "lib/Target/HSACO/*.cpp", + ]), + hdrs = glob(["include/triton/Target/HSACO/*.h"]), + includes = ["include"], + deps = [ + ":TritonLLVMIR", + ":TritonTools", + "@llvm-project//llvm:Core", + "@llvm-project//llvm:ExecutionEngine", + "@llvm-project//llvm:MC", + "@llvm-project//llvm:Scalar", + "@llvm-project//llvm:Support", + "@llvm-project//llvm:Target", + "@llvm-project//llvm:TransformUtils", + "@llvm-project//mlir:ExecutionEngine", + "@llvm-project//mlir:ExecutionEngineUtils", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:LLVMDialect", + "@llvm-project//mlir:LLVMToLLVMIRTranslation", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:ToLLVMIRTranslation", + ], +) + +cc_library( + name = "TritonTools", + hdrs = ["include/triton/Tools/Sys/GetEnv.hpp"], + includes = ["include"], +) + +cc_binary( + name = "triton-opt", + srcs = [ + "bin/RegisterTritonDialects.h", + "bin/triton-opt.cpp", + "include/triton/Conversion/TritonToTritonGPU/Passes.h", + "include/triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h", + ], + includes = ["include"], + deps = [ + ":TritonDialects", + ":TritonGPUToLLVM", + ":TritonGPUTransforms", + ":TritonLLVMIR", + ":TritonNvidiaGPUTransforms", + ":TritonToTritonGPU", + ":TritonTransforms", + ":triton_conversion_triton_to_triton_gpu_passes_inc_gen", + ":triton_nvidia_gpu_transforms_inc_gen", + "@llvm-project//llvm:Support", + "@llvm-project//llvm:ir_headers", + "@llvm-project//mlir:AllPassesAndDialects", + "@llvm-project//mlir:ControlFlowDialect", + "@llvm-project//mlir:ConversionPasses", + "@llvm-project//mlir:ExecutionEngine", + "@llvm-project//mlir:ExecutionEngineUtils", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:LLVMCommonConversion", + "@llvm-project//mlir:LLVMDialect", + "@llvm-project//mlir:LLVMToLLVMIRTranslation", + "@llvm-project//mlir:MlirOptLib", + "@llvm-project//mlir:NVVMDialect", + "@llvm-project//mlir:Parser", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:ROCDLDialect", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:ToLLVMIRTranslation", + "@llvm-project//mlir:TransformUtils", + "@llvm-project//mlir:Transforms", + "//third_party/triton/test:TritonTestAnalysis", + "//third_party/triton/third_party/nvidia:NVGPUToLLVM", + "//third_party/triton/third_party/nvidia:TritonNVIDIAGPUToLLVM", + ], +) + +cc_binary( + name = "triton-llvm-opt", + srcs = [ + "bin/triton-llvm-opt.cpp", + "lib/Target/LLVMIR/LLVMPasses.h", + ], + includes = [ + ".", # because it includes "lib/Target/LLVMIR/LLVMPasses.h" + "include", + ], + deps = [ + ":TritonLLVMIR", + "@llvm-project//llvm:CodeGen", + "@llvm-project//llvm:Core", + "@llvm-project//llvm:IRReader", + "@llvm-project//llvm:Option", + "@llvm-project//llvm:Passes", + "@llvm-project//llvm:Support", + "@llvm-project//llvm:TargetParser", + ], +) + +# See go/triton-debug for usage. +cc_binary( + name = "triton-reduce", + srcs = [ + "bin/RegisterTritonDialects.h", + "bin/triton-reduce.cpp", + ], + includes = [ + "include", + ], + deps = [ + ":TritonDialects", + ":TritonGPUToLLVM", + ":TritonGPUTransforms", + ":TritonLLVMIR", + ":TritonNvidiaGPUTransforms", + ":TritonToTritonGPU", + ":TritonTransforms", + "@llvm-project//mlir:AllPassesAndDialects", + "@llvm-project//mlir:MlirReduceLib", + "@llvm-project//mlir:NVVMDialect", + "@llvm-project//mlir:ROCDLDialect", + "//third_party/triton/test:TritonTestAnalysis", + "//third_party/triton/third_party/nvidia:NVGPUToLLVM", + "//third_party/triton/third_party/nvidia:TritonNVIDIAGPUToLLVM", + ], +) diff --git a/cmake/llvm-hash.txt b/cmake/llvm-hash.txt index 461e863d0b58..fcb66f04b6bc 100644 --- a/cmake/llvm-hash.txt +++ b/cmake/llvm-hash.txt @@ -1 +1 @@ -4017f04e310454ccced4c404a23f7698eec735ca +6f44bb7717897191be25aa01161831c67cdf5b84 diff --git a/include/triton/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVMBase.h b/include/triton/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVMBase.h index 3d74d6d8cf20..da4ff8177b8e 100644 --- a/include/triton/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVMBase.h +++ b/include/triton/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVMBase.h @@ -68,6 +68,18 @@ class ElementwiseOpConversionBase : public ConvertOpToLLVMPattern { : ConvertOpToLLVMPattern(typeConverter, benefit), axisAnalysisPass(axisAnalysisPass) {} + // True if elements allocated to a thread are contiguous within the axis. This + // is not the case in MMA-like encodings wherea thread might have elements + // (0,0),(0,1) and (8,0),(8,1) for example. The problem with this is that the + // deduplication mechanism assumes that for example constancy=4 and + // elements/thread=4 that if a thread has all elements constant. + bool contiguouslyMapped(Attribute encoding) const { + if (auto slice = encoding.dyn_cast()) { + return contiguouslyMapped(slice.getParent()); + } + return encoding.isa(); + } + // Try to deduplicate the resultVals based on the // constancy properties of the result discovered by // the axis analysis pass. If possible, redundant @@ -93,8 +105,7 @@ class ElementwiseOpConversionBase : public ConvertOpToLLVMPattern { if (!encoding) // encoding not available return resultVals; - if (!encoding.dyn_cast() && - !encoding.dyn_cast()) { + if (!contiguouslyMapped(encoding)) { // TODO: constraining the ecndoing type here is necessary for avoiding // crashes in the getElemsPerThread call below happening in the // test_core::test_fp8_dot_acc diff --git a/include/triton/Conversion/TritonGPUToLLVM/Utility.h b/include/triton/Conversion/TritonGPUToLLVM/Utility.h index 2c3480655a50..9f2e6fb2ec7f 100644 --- a/include/triton/Conversion/TritonGPUToLLVM/Utility.h +++ b/include/triton/Conversion/TritonGPUToLLVM/Utility.h @@ -1250,6 +1250,8 @@ loadSharedToDistributed(Value dst, ArrayRef> dstIndices, srcTy.getEncoding().cast(); auto srcElemTy = srcTy.getElementType(); auto dstElemTy = dstTy.getElementType(); + LDBG("loadSharedToDistributed elemTy " << elemTy << " srcElemTy " << srcElemTy + << " dstElemTy " << dstElemTy); auto inOrd = triton::gpu::getOrder(srcSharedLayout); auto outOrd = triton::gpu::getOrder(dstDistributedLayout); unsigned outVec = inOrd == outOrd @@ -1281,7 +1283,7 @@ loadSharedToDistributed(Value dst, ArrayRef> dstIndices, auto valVec = load(wordTy, smemAddr); valVec.setAlignment(minVec * elemTy.getIntOrFloatBitWidth() / 8); for (unsigned v = 0; v < minVec; ++v) { - Value currVal = extract_element(dstElemTy, valVec, i32_val(v)); + Value currVal = extract_element(elemTy, valVec, i32_val(v)); outVals[i * minVec + v] = currVal; } } @@ -1407,6 +1409,8 @@ static Value packLLElements(Location loc, << v.value(); } if (v.value().getType() != elementTypes[v.index()]) { + LDBG("type " << type << " structType " << structType); + LDBG("value " << v.value()); emitError(loc) << "invalid element type in packLLEElements. Expected " << elementTypes[v.index()] << " but got " << v.value().getType(); diff --git a/include/triton/Dialect/Triton/IR/Utility.h b/include/triton/Dialect/Triton/IR/Utility.h index 6a7143f9188d..c91d623fab24 100644 --- a/include/triton/Dialect/Triton/IR/Utility.h +++ b/include/triton/Dialect/Triton/IR/Utility.h @@ -184,6 +184,27 @@ template bool isConsecutive(const VecT &vec) { return isConsecutive(ArrayRef(vec)); } +// LLVM's STLExtras.h provides a bunch of functions that work over ranges, but +// it's missing min/max_element until +// https://github.com/llvm/llvm-project/commit/fab2bb8b makes it into Triton. +// TODO(jlebar): Remove this once we have the LLVM helpers. +template auto min_element(R &&Range) { + return std::min_element(llvm::adl_begin(Range), llvm::adl_end(Range)); +} +template +auto min_element(R &&Range, Compare &&C) { + return std::min_element(llvm::adl_begin(Range), llvm::adl_end(Range), + std::forward(C)); +} +template auto max_element(R &&Range) { + return std::max_element(llvm::adl_begin(Range), llvm::adl_end(Range)); +} +template +auto max_element(R &&Range, Compare &&C) { + return std::max_element(llvm::adl_begin(Range), llvm::adl_end(Range), + std::forward(C)); +} + } // namespace triton } // namespace mlir diff --git a/lib/Analysis/Utility.cpp b/lib/Analysis/Utility.cpp index e7c3fcd71ba7..d3a68d94695a 100644 --- a/lib/Analysis/Utility.cpp +++ b/lib/Analysis/Utility.cpp @@ -527,7 +527,8 @@ bool supportMMA(triton::DotOp op, int version) { auto aElemTy = op.getA().getType().getElementType(); auto bElemTy = op.getB().getType().getElementType(); if (version == 3) { - if (triton::tools::getBoolEnv("DISABLE_MMA_V3")) + // TODO(b/311157761): enable mma_v3 + if (!triton::tools::getBoolEnv("ENABLE_MMA_V3")) return false; auto retType = op.getType(); auto retShapePerCTA = getShapePerCTA(retType); diff --git a/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp index b3218f5b78ab..74b82f8648fb 100644 --- a/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp @@ -301,6 +301,26 @@ struct ExternElementwiseOpConversion } }; +template +struct ElementwiseOpConversion + : public ElementwiseOpConversionBase< + SourceOp, ElementwiseOpConversion> { + using Base = + ElementwiseOpConversionBase>; + using Base::Base; + using OpAdaptor = typename Base::OpAdaptor; + + // An interface to support variant DestOp builder. + SmallVector createDestOps(SourceOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, + Type elemTy, MultipleOperandsRange operands, + Location loc) const { + return {rewriter.create(loc, elemTy, operands[0], + adaptor.getAttributes().getValue())}; + } +}; + struct ElementwiseInlineAsmOpConversion : public ConvertOpToLLVMPattern { using Base = ConvertOpToLLVMPattern; @@ -720,6 +740,60 @@ void mlir::triton::populateClampFOpToLLVMPattern( void mlir::triton::populateElementwiseOpToLLVMPatterns( LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, ModuleAxisInfoAnalysis &axisInfoAnalysis, PatternBenefit benefit) { +#define POPULATE_UNARY_OP(SRC_OP, DST_OP) \ + patterns.add>( \ + typeConverter, axisInfoAnalysis, benefit); + + POPULATE_UNARY_OP(arith::TruncIOp, LLVM::TruncOp) + POPULATE_UNARY_OP(arith::ExtSIOp, LLVM::SExtOp) + POPULATE_UNARY_OP(arith::ExtUIOp, LLVM::ZExtOp) + POPULATE_UNARY_OP(arith::FPToUIOp, LLVM::FPToUIOp) + POPULATE_UNARY_OP(arith::UIToFPOp, LLVM::UIToFPOp) + POPULATE_UNARY_OP(math::FloorOp, math::FloorOp) + POPULATE_UNARY_OP(math::LogOp, math::LogOp) + POPULATE_UNARY_OP(math::Log2Op, math::Log2Op) + POPULATE_UNARY_OP(math::CosOp, math::CosOp) + POPULATE_UNARY_OP(math::SinOp, math::SinOp) + POPULATE_UNARY_OP(math::SqrtOp, math::SqrtOp) + POPULATE_UNARY_OP(math::ExpOp, math::ExpOp) + POPULATE_UNARY_OP(math::Exp2Op, math::Exp2Op) + POPULATE_UNARY_OP(math::ErfOp, math::ErfOp) + POPULATE_UNARY_OP(triton::BitcastOp, LLVM::BitcastOp) + POPULATE_UNARY_OP(triton::IntToPtrOp, LLVM::IntToPtrOp) + POPULATE_UNARY_OP(triton::PtrToIntOp, LLVM::PtrToIntOp) +#undef POPULATE_UNARY_OP + +#define POPULATE_BINARY_OP(SRC_OP, DST_OP) \ + patterns.add>( \ + typeConverter, axisInfoAnalysis, benefit); + + POPULATE_BINARY_OP(arith::SubIOp, LLVM::SubOp) // - + POPULATE_BINARY_OP(arith::AddIOp, LLVM::AddOp) // + + POPULATE_BINARY_OP(arith::MulIOp, LLVM::MulOp) // * + POPULATE_BINARY_OP(arith::DivSIOp, LLVM::SDivOp) + POPULATE_BINARY_OP(arith::DivUIOp, LLVM::UDivOp) + POPULATE_BINARY_OP(arith::RemFOp, LLVM::FRemOp) // % + POPULATE_BINARY_OP(arith::RemSIOp, LLVM::SRemOp) + POPULATE_BINARY_OP(arith::RemUIOp, LLVM::URemOp) + POPULATE_BINARY_OP(arith::AndIOp, LLVM::AndOp) // & + POPULATE_BINARY_OP(arith::OrIOp, LLVM::OrOp) // | + POPULATE_BINARY_OP(arith::XOrIOp, LLVM::XOrOp) // ^ + POPULATE_BINARY_OP(arith::ShLIOp, LLVM::ShlOp) // << + POPULATE_BINARY_OP(arith::ShRSIOp, LLVM::AShrOp) // >> + POPULATE_BINARY_OP(arith::ShRUIOp, LLVM::LShrOp) // >> + // fmin (return non-NaN if either op is non-NaN) + POPULATE_BINARY_OP(arith::MinNumFOp, LLVM::MinNumOp) + // fmax (return non-NaN if either op is non-NaN) + POPULATE_BINARY_OP(arith::MaxNumFOp, LLVM::MaxNumOp) + POPULATE_BINARY_OP(arith::MinSIOp, LLVM::SMinOp) // smin + POPULATE_BINARY_OP(arith::MaxSIOp, LLVM::SMaxOp) // smax + POPULATE_BINARY_OP(arith::MinUIOp, LLVM::UMinOp) // umin + POPULATE_BINARY_OP(arith::MaxUIOp, LLVM::UMaxOp) // umax +#undef POPULATE_BINARY_OP + + patterns.add>( + typeConverter, axisInfoAnalysis, benefit); + patterns.add(typeConverter, benefit); patterns.add(typeConverter, axisInfoAnalysis, benefit); patterns.add(typeConverter, axisInfoAnalysis, benefit); diff --git a/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp b/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp index 4721eadaa4dc..0ed8a45c22e3 100644 --- a/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp +++ b/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp @@ -23,7 +23,7 @@ using ttg::SliceEncodingAttr; // Get the highest version supported for the hardware and the dot. static int getMMAVersionSafe(int computeCapability, tt::DotOp op) { int baseVersion = 0; - if (computeCapability < 75) { + if (computeCapability < 80) { baseVersion = 1; } else if (computeCapability < 90) { baseVersion = 2; @@ -307,8 +307,10 @@ class BlockedToMMA : public mlir::RewritePattern { } else { // convert operands - int minBitwidth = - std::min(computeOrigBitWidth(a), computeOrigBitWidth(b)); + // TODO(b/296812125): Fix minBitwidth issue upstream and uncomment. + // int minBitwidth = + // std::min(computeOrigBitWidth(a), computeOrigBitWidth(b)); + int minBitwidth = 0; Type minType = IntegerType::get(ctx, minBitwidth); // convert A operand auto newAEncoding = ttg::DotOperandEncodingAttr::get( diff --git a/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp b/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp index 8d5162a390f0..976967105a1c 100644 --- a/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp +++ b/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp @@ -7,7 +7,19 @@ #include "triton/Dialect/TritonGPU/IR/Dialect.h" #include "triton/Dialect/TritonGPU/Transforms/Passes.h" #include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include +#include +#include #include +#include + +inline bool isPipeliningEnabled() { + const char *s = std::getenv("ENABLE_PIPELINING"); + std::string str(s ? s : ""); + std::transform(str.begin(), str.end(), str.begin(), + [](unsigned char c) { return std::tolower(c); }); + return (str == "on" || str == "true" || str == "1"); +} namespace { @@ -329,7 +341,9 @@ class TritonGPUOptimizeDotOperandsPass mlir::RewritePatternSet patterns(context); patterns.add(context); - if (triton::gpu::TritonGPUDialect::getComputeCapability(m) >= 80) + // TODO(b/291216607): Fix crashes and enable by default. + if (isPipeliningEnabled() && + triton::gpu::TritonGPUDialect::getComputeCapability(m) >= 80) patterns.add(context); patterns.add(context); patterns.add(context); diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp index 3ae111a7a79a..855287245de7 100644 --- a/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp @@ -6,6 +6,7 @@ #include "mlir/IR/IRMapping.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/Interfaces/SideEffectInterfaces.h" +#include "mlir/Support/LLVM.h" #include "triton/Analysis/AxisInfo.h" #include "triton/Dialect/Triton/IR/Types.h" #include "triton/Dialect/Triton/IR/Utility.h" @@ -15,6 +16,7 @@ #include "triton/Dialect/TritonGPU/Transforms/Utility.h" #include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" #include "llvm/ADT/MapVector.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SetVector.h" #include "llvm/Support/Debug.h" @@ -47,6 +49,14 @@ struct PipelinedOpInfo { } // namespace +static bool isMMAv3Dot(Operation *op) { + auto dot = dyn_cast(op); + if (!dot) + return false; + auto enc = dot.getType().getEncoding().dyn_cast(); + return enc && enc.isHopper(); +} + // Replace the ForOp's yield with a new one with the given operands appended. static void appendToYield(scf::ForOp forOp, ArrayRef newOperands) { // Fix up the yield op. @@ -132,23 +142,11 @@ createAsyncCopy(scf::ForOp &forOp, tt::LoadOp loadOp, Value alloc, loadOp.erase(); } -/// Create an async load equivalent to the given load. -static void -createAsyncLoad(scf::ForOp &forOp, tt::LoadOp loadOp, Value alloc, - Value insertIdx, Value extractIdx, Value phase, - llvm::MapVector &opToInfo) { - createAsyncCopy(forOp, loadOp, alloc, insertIdx, extractIdx, opToInfo); -} - // If all the transitive uses of the given value have are used by a convert to // the same dot operand encoding, return true and set the shared encoding that // needs to be used to be compatible with users' layouts. -// -// TODO: Rename, because the name only tells us half the story: We check for all -// users having a dot encoding, but then we return a shared encoding, which is -// surprising given the name. static std::optional -allTransitiveUsesHaveDotEncoding(Value val) { +getSharedEncIfAllUsersAreDotEnc(Value val) { ttg::SharedEncodingAttr attr; for (Operation *user : val.getUsers()) { ttg::SharedEncodingAttr tempAttr; @@ -160,7 +158,7 @@ allTransitiveUsesHaveDotEncoding(Value val) { // use it if it is compatible with the other users. if (!tempAttr) tempAttr = memDesc.getEncoding().cast(); - if (!allTransitiveUsesHaveDotEncoding(user->getResult(0)).has_value()) + if (!getSharedEncIfAllUsersAreDotEnc(user->getResult(0)).has_value()) return std::nullopt; } else { if (!isa(user)) @@ -190,42 +188,6 @@ allTransitiveUsesHaveDotEncoding(Value val) { return attr; } -// TODO: This returns true and *sometimes* sets enc? -bool loadDotOperand(tt::LoadOp loadOp, bool &hasMMAV3, - ttg::SharedEncodingAttr &enc) { - if (loadOp.getResult().hasOneUse()) { - Operation *use = *loadOp.getResult().getUsers().begin(); - if (auto alloc = llvm::dyn_cast(use)) { - auto sharedEnc = - alloc.getType().getEncoding().cast(); - if (sharedEnc.getHasLeadingOffset()) { - // MMA V3 case. - auto newOrder = sharedEnc.getOrder(); - auto ty = loadOp.getType().cast(); - auto oldOrder = ttg::getOrder(ty.getEncoding()); - if (newOrder[0] == oldOrder[0] || newOrder[1] == oldOrder[1]) { - // The operand of MMAv3 is in SharedEncoding and it's order should - // not be changed after FuseTranspositions Pass. So we only pipeline - // the load if the order of the loaded BlockedEncoding is the same - // as the order of the SharedEncoding it is converted to. - // TODO: remove this constraint once the LoadOp supports transpose - // fusion - hasMMAV3 = true; - return true; - } - } - } - } - - std::optional sharedEnc = - allTransitiveUsesHaveDotEncoding(loadOp.getResult()); - if (!sharedEnc.has_value()) { - return false; - } - enc = *sharedEnc; - return true; -} - static ttg::BlockedEncodingAttr getBlockedEncoding(tt::LoadOp loadOp, tt::ModuleAxisInfoAnalysis &axisInfo) { Value src = loadOp.getPtr(); @@ -245,9 +207,8 @@ getBlockedEncoding(tt::LoadOp loadOp, tt::ModuleAxisInfoAnalysis &axisInfo) { threadsPerWarp, ctaLayout); } -static ttg::SharedEncodingAttr getSharedEncoding(tt::LoadOp loadOp, - Operation *use, bool isMMAV3) { - +static std::optional +getSharedEncoding(tt::LoadOp loadOp, bool isMMAV3) { auto ty = loadOp.getType().cast(); auto ctaLayout = ttg::getCTALayout(ty.getEncoding()); auto blockedOrder = ttg::getOrder(ty.getEncoding()); @@ -262,19 +223,36 @@ static ttg::SharedEncodingAttr getSharedEncoding(tt::LoadOp loadOp, } else { order = blockedOrder; } - if (isa(use)) { - assert(isMMAV3 && - "Load used by dot op should be either MMAv3 or have a " - "shared encoding already picked based on users' layouts."); + if (isMMAV3) { return ttg::SharedEncodingAttr::get(ty.getContext(), ty.getShape(), order, ctaLayout, ty.getElementType()); - } else { - assert(!isMMAV3 && "Load used by non-dot op should not be MMAv3."); - // Use non-swizzled layout for loads that do not feed into dot ops. - // TODO: This won't be optimal for 2D tensors. - return ttg::SharedEncodingAttr::get(ty.getContext(), 1, 1, 1, order, - ctaLayout); } + + // If the load is used by a LocalAllocOp, use the same encoding as the allocs. + // If the allocs don't all have the same encoding, bail. + if (llvm::any_of(loadOp->getUsers(), [&](Operation *user) { + return isa(user); + })) { + ttg::SharedEncodingAttr localAllocEnc; + for (auto user : loadOp->getUsers()) { + auto localAlloc = dyn_cast(user); + if (!localAlloc) + continue; + auto enc = + localAlloc.getType().getEncoding().cast(); + if (!localAllocEnc) { + localAllocEnc = enc; + } + if (enc != localAllocEnc) + return std::nullopt; + } + return localAllocEnc; + } + + // Use non-swizzled layout for loads that do not feed into dot ops. + // TODO: This won't be optimal for 2D tensors. + return ttg::SharedEncodingAttr::get(ty.getContext(), 1, 1, 1, order, + ctaLayout); } // Create a map from load ops to their distance to the nearest dot op and the @@ -350,11 +328,34 @@ loadOpsToDistanceAndUse(scf::ForOp forOp) { return loadOpToDistAndUse; } -/// Collect loads to pipeline. Returns true if loads are found to pipeline. +static bool loadIsMMAv3(tt::LoadOp loadOp) { + if (!loadOp->hasOneUse()) + return false; + auto alloc = dyn_cast(*loadOp->getUsers().begin()); + if (!alloc) + return false; + auto sharedEnc = + alloc.getType().getEncoding().cast(); + if (!sharedEnc.getHasLeadingOffset()) + return false; + + // MMA V3 case. + auto newOrder = sharedEnc.getOrder(); + auto ty = loadOp.getType().cast(); + auto oldOrder = ttg::getOrder(ty.getEncoding()); + + // The operand of MMAv3 is in SharedEncoding and its order should not + // be changed after FuseTranspositions Pass. So we only pipeline the + // load if the order of the loaded BlockedEncoding is the same as the + // order of the SharedEncoding it is converted to. + return oldOrder == newOrder; +} + +/// Collect ops to pipeline. Returns true if any ops are found to pipeline. static bool collectOpsToPipeline(scf::ForOp forOp, llvm::MapVector &opInfo, - int numStages, bool &hasMMAV3) { + int numStages) { ModuleOp moduleOp = forOp->getParentOfType(); tt::ModuleAxisInfoAnalysis axisInfoAnalysis(moduleOp); @@ -362,51 +363,60 @@ collectOpsToPipeline(scf::ForOp forOp, llvm::MapVector> loadOpToDistAndUse = loadOpsToDistanceAndUse(forOp); LLVM_DEBUG({ - DBGS() << "Found " << loadOpToDistAndUse.size() << " loads to pipeline:\n"; + LDBG("Found " << loadOpToDistAndUse.size() << " loads to pipeline:"); for (const auto &[k, v] : loadOpToDistAndUse) { - DBGS() << " " << *k << " distance=" << v.first << " use=" << *v.second - << "\n"; + LDBG(" - distance: " << v.first); + LDBG(" op to pipeline: " << *k); + LDBG(" use: " << *v.second); } }); if (loadOpToDistAndUse.empty()) return false; - int maxDistance = -1; - for (auto &[op, distAndUse] : loadOpToDistAndUse) { - if (distAndUse.first > maxDistance) { - maxDistance = distAndUse.first; - } - } - assert(maxDistance >= 0); - // Start by initializing PipelinedOpInfo for users of the loads. for (auto &[loadOp, distAndUse] : loadOpToDistAndUse) opInfo[distAndUse.second] = PipelinedOpInfo(); + int maxDistance = *triton::max_element( + llvm::make_first_range(llvm::make_second_range(loadOpToDistAndUse))); unsigned stagesBetweenLoads = ceil(numStages - 2, maxDistance + 1); // Then consider the load ops that feed into the dot ops or are used by other // loads. for (auto &[loadOp, distAndUse] : loadOpToDistAndUse) { PipelinedOpInfo loadInfo; - bool loadIsMMAV3 = false; if (isa(distAndUse.second)) { - ttg::SharedEncodingAttr sharedEnc; - bool isLoadDotOperand = loadDotOperand(loadOp, loadIsMMAV3, sharedEnc); - hasMMAV3 |= loadIsMMAV3; - if (!isLoadDotOperand) + if (loadIsMMAv3(loadOp)) { + loadInfo.loadIsMMAV3 = true; + loadInfo.sharedEncoding = + getSharedEncoding(loadOp, /*loadIsMMAv3=*/true).value_or(nullptr); + } else { + loadInfo.sharedEncoding = + getSharedEncIfAllUsersAreDotEnc(loadOp.getResult()) + .value_or(nullptr); + } + // TODO(jlebar): Remove this if statement, which effectively rolls back + // back https://github.com/openai/triton/pull/3415, once internal bugs are + // fixed. + if (!loadInfo.sharedEncoding) continue; - loadInfo.sharedEncoding = sharedEnc; - } else { - loadInfo.blockedEncoding = getBlockedEncoding(loadOp, axisInfoAnalysis); } - // If we haven't already assigned a layout do it now. - if (!loadInfo.sharedEncoding) + + // If we still don't have a shared encoding, try a "generic" shared + // encoding. + if (!loadInfo.sharedEncoding && !isMMAv3Dot(distAndUse.second)) { loadInfo.sharedEncoding = - getSharedEncoding(loadOp, distAndUse.second, loadIsMMAV3); - loadInfo.loadIsMMAV3 = loadIsMMAV3; - int stage = (maxDistance - distAndUse.first) * stagesBetweenLoads; - loadInfo.stage = stage; + getSharedEncoding(loadOp, /*isMMAV3=*/loadInfo.loadIsMMAV3) + .value_or(nullptr); + loadInfo.blockedEncoding = getBlockedEncoding(loadOp, axisInfoAnalysis); + } + + // If that still didn't work, bail on pipelining this load. + if (!loadInfo.sharedEncoding) { + continue; + } + + loadInfo.stage = (maxDistance - distAndUse.first) * stagesBetweenLoads; loadInfo.use = distAndUse.second; opInfo[loadOp] = loadInfo; } @@ -444,9 +454,9 @@ static Value createAlloc(scf::ForOp &forOp, tt::LoadOp loadOp, // Convert load ops into their asyn version and apply multi-buffering based on // the required number of buffers. static SmallVector -createAsynOps(scf::ForOp &forOp, - llvm::MapVector &opToInfo, - int numBuffers, bool hasMMAV3) { +createAsyncOps(scf::ForOp &forOp, + llvm::MapVector &opToInfo, + int numBuffers, bool hasMMAV3) { struct AsyncLoad { AsyncLoad(tt::LoadOp loadOp, Value alloc) : loadOp(loadOp), alloc(alloc) {} tt::LoadOp loadOp; @@ -503,8 +513,8 @@ createAsynOps(scf::ForOp &forOp, extractIdx = builder.create(loc, cndExt, extractIdx, zero); for (AsyncLoad &asyncLoad : asyncLoads) { - createAsyncLoad(forOp, asyncLoad.loadOp, asyncLoad.alloc, insertIdx, - extractIdx, phase, opToInfo); + createAsyncCopy(forOp, asyncLoad.loadOp, asyncLoad.alloc, insertIdx, + extractIdx, opToInfo); } SmallVector newYieldOperands = {insertIdx, extractIdx}; // Patch the yield with the updated counters. @@ -711,10 +721,12 @@ bool mlir::triton::preProcessLoopAndGetSchedule( // 1. First collect "interesting" operations with a stage where to schedule // them. This gives a coarse scheduling for the loop. llvm::MapVector opToInfo; - bool hasMMAV3 = false; - if (!collectOpsToPipeline(forOp, opToInfo, numStages, hasMMAV3)) + if (!collectOpsToPipeline(forOp, opToInfo, numStages)) return false; + bool hasMMAV3 = + llvm::any_of(opToInfo, [](auto &kv) { return kv.second.loadIsMMAV3; }); + // Calculate the number of buffers needed for each load. // TODO pawel: we could do more fine-grained allocation here and // allocate only the number of buffers that specific loads need. @@ -742,7 +754,7 @@ bool mlir::triton::preProcessLoopAndGetSchedule( // 2. Convert the loads into async loads and create the allocs. SmallVector allocs = - createAsynOps(forOp, opToInfo, maxNumBuffers, hasMMAV3); + createAsyncOps(forOp, opToInfo, maxNumBuffers, hasMMAV3); // 3. Create the final schedule for the kernel loop. This will dictate the // stages and order of operations to the pipeline expander. @@ -1157,9 +1169,7 @@ void triton::asyncLaunchDots(scf::ForOp forOp) { // "properly async", or sometimes just "async". IRRewriter builder(forOp.getContext()); for (auto dotOp : llvm::to_vector(forOp.getBody()->getOps())) { - auto resEnc = - dotOp.getType().getEncoding().dyn_cast(); - if (resEnc && resEnc.isHopper()) { + if (isMMAv3Dot(dotOp)) { builder.setInsertionPoint(dotOp); builder.replaceOpWithNewOp( dotOp, dotOp.getA(), dotOp.getB(), dotOp.getC(), dotOp.getAllowTF32(), diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeliner/OuterLoopPipeline.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeliner/OuterLoopPipeline.cpp index 957c18830156..812c81fd600b 100644 --- a/lib/Dialect/TritonGPU/Transforms/Pipeliner/OuterLoopPipeline.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Pipeliner/OuterLoopPipeline.cpp @@ -55,8 +55,14 @@ createSchedule(scf::ForOp forOp, int numStages) { static void hoistAllocAndConst(scf::ForOp forOp) { SmallVector toHoist; for (Operation &op : forOp.getBody()->without_terminator()) { - if (isa(op)) + if (auto allocOp = dyn_cast(op)) { + // We hoist the allocOp only if it is created by the inner loop + // pipelining. + if (!allocOp.getInit()) + toHoist.push_back(&op); + } else if (isa(op)) { toHoist.push_back(&op); + } } for (Operation *op : toHoist) { op->moveBefore(forOp); diff --git a/lib/Dialect/TritonNvidiaGPU/Transforms/FenceInsertion.cpp b/lib/Dialect/TritonNvidiaGPU/Transforms/FenceInsertion.cpp index 2187dc54e57d..55834b94007c 100644 --- a/lib/Dialect/TritonNvidiaGPU/Transforms/FenceInsertion.cpp +++ b/lib/Dialect/TritonNvidiaGPU/Transforms/FenceInsertion.cpp @@ -40,7 +40,8 @@ struct FenceInsertionPass // Only insert fences for compute capability 9.0 if (computeCapability < 90) return; - if (::triton::tools::getBoolEnv("DISABLE_MMA_V3")) + // TODO(b/311157761): enable mma_v3 + if (!::triton::tools::getBoolEnv("ENABLE_MMA_V3")) return; ModuleOp mod = getOperation(); mod.walk([&](Operation *op) { diff --git a/python/BUILD b/python/BUILD new file mode 100644 index 000000000000..9bc774e50099 --- /dev/null +++ b/python/BUILD @@ -0,0 +1,91 @@ +# NOTE: Do not depend on any targets from this directory, +# but use //third_party/py/triton instead. + +load("@pybind11_bazel//:build_defs.bzl", "pybind_extension") + +package( + default_applicable_licenses = ["//:license"], + default_visibility = [ + "//third_party/py/triton:__pkg__", + "//third_party/triton/python:__subpackages__", + ], +) + +cc_library( + name = "passes", + hdrs = ["src/passes.h"], + includes = ["src"], + visibility = ["//third_party/triton/third_party:__subpackages__"], +) + +pybind_extension( + name = "libtriton", + srcs = [ + "src/interpreter.cc", + "src/ir.cc", + "src/llvm.cc", + "src/main.cc", + "src/passes.cc", + ], + copts = ["-DTRITON_BACKENDS_TUPLE=(nvidia)"], + deps = [ + ":passes", + "@llvm-project//llvm:Core", + "@llvm-project//llvm:IPO", + "@llvm-project//llvm:IRReader", + "@llvm-project//llvm:InstCombine", + "@llvm-project//llvm:Linker", + "@llvm-project//llvm:MC", + "@llvm-project//llvm:Passes", + "@llvm-project//llvm:Support", + "@llvm-project//llvm:Target", + "@llvm-project//mlir:BuiltinToLLVMIRTranslation", + "@llvm-project//mlir:BytecodeWriter", + "@llvm-project//mlir:ControlFlowDialect", + "@llvm-project//mlir:ConversionPasses", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:IndexDialect", + "@llvm-project//mlir:LLVMDialect", + "@llvm-project//mlir:LLVMToLLVMIRTranslation", + "@llvm-project//mlir:NVVMToLLVMIRTranslation", + "@llvm-project//mlir:Parser", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:ToLLVMIRTranslation", + "@llvm-project//mlir:Transforms", + "//:TritonAnalysis", + "//:TritonDialects", + "//:TritonGPUToLLVM", + "//:TritonGPUTransforms", + "//:TritonHSACO", + "//:TritonLLVMIR", + "//:TritonNvidiaGPUTransforms", + "//:TritonPTX", + "//:TritonToTritonGPU", + "//:TritonTools", + "//:TritonTransforms", + "//third_party/triton/third_party/nvidia:triton_nvidia", + ], +) + +pybind_extension( + name = "triton_launcher", + srcs = [ + "triton/compiler/triton_launcher.c", + ], + tags = [ + "config-cuda-only", + "requires-gpu-sm80", + ], + deps = [ + "@local_config_cuda//cuda:cuda_headers", + "@local_config_cuda//cuda:cuda_runtime", + ], +) + +filegroup( + name = "files", + srcs = glob( + include = ["triton/**/*.py"], + ), +) diff --git a/python/src/ir.cc b/python/src/ir.cc index 65bddee98cdf..f2865b0b6d54 100644 --- a/python/src/ir.cc +++ b/python/src/ir.cc @@ -206,7 +206,8 @@ void init_triton_ir(py::module &&m) { }); py::class_(m, "type", py::module_local()) - .def("is_integer", &Type::isInteger) + .def("is_integer", + [](Type &self, unsigned width) { return self.isInteger(width); }) .def("is_fp16", &Type::isF16) .def("__str__", [](Type &self) { std::string str; diff --git a/python/test/regression/BUILD b/python/test/regression/BUILD new file mode 100644 index 000000000000..b6a3534474d1 --- /dev/null +++ b/python/test/regression/BUILD @@ -0,0 +1,27 @@ +load("//third_party/py/pytest:pytest_defs.bzl", "pytest_multi_tests") + +package( + default_applicable_licenses = ["//:license"], +) + +pytest_multi_tests( + name = "tests", + size = "large", + shard_count = 10, + tags = [ + "config-cuda-only", + "requires-gpu-sm80", + ], + tests = glob( + include = ["test_*.py"], + + #TODO(b/321005767): fix failing test + exclude = [ + "test_performance.py", + ], + ), + deps = [ + "//third_party/py/torch:pytorch", + "//third_party/py/triton", + ], +) diff --git a/python/test/unit/BUILD b/python/test/unit/BUILD new file mode 100644 index 000000000000..07ebc8809c33 --- /dev/null +++ b/python/test/unit/BUILD @@ -0,0 +1,107 @@ +load("//third_party/py/pytest:pytest_defs.bzl", "pytest_multi_tests") + +package( + default_applicable_licenses = ["//:license"], +) + +pytest_multi_tests( + name = "hopper", + shard_count = 10, + tags = [ + "config-cuda-only", + "requires-gpu-sm80", + ], + tests = glob( + include = ["hopper/**/test_*.py"], + ), + deps = [ + "//third_party/py/torch:pytorch", + "//third_party/py/triton", + ], +) + +pytest_multi_tests( + name = "language", + size = "large", + srcs = [ + "conftest.py", + "language/conftest.py", + "language/test_core.py", + ], + shard_count = 10, + tags = [ + "config-cuda-only", + "requires-gpu-sm80", + ], + tests = glob( + include = ["language/**/test_*.py"], + exclude = [ + "language/test_subprocess.py", # TODO(b/320224484): fix failing test + "language/test_reproducer.py", # this is not an actual test, but a tool for running reproducers + ], + ), + deps = [ + "//third_party/py/torch:pytorch", + "//third_party/py/triton", + ], +) + +pytest_multi_tests( + name = "operators", + size = "large", + srcs = ["conftest.py"], + shard_count = 10, + tags = [ + "config-cuda-only", + "requires-gpu-sm80", + ], + tests = glob( + [ + "operators/**/test_*.py", + ], + ), + deps = [ + "//third_party/py/torch:pytorch", + "//third_party/py/triton", + ], +) + +pytest_multi_tests( + name = "runtime", + tags = [ + "config-cuda-only", + "requires-gpu-sm80", + ], + tests = + glob( + include = ["runtime/**/test_*.py"], + exclude = [ + "runtime/test_launch.py", #TODO(b/320226169): fix failing tests + ], + ), + deps = [ + "//third_party/py/torch:pytorch", + "//third_party/py/triton", + ], +) + +pytest_multi_tests( + name = "tools", + size = "large", + shard_count = 10, + tags = [ + "config-cuda-only", + "requires-gpu-sm80", + ], + tests = + glob( + include = ["tools/**/test_*.py"], + exclude = [ + "tools/test_aot.py", # TODO(b/320224484): fix failing test + ], + ), + deps = [ + "//third_party/py/torch:pytorch", + "//third_party/py/triton", + ], +) diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index f6e62c2d62a8..7bf9f6376b0d 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -2587,14 +2587,11 @@ def test_store_op(M, src_layout, device): layouts = [ - BlockedLayout([1, 4], [1, 32], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]), - BlockedLayout([1, 4], [1, 32], [2, 2], [1, 0], [1, 1], [1, 1], [0, 1]), + # TODO (lixun): Add MfmaLayout + BlockedLayout([1, 4], [1, THREADS_PER_WARP], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]), + BlockedLayout([1, 4], [1, THREADS_PER_WARP], [2, 2], [1, 0], [1, 1], [1, 1], [0, 1]), MmaLayout(version=(2, 0), warps_per_cta=[4, 1], ctas_per_cga=[1, 1], cta_split_num=[1, 1], cta_order=[0, 1], instr_shape=[16, 8]) -] if not is_hip() else [ - # TODO (lixun): Add MfmaLayout - BlockedLayout([1, 4], [1, 64], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]), - BlockedLayout([1, 4], [1, 64], [2, 2], [1, 0], [1, 1], [1, 1], [0, 1]) ] @@ -3256,7 +3253,7 @@ def kernel(X, Y, Z, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.co @pytest.mark.parametrize('in_dtype', ['float32']) -def test_dot_mulbroadcastred(in_dtype, device): +def test_dot_mulbroadcasted(in_dtype, device): if is_cuda(): capability = torch.cuda.get_device_capability() if capability[0] < 8: @@ -3300,10 +3297,9 @@ def kernel(Z, X, Y, M: tl.constexpr, N: tl.constexpr, K: tl.constexpr, BM: tl.co if not is_cuda(): return assert "tt.dot" in h.asm['ttir'] - # when using MMAv3, we will not pipeline the load op for Y - # as the loaded value is in rowmajor. But MMAv3 requires it's second - # operand is in colmajor because transpose is not supported for MMAv3 - # with float32 input. + # When using MMAv3, we will not pipeline the load op for Y, as the loaded + # value is in rowmajor. But MMAv3 requires its second operand is in colmajor + # because transpose is not supported for MMAv3 with float32 input. if capability[0] >= 9: assert re.search(r"triton_gpu.async_wait %.* {num = 1 : i32}", h.asm["ttgir"]) is not None else: diff --git a/python/triton/_C/include b/python/triton/_C/include index b85a409837d1..8a5dba6c4b56 120000 --- a/python/triton/_C/include +++ b/python/triton/_C/include @@ -1 +1 @@ -../../../include/ \ No newline at end of file +../../../include \ No newline at end of file diff --git a/python/triton/backends/__init__.py b/python/triton/backends/__init__.py index fbf65d9e908f..5d8fb01b1191 100644 --- a/python/triton/backends/__init__.py +++ b/python/triton/backends/__init__.py @@ -46,5 +46,8 @@ def _discover_backends(): _find_concrete_subclasses(driver, DriverBase)) return backends - -backends = _discover_backends() +from triton.backends.nvidia.driver import CudaDriver +from triton.backends.nvidia.compiler import CUDABackend +backends = { + "nvidia": Backend(CUDABackend, CudaDriver) +} diff --git a/python/triton/compiler/compiler.py b/python/triton/compiler/compiler.py index aeda2dd680f4..b5009503f365 100644 --- a/python/triton/compiler/compiler.py +++ b/python/triton/compiler/compiler.py @@ -102,7 +102,9 @@ def __init__(self, fn, signature, constants=None, attrs=None) -> None: def hash(self): sorted_sig = [v for k, v in sorted(self.signature.items())] - sorted_constants = [(k, v) for k, v in sorted(self.constants.items())] + # Note - we stringify the keys here to allow sorting to work for cases + # where constants have mixed int/str keys. + sorted_constants = sorted((str(k), v) for k, v in self.constants.items()) key = f"{self.fn.cache_key}-{self.attrs.hash()}-{sorted_sig}-{sorted_constants}" return hashlib.sha256(key.encode("utf-8")).hexdigest() diff --git a/python/triton/language/semantic.py b/python/triton/language/semantic.py index ac8e5af53201..67c9bd7fec68 100644 --- a/python/triton/language/semantic.py +++ b/python/triton/language/semantic.py @@ -1406,19 +1406,14 @@ def wrap_tensor(x, scalar_ty, ret_shape): def reduction(inputs: Sequence[tl.tensor], axis: int, region_builder_fn, builder: ir.builder) -> Tuple[tl.tensor, ...]: if axis is None: - new_inputs = [] - for i in range(len(inputs)): - new_shape = [inputs[i].numel.value] - new_inputs.append(view(inputs[i], new_shape, builder)) - inputs = tuple(new_inputs) + inputs = tuple(view(t, [t.numel.value], builder) for t in inputs) axis = 0 # get result shape shape = inputs[0].type.shape rank = len(shape) assert axis < rank, f"reduction axis must be < inputs rank ({rank})" ret_shape = [s for i, s in enumerate(shape) if i != axis] - for t in inputs: - assert t.type.shape == shape, "all reduction inputs must have the same shape" + assert all(t.type.shape == shape for t in inputs), "all reduction inputs must have the same shape" reduce_op = builder.create_reduce([t.handle for t in inputs], axis) region_builder_fn(reduce_op) @@ -1502,9 +1497,7 @@ def device_print(prefix: str, args: List[tl.tensor], hex: bool, builder: ir.buil if len(prefix) > 2 and not prefix.startswith(" "): prefix = " " + prefix - new_args = [] - for arg in args: - new_args.append(arg.handle) + new_args = [arg.handle for arg in args] return tl.tensor(builder.create_print(prefix, hex, new_args), tl.void) diff --git a/test/BUILD b/test/BUILD new file mode 100644 index 000000000000..08fcae5cd788 --- /dev/null +++ b/test/BUILD @@ -0,0 +1,61 @@ +# copybara:uncomment_begin +# load("//third_party/llvm/build_defs:lit.bzl", "glob_lit_tests") +# load("//tools/build_defs/build_test:build_test.bzl", "build_test") +# +# package( +# default_applicable_licenses = ["//:license"], +# default_compatible_with = ["//buildenv/target:gce"], +# default_visibility = ["//:__subpackages__"], +# ) +# +# glob_lit_tests( +# name = "all_tests", +# data = [ +# "@llvm-project//llvm:FileCheck", +# "//:triton-llvm-opt", +# "//:triton-opt", +# ], +# driver = "@llvm-project//mlir:run_lit.sh", +# exclude = [ +# # TODO(b/283035396): broken by cl536931041.patch +# "TritonGPU/dot-operands.mlir", +# "TritonGPU/optimize_epilogue.mlir", # AMD-specific +# ], +# test_file_exts = [ +# "mlir", +# "ll", +# ], +# ) +# +# build_test( +# name = "build_test", +# allow_empty_target = False, +# targets = [ +# "//:TritonAnalysis", +# "//:TritonDialects", +# "//:TritonGPUToLLVM", +# "//:TritonGPUTransforms", +# "//:TritonLLVMIR", +# "//:TritonPTX", +# "//:TritonToTritonGPU", +# "//:TritonTools", +# "//:TritonTransforms", +# "//:triton-opt", +# ], +# ) +# copybara:uncomment_end + +cc_library( + name = "TritonTestAnalysis", + srcs = glob(["lib/Analysis/*.cpp"]), + deps = [ + "@llvm-project//mlir:GPUDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:SCFToControlFlow", + "@llvm-project//mlir:TransformUtils", + "@llvm-project//mlir:Transforms", + "//:TritonAnalysis", + "//:TritonDialects", + ], +) diff --git a/test/Conversion/tritongpu_to_llvm.mlir b/test/Conversion/tritongpu_to_llvm.mlir index a65bd652514c..058e4681edd3 100644 --- a/test/Conversion/tritongpu_to_llvm.mlir +++ b/test/Conversion/tritongpu_to_llvm.mlir @@ -1561,3 +1561,19 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : tt.return } } + +// ----- +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [1, 32], warpsPerCTA = [1, 8], order = [1, 0]}> +#shared = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0], hasLeadingOffset = false}> +module attributes {"triton_gpu.compute-capability" = 80 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { + // CHECK-LABEL: test_local_load_bf16 + // CHECK: llvm.extractelement {{.*}} : vector<8xi16> + tt.func public @test_local_load_bf16() { + %c0_i32 = arith.constant 0 : i32 + %19 = triton_gpu.local_alloc : () -> !tt.memdesc<1x1x2048xbf16, #shared, mutable> + %22 = triton_gpu.memdesc_subview %19[%c0_i32, %c0_i32, %c0_i32] : !tt.memdesc<1x1x2048xbf16, #shared, mutable> -> !tt.memdesc<1x2048xbf16, #shared, mutable> + %39 = triton_gpu.local_load %22 : !tt.memdesc<1x2048xbf16, #shared, mutable> -> tensor<1x2048xbf16, #blocked> + %40 = arith.extf %39 : tensor<1x2048xbf16, #blocked> to tensor<1x2048xf32, #blocked> + tt.return + } +} diff --git a/test/Conversion/tritongpu_to_llvm_hopper.mlir b/test/Conversion/tritongpu_to_llvm_hopper.mlir index f514cbc950d9..1cf141da52de 100644 --- a/test/Conversion/tritongpu_to_llvm_hopper.mlir +++ b/test/Conversion/tritongpu_to_llvm_hopper.mlir @@ -1,4 +1,4 @@ -// RUN: triton-opt %s -split-input-file --decompose-unsupported-conversions --allocate-shared-memory --convert-triton-gpu-to-llvm=compute-capability=90 2>&1 | FileCheck %s +// RUN: ENABLE_MMA_V3=1 triton-opt %s -split-input-file --decompose-unsupported-conversions --allocate-shared-memory --convert-triton-gpu-to-llvm=compute-capability=90 2>&1 | FileCheck %s #mma = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 256, 32]}> #shared = #triton_gpu.shared<{vec = 16, perPhase = 4, maxPhase = 2, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], hasLeadingOffset = true}> diff --git a/test/TritonGPU/accelerate-matmul.mlir b/test/TritonGPU/accelerate-matmul.mlir index ed4f0d3088ca..a2d8886807b3 100644 --- a/test/TritonGPU/accelerate-matmul.mlir +++ b/test/TritonGPU/accelerate-matmul.mlir @@ -1,5 +1,5 @@ -// RUN: triton-opt %s -split-input-file --tritongpu-accelerate-matmul=compute-capability=90 | FileCheck %s -// RUN: triton-opt %s -split-input-file --tritongpu-accelerate-matmul=compute-capability=80 | FileCheck %s --check-prefix=CHECK-80 +// RUN: ENABLE_MMA_V3=1 triton-opt %s -split-input-file --tritongpu-accelerate-matmul=compute-capability=90 | FileCheck %s +// RUN: triton-opt %s -split-input-file --tritongpu-accelerate-matmul=compute-capability=80 | FILECHECK_OPTS= FileCheck %s --check-prefix=CHECK-80 // CHECK: #[[MMA:.+]] = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 16, 16]}> // CHECK: #[[MMA1:.+]] = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 64, 16]}> diff --git a/test/TritonGPU/fence-inserstion.mlir b/test/TritonGPU/fence-inserstion.mlir index 3b2fe3633937..bcd56a34e5d1 100644 --- a/test/TritonGPU/fence-inserstion.mlir +++ b/test/TritonGPU/fence-inserstion.mlir @@ -1,4 +1,4 @@ -// RUN: triton-opt %s -split-input-file --triton-nvidia-gpu-fence-insertion | FileCheck %s +// RUN: ENABLE_MMA_V3=1 triton-opt %s -split-input-file --triton-nvidia-gpu-fence-insertion | FileCheck %s #blocked = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [2, 16], warpsPerCTA = [8, 1], order = [1, 0]}> #blocked2 = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [16, 2], warpsPerCTA = [1, 8], order = [0, 1]}> diff --git a/test/TritonGPU/loop-pipeline.mlir b/test/TritonGPU/loop-pipeline.mlir index dc52a7fb2bfb..34b18db00206 100644 --- a/test/TritonGPU/loop-pipeline.mlir +++ b/test/TritonGPU/loop-pipeline.mlir @@ -722,7 +722,7 @@ module attributes {"triton_gpu.compute-capability" = 80 : i32, "triton_gpu.num-c #blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> #mma = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 2], instrShape = [16, 8]}> module attributes {"triton_gpu.compute-capability" = 80 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { - tt.func public @nested_loops(%arg0: !tt.ptr {tt.divisibility = 16 : i32} loc(unknown), %arg1: !tt.ptr {tt.divisibility = 16 : i32} loc(unknown), %arg2: !tt.ptr {tt.divisibility = 16 : i32} loc(unknown), %arg3: !tt.ptr {tt.divisibility = 16 : i32} loc(unknown)) attributes {noinline = false} { + tt.func public @nested_loops(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: !tt.ptr {tt.divisibility = 16 : i32}) attributes {noinline = false} { %cst = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #mma> %cst_0 = arith.constant dense<320> : tensor<32x1xi32, #blocked> %c0_i32 = arith.constant 0 : i32 @@ -784,7 +784,7 @@ module attributes {"triton_gpu.compute-capability" = 80 : i32, "triton_gpu.num-c #shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 4, order = [0, 1], hasLeadingOffset = false}> #shared1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 4, order = [1, 0], hasLeadingOffset = false}> module attributes {"triton_gpu.compute-capability" = 80 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { - tt.func public @_jagged_hstu_attn_fwd_0d1d2d3d4d5de(%arg0: !tt.ptr {tt.divisibility = 16 : i32} loc(unknown), %arg1: !tt.ptr {tt.divisibility = 16 : i32} loc(unknown), %arg2: !tt.ptr {tt.divisibility = 16 : i32} loc(unknown), %arg3: !tt.ptr {tt.divisibility = 16 : i32} loc(unknown), %arg4: !tt.ptr {tt.divisibility = 16 : i32} loc(unknown), %arg5: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32} loc(unknown)) attributes {noinline = false} { + tt.func public @_jagged_hstu_attn_fwd_0d1d2d3d4d5de(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: !tt.ptr {tt.divisibility = 16 : i32}, %arg4: !tt.ptr {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}) attributes {noinline = false} { %cst = arith.constant dense<0.000000e+00> : tensor<64x32xf32, #mma> %c64_i32 = arith.constant 64 : i32 %c0_i32 = arith.constant 0 : i32 @@ -1024,3 +1024,149 @@ module attributes {"triton_gpu.compute-capability" = 90 : i32, "triton_gpu.num-c tt.return } } + + +// ----- + +// CHECK-LABEL: @nested_loops +// CHECK: tt.addptr %{{.*}}, {{.*}} +// CHECK: %[[NEXT_BUFFER_1:.*]] = tt.addptr %{{.*}}, {{.*}} +// CHECK: %[[BUFFER_1:.*]] = triton_gpu.local_alloc +// CHECK: %[[SUBVIEW_1:.*]] = triton_gpu.memdesc_subview %[[BUFFER_1]] +// CHECK: %[[ASYNC_COPY_1:.*]] = triton_gpu.async_copy_global_to_local %[[NEXT_BUFFER_1]], %[[SUBVIEW_1]] +// CHECK: triton_gpu.async_commit_group %[[ASYNC_COPY_1]] +// CHECK: %[[SUBVIEW_2:.*]] = triton_gpu.memdesc_subview %[[BUFFER_1]] +// CHECK: %[[ASYNC_COPY_2:.*]] = triton_gpu.async_copy_global_to_local %[[NEXT_BUFFER_1]], %[[SUBVIEW_2]] +// CHECK: triton_gpu.async_commit_group %[[ASYNC_COPY_2]] +// CHECK: scf.for +// CHECK: %[[LOAD_1:.*]] = tt.load %[[NEXT_BUFFER_1]] +// CHECK: %[[BUFFER_2:.*]] = triton_gpu.local_alloc %[[LOAD_1]] +// CHECK: %[[TRANS:.*]] = tt.trans %[[BUFFER_2]] +// CHECK: %[[LOCAL_LOAD_1:.*]] = triton_gpu.local_load %[[TRANS]] +// CHECK: triton_gpu.async_wait +// CHECK: triton_gpu.memdesc_subview %[[BUFFER_1]] +// CHECK: scf.for +// CHECK: %[[LOCAL_LOAD_2:.*]] = triton_gpu.local_load +// CHECK: %[[DOT:.*]] = tt.dot %[[LOCAL_LOAD_2]], %[[LOCAL_LOAD_1]] +// CHECK: %[[CONVERT_LAYOUT_3:.*]] = triton_gpu.convert_layout %[[DOT]] +// CHECK: %[[SUBVIEW_4:.*]] = triton_gpu.memdesc_subview %[[BUFFER_1]] +// CHECK: %[[ASYNC_COPY_3:.*]] = triton_gpu.async_copy_global_to_local %[[NEXT_BUFFER_1]], %[[SUBVIEW_4]] +// CHECK: triton_gpu.async_commit_group %[[ASYNC_COPY_3]] +// CHECK: triton_gpu.memdesc_subview %[[BUFFER_1]] +// CHECK: %[[SUBVIEW_6:.*]] = triton_gpu.memdesc_subview %[[BUFFER_1]] +// CHECK: %[[ASYNC_COPY_4:.*]] = triton_gpu.async_copy_global_to_local %[[NEXT_BUFFER_1]], %[[SUBVIEW_6]] mask +// CHECK: %[[COMMIT_1:.*]] = triton_gpu.async_commit_group %[[ASYNC_COPY_4]] +// CHECK: %[[SUBVIEW_7:.*]] = triton_gpu.memdesc_subview %[[BUFFER_1]] +// CHECK: %[[ASYNC_COPY_5:.*]] = triton_gpu.async_copy_global_to_local %[[NEXT_BUFFER_1]], %[[SUBVIEW_7]] mask +// CHECK: %[[COMMIT_2:.*]] = triton_gpu.async_commit_group %[[ASYNC_COPY_5]] +// CHECK: scf.yield %[[COMMIT_1]], %[[COMMIT_2]] +// CHECK: triton_gpu.local_dealloc %[[BUFFER_1]] +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [2, 1], order = [1, 0]}> +#mma = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 2], instrShape = [16, 8]}> +#shared = #triton_gpu.shared<{vec = 4, perPhase = 2, maxPhase = 4, order = [1, 0], hasLeadingOffset = false}> +#shared1 = #triton_gpu.shared<{vec = 4, perPhase = 2, maxPhase = 4, order = [0, 1], hasLeadingOffset = false}> +module attributes {"triton_gpu.compute-capability" = 80 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { + tt.func public @nested_loops(%arg0: !tt.ptr {tt.divisibility = 16 : i32}) attributes {noinline = false} { + %cst = arith.constant dense<0.000000e+00> : tensor<16x16xf32, #mma> + %c1_i32 = arith.constant 1 : i32 + %c2_i32 = arith.constant 2 : i32 + %c0_i32 = arith.constant 0 : i32 + %cst_0 = arith.constant dense<16> : tensor<16x1xi32, #blocked> + %0 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %1 = tt.expand_dims %0 {axis = 1 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<16x1xi32, #blocked> + %2 = arith.muli %1, %cst_0 : tensor<16x1xi32, #blocked> + %3 = tt.splat %arg0 : !tt.ptr -> tensor<16x1x!tt.ptr, #blocked> + %4 = tt.addptr %3, %2 : tensor<16x1x!tt.ptr, #blocked>, tensor<16x1xi32, #blocked> + %5 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> + %6 = tt.expand_dims %5 {axis = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> -> tensor<1x16xi32, #blocked> + %7 = tt.broadcast %4 : tensor<16x1x!tt.ptr, #blocked> -> tensor<16x16x!tt.ptr, #blocked> + %8 = tt.broadcast %6 : tensor<1x16xi32, #blocked> -> tensor<16x16xi32, #blocked> + %9 = tt.addptr %7, %8 : tensor<16x16x!tt.ptr, #blocked>, tensor<16x16xi32, #blocked> + scf.for %arg1 = %c0_i32 to %c2_i32 step %c1_i32 : i32 { + %10 = tt.load %9 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x16xf32, #blocked> + %11 = triton_gpu.local_alloc %10 : (tensor<16x16xf32, #blocked>) -> !tt.memdesc<16x16xf32, #shared> + %12 = tt.trans %11 {order = array} : !tt.memdesc<16x16xf32, #shared> -> !tt.memdesc<16x16xf32, #shared1> + %13 = triton_gpu.local_load %12 : !tt.memdesc<16x16xf32, #shared1> -> tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> + scf.for %arg2 = %c0_i32 to %c2_i32 step %c1_i32 : i32 { + %14 = tt.load %9 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x16xf32, #blocked> + %15 = triton_gpu.convert_layout %14 : tensor<16x16xf32, #blocked> -> tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> + %16 = tt.dot %15, %13, %cst {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> -> tensor<16x16xf32, #mma> + %17 = triton_gpu.convert_layout %16 : tensor<16x16xf32, #mma> -> tensor<16x16xf32, #blocked> + tt.store %9, %17 {cache = 1 : i32, evict = 1 : i32} : tensor<16x16xf32, #blocked> + } + } + tt.return + } +} + +// ----- + + // CHECK-LABEL: @int4_matmul_ampere +#blocked = #triton_gpu.blocked<{sizePerThread = [16, 1], threadsPerWarp = [4, 8], warpsPerCTA = [1, 8], order = [0, 1]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [2, 16], warpsPerCTA = [8, 1], order = [1, 0]}> +#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [1, 32], warpsPerCTA = [8, 1], order = [1, 0]}> +#blocked3 = #triton_gpu.blocked<{sizePerThread = [16, 1, 2], threadsPerWarp = [4, 8, 1], warpsPerCTA = [1, 8, 1], order = [2, 0, 1]}> +#blocked4 = #triton_gpu.blocked<{sizePerThread = [16, 2, 1], threadsPerWarp = [4, 1, 8], warpsPerCTA = [1, 1, 8], order = [1, 0, 2]}> +#blocked5 = #triton_gpu.blocked<{sizePerThread = [32, 1], threadsPerWarp = [4, 8], warpsPerCTA = [1, 8], order = [0, 1]}> +#mma = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 8], instrShape = [16, 8]}> +module attributes {"triton_gpu.compute-capability" = 80 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { + tt.func public @int4_matmul_ampere( + %arg0: !tt.ptr {tt.divisibility = 16 : i32}, + %arg1: !tt.ptr {tt.divisibility = 16 : i32} + ) -> tensor<16x256xf32, #mma> attributes {noinline = false} { + %cst = arith.constant dense<64> : tensor<64x256xi32, #blocked> + %cst_0 = arith.constant dense<128> : tensor<16x128xi32, #blocked1> + %c256_i32 = arith.constant 256 : i32 + %c16_i32 = arith.constant 16 : i32 + %c128_i32 = arith.constant 128 : i32 + %cst_1 = arith.constant dense<0.000000e+00> : tensor<16x128xf16, #blocked1> + %c0_i32 = arith.constant 0 : i32 + %c1_i32 = arith.constant 1 : i32 + %c255_i32 = arith.constant 255 : i32 + %c15_i32 = arith.constant 15 : i32 + %cst_2 = arith.constant dense<4> : tensor<64x256xi8, #blocked> + %cst_3 = arith.constant dense<0.000000e+00> : tensor<16x256xf32, #mma> + + %35 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> + %36 = tt.expand_dims %35 {axis = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x128xi32, #blocked1> + %38 = tt.broadcast %36 : tensor<1x128xi32, #blocked1> -> tensor<16x128xi32, #blocked1> + %40 = tt.splat %arg0 : !tt.ptr -> tensor<16x128x!tt.ptr, #blocked1> + %41 = tt.addptr %40, %38 : tensor<16x128x!tt.ptr, #blocked1>, tensor<16x128xi32, #blocked1> + + %42 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %43 = tt.expand_dims %42 {axis = 1 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> + %47 = tt.broadcast %43 : tensor<64x1xi32, #blocked> -> tensor<64x256xi32, #blocked> + %50 = tt.splat %arg1 : !tt.ptr -> tensor<64x256x!tt.ptr, #blocked> + %51 = tt.addptr %50, %47 : tensor<64x256x!tt.ptr, #blocked>, tensor<64x256xi32, #blocked> + + // Check that both loads in the loop are pipelined. + // TODO(jlebar): https://github.com/openai/triton/pull/3472 disables the + // relevant optimization. Once we've reenabled it, we can uncomment this test. + // CHECK: scf.for + // COM: CHECK-NOT: tt.load + // CHECK: triton_gpu.async_copy_global_to_local + // COM: CHECK-NOT: tt.load + // COM: CHECK: triton_gpu.async_copy_global_to_local + // COM: CHECK-NOT: tt.load + // CHECK: scf.yield + %54:3 = scf.for %arg9 = %c0_i32 to %c16_i32 step %c1_i32 iter_args(%arg10 = %cst_3, %arg11 = %41, %arg12 = %51) -> (tensor<16x256xf32, #mma>, tensor<16x128x!tt.ptr, #blocked1>, tensor<64x256x!tt.ptr, #blocked>) : i32 { + %78 = tt.load %arg11 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x128xf16, #blocked1> + %79 = tt.load %arg12 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x256xi8, #blocked> + %80 = arith.shli %79, %cst_2 : tensor<64x256xi8, #blocked> + %81 = arith.shrsi %80, %cst_2 : tensor<64x256xi8, #blocked> + %82 = arith.shrsi %79, %cst_2 : tensor<64x256xi8, #blocked> + %83 = arith.sitofp %81 : tensor<64x256xi8, #blocked> to tensor<64x256xf16, #blocked> + %84 = arith.sitofp %82 : tensor<64x256xi8, #blocked> to tensor<64x256xf16, #blocked> + %85 = tt.join %83, %84 : tensor<64x256xf16, #blocked> -> tensor<64x256x2xf16, #blocked3> + %86 = tt.trans %85 {order = array} : tensor<64x256x2xf16, #blocked3> -> tensor<64x2x256xf16, #blocked4> + %87 = tt.reshape %86 {allow_reorder = false} : tensor<64x2x256xf16, #blocked4> -> tensor<128x256xf16, #blocked5> + %88 = triton_gpu.convert_layout %78 : tensor<16x128xf16, #blocked1> -> tensor<16x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> + %89 = triton_gpu.convert_layout %87 : tensor<128x256xf16, #blocked5> -> tensor<128x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %90 = tt.dot %88, %89, %arg10 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<16x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<128x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<16x256xf32, #mma> + %91 = tt.addptr %arg11, %cst_0 : tensor<16x128x!tt.ptr, #blocked1>, tensor<16x128xi32, #blocked1> + %92 = tt.addptr %arg12, %cst : tensor<64x256x!tt.ptr, #blocked>, tensor<64x256xi32, #blocked> + scf.yield %90, %91, %92 : tensor<16x256xf32, #mma>, tensor<16x128x!tt.ptr, #blocked1>, tensor<64x256x!tt.ptr, #blocked> + } + tt.return %54#0 : tensor<16x256xf32, #mma> + } +} diff --git a/third_party/amd/BUILD b/third_party/amd/BUILD new file mode 100644 index 000000000000..624b43296367 --- /dev/null +++ b/third_party/amd/BUILD @@ -0,0 +1,47 @@ +load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library") + +cc_library( + name = "TritonAMDGPUToLLVM", + srcs = glob([ + "third_party/amd/lib/TritonAMDGPUToLLVM/*.cpp", + "third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM/*.cpp", + "third_party/amd/lib/TritonAMDGPUToLLVM/DotOpToLLVM/*.cpp", + ]), + hdrs = glob([ + "third_party/amd/include/TritonAMDGPUToLLVM/*.h", + "third_party/amd/include/*.h", + "third_party/amd/lib/*.h", + "third_party/amd/lib/TritonAMDGPUToLLVM/*.h", + ]), + copts = _no_unused_variable, + includes = ["include", + "third_party/amd/include"], + deps = [ + ":TritonAnalysis", + ":TritonDialects", + ":TritonGPUToLLVM", + ":triton_conversion_amdgpu_to_llvm_passes_inc_gen", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:LLVMDialect", + "@llvm-project//mlir:NVVMDialect", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:Transforms", + ], +) + +gentbl_cc_library( + name = "triton_conversion_amdgpu_to_llvm_passes_inc_gen", + tbl_outs = [ + ( + [ + "--gen-pass-decls", + "--name=TritonAMDGPUToLLVM", + ], + "third_party/amd/include/TritonAMDGPUToLLVM/Passes.h.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "third_party/amd/include/TritonAMDGPUToLLVM/Passes.td", + deps = ["td_files"], +) diff --git a/third_party/nvidia/BUILD b/third_party/nvidia/BUILD new file mode 100644 index 000000000000..1ebf18326c3f --- /dev/null +++ b/third_party/nvidia/BUILD @@ -0,0 +1,162 @@ +load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library") +load("@pybind11_bazel//:build_defs.bzl", "pybind_library") + +package( + # copybara:uncomment_begin + # default_applicable_licenses = ["//:license"], + # default_visibility = [ + # "//third_party/tensorflow/compiler/xla/service/gpu:__subpackages__", + # "//:__subpackages__", + # ], + # copybara:uncomment_end_and_comment_begin + default_visibility = ["//visibility:public"], + # copybara:comment_end +) + +pybind_library( + name = "triton_nvidia", + srcs = [ + "triton_nvidia.cc", + ], + # copybara:uncomment_begin + # visibility = [ + # "//third_party/triton/python:__subpackages__", + # ], + # copybara:uncomment_end + deps = [ + ":NVGPUToLLVM", + ":TritonNVIDIAGPUToLLVM", + "@llvm-project//llvm:Core", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:NVVMToLLVMIRTranslation", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Transforms", + "//:TritonDialects", + "//:TritonGPUToLLVM", + "//:TritonNvidiaGPUTransforms", + "//third_party/triton/python:passes", + ], +) + +cc_library( + name = "NVGPUToLLVM", + srcs = glob([ + "lib/NVGPUToLLVM/*.cpp", + ]), + hdrs = glob([ + "include/NVGPUToLLVM/*.h", + ]), + # copybara:uncomment_begin + # compatible_with = ["//buildenv/target:gce"], + # copybara:uncomment_end + copts = select({ + "//conditions:default": [ + "-Wno-unused-variable", + ], + }), + includes = [ + "..", + "include", + ], + deps = [ + "triton_conversion_nvgpu_to_llvm_passes_inc_gen", + ":TritonNVIDIAGPUToLLVM", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:LLVMDialect", + "@llvm-project//mlir:NVVMDialect", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TransformUtils", + "@llvm-project//mlir:Transforms", + "//:TritonDialects", + ], +) + +cc_library( + name = "TritonNVIDIAGPUToLLVM", + srcs = glob([ + "lib/TritonNVIDIAGPUToLLVM/*.h", + "lib/TritonNVIDIAGPUToLLVM/**/*.cpp", + ]), + hdrs = glob([ + "include/TritonNVIDIAGPUToLLVM/*.h", + ]) + [ + "lib/TritonNVIDIAGPUToLLVM/Utility.h", + ], + # copybara:uncomment_begin + # compatible_with = ["//buildenv/target:gce"], + # copybara:uncomment_end + copts = select({ + "//conditions:default": [ + "-Wno-reorder-ctor", + "-Wno-unused-variable", + ], + }), + includes = [ + "..", + "include", + "lib/TritonNVIDIAGPUToLLVM", + ], + deps = [ + ":triton_conversion_triton_nvidia_gpu_to_llvm_passes_inc_gen", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:Analysis", + "@llvm-project//mlir:ArithToLLVM", + "@llvm-project//mlir:ControlFlowDialect", + "@llvm-project//mlir:ControlFlowToLLVM", + "@llvm-project//mlir:GPUDialect", + "@llvm-project//mlir:GPUToNVVMTransforms", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:IndexDialect", + "@llvm-project//mlir:LLVMCommonConversion", + "@llvm-project//mlir:LLVMDialect", + "@llvm-project//mlir:MathToLLVM", + "@llvm-project//mlir:NVVMDialect", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:SCFToControlFlow", + "@llvm-project//mlir:TransformUtils", + "@llvm-project//mlir:Transforms", + "//:TritonAnalysis", + "//:TritonDialects", + "//:TritonGPUToLLVM", + "//:triton_gpu_attr_inc_gen", + ], +) + +gentbl_cc_library( + name = "triton_conversion_nvgpu_to_llvm_passes_inc_gen", + # copybara:uncomment_begin + # compatible_with = ["//buildenv/target:gce"], + # copybara:uncomment_end + tbl_outs = [ + ( + [ + "--gen-pass-decls", + "--name=NVGPUToLLVM", + ], + "include/NVGPUToLLVM/Passes.h.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/NVGPUToLLVM/Passes.td", + deps = ["//:td_files"], +) + +gentbl_cc_library( + name = "triton_conversion_triton_nvidia_gpu_to_llvm_passes_inc_gen", + # copybara:uncomment_begin + # compatible_with = ["//buildenv/target:gce"], + # copybara:uncomment_end + tbl_outs = [ + ( + [ + "--gen-pass-decls", + "--name=TritonNVIDIAGPUToLLVM", + ], + "include/TritonNVIDIAGPUToLLVM/Passes.h.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/TritonNVIDIAGPUToLLVM/Passes.td", + deps = ["//:td_files"], +) diff --git a/third_party/nvidia/backend/BUILD b/third_party/nvidia/backend/BUILD new file mode 100644 index 000000000000..9bd0230c572e --- /dev/null +++ b/third_party/nvidia/backend/BUILD @@ -0,0 +1,28 @@ +load("//third_party/bazel_rules/rules_python/python:py_extension.bzl", "py_extension") + +package( + default_applicable_licenses = ["//:license"], +) + +py_extension( + name = "cuda_utils", + srcs = ["driver.c"], + visibility = [ + "//learning/deepmind/jax/triton/ops:__subpackages__", + "//third_party/py/triton:__subpackages__", + ], + deps = [ + "@local_config_cuda//cuda:cuda_headers", + "//third_party/python_runtime:headers", + ], +) + +filegroup( + name = "files", + srcs = glob( + include = ["*.py"], + ), + visibility = [ + "//third_party/py/triton:__subpackages__", + ], +) diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ElementwiseOpToLLVM.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ElementwiseOpToLLVM.cpp index 591fbe48de44..e5eff624e732 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ElementwiseOpToLLVM.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ElementwiseOpToLLVM.cpp @@ -396,26 +396,6 @@ static ConverterT makeConverterFromPtx(const std::string &ptxAsm, Type inType, return converter; } -template -struct ElementwiseOpConversion - : public ElementwiseOpConversionBase< - SourceOp, ElementwiseOpConversion> { - using Base = - ElementwiseOpConversionBase>; - using Base::Base; - using OpAdaptor = typename Base::OpAdaptor; - - // An interface to support variant DestOp builder. - SmallVector createDestOps(SourceOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter, - Type elemTy, MultipleOperandsRange operands, - Location loc) const { - return {rewriter.create(loc, elemTy, operands[0], - adaptor.getAttributes().getValue())}; - } -}; - // Attempts to use vectorized conversions via inline PTX when possible. struct FpToFpOpConversion : public ElementwiseOpConversionBase { @@ -1061,60 +1041,6 @@ void mlir::triton::NVIDIA::populateElementwiseOpToLLVMPatterns( const TargetInfo &targetInfo, PatternBenefit benefit) { using namespace mlir::triton::gpu; -#define POPULATE_BINARY_OP(SRC_OP, DST_OP) \ - patterns.add>( \ - typeConverter, axisInfoAnalysis, benefit); - POPULATE_BINARY_OP(arith::SubIOp, LLVM::SubOp) // - - POPULATE_BINARY_OP(arith::AddIOp, LLVM::AddOp) // + - POPULATE_BINARY_OP(arith::MulIOp, LLVM::MulOp) // * - POPULATE_BINARY_OP(arith::DivSIOp, LLVM::SDivOp) - POPULATE_BINARY_OP(arith::DivUIOp, LLVM::UDivOp) - POPULATE_BINARY_OP(arith::RemFOp, LLVM::FRemOp) // % - POPULATE_BINARY_OP(arith::RemSIOp, LLVM::SRemOp) - POPULATE_BINARY_OP(arith::RemUIOp, LLVM::URemOp) - POPULATE_BINARY_OP(arith::AndIOp, LLVM::AndOp) // & - POPULATE_BINARY_OP(arith::OrIOp, LLVM::OrOp) // | - POPULATE_BINARY_OP(arith::XOrIOp, LLVM::XOrOp) // ^ - POPULATE_BINARY_OP(arith::ShLIOp, LLVM::ShlOp) // << - POPULATE_BINARY_OP(arith::ShRSIOp, LLVM::AShrOp) // >> - POPULATE_BINARY_OP(arith::ShRUIOp, LLVM::LShrOp) // >> - POPULATE_BINARY_OP( - arith::MinNumFOp, - LLVM::MinNumOp) // fmin (return non-NaN if either op is non-NaN) - POPULATE_BINARY_OP( - arith::MaxNumFOp, - LLVM::MaxNumOp) // fmax (return non-NaN if either op is non-NaN) - POPULATE_BINARY_OP(arith::MinSIOp, LLVM::SMinOp) // smin - POPULATE_BINARY_OP(arith::MaxSIOp, LLVM::SMaxOp) // smax - POPULATE_BINARY_OP(arith::MinUIOp, LLVM::UMinOp) // umin - POPULATE_BINARY_OP(arith::MaxUIOp, LLVM::UMaxOp) // umax -#undef POPULATE_BINARY_OP - -#define POPULATE_UNARY_OP(SRC_OP, DST_OP) \ - patterns.add>( \ - typeConverter, axisInfoAnalysis, benefit); - POPULATE_UNARY_OP(arith::TruncIOp, LLVM::TruncOp) - POPULATE_UNARY_OP(arith::ExtSIOp, LLVM::SExtOp) - POPULATE_UNARY_OP(arith::ExtUIOp, LLVM::ZExtOp) - POPULATE_UNARY_OP(arith::FPToUIOp, LLVM::FPToUIOp) - POPULATE_UNARY_OP(arith::UIToFPOp, LLVM::UIToFPOp) - POPULATE_UNARY_OP(math::FloorOp, math::FloorOp) - POPULATE_UNARY_OP(math::LogOp, math::LogOp) - POPULATE_UNARY_OP(math::Log2Op, math::Log2Op) - POPULATE_UNARY_OP(math::CosOp, math::CosOp) - POPULATE_UNARY_OP(math::SinOp, math::SinOp) - POPULATE_UNARY_OP(math::SqrtOp, math::SqrtOp) - POPULATE_UNARY_OP(math::ExpOp, math::ExpOp) - POPULATE_UNARY_OP(math::Exp2Op, math::Exp2Op) - POPULATE_UNARY_OP(math::ErfOp, math::ErfOp) - POPULATE_UNARY_OP(triton::BitcastOp, LLVM::BitcastOp) - POPULATE_UNARY_OP(triton::IntToPtrOp, LLVM::IntToPtrOp) - POPULATE_UNARY_OP(triton::PtrToIntOp, LLVM::PtrToIntOp) -#undef POPULATE_UNARY_OP - - patterns.add>( - typeConverter, axisInfoAnalysis, benefit); - patterns.add>( typeConverter, axisInfoAnalysis, "__nv_fsqrt_rn", benefit); patterns.add>( diff --git a/third_party/nvidia/triton_nvidia.cc b/third_party/nvidia/triton_nvidia.cc index 269848cf26f1..78d83b719f69 100644 --- a/third_party/nvidia/triton_nvidia.cc +++ b/third_party/nvidia/triton_nvidia.cc @@ -1,4 +1,4 @@ -#include "NVGPUToLLVM/Passes.h" +#include "NVGPUToLLVM/Passes.h" #include "TritonNVIDIAGPUToLLVM/Passes.h" #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" diff --git a/unittest/BUILD b/unittest/BUILD new file mode 100644 index 000000000000..33c8c504d1e7 --- /dev/null +++ b/unittest/BUILD @@ -0,0 +1,102 @@ +load("//tools/build_defs/build_test:build_test.bzl", "build_test") + +package( + default_applicable_licenses = ["//:license"], + default_compatible_with = ["//buildenv/target:gce"], + default_visibility = ["//:__subpackages__"], +) + +cc_test( + name = "AnalysisTest", + srcs = glob(["Analysis/*.cpp"]), + deps = [ + "//testing/base/public:gunit_main", + "//:TritonDialects", + ], +) + +cc_test( + name = "DialectTest", + srcs = glob([ + "Dialect/**/*.cpp", + ]), + deps = [ + "//testing/base/public:gunit_main", + "@llvm-project//mlir:AsmParser", + "//:TritonDialects", + ], +) + +cc_test( + name = "ConversionTest", + srcs = glob( + [ + "Conversion/**/*.cpp", + "Conversion/**/*.h", + ], + exclude = [ + "Conversion/TritonGPUToLLVM/EmitIndicesTest.cpp", + "Conversion/TritonGPUToLLVM/DumpLayout.cpp", + "Conversion/TritonGPUToLLVM/DumpLayout.h", + ], + ), + copts = select({ + "//:compiler_is_msvc": [], + "//conditions:default": [ + "-Wno-unused-variable", + ], + }), + deps = [ + "//testing/base/public:gunit_main", + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:IR", + "//:TritonDialects", + "//:TritonNvidiaGPUTransforms", + "//third_party/triton/third_party/nvidia:TritonNVIDIAGPUToLLVM", + ], +) + +cc_test( + name = "EmitIndicesTest", + srcs = [ + "Conversion/TritonGPUToLLVM/DumpLayout.cpp", + "Conversion/TritonGPUToLLVM/DumpLayout.h", + "Conversion/TritonGPUToLLVM/EmitIndicesTest.cpp", + ], + copts = select({ + "//:compiler_is_msvc": [], + "//conditions:default": [ + "-Wno-unused-variable", + "-Wno-private-header", + ], + }), + includes = [ + "Conversion/TritonGPUToLLVM", + ], + # We want this to be buildable to update LLVM, but it doesn't pass and never has, even in OSS: + # https://github.com/openai/triton/blob/ded624282e67e5f58db332380e6ff088f276d534/unittest/Conversion/TritonGPUToLLVM/EmitIndicesTest.cpp#L677 + tags = [ + "manual", + "notap", + ], + deps = [ + "//testing/base/public:gunit_main", + "@llvm-project//mlir:GPUDialect", + "@llvm-project//mlir:LLVMDialect", + "//:TritonDialects", + "//:TritonGPUToLLVM", + "//:TritonNvidiaGPUTransforms", + "//third_party/triton/third_party/nvidia:TritonNVIDIAGPUToLLVM", + ], +) + +build_test( + name = "build_test", + allow_empty_target = False, + targets = [ + ":ConversionTest", + ":AnalysisTest", + ":DialectTest", + ":EmitIndicesTest", + ], +)