Skip to content

Commit 4267fbf

Browse files
chunit-quicJoey Tsai
andauthored
[IR][SIBuilder] (#14574)
* [IR][SIBuilder] - Add SIBuilder to handle the span propagation between passes - Add SequentialSpan for multiple source expressions conversion between passes - Add test cases for SIBuilder and SequentialSpan * [IR][SIBuilder] - Make null implementation as base class - Add comments and change naming based on reviewing --------- Co-authored-by: Joey Tsai <[email protected]>
1 parent 43f06ca commit 4267fbf

File tree

9 files changed

+952
-2
lines changed

9 files changed

+952
-2
lines changed

include/tvm/ir/si_builder.h

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
/*!
21+
* \file tvm/ir/si_builder.h
22+
* \brief build a source info during rewriting expressions.
23+
*/
24+
#ifndef TVM_IR_SI_BUILDER_H_
25+
#define TVM_IR_SI_BUILDER_H_
26+
27+
#include <tvm/relay/expr.h>
28+
#include <tvm/relay/expr_functor.h>
29+
#include <tvm/tir/stmt.h>
30+
31+
#include <memory>
32+
#include <unordered_set>
33+
34+
namespace tvm {
35+
36+
/*!
37+
* \brief Source Information Builder, SIBuilder provides helper APIs for filling spans,
38+
* particularly useful for one-to-many, many-to-one and many-to-many IR transformations.
39+
*/
40+
class SIBuilder {
41+
public:
42+
/*!
43+
* \brief Create SIBuilder from a given span
44+
*/
45+
explicit SIBuilder(const Span& span = Span());
46+
47+
/*!
48+
* \brief Create SIBuilder from a given span sequence
49+
*/
50+
explicit SIBuilder(const Array<Span>& spans = Array<Span>());
51+
explicit SIBuilder(const std::initializer_list<Span>& init);
52+
53+
/*!
54+
* \brief Create SIBuilder via a subgraph,
55+
* Will construct span based on the exprs in the subgraph. Including the inputs exprs.
56+
*
57+
* \param entry Entry expr for subgraph
58+
* \param inputs End exprs for subgraph
59+
*/
60+
template <typename T, typename = std::enable_if_t<std::is_base_of<BaseExpr, T>::value>>
61+
explicit SIBuilder(const T& entry, const tvm::Array<T>& inputs = {});
62+
explicit SIBuilder(const tir::Stmt& entry, const tvm::Array<PrimExpr>& inputs = {});
63+
explicit SIBuilder(const tir::Stmt& entry, const tvm::Array<tir::Stmt>& inputs = {});
64+
65+
~SIBuilder();
66+
67+
SIBuilder(const SIBuilder&) = delete;
68+
SIBuilder& operator=(const SIBuilder&) = delete;
69+
70+
/*!
71+
* \brief build a span of source information, which is based on the given span or subgraph.
72+
*
73+
* \return the built span
74+
*/
75+
Span Build() const;
76+
77+
/*!
78+
* \brief Recursively fill all span of exprs in subgraph from entry until inputs.
79+
*
80+
* \param entry Entry expr for subgraph.
81+
* \param inputs End exprs for subgraph, will not be filled with new span.
82+
*/
83+
template <typename T, typename = std::enable_if_t<std::is_base_of<BaseExpr, T>::value>>
84+
void RecursivelyFillSpan(
85+
const T& entry, const std::unordered_set<T, ObjectPtrHash, ObjectPtrEqual>& inputs) const;
86+
87+
void RecursivelyFillSpan(
88+
const tir::Stmt& entry,
89+
const std::unordered_set<PrimExpr, ObjectPtrHash, ObjectPtrEqual>& inputs) const;
90+
void RecursivelyFillSpan(
91+
const tir::Stmt& entry,
92+
const std::unordered_set<tir::Stmt, ObjectPtrHash, ObjectPtrEqual>& inputs) const;
93+
94+
private:
95+
struct Impl;
96+
std::unique_ptr<Impl> impl_;
97+
98+
std::unique_ptr<Impl> CreateImpl(const Span& span);
99+
};
100+
101+
} // namespace tvm
102+
103+
#endif // TVM_IR_SI_BUILDER_H_

include/tvm/ir/source_map.h

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ class SpanNode : public Object {
114114
}
115115

116116
static constexpr const char* _type_key = "Span";
117-
TVM_DECLARE_FINAL_OBJECT_INFO(SpanNode, Object);
117+
TVM_DECLARE_BASE_OBJECT_INFO(SpanNode, Object);
118118
};
119119

120120
class Span : public ObjectRef {
@@ -127,6 +127,50 @@ class Span : public ObjectRef {
127127
TVM_DEFINE_OBJECT_REF_METHODS(Span, ObjectRef, SpanNode);
128128
};
129129

130+
/*!
131+
* \brief Store a list of spans for an expr generated from mulitple source exprs
132+
*/
133+
class SequentialSpanNode : public SpanNode {
134+
public:
135+
/*! \brief The original source list of spans to construct a sequential span. */
136+
Array<Span> spans;
137+
138+
// override attr visitor
139+
void VisitAttrs(AttrVisitor* v) {
140+
SpanNode::VisitAttrs(v);
141+
v->Visit("spans", &spans);
142+
}
143+
144+
static constexpr const char* _type_key = "SequentialSpan";
145+
TVM_DECLARE_FINAL_OBJECT_INFO(SequentialSpanNode, SpanNode);
146+
147+
bool SEqualReduce(const SequentialSpanNode* other, SEqualReducer equal) const {
148+
if (spans.size() != other->spans.size()) {
149+
return false;
150+
}
151+
152+
for (size_t i = 0, e = spans.size(); i != e; ++i) {
153+
if (!StructuralEqual()(spans[i], other->spans[i])) {
154+
return false;
155+
}
156+
}
157+
return true;
158+
}
159+
};
160+
161+
/*!
162+
* \brief Reference class of SequentialSpanNode.
163+
* \sa SequentialSpanNode
164+
*/
165+
class SequentialSpan : public Span {
166+
public:
167+
TVM_DLL SequentialSpan(Array<Span> spans);
168+
169+
TVM_DLL SequentialSpan(std::initializer_list<Span> init);
170+
171+
TVM_DEFINE_OBJECT_REF_METHODS(SequentialSpan, Span, SequentialSpanNode);
172+
};
173+
130174
/*! \brief A program source in any language.
131175
*
132176
* Could represent the source from an ML framework or a source

python/tvm/ir/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
Node,
2626
SourceName,
2727
Span,
28+
SequentialSpan,
2829
assert_structural_equal,
2930
load_json,
3031
save_json,

python/tvm/ir/base.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,23 @@ def __init__(self, source_name, line, end_line, column, end_column):
6969
)
7070

7171

72+
@register_object("SequentialSpan")
73+
class SequentialSpan(Object):
74+
"""A sequence of source spans
75+
76+
This span is specific for an expression, which is from multiple expressions
77+
after an IR transform.
78+
79+
Parameters
80+
----------
81+
spans : Array
82+
The array of spans.
83+
"""
84+
85+
def __init__(self, spans):
86+
self.__init_handle_by_constructor__(_ffi_api.SequentialSpan, spans)
87+
88+
7289
@register_object
7390
class EnvFunc(Object):
7491
"""Environment function.

python/tvm/relay/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@
7373

7474
# Span
7575
Span = base.Span
76+
SequentialSpan = base.SequentialSpan
7677
SourceName = base.SourceName
7778

7879
# Type

python/tvm/relay/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020

2121
import tvm._ffi
2222
from tvm.ir import Node as RelayNode
23-
from tvm.ir import SourceName, Span
23+
from tvm.ir import SourceName, Span, SequentialSpan
2424
from tvm.runtime import Object
2525

2626
from . import _ffi_api

0 commit comments

Comments
 (0)