Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion include/triton/Dialect/Triton/IR/TritonOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -891,7 +891,7 @@ def TT_AssertOp : TT_Op<"assert", [MemoryEffects<[MemWrite<GlobalMemory>]>]> {
`tt.assert` takes a condition tensor and a message string.
If the condition is false, the message is printed, and the program is aborted.
}];
let arguments = (ins TT_Tensor:$condition, StrAttr:$message);
let arguments = (ins AnyTypeOf<[I1, I1Tensor]>:$condition, StrAttr:$message);
let assemblyFormat = "$condition `,` $message attr-dict `:` type($condition)";
}

Expand Down
8 changes: 8 additions & 0 deletions lib/Conversion/TritonGPUToLLVM/AssertOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,14 @@ struct AssertOpConversion : public ConvertOpToLLVMPattern<triton::AssertOp> {
}
}
llAssert(op, condition, adaptor.getMessage(), rewriter);
if (isa<RankedTensorType>(op.getCondition().getType())) {
// Add a barrier to avoid a race condition in case an assert is followed
// by an op that may trap if the assert condition is true. Since the
// tensor in those two operations may have different layout we need to
// make sure all the threads are done executing the assert before going to
// the next op.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

btw this highlights an interesting thing I noticed a while back. Sometimes we do the same computation in multiple different layouts when there are assert statements because the assertion path doesn't have a data flow path to the next layout anchor operation.

It's not really the end of the world, but an interesting quirk of the remove layout conversion code.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

right, yes this is definitely a limitation as we haven't tried very hard to make assert efficient.

barrier();
}
rewriter.eraseOp(op);
return success();
}
Expand Down
4 changes: 0 additions & 4 deletions python/triton/language/semantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1724,10 +1724,6 @@ def device_print(prefix: str, args: List[tl.tensor], hex: bool, builder: ir.buil
def device_assert(cond: tl.tensor, msg: str, builder: ir.builder) -> tl.tensor:
if not builder.options.debug:
return
cond_ty = cond.type
if not cond_ty.is_block():
cond_ty = tl.block_type(cond_ty.scalar, (1, ))
cond = tl.tensor(builder.create_splat(cond.handle, (1, )), cond_ty)
return tl.tensor(builder.create_assert(cond.handle, msg), tl.void)


Expand Down
2 changes: 2 additions & 0 deletions test/Conversion/tritongpu_to_llvm.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1906,6 +1906,8 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
// CHECK-DAG: llvm.mlir.global internal constant @assertFunc_0("unknown\00") {addr_space = 0 : i32}
// CHECK-DAG: llvm.mlir.global internal constant @assertFile_0("inner_call\00") {addr_space = 0 : i32}
// CHECK-DAG: llvm.mlir.global internal constant @assertMessage_0("assert text\00") {addr_space = 0 : i32}
// CHECK: llvm.call @__assertfail
// CHECK: nvvm.barrier0
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "cuda:90", "triton_gpu.threads-per-warp" = 32 : i32} {
tt.func public @add_kernel(%arg0: tensor<1xi1, #blocked>) {
tt.assert %arg0, "assert text" : tensor<1xi1, #blocked> loc(#loc5)
Expand Down