Skip to content

[RemoveLayoutConversions] Fix reduce failed infer type error#377

Merged
zhanglx13 merged 3 commits intoROCm:triton-mlirfrom
binarman:redice_result_type_fix2
Nov 1, 2023
Merged

[RemoveLayoutConversions] Fix reduce failed infer type error#377
zhanglx13 merged 3 commits intoROCm:triton-mlirfrom
binarman:redice_result_type_fix2

Conversation

@binarman
Copy link
Copy Markdown

@binarman binarman commented Oct 26, 2023

This PR fixes layout propagation algorithm in RemoveLayoutConversions pass. In some cases during rewriteSlice process, reduce operation with multiple outputs rewrites only one output layout, which breaks assumption that both outputs should have same layout.

This change is a minimal part of triton-lang#2331 change and small lit test for regression testing.

Fixes #364

@binarman binarman marked this pull request as ready for review October 26, 2023 10:50
Comment on lines +1016 to +1025
#ifndef USE_ROCM
auto isExtOp = [](Operation *op) {
return isa<arith::ExtSIOp, arith::ExtUIOp, arith::ExtFOp>(op);
};
#else
auto isExtOp = [](Operation *op) {
return isa<arith::ExtSIOp, arith::ExtUIOp, arith::ExtFOp,
triton::BroadcastOp, triton::ExpandDimsOp>(op);
};
#endif
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This part is changed multiple times between our current and previous IFU.
I wrapped it with USE_ROCM to highlight it in case this PR is merged before ongoing IFU

auto newType =
RankedTensorType::get(tensorType.getShape(), tensorType.getElementType(),
layout[extOp->getResult(0)]);
*srcEncoding);
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This part changed in upstream only once (and looks exactly the same), so it should not induct any problems on merge or rebase on top of upstream changes.

@alefimov-amd alefimov-amd force-pushed the redice_result_type_fix2 branch from 08903f2 to 086fbcb Compare October 26, 2023 10:58
@alefimov-amd
Copy link
Copy Markdown

For the record, description of the problem here:
Error was in an optimization that moved layout conversion on top of "extension" operation to reduce LDS usage (smaller type, smaller scratch buffers we need to convert layouts).

Consider this example as input:

%0 = ...
%1 = ..
%2 = arith.extf %1
%3 = tt.reduce %0, %1  // %3 contains two outputs
%4 = tt.expand_dims %3#1 // expand dims takes second output of reduce op
%5 = triton_gpu.convert_layout %4

before this change, optimizations moved layout conversion before %2 extf operation, rewriting layouts of all values along data flow path. Unfortunately it did not rewrite layout of first output of reduce operation, which broke assumption that layouts of this reduce operations should be equal.

With this change, layout conversions is moved only before %4 expand_dim operation.
This is achieved by change in stop condition of algorithm and fix in target layout computation algorithm.

@alefimov-amd alefimov-amd requested a review from jataylo October 26, 2023 11:12
@alefimov-amd alefimov-amd self-assigned this Oct 26, 2023
Comment on lines +1175 to +1180
// CHECK-NOT: triton_gpu.convert_layout
// CHECK: tt.reduce
// CHECK-SAME: axis = 1
// CHECK: (tensor<1x256xf32, #blocked>) -> tensor<1xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>
// CHECK: (tensor<1x256xf32, #{{.*}}>) -> tensor<1xf32, #triton_gpu.slice<{dim = 1, parent = #{{.*}}}>>
// CHECK: triton_gpu.convert_layout
// CHECK: tt.expand_dims
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change is also part of upstream changes

@jataylo
Copy link
Copy Markdown

jataylo commented Oct 26, 2023

With this PR I can confirm the argmax unit tests and previous reproducer now pass.

But this change seems to have caused a failure for us in another unit test

test_torchinductor.py::CudaTests::test_batch_norm_2d_2_cuda loc("/tmp/torchinductor_root/5m/c5mbigghp7pzp26f4zncmescmhh3tgebh2yix7l2u3fog5rzai4j.py":16:23): error: ExpandDimsOp operand encoding must be SliceEncodingAttr
loc("/tmp/torchinductor_root/5m/c5mbigghp7pzp26f4zncmescmhh3tgebh2yix7l2u3fog5rzai4j.py":16:23): error: failed to infer layout for ExpandDimsOp
loc("/tmp/torchinductor_root/5m/c5mbigghp7pzp26f4zncmescmhh3tgebh2yix7l2u3fog5rzai4j.py":16:23): error: 'tt.expand_dims' op failed to infer returned types
FAILED [1.2648s]       

@alefimov-amd let me know if you want me to get together a reproducer for this new issue.

@jataylo
Copy link
Copy Markdown

jataylo commented Oct 26, 2023

@alefimov-amd Reproducer for the new issue

import torch
import triton
import triton.language as tl


@triton.jit
def welford_combine(mean_1, m2_1, weight_1, mean_2, m2_2, weight_2):
    delta = mean_2 - mean_1
    new_weight = weight_1 + weight_2
    w2_over_w = tl.where(new_weight == 0.0, 0.0, weight_2 / new_weight)
    return (
        mean_1 + delta * w2_over_w,
        m2_1 + m2_2 + delta * delta * weight_1 * w2_over_w,
        new_weight,
    )

@triton.jit
def welford(mean, m2, weight, dim):
    return tl.reduce((mean, m2, weight), dim, welford_combine)

@triton.jit
def triton_kernel(in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, out_ptr0, out_ptr1, out_ptr2, out_ptr3, out_ptr4, XBLOCK : tl.constexpr):
    xnumel = 128
    rnumel = 3
    RBLOCK: tl.constexpr = 4
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
    xmask = xindex < xnumel
    rindex = tl.arange(0, RBLOCK)[None, :]
    rmask = rindex < rnumel
    r1 = rindex
    x0 = xindex
    tmp0 = tl.load(in_ptr0 + (x0 + (128*r1)), rmask & xmask, other=0)
    tmp1 = tl.load(in_ptr1 + (x0 + (128*r1)), rmask & xmask, other=0)
    tmp2 = tl.load(in_ptr2 + (x0 + (128*r1)), rmask & xmask, other=0)
    tmp25 = tl.load(in_ptr3 + (x0), xmask, eviction_policy='evict_last')
    tmp30 = tl.load(in_ptr4 + (x0), xmask, eviction_policy='evict_last')
    tmp3 = tl.broadcast_to(tmp0, [XBLOCK, RBLOCK])
    tmp4 = tl.broadcast_to(tmp1, [XBLOCK, RBLOCK])
    tmp5 = tl.broadcast_to(tmp2, [XBLOCK, RBLOCK])
    tmp7 = tl.where(rmask & xmask, tmp3, 0)
    tmp8 = tl.where(rmask & xmask, tmp4, 0)
    tmp9 = tl.where(rmask & xmask, tmp5, 0)
    tmp10, tmp11, tmp12 = welford(tmp7, tmp8, tmp9, 1)
    tmp13 = tmp10[:, None]
    tmp14 = tmp11[:, None]
    tmp15 = tmp12[:, None]
    tmp16 = 49152.0
    tmp17 = tmp14 / tmp16
    tmp18 = 0.0001
    tmp19 = tmp17 + tmp18
    tmp20 = tl.math.rsqrt(tmp19)
    tmp21 = 1.0000203454660128
    tmp22 = tmp17 * tmp21
    tmp23 = 0.03
    tmp24 = tmp22 * tmp23
    tmp26 = 0.97
    tmp27 = tmp25 * tmp26
    tmp28 = tmp24 + tmp27
    tmp29 = tmp13 * tmp23
    tmp31 = tmp30 * tmp26
    tmp32 = tmp29 + tmp31
    tl.store(out_ptr2 + (x0), tmp20, xmask)
    tl.store(out_ptr3 + (x0), tmp28, xmask)
    tl.store(out_ptr4 + (x0), tmp32, xmask)
    tl.store(out_ptr0 + (x0), tmp13, xmask)
    tl.store(out_ptr1 + (x0), tmp14, xmask)

from torch._dynamo.testing import rand_strided
from torch import empty_strided
primals_4 = rand_strided((128, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_5 = rand_strided((128, ), (1, ), device='cuda:0', dtype=torch.float32)
buf7 = empty_strided((1, 128, 1, 1, 3), (384, 1, 384, 384, 128), device='cuda', dtype=torch.float32)
buf8 = empty_strided((1, 128, 1, 1, 3), (384, 1, 384, 384, 128), device='cuda', dtype=torch.float32)
buf9 = empty_strided((1, 128, 1, 1, 3), (384, 1, 384, 384, 128), device='cuda', dtype=torch.float32)
buf10 = empty_strided((1, 128, 1, 1), (128, 1, 128, 128), device='cuda', dtype=torch.float32)
buf11 = empty_strided((1, 128, 1, 1), (128, 1, 128, 128), device='cuda', dtype=torch.float32)
buf13 = empty_strided((128, ), (1, ), device='cuda', dtype=torch.float32)
buf15 = empty_strided((128, ), (1, ), device='cuda', dtype=torch.float32)
buf14 = empty_strided((128, ), (1, ), device='cuda', dtype=torch.float32)
triton_kernel[(128,)](buf7, buf8, buf9, primals_5, primals_4, buf10, buf11, buf13, buf15, buf14, 128)

Error log:

root@ctr-ubbsmc13:~/pytorch/test/inductor# python bnorm_simple.py
loc("bnorm_simple.py":27:23): error: ExpandDimsOp operand encoding must be SliceEncodingAttr
loc("bnorm_simple.py":27:23): error: failed to infer layout for ExpandDimsOp
loc("bnorm_simple.py":27:23): error: 'tt.expand_dims' op failed to infer returned types
Traceback (most recent call last):
  File "bnorm_simple.py", line 81, in <module>
    triton_kernel[(128,)](buf7, buf8, buf9, primals_5, primals_4, buf10, buf11, buf13, buf15, buf14, 128)
  File "<string>", line 74, in triton_kernel
  File "/root/triton/python/triton/compiler/compiler.py", line 611, in compile
    next_module = compile_kernel(module)
  File "/root/triton/python/triton/compiler/compiler.py", line 510, in <lambda>
    lambda src: optimize_ttgir(ttir_to_ttgir(src, num_warps, warp_size, num_ctas, arch), num_stages, num_warps, num_ctas, arch, cluster_info, enable_warp_specialization, enable_persistent, optimize_epilogue, matrix_instr_nonkdim))
  File "/root/triton/python/triton/compiler/compiler.py", line 155, in optimize_ttgir
    pm.run(mod)
RuntimeError: PassManager::run failed

This PR fixes layout propagation algorithm in RemoveLayoutConversions pass.
In some cases during rewriteSlice process, reduce operation with multiple outputs
rewrites only one output layout, which breaks assumption that both outputs should have same layout.

This change is a minimal part of triton-lang#2331 change and
small lit test for regression testing.
@binarman binarman force-pushed the redice_result_type_fix2 branch from 4728d39 to 9da8508 Compare October 30, 2023 19:28
Copy link
Copy Markdown

@jataylo jataylo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Our core UTs and noted failing cases are passing after this change.

@zhanglx13 zhanglx13 merged commit 74c5fd4 into ROCm:triton-mlir Nov 1, 2023
jataylo pushed a commit that referenced this pull request Nov 3, 2023
* [RemoveLayoutConversions] Fix reduce failed infer type error

This PR fixes layout propagation algorithm in RemoveLayoutConversions pass.
In some cases during rewriteSlice process, reduce operation with multiple outputs
rewrites only one output layout, which breaks assumption that both outputs should have same layout.

This change is a minimal part of triton-lang#2331 change and
small lit test for regression testing.

* fix combine test

* Fix issue with incorrect inference layout of make_range output result
guacamoleo pushed a commit that referenced this pull request Dec 10, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

argmax tl.reduce gives error: 'tt.reduce' op failed to infer returned types

4 participants