diff --git a/lib/Conversion/TritonGPUToLLVM/AssertOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/AssertOpToLLVM.cpp index 508eb25f99d4..7a3c8ce27abd 100644 --- a/lib/Conversion/TritonGPUToLLVM/AssertOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/AssertOpToLLVM.cpp @@ -50,7 +50,6 @@ struct AssertOpConversion : public ConvertOpToLLVMPattern { // know about the op to split the block. void llAssert(Operation *op, Value condition, StringRef message, ConversionPatternRewriter &rewriter) const { - ConversionPatternRewriter::InsertionGuard guard(rewriter); auto ctx = rewriter.getContext(); auto loc = op->getLoc(); @@ -87,6 +86,7 @@ struct AssertOpConversion : public ConvertOpToLLVMPattern { rewriter.create(loc, thenBlock); rewriter.setInsertionPointToEnd(prevBlock); rewriter.create(loc, condition, ifBlock, thenBlock); + rewriter.setInsertionPointToStart(thenBlock); } protected: diff --git a/python/test/unit/test_debug.py b/python/test/unit/test_debug.py index 405f04bfa8f5..8ea6212020ec 100644 --- a/python/test/unit/test_debug.py +++ b/python/test/unit/test_debug.py @@ -1,4 +1,3 @@ -import os import pytest import torch import triton.language as tl @@ -10,8 +9,8 @@ @pytest.mark.parametrize('env_var', [True, False]) @pytest.mark.parametrize('jit_flag', [True, False]) @pytest.mark.forked -def test_device_assert(cond, opt_flag, env_var, jit_flag, device): - os.environ['TRITON_DEBUG'] = str(int(env_var)) +def test_device_assert(monkeypatch, cond, opt_flag, env_var, jit_flag, device): + monkeypatch.setenv("TRITON_DEBUG", str(int(env_var))) torch.zeros([1], dtype=torch.int32, device=device) @triton.jit(debug=jit_flag) @@ -34,6 +33,20 @@ def _kernel(COND: tl.constexpr): getattr(torch, device).synchronize() +def test_device_assert_barrier(monkeypatch, device): + monkeypatch.setenv("TRITON_DEBUG", "1") + tensor = torch.zeros([16], dtype=torch.int32, device=device) + + @triton.jit + def _kernel(in_ptr0): + xindex = tl.arange(0, 8) + tmp0 = tl.load(in_ptr0 + xindex) + tl.device_assert(tmp0 < 1) + + _kernel[(1, )](tensor) + getattr(torch, device).synchronize() + + @pytest.mark.parametrize("cond", [False, True]) def test_static_assert(cond):