Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions src/tir/schedule/primitive/reduction.cc
Original file line number Diff line number Diff line change
Expand Up @@ -912,7 +912,9 @@ class RFactorBlockCreator : public BaseBlockCreator {
write_regions_.reserve(old_block->writes.size());
for (const BufferRegion& write_region : old_block->writes) {
Array<Range> region = write_region->region;
region.insert(region.begin() + factor_axis_, Range::FromMinExtent(additional_iter_->var, 1));
region.insert(region.begin() + factor_axis_,
Range::FromMinExtent(additional_iter_->var,
make_const(additional_iter_->var.dtype(), 1)));
Optional<Buffer> rf_buffer = buffer_map.Get(write_region->buffer);
ICHECK(rf_buffer.defined());
write_regions_.push_back(BufferRegion(rf_buffer.value(), Substitute(region, var_map_)));
Expand Down Expand Up @@ -1005,7 +1007,7 @@ class WriteBackBlockCreator : public BaseBlockCreator {
Array<Range> region;
region.reserve(buf_load->indices.size());
for (const PrimExpr& index : buf_load->indices) {
region.push_back(Range::FromMinExtent(index, 1));
region.push_back(Range::FromMinExtent(index, make_const(index.dtype(), 1)));
}
buf_regions.push_back(BufferRegion(buf_load->buffer, std::move(region)));
}
Expand Down
57 changes: 57 additions & 0 deletions tests/python/unittest/test_tir_schedule_rfactor.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# under the License.
# pylint: disable=missing-function-docstring,missing-module-docstring
import pytest

import tvm
import tvm.testing
from tvm import te, tir, topi
Expand Down Expand Up @@ -1643,5 +1644,61 @@ def test_reduction_rfactor_topi_argmin():
verify_trace_roundtrip(s, mod=argmin_topi)


def test_reduction_rfactor_int64():
# fmt: off
@T.prim_func
def before(
A: T.Buffer((T.int64(128), T.int64(128)), "float32"),
B: T.Buffer((T.int64(128), T.int64(128)), "float32"),
C: T.Buffer((T.int64(128), T.int64(128)), "float32"),
):
for i0, i1, i2_outer, i2_inner_outer, i2_inner_inner in T.grid(
T.int64(128), T.int64(128), T.int64(4), T.int64(8), T.int64(4)
):
with T.block("update"):
vi, vj = T.axis.remap("SS", [i0, i1])
vk = T.axis.R(
T.int64(128),
i2_outer * T.int64(32) + i2_inner_outer * T.int64(4) + i2_inner_inner,
)
with T.init():
C[vi, vj] = 0.0
C[vi, vj] = C[vi, vj] + (A[vi, vk] * B[vj, vk])

@T.prim_func
def expected(A: T.Buffer((T.int64(128), T.int64(128)), "float32"),
B: T.Buffer((T.int64(128), T.int64(128)), "float32"),
C: T.Buffer((T.int64(128), T.int64(128)), "float32"),
):
C_rf = T.alloc_buffer((T.int64(4), T.int64(128), T.int64(128)), "float32")

for i0, i1, i2_outer, i2_inner_outer, i2_inner_inner in T.grid(T.int64(128), T.int64(128), T.int64(4), T.int64(8), T.int64(4)):
with T.block("update_rf"):
vi2_inner_inner, vi, vj, vi2_outer, vi2_inner_outer= T.axis.remap("SSSRR", [i2_inner_inner, i0, i1, i2_outer, i2_inner_outer])
with T.init():
C_rf[vi2_inner_inner, vi, vj] = 0.0
C_rf[vi2_inner_inner, vi, vj] = C_rf[vi2_inner_inner, vi, vj] + (
A[vi, (((vi2_outer * T.int64(32)) + (vi2_inner_outer * T.int64(4))) + vi2_inner_inner)]
* B[vj, (((vi2_outer * T.int64(32)) + (vi2_inner_outer * T.int64(4))) + vi2_inner_inner)]
)

for i0_1, i1_1, i2_inner_inner_1 in T.grid(T.int64(128), T.int64(128), T.int64(4)):
with T.block("update"):
vi2_inner_inner_1, vi_1, vj_1 = T.axis.remap("RSS", [i2_inner_inner_1, i0_1, i1_1])
with T.init():
C[vi_1, vj_1] = 0.0
C[vi_1, vj_1] = C[vi_1, vj_1] + C_rf[vi2_inner_inner_1, vi_1, vj_1]
# fmt: on

s = tir.Schedule(before, debug_mask="all")
update = s.get_block("update")
_, _, _, _, kii = s.get_loops(update)
rf_block = s.rfactor(kii, 0)
assert_structural_equal_ignore_global_symbol(s.mod["main"], expected)
assert s.get(rf_block).same_as(s.get(s.get_block("update_rf")))
assert s.get(update).same_as(s.get(s.get_block("update")))
verify_trace_roundtrip(s, mod=before)


if __name__ == "__main__":
tvm.testing.main()