Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion include/tvm/tir/data_type_rewriter.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ class DataTypeLegalizer : public StmtExprMutator {
Stmt VisitStmt_(const AttrStmtNode* op) override;
Stmt VisitStmt_(const BlockRealizeNode* op) override;
Stmt VisitStmt_(const BlockNode* op) override;
Stmt VisitStmt_(const LetStmtNode* op) override;
PrimExpr VisitExpr_(const VarNode* op) override;
PrimExpr VisitExpr_(const SelectNode* op) override;
PrimExpr VisitExpr_(const RampNode* op) override;
PrimExpr VisitExpr_(const AddNode* op) override;
Expand All @@ -79,6 +81,8 @@ class DataTypeLegalizer : public StmtExprMutator {
// a map from IterVar before rewrite to that after rewrite,
// ensures one old IterVar maps to exactly one new IterVar
std::unordered_map<const IterVarNode*, IterVar> ivmap_;
// a map from original vars to ones with new dtype
std::unordered_map<const VarNode*, Var> var_remap_;
};

/*!
Expand Down Expand Up @@ -123,7 +127,6 @@ class IndexDataTypeRewriter : public DataTypeLegalizer {
// indicator of condition
bool is_condition_{false};

Map<Var, Var> var_remap_;
Map<Buffer, Buffer> buffer_remap_;
};

Expand Down
43 changes: 35 additions & 8 deletions src/tir/ir/data_type_rewriter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,35 @@ Stmt DataTypeLegalizer::VisitStmt_(const AttrStmtNode* op) {
return StmtExprMutator::VisitStmt_(op);
}

Stmt DataTypeLegalizer::VisitStmt_(const LetStmtNode* op) {
PrimExpr value = this->VisitExpr(op->value);
auto new_var = op->var.copy_with_dtype(value.dtype());

if (value.dtype() != op->var->dtype) {
var_remap_[op->var.get()] = new_var;
}

Stmt new_body = this->VisitStmt(op->body);

if (value.same_as(op->value) && new_body.same_as(op->body)) {
return GetRef<Stmt>(op);
} else if (value.dtype() == op->var->dtype) {
auto n = CopyOnWrite(op);
n->value = std::move(value);
n->body = std::move(new_body);
return Stmt(n);
} else {
return LetStmt(new_var, value, new_body, op->span);
}
}

PrimExpr DataTypeLegalizer::VisitExpr_(const VarNode* op) {
if (auto it = var_remap_.find(op); it != var_remap_.end()) {
return it->second;
}
return GetRef<Var>(op);
}

PrimExpr DataTypeLegalizer::VisitExpr_(const SelectNode* op) {
PrimExpr condition = this->VisitExpr(op->condition);
PrimExpr true_value = this->VisitExpr(op->true_value);
Expand Down Expand Up @@ -397,6 +426,9 @@ Stmt IndexDataTypeRewriter::VisitStmt_(const BufferStoreNode* op) {

Buffer new_buffer = GetRemappedBuffer(op->buffer);
auto value = this->VisitExpr(op->value);
if (new_buffer->dtype != value->dtype && value->dtype.lanes() == 1) {
value = cast(new_buffer->dtype, value);
}
auto indices = VisitIndices(op->indices);

if (!new_buffer.same_as(op->buffer) || !value.same_as(op->value) ||
Expand Down Expand Up @@ -535,15 +567,10 @@ PrimExpr IndexDataTypeNormalizer::VisitExpr_(const IntImmNode* op) {
}

PrimExpr IndexDataTypeNormalizer::VisitExpr_(const VarNode* op) {
if (auto it = var_remap_.find(GetRef<Var>(op)); it != var_remap_.end()) {
return (*it).second;
}
if (is_enabled_ && op->dtype != target_data_type_) {
Var new_var = GetRef<Var>(op).copy_with_dtype(target_data_type_);
var_remap_.Set(GetRef<Var>(op), new_var);
return std::move(new_var);
if (is_enabled_ && op->dtype != target_data_type_ && !var_remap_.count(op)) {
var_remap_[op] = GetRef<Var>(op).copy_with_dtype(target_data_type_);
}
return GetRef<PrimExpr>(op);
return DataTypeLegalizer::VisitExpr_(op);
}

PrimExpr IndexDataTypeNormalizer::VisitExpr_(const CastNode* op) {
Expand Down
11 changes: 2 additions & 9 deletions src/tir/transforms/narrow_datatype.cc
Original file line number Diff line number Diff line change
Expand Up @@ -233,12 +233,8 @@ class NarrowDataTypeRewriter : public IndexDataTypeRewriter {
}

PrimExpr VisitExpr_(const VarNode* op) final {
if (auto it = var_remap_.find(GetRef<Var>(op)); it != var_remap_.end()) {
return (*it).second;
} else if (visitor_.vmap.find(op) != visitor_.vmap.end()) {
Var v = Var(op->name_hint, visitor_.vmap[op]);
var_remap_.Set(GetRef<Var>(op), v);
return v;
if (auto it = visitor_.vmap.find(op); !var_remap_.count(op) && it != visitor_.vmap.end()) {
var_remap_[op] = Var(op->name_hint, it->second);
}
return Parent::VisitExpr_(op);
}
Expand Down Expand Up @@ -266,9 +262,6 @@ class NarrowDataTypeRewriter : public IndexDataTypeRewriter {
private:
// the internal visitor to deduce the narrowed dtype
DataTypeVisitor visitor_;
// a map from Var before rewrite to that after rewrite,
// ensures one old Var maps to exactly one new Var
std::unordered_map<const VarNode*, Var> vmap_;
};

Stmt NarrowDataType(Stmt stmt, int target_bits) {
Expand Down
55 changes: 54 additions & 1 deletion tests/python/unittest/test_te_create_primfunc.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import numpy as np
import tvm
import tvm.testing
from tvm import te, tir, topi
from tvm import te, tir, topi, relay
from tvm.script import tir as T
import pytest

Expand Down Expand Up @@ -636,5 +636,58 @@ def test_reshape():
_check_workload(te_reshape, tir_reshape, index_dtype_override="int64")


@T.prim_func
def argmax_expected(
p0: T.Buffer[(T.int64(1), T.int64(64), T.int64(56), T.int64(56)), "uint8"],
p0_red: T.Buffer[(T.int64(1), T.int64(56), T.int64(56)), "int32"],
):
T.func_attr({"global_symbol": "main", "tir.noalias": True})
p0_red_temp_v0 = T.alloc_buffer([T.int64(1), T.int64(56), T.int64(56)], dtype="int32")
p0_red_temp_v1 = T.alloc_buffer([T.int64(1), T.int64(56), T.int64(56)], dtype="uint8")
for ax0, ax1, ax2, k1 in T.grid(T.int64(1), T.int64(56), T.int64(56), T.int64(64)):
with T.block("p0_red_temp"):
v_ax0, v_ax1, v_ax2, v_k1 = T.axis.remap("SSSR", [ax0, ax1, ax2, k1])
T.reads(p0[v_ax0, v_k1, v_ax1, v_ax2])
T.writes(p0_red_temp_v0[v_ax0, v_ax1, v_ax2], p0_red_temp_v1[v_ax0, v_ax1, v_ax2])
with T.init():
p0_red_temp_v0[v_ax0, v_ax1, v_ax2] = -1
p0_red_temp_v1[v_ax0, v_ax1, v_ax2] = T.uint8(0)
v_p0_red_temp_v0: T.int64 = T.Select(
p0_red_temp_v1[v_ax0, v_ax1, v_ax2] > p0[v_ax0, v_k1, v_ax1, v_ax2]
or (
p0_red_temp_v1[v_ax0, v_ax1, v_ax2] == p0[v_ax0, v_k1, v_ax1, v_ax2]
and T.Cast("int64", p0_red_temp_v0[v_ax0, v_ax1, v_ax2]) < v_k1
),
T.Cast("int64", p0_red_temp_v0[v_ax0, v_ax1, v_ax2]),
v_k1,
)
v_p0_red_temp_v1: T.uint8 = T.Select(
p0_red_temp_v1[v_ax0, v_ax1, v_ax2] > p0[v_ax0, v_k1, v_ax1, v_ax2],
p0_red_temp_v1[v_ax0, v_ax1, v_ax2],
p0[v_ax0, v_k1, v_ax1, v_ax2],
)
p0_red_temp_v0[v_ax0, v_ax1, v_ax2] = T.Cast("int32", v_p0_red_temp_v0)
p0_red_temp_v1[v_ax0, v_ax1, v_ax2] = v_p0_red_temp_v1
for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(56), T.int64(56)):
with T.block("p0_red"):
v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
T.reads(p0_red_temp_v0[v_ax0, v_ax1, v_ax2])
T.writes(p0_red[v_ax0, v_ax1, v_ax2])
p0_red[v_ax0, v_ax1, v_ax2] = p0_red_temp_v0[v_ax0, v_ax1, v_ax2]


def test_argmax():
data = relay.var("data", shape=(1, 64, 56, 56), dtype="uint8")
mod = tvm.IRModule.from_expr(relay.argmax(data, axis=1))

target = tvm.target.Target("llvm")

opt_mod, _ = relay.optimize(mod, params={}, target=target)

prim_func = relay.backend.te_compiler.lower_to_primfunc(opt_mod["main"].body.op, target)

tvm.ir.assert_structural_equal(prim_func, argmax_expected)


if __name__ == "__main__":
tvm.testing.main()