Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion third_party/triton/common/assert_fail.patch
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
--- a/python/triton/experimental/gsan/src/GSanLibrary.cu
+++ b/python/triton/experimental/gsan/src/GSanLibrary.cu
@@ -6,11 +6,12 @@
@@ -6,11 +6,12 @@ extern "C" GSAN_DEVICE void __assertfail(const char *assertion,
const char *function,
__SIZE_TYPE__ charSize);

Expand Down
475 changes: 0 additions & 475 deletions third_party/triton/common/cherry-pick-f43dff6f.patch

This file was deleted.

4 changes: 2 additions & 2 deletions third_party/triton/common/construction_order.patch
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@

--- a/third_party/nvidia/lib/Dialect/NVWS/Transforms/InsertTmemAref.cpp
+++ b/third_party/nvidia/lib/Dialect/NVWS/Transforms/InsertTmemAref.cpp
@@ -65,8 +65,8 @@ struct TmemAccessDag {
@@ -66,8 +66,8 @@ struct TmemAccessDag {
SmallVector<std::unique_ptr<Node>> subDags;
Node(Operation *op, OpOperand *tokOperand,
std::optional<PartitionId> partitionId, Node *parent)
Expand All @@ -12,7 +12,7 @@

// ------------------------------------------------------------------------

@@ -390,7 +390,7 @@ struct TMEMAref {
@@ -391,7 +391,7 @@ struct TMEMAref {
enum Kind { PUT, GET };

TMEMAref(Value aref, Value origBuffer, Value replToken)
Expand Down
12 changes: 0 additions & 12 deletions third_party/triton/common/llvm_cl887809531.patch

This file was deleted.

11 changes: 0 additions & 11 deletions third_party/triton/common/llvm_cl893899241.patch

This file was deleted.

135 changes: 0 additions & 135 deletions third_party/triton/common/llvm_cl895542516.patch

This file was deleted.

10 changes: 5 additions & 5 deletions third_party/triton/common/llvm_cl900404532.patch
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
--- a/test/Conversion/tritongpu_to_llvm.mlir
+++ b/test/Conversion/tritongpu_to_llvm.mlir
@@ -87,9 +87,9 @@
@@ -87,9 +87,9 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32} {
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
// CHECK-LABEL: store_with_cache_attr
tt.func @store_with_cache_attr(%a_ptr_init : tensor<256x!tt.ptr<f32>, #blocked0>, %cst : tensor<256xi1, #blocked0>, %cst_0 : tensor<256xf32, #blocked0>) {
Expand All @@ -12,7 +12,7 @@
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att {{.*}} "@$3 st.global.L1::evict_last.L2::cache_hint.b32 [ $1 + 0 ], { $0 }, $2;"
tt.store %a_ptr_init, %cst_0, %cst evictionPolicy = evict_last cacheModifier = ca : tensor<256x!tt.ptr<f32>, #blocked0>
tt.return
@@ -102,10 +102,10 @@
@@ -102,10 +102,10 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
// CHECK-LABEL: load_with_l2_cache_hint
tt.func @load_with_l2_cache_hint(%a_ptr_init : tensor<256x!tt.ptr<f32>, #blocked0>, %cst : tensor<256xi1, #blocked0>, %cst_0 : tensor<256xf32, #blocked0>) {
Expand All @@ -27,7 +27,7 @@
%1 = tt.load %a_ptr_init, %cst, %cst_0 evictionPolicy = evict_first : tensor<256x!tt.ptr<f32>, #blocked0>
tt.return
}
@@ -116,9 +116,9 @@
@@ -116,9 +116,9 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
// CHECK-LABEL: store_with_l2_cache_hint
tt.func @store_with_l2_cache_hint(%a_ptr_init : tensor<256x!tt.ptr<f32>, #blocked0>, %cst : tensor<256xi1, #blocked0>, %cst_0 : tensor<256xf32, #blocked0>) {
Expand All @@ -42,7 +42,7 @@

--- a/test/Conversion/tritoninstrument_to_llvm.mlir
+++ b/test/Conversion/tritoninstrument_to_llvm.mlir
@@ -37,7 +37,7 @@
@@ -37,7 +37,7 @@ tt.func private @experimental_buffer_descriptors_shared() {
#smem = #ttg.shared_memory
module attributes {"ttg.num-warps" = 4 : i32, ttg.target = "cuda:90"} {
// CHECK-LABEL: @experimental_lock_acquire
Expand All @@ -65,7 +65,7 @@

--- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/PTXAsmFormat.cpp
+++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/PTXAsmFormat.cpp
@@ -144,7 +144,7 @@
@@ -144,7 +144,7 @@ std::string PTXBuilder::dump() const {
lines.push_back(exec->dump());
}

Expand Down
2 changes: 1 addition & 1 deletion third_party/triton/common/mixed_precision_fix.patch
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
--- a/lib/Dialect/TritonGPU/Transforms/Prefetch.cpp
+++ b/lib/Dialect/TritonGPU/Transforms/Prefetch.cpp
@@ -434,6 +434,15 @@ LogicalResult Prefetcher::initialize() {
@@ -433,6 +433,15 @@ LogicalResult Prefetcher::initialize() {
while (op) {
if (!op->getResult(0).hasOneUse())
break;
Expand Down
2 changes: 1 addition & 1 deletion third_party/triton/common/mma_limit_pred.patch
Original file line number Diff line number Diff line change
Expand Up @@ -19,5 +19,5 @@ fusions reported.
+
return (op->hasTrait<OpTrait::Elementwise>() && isMemoryEffectFree(op)) ||
isView(op) ||
isa<Fp4ToFpOp, LoadOp, DescriptorLoadOp, BroadcastOp, ConvertLayoutOp>(
isa<Fp4ToFpOp, LoadOp, DescriptorLoadLikeOpInterface, BroadcastOp,

11 changes: 4 additions & 7 deletions third_party/triton/common/no_accelerate_through_broadcast.patch
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ Fix for b/405045790.

--- a/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp
+++ b/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp
@@ -272,10 +272,13 @@ static bool bwdFilter(Operation *op) {
@@ -272,9 +272,12 @@ static bool bwdFilter(Operation *op) {
return false;
}

Expand All @@ -11,11 +11,8 @@ Fix for b/405045790.
+ // small. This is just a heuristic to avoid a regression.
return (op->hasTrait<OpTrait::Elementwise>() && isMemoryEffectFree(op)) ||
isView(op) ||
- isa<Fp4ToFpOp, LoadOp, DescriptorLoadOp, BroadcastOp, ConvertLayoutOp>(
- op);
+ isa<Fp4ToFpOp, LoadOp, DescriptorLoadOp, /*BroadcastOp,*/
+ ConvertLayoutOp>(op);
- isa<Fp4ToFpOp, LoadOp, DescriptorLoadLikeOpInterface, BroadcastOp,
+ isa<Fp4ToFpOp, LoadOp, DescriptorLoadLikeOpInterface,/*BroadcastOp,*/
ConvertLayoutOp>(op);
}

// Finds the bitwidth with which the value x is loaded

4 changes: 0 additions & 4 deletions third_party/triton/common/series.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,9 @@ common_patch_list = [
"//third_party/triton:common/avoid-0fc-mid-ptwas-128.patch",
"//third_party/triton:common/wgmma_pipeline_fix.patch",
"//third_party/triton:common/nvdisasm_bin_path.patch",
"//third_party/triton:common/llvm_cl887809531.patch",
"//third_party/triton:common/llvm_cl893899241.patch",
"//third_party/triton:common/stage_and_cluster_map.patch",
"//third_party/triton:common/llvm_cl895542516.patch",
"//third_party/triton:common/assert_fail.patch",
"//third_party/triton:common/llvm_cl900404532.patch",
"//third_party/triton:common/cherry-pick-f43dff6f.patch",
"//third_party/triton:common/llvm_cl902211192.patch",
"//third_party/triton:common/silence_matchAndRewrite_failures.patch",
# Add new patches just above this line
Expand Down
6 changes: 3 additions & 3 deletions third_party/triton/common/speed_up_int4_unpacking.patch
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ an early stage on the XLA side with further optimizations from nVidia #24 and

--- a/lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp
+++ b/lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp
@@ -232,6 +232,19 @@ struct JoinOpConversion : public ConvertOpToLLVMPattern<JoinOp> {
@@ -224,6 +224,19 @@ struct JoinOpConversion : public ConvertOpToLLVMPattern<JoinOp> {
assert(lhsVals.size() == rhsVals.size());
SmallVector<Value> joinedVals;
joinedVals.resize(lhsVals.size() * 2);
Expand All @@ -27,7 +27,7 @@ an early stage on the XLA side with further optimizations from nVidia #24 and

--- a/lib/Dialect/TritonGPU/Transforms/Utility.cpp
+++ b/lib/Dialect/TritonGPU/Transforms/Utility.cpp
@@ -1116,6 +1116,39 @@ swizzleDotOperandLike(RankedTensorType type, ttg::CGAEncodingAttr cgaLayout) {
@@ -1096,6 +1096,39 @@ swizzleDotOperandLike(RankedTensorType type, ttg::CGAEncodingAttr cgaLayout) {
type.getElementTypeBitWidth(), false);
}

Expand Down Expand Up @@ -67,7 +67,7 @@ an early stage on the XLA side with further optimizations from nVidia #24 and
// If all the transitive uses of the given value have are used by a convert to
// the same dot operand encoding, return the shared encoding that needs to be
// used to be compatible with users' layouts. If there are incompatible shared
@@ -1150,14 +1183,21 @@ getSharedEncIfAllUsersAreDotEnc(Value val, bool &incompatible) {
@@ -1130,14 +1163,21 @@ getSharedEncIfAllUsersAreDotEnc(Value val, bool &incompatible) {
auto CGALayout = isa<ttg::LinearEncodingAttr>(dstTy.getEncoding())
? ttg::getCGALayout(srcTy.getEncoding())
: ttg::getCGALayout(dstTy.getEncoding());
Expand Down
6 changes: 3 additions & 3 deletions third_party/triton/common/verify_nvmma_encoding.patch
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ reshape outcomes.

// Provide custom directive handlers for declarative assemblyFormat.
// They must be visible before including the generated op classes.
@@ -584,7 +585,45 @@ LogicalResult MemDescReshapeOp::verify() {
@@ -583,7 +584,45 @@ LogicalResult MemDescReshapeOp::verify() {
return OpTrait::impl::verifyEquivalentMemDescType(expectedTy, dstType);
}

Expand Down Expand Up @@ -72,7 +72,7 @@ reshape outcomes.
Attribute srcEnc,
ArrayRef<int64_t> dstShape,
Attribute &dstEnc) {
@@ -605,6 +644,11 @@ static LogicalResult inferMemDescReshapeOpEncoding(ArrayRef<int64_t> srcShape,
@@ -604,6 +643,11 @@ static LogicalResult inferMemDescReshapeOpEncoding(ArrayRef<int64_t> srcShape,
ctx, mmaEncoding.getSwizzlingByteWidth(),
mmaEncoding.getTransposed(), mmaEncoding.getElementBitWidth(),
mmaEncoding.getFp4Padded(), CGALayout);
Expand All @@ -84,7 +84,7 @@ reshape outcomes.
auto srcLL = toLinearLayout(srcShape, srcEnc);
auto dstLL = toLinearLayout(dstShape, candidateEncoding);
if (reshapeLayout(ctx, srcLL, dstShape) == dstLL) {
@@ -644,8 +688,8 @@ LogicalResult MemDescReshapeOp::inferReturnTypes(
@@ -643,8 +687,8 @@ LogicalResult MemDescReshapeOp::inferReturnTypes(

Attribute dstEncoding;
if (Attribute srcEnc = srcTy.getEncoding()) {
Expand Down
6 changes: 3 additions & 3 deletions third_party/triton/common/wgmma_pipeline_fix.patch
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ That likely makes some kernels slower but correctness > speed.

--- a/test/TritonGPU/loop-pipeline-hopper.mlir
+++ b/test/TritonGPU/loop-pipeline-hopper.mlir
@@ -786,9 +786,8 @@ module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-
@@ -787,9 +787,8 @@ module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-
// CHECK: ttg.async_wait {{.*}} {num = 2 : i32}
// CHECK: ttg.local_load
// CHECK: ttng.warp_group_dot
Expand All @@ -61,7 +61,7 @@ That likely makes some kernels slower but correctness > speed.
// CHECK: ttg.async_copy_global_to_local
// CHECK: ttg.async_commit_group
// CHECK: ttg.async_copy_global_to_local
@@ -850,9 +849,8 @@ module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-
@@ -851,9 +850,8 @@ module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-
// CHECK: scf.for
// CHECK: ttg.async_wait {{.*}} {num = 2 : i32}
// CHECK: ttng.warp_group_dot
Expand All @@ -73,7 +73,7 @@ That likely makes some kernels slower but correctness > speed.
// CHECK: ttg.async_copy_global_to_local
// CHECK: ttg.async_copy_global_to_local
// CHECK: ttg.async_commit_group
@@ -940,10 +938,8 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.targ
@@ -941,10 +939,8 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.targ
// CHECK: ttg.local_load
// CHECK: ttg.local_load
// CHECK: ttng.warp_group_dot
Expand Down
Loading
Loading