Skip to content

Commit 8c0b009

Browse files
committed
[TIR] Unify index data type when creating prim func
1 parent fbe174b commit 8c0b009

File tree

14 files changed

+647
-192
lines changed

14 files changed

+647
-192
lines changed
Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,153 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
/*!
21+
* \file data_type_rewriter.h
22+
* \brief Rewrite the data type of expressions.
23+
*/
24+
#ifndef TVM_TIR_DATA_TYPE_REWRITER_H_
25+
#define TVM_TIR_DATA_TYPE_REWRITER_H_
26+
27+
#include <tvm/tir/stmt_functor.h>
28+
29+
#include <unordered_map>
30+
31+
namespace tvm {
32+
namespace tir {
33+
34+
/*!
35+
* \brief Legalize the data types of expressions to make sure they are consistent with other
36+
* parts of the program.
37+
*
38+
* It enforces the following rules:
39+
* - The data type of the index variable in a loop must be consistent with the data type of the loop
40+
* bounds.
41+
* - The data type of the binary and ternary expressions must be consistent with the data types of
42+
* each of their operands.
43+
* - The data type of the bounds and binding values of block iter vars must be consistent with the
44+
* data type of the block iter vars.
45+
*
46+
* Usually we enforce the consistency of data types when constructing the IR nodes. However, such
47+
* inconsistency may happen as a result of IR mutation in some passes. This class can be used as
48+
* base class of such passes to ensure the consistency of data types.
49+
*/
50+
class DataTypeLegalizer : public StmtExprMutator {
51+
protected:
52+
Stmt VisitStmt_(const ForNode* op) override;
53+
Stmt VisitStmt_(const AttrStmtNode* op) override;
54+
Stmt VisitStmt_(const BlockRealizeNode* op) override;
55+
Stmt VisitStmt_(const BlockNode* op) override;
56+
PrimExpr VisitExpr_(const SelectNode* op) override;
57+
PrimExpr VisitExpr_(const RampNode* op) override;
58+
PrimExpr VisitExpr_(const AddNode* op) override;
59+
PrimExpr VisitExpr_(const SubNode* op) override;
60+
PrimExpr VisitExpr_(const MulNode* op) override;
61+
PrimExpr VisitExpr_(const DivNode* op) override;
62+
PrimExpr VisitExpr_(const ModNode* op) override;
63+
PrimExpr VisitExpr_(const FloorDivNode* op) override;
64+
PrimExpr VisitExpr_(const FloorModNode* op) override;
65+
PrimExpr VisitExpr_(const MinNode* op) override;
66+
PrimExpr VisitExpr_(const MaxNode* op) override;
67+
PrimExpr VisitExpr_(const EQNode* op) override;
68+
PrimExpr VisitExpr_(const NENode* op) override;
69+
PrimExpr VisitExpr_(const LTNode* op) override;
70+
PrimExpr VisitExpr_(const LENode* op) override;
71+
PrimExpr VisitExpr_(const GTNode* op) override;
72+
PrimExpr VisitExpr_(const GENode* op) override;
73+
PrimExpr VisitExpr_(const CallNode* op) override;
74+
75+
using StmtExprMutator::VisitExpr_;
76+
using StmtExprMutator::VisitStmt_;
77+
78+
// a map from IterVar before rewrite to that after rewrite,
79+
// ensures one old IterVar maps to exactly one new IterVar
80+
std::unordered_map<const IterVarNode*, IterVar> ivmap_;
81+
};
82+
83+
/*!
84+
* \brief Data type rewriter for buffer indices.
85+
*
86+
* Detect the components of buffer indices that should be considered for data type rewriting.
87+
* This class doesn't perform actual rewriting of data types. During recursive visiting, the
88+
* internal flags `is_enabled_` and `is_conditional_` are used to indicate whether the current
89+
* expression is a buffer index or a conditional expression, which can be used in the sub-classes to
90+
* implement different rewriting rules.
91+
*/
92+
class IndexDataTypeRewriter : public DataTypeLegalizer {
93+
using Parent = DataTypeLegalizer;
94+
95+
protected:
96+
Stmt VisitStmt_(const BlockRealizeNode* op) override;
97+
Stmt VisitStmt_(const BlockNode* op) override;
98+
Stmt VisitStmt_(const BufferStoreNode* op) override;
99+
PrimExpr VisitExpr_(const BufferLoadNode* op) override;
100+
Array<PrimExpr> VisitIndices(Array<PrimExpr> indices);
101+
Stmt VisitStmt_(const IfThenElseNode* op) override;
102+
Stmt VisitStmt_(const DeclBufferNode* op) override;
103+
Stmt VisitStmt_(const AllocateNode* op) override;
104+
PrimExpr VisitExpr_(const EQNode* op) override;
105+
PrimExpr VisitExpr_(const NENode* op) override;
106+
PrimExpr VisitExpr_(const LTNode* op) override;
107+
PrimExpr VisitExpr_(const LENode* op) override;
108+
PrimExpr VisitExpr_(const GTNode* op) override;
109+
PrimExpr VisitExpr_(const GENode* op) override;
110+
PrimExpr VisitExpr_(const CallNode* op) override;
111+
Stmt VisitStmt_(const ForNode* op) override;
112+
113+
using DataTypeLegalizer::VisitExpr_;
114+
using DataTypeLegalizer::VisitStmt_;
115+
116+
Buffer VisitBuffer(const Buffer& buffer);
117+
Buffer GetRemappedBuffer(const Buffer& buffer);
118+
Map<String, ObjectRef> VisitBlockAnnotations(const Map<String, ObjectRef>& annotations);
119+
BufferRegion VisitBufferRegion(const BufferRegion& region);
120+
IterVar VisitIterVar(const IterVar& iter_var);
121+
// indicator of index expr to rewrite
122+
bool is_enabled_{false};
123+
// indicator of condition
124+
bool is_condition_{false};
125+
126+
Map<Var, Var> var_remap_;
127+
Map<Buffer, Buffer> buffer_remap_;
128+
};
129+
130+
/*!
131+
* \brief Normalize the data types of buffer shapes and indices to the same data type.
132+
*
133+
* This pass rewrites the data types of buffer shapes and indices to the specified data type. It
134+
* assumes the specified data type is large enough to hold the original ranges of buffer shapes and
135+
* indices.
136+
*/
137+
class IndexDataTypeNormalizer : public IndexDataTypeRewriter {
138+
public:
139+
explicit IndexDataTypeNormalizer(DataType target_data_type);
140+
PrimFunc Rewrite(PrimFunc func);
141+
142+
private:
143+
PrimExpr VisitExpr_(const IntImmNode* op) final;
144+
PrimExpr VisitExpr_(const VarNode* op) final;
145+
PrimExpr VisitExpr_(const SizeVarNode* op) final;
146+
147+
DataType target_data_type_ = DataType::Int(64);
148+
};
149+
150+
} // namespace tir
151+
} // namespace tvm
152+
153+
#endif // TVM_TIR_DATA_TYPE_REWRITER_H_

include/tvm/tir/stmt.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -858,6 +858,7 @@ class IfThenElse : public Stmt {
858858
Span span = Span());
859859

860860
TVM_DEFINE_OBJECT_REF_METHODS(IfThenElse, Stmt, IfThenElseNode);
861+
TVM_DEFINE_OBJECT_REF_COW_METHOD(IfThenElseNode);
861862
};
862863

863864
/*!

include/tvm/tir/stmt_functor.h

Lines changed: 0 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -485,56 +485,6 @@ bool ContainsNode(const Stmt& stmt) {
485485
return visitor.contains_node;
486486
}
487487

488-
/*!
489-
* \brief Legalize the data types of expressions to make sure they are consistent with other
490-
* parts of the program.
491-
*
492-
* It enforces the following rules:
493-
* - The data type of the index variable in a loop must be consistent with the data type of the loop
494-
* bounds.
495-
* - The data type of the binary and ternary expressions must be consistent with the data types of
496-
* each of their operands.
497-
* - The data type of the bounds and binding values of block iter vars must be consistent with the
498-
* data type of the block iter vars.
499-
*
500-
* Usually we enforce the consistency of data types when constructing the IR nodes. However, such
501-
* inconsistency may happen as a result of IR mutation in some passes. This class can be used as
502-
* base class of such passes to ensure the consistency of data types.
503-
*/
504-
class DataTypeLegalizer : public StmtExprMutator {
505-
protected:
506-
Stmt VisitStmt_(const ForNode* op) override;
507-
508-
Stmt VisitStmt_(const AttrStmtNode* op) override;
509-
Stmt VisitStmt_(const BlockRealizeNode* op) override;
510-
Stmt VisitStmt_(const BlockNode* op) override;
511-
PrimExpr VisitExpr_(const SelectNode* op) override;
512-
PrimExpr VisitExpr_(const RampNode* op) override;
513-
PrimExpr VisitExpr_(const AddNode* op) override;
514-
PrimExpr VisitExpr_(const SubNode* op) override;
515-
PrimExpr VisitExpr_(const MulNode* op) override;
516-
PrimExpr VisitExpr_(const DivNode* op) override;
517-
PrimExpr VisitExpr_(const ModNode* op) override;
518-
PrimExpr VisitExpr_(const FloorDivNode* op) override;
519-
PrimExpr VisitExpr_(const FloorModNode* op) override;
520-
PrimExpr VisitExpr_(const MinNode* op) override;
521-
PrimExpr VisitExpr_(const MaxNode* op) override;
522-
PrimExpr VisitExpr_(const EQNode* op) override;
523-
PrimExpr VisitExpr_(const NENode* op) override;
524-
PrimExpr VisitExpr_(const LTNode* op) override;
525-
PrimExpr VisitExpr_(const LENode* op) override;
526-
PrimExpr VisitExpr_(const GTNode* op) override;
527-
PrimExpr VisitExpr_(const GENode* op) override;
528-
PrimExpr VisitExpr_(const CallNode* op) override;
529-
530-
using StmtExprMutator::VisitExpr_;
531-
using StmtExprMutator::VisitStmt_;
532-
533-
// a map from IterVar before rewrite to that after rewrite,
534-
// ensures one old IterVar maps to exactly one new IterVar
535-
std::unordered_map<const IterVarNode*, IterVar> ivmap_;
536-
};
537-
538488
} // namespace tir
539489
} // namespace tvm
540490

python/tvm/te/operation.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
# pylint: disable=invalid-name
2121
from numbers import Integral as _Integral
22-
from typing import List
22+
from typing import List, Optional
2323

2424
import tvm._ffi
2525
import tvm.arith._ffi_api
@@ -566,7 +566,9 @@ def reduce_axis(dom, name="rv", thread_tag="", span=None):
566566
return tvm.tir.IterVar(dom, name, 2, thread_tag, span)
567567

568568

569-
def create_prim_func(ops: List[_tensor.Tensor]) -> tvm.tir.PrimFunc:
569+
def create_prim_func(
570+
ops: List[_tensor.Tensor], index_dtype_override: Optional[str] = None
571+
) -> tvm.tir.PrimFunc:
570572
"""Create a TensorIR PrimFunc from tensor expression
571573
572574
Parameters
@@ -618,4 +620,4 @@ def tir_matmul(a: T.handle, b: T.handle, c: T.handle) -> None:
618620
"""
619621
if not isinstance(ops, (list, tuple, Array)):
620622
ops = [ops]
621-
return _ffi_api.CreatePrimFunc(ops)
623+
return _ffi_api.CreatePrimFunc(ops, index_dtype_override)

src/relay/backend/utils.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -416,7 +416,7 @@ Optional<tir::PrimFunc> DefaultTIRConverterImpl(const Array<te::Tensor>& args,
416416
return NullOpt;
417417
}
418418
}
419-
PrimFunc func = te::CreatePrimFuncWithConstants(args, constants);
419+
PrimFunc func = te::CreatePrimFuncWithConstants(args, constants, DataType::Int(64));
420420
bool dynamic_loop_extent = false;
421421
tir::PostOrderVisit(func->body, [&dynamic_loop_extent](const ObjectRef& obj) -> void {
422422
if (const auto* loop = obj.as<tir::ForNode>()) {

src/te/operation/create_primfunc.cc

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
#include <tvm/arith/analyzer.h>
2323
#include <tvm/ir/name_supply.h>
2424
#include <tvm/runtime/registry.h>
25+
#include <tvm/tir/data_type_rewriter.h>
2526
#include <tvm/tir/function.h>
2627
#include <tvm/tir/stmt_functor.h>
2728

@@ -486,7 +487,8 @@ PrimFunc GenerateAndCompletePrimFunc(const Array<te::Tensor>& arg_list,
486487
}
487488

488489
PrimFunc CreatePrimFuncWithConstants(const Array<te::Tensor>& arg_list,
489-
const Array<runtime::NDArray>& constants) {
490+
const Array<runtime::NDArray>& constants,
491+
std::optional<DataType> index_dtype_override) {
490492
// Infomations used in CreatePrimFunc and its sub-functions.
491493
CreateFuncInfo info(arg_list);
492494
// Root body stmts.
@@ -508,14 +510,27 @@ PrimFunc CreatePrimFuncWithConstants(const Array<te::Tensor>& arg_list,
508510
// Step 4. Create func and complete prim func.
509511
auto func = GenerateAndCompletePrimFunc(arg_list, root_stmts, &info);
510512
func = tir::BindParams(func, constants);
511-
return LayoutFreePlaceholdersNormalizer().Process(std::move(func));
513+
if (index_dtype_override.has_value()) {
514+
func = IndexDataTypeNormalizer(index_dtype_override.value()).Rewrite(std::move(func));
515+
}
516+
auto result = LayoutFreePlaceholdersNormalizer().Process(std::move(func));
517+
return result;
512518
}
513519

514-
PrimFunc CreatePrimFunc(const Array<te::Tensor>& arg_list) {
515-
return CreatePrimFuncWithConstants(arg_list, {});
520+
PrimFunc CreatePrimFunc(const Array<te::Tensor>& arg_list,
521+
std::optional<DataType> index_dtype_override) {
522+
return CreatePrimFuncWithConstants(arg_list, {}, index_dtype_override);
516523
}
517524

518-
TVM_REGISTER_GLOBAL("te.CreatePrimFunc").set_body_typed(CreatePrimFunc);
525+
TVM_REGISTER_GLOBAL("te.CreatePrimFunc").set_body([](TVMArgs args, TVMRetValue* ret) {
526+
Array<te::Tensor> arg_list = args[0];
527+
std::optional<DataType> index_dtype_override{std::nullopt};
528+
// Add conversion to make std::optional compatible with FFI.
529+
if (args[1].type_code() != kTVMNullptr) {
530+
index_dtype_override = args[1].operator DataType();
531+
}
532+
*ret = CreatePrimFunc(arg_list, index_dtype_override);
533+
});
519534

520535
} // namespace tir
521536
} // namespace tvm

src/te/operation/create_primfunc.h

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,19 +24,23 @@
2424
#include <tvm/te/tensor.h>
2525
#include <tvm/tir/function.h>
2626

27+
#include <optional>
28+
2729
namespace tvm {
2830
namespace tir {
2931

3032
/*! \brief Use Tensor Expression to create a schedulable TensorIR func. */
31-
PrimFunc CreatePrimFunc(const Array<te::Tensor>& arg_list);
33+
PrimFunc CreatePrimFunc(const Array<te::Tensor>& arg_list,
34+
std::optional<DataType> index_dtype_override = std::nullopt);
3235

3336
/*! \brief The same as above but create a PrimFunc with AllocateConstNode. If the size of the
3437
* constants array is N, the last N tensors in arg_list will be treated as constant tensors.
3538
* Constant tensors will not be part of the parameters of the created PrimFunc, instead constants
3639
* will be embedded in the body as AllocateConstNode.
3740
*/
3841
PrimFunc CreatePrimFuncWithConstants(const Array<te::Tensor>& arg_list,
39-
const Array<runtime::NDArray>& constants);
42+
const Array<runtime::NDArray>& constants,
43+
std::optional<DataType> index_dtype_override = std::nullopt);
4044

4145
} // namespace tir
4246
} // namespace tvm

0 commit comments

Comments
 (0)