From 1b0c261492982faffd643228c8051d5125f59f38 Mon Sep 17 00:00:00 2001 From: Anatoly Myachev Date: Mon, 11 Nov 2024 15:48:46 +0100 Subject: [PATCH 1/2] Fix barrier insertion after 'assert' op Signed-off-by: Anatoly Myachev --- lib/Conversion/TritonGPUToLLVM/AssertOpToLLVM.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/Conversion/TritonGPUToLLVM/AssertOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/AssertOpToLLVM.cpp index 508eb25f99d4..7a3c8ce27abd 100644 --- a/lib/Conversion/TritonGPUToLLVM/AssertOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/AssertOpToLLVM.cpp @@ -50,7 +50,6 @@ struct AssertOpConversion : public ConvertOpToLLVMPattern { // know about the op to split the block. void llAssert(Operation *op, Value condition, StringRef message, ConversionPatternRewriter &rewriter) const { - ConversionPatternRewriter::InsertionGuard guard(rewriter); auto ctx = rewriter.getContext(); auto loc = op->getLoc(); @@ -87,6 +86,7 @@ struct AssertOpConversion : public ConvertOpToLLVMPattern { rewriter.create(loc, thenBlock); rewriter.setInsertionPointToEnd(prevBlock); rewriter.create(loc, condition, ifBlock, thenBlock); + rewriter.setInsertionPointToStart(thenBlock); } protected: From 88094052e0f9521352652bcc5b10f1440f11927c Mon Sep 17 00:00:00 2001 From: Anatoly Myachev Date: Tue, 12 Nov 2024 13:46:17 +0100 Subject: [PATCH 2/2] add test Signed-off-by: Anatoly Myachev --- python/test/unit/test_debug.py | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/python/test/unit/test_debug.py b/python/test/unit/test_debug.py index 405f04bfa8f5..8ea6212020ec 100644 --- a/python/test/unit/test_debug.py +++ b/python/test/unit/test_debug.py @@ -1,4 +1,3 @@ -import os import pytest import torch import triton.language as tl @@ -10,8 +9,8 @@ @pytest.mark.parametrize('env_var', [True, False]) @pytest.mark.parametrize('jit_flag', [True, False]) @pytest.mark.forked -def test_device_assert(cond, opt_flag, env_var, jit_flag, device): - os.environ['TRITON_DEBUG'] = str(int(env_var)) +def test_device_assert(monkeypatch, cond, opt_flag, env_var, jit_flag, device): + monkeypatch.setenv("TRITON_DEBUG", str(int(env_var))) torch.zeros([1], dtype=torch.int32, device=device) @triton.jit(debug=jit_flag) @@ -34,6 +33,20 @@ def _kernel(COND: tl.constexpr): getattr(torch, device).synchronize() +def test_device_assert_barrier(monkeypatch, device): + monkeypatch.setenv("TRITON_DEBUG", "1") + tensor = torch.zeros([16], dtype=torch.int32, device=device) + + @triton.jit + def _kernel(in_ptr0): + xindex = tl.arange(0, 8) + tmp0 = tl.load(in_ptr0 + xindex) + tl.device_assert(tmp0 < 1) + + _kernel[(1, )](tensor) + getattr(torch, device).synchronize() + + @pytest.mark.parametrize("cond", [False, True]) def test_static_assert(cond):