Skip to content

Commit c0f148a

Browse files
authored
[TIR][Analysis] Implement IdentifyMemCpy analysis function (#13947)
1 parent df429c5 commit c0f148a

File tree

3 files changed

+669
-0
lines changed

3 files changed

+669
-0
lines changed

include/tvm/tir/analysis.h

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,15 @@
3131
#include <tvm/tir/op_attr_types.h>
3232
#include <tvm/tir/stmt.h>
3333

34+
#include <optional>
3435
#include <string>
3536

3637
namespace tvm {
38+
39+
namespace arith {
40+
class Analyzer;
41+
}
42+
3743
namespace tir {
3844

3945
/*!
@@ -203,6 +209,29 @@ TVM_DLL Array<Array<BufferRegion>> GetBlockAccessRegion(const Block& block,
203209
TVM_DLL Array<Array<BufferRegion>> GetBlockReadWriteRegion(const Block& block,
204210
const Map<Var, Buffer>& buffer_var_map);
205211

212+
/*! \brief Helper struct for return value of IdentifyMemCpy
213+
*
214+
* This helper struct is not strictly necessary, as `IdentifyMemCpy`
215+
* could instead return a `std::pair<BufferRegion, BufferRegion>`.
216+
* However, that would introduce ambiguity between the two unnamed
217+
* regions.
218+
*/
219+
struct MemCpyDetails {
220+
BufferRegion source;
221+
BufferRegion dest;
222+
};
223+
224+
/*! \brief Identify whether a For loop is semantically equivalent to MemCpy
225+
*
226+
* \param loop The loop to be checked
227+
*
228+
* \param analyzer The analyzer with which to check any algebraic expressions
229+
*
230+
* \returns The source and destination regions being copied, if the
231+
* loop is equivalent to memcpy. Otherwise, returns nullopt.
232+
*/
233+
TVM_DLL std::optional<MemCpyDetails> IdentifyMemCpy(const For& loop, arith::Analyzer* analyzer);
234+
206235
/*!
207236
* \brief Calculate the expresion complexity based on number of symbols it contains.
208237
* \param expr The expr to be calculated.
Lines changed: 316 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,316 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
/*!
21+
* \file tir/analysis/identify_memcpy.cc
22+
* \brief Check if a loop nest is equivalent to memcpy
23+
*/
24+
25+
#include <tvm/arith/bound.h>
26+
#include <tvm/arith/iter_affine_map.h>
27+
#include <tvm/runtime/container/optional.h>
28+
#include <tvm/tir/analysis.h>
29+
#include <tvm/tir/buffer.h>
30+
#include <tvm/tir/stmt.h>
31+
32+
#include <optional>
33+
#include <sstream>
34+
#include <string>
35+
#include <variant>
36+
37+
#include "../../arith/ir_visitor_with_analyzer.h"
38+
39+
namespace tvm {
40+
namespace tir {
41+
42+
std::variant<MemCpyDetails, std::string> IdentifyMemCpyImpl(const For& loop,
43+
arith::Analyzer* analyzer) {
44+
Map<Var, arith::IntSet> loop_intervals;
45+
Map<Var, Range> loop_ranges;
46+
PrimExpr total_loop_iterations = 1;
47+
48+
// Walk through the loop nest, stopping at the first loop whose body
49+
// is not a loop.
50+
Stmt stmt = loop;
51+
while (auto* for_node = stmt.as<ForNode>()) {
52+
loop_ranges.Set(for_node->loop_var, Range::FromMinExtent(for_node->min, for_node->extent));
53+
loop_intervals.Set(for_node->loop_var,
54+
arith::IntSet::FromMinExtent(for_node->min, for_node->extent));
55+
total_loop_iterations = total_loop_iterations * for_node->extent;
56+
57+
stmt = for_node->body;
58+
}
59+
60+
BufferStore store;
61+
if (auto* ptr = stmt.as<BufferStoreNode>()) {
62+
store = GetRef<BufferStore>(ptr);
63+
} else {
64+
return static_cast<const std::stringstream&>(
65+
std::stringstream()
66+
<< "Expected innermost loop to have BufferStore body, but instead found " << stmt)
67+
.str();
68+
}
69+
70+
BufferLoad load;
71+
if (auto* ptr = store->value.as<BufferLoadNode>()) {
72+
load = GetRef<BufferLoad>(ptr);
73+
} else {
74+
return static_cast<const std::stringstream&>(
75+
std::stringstream()
76+
<< "Expected BufferStore's value to be BufferLoad, but instead found "
77+
<< store->value)
78+
.str();
79+
}
80+
81+
// Now, we have a BufferStore whose value is a BufferLoad. Because
82+
// non-flat physical indices are target-dependent, only handle cases
83+
// where the buffer will be flattened to a 1-d physical buffer.
84+
Array<PrimExpr> flattened_dst = store->buffer.OffsetOf(store->indices);
85+
Array<PrimExpr> flattened_src = load->buffer.OffsetOf(load->indices);
86+
87+
if (flattened_dst.size() != 1 || flattened_src.size() != 1) {
88+
return static_cast<const std::stringstream&>(
89+
std::stringstream()
90+
<< "Expected flattened dimension of src/dest to be 1, but found"
91+
<< flattened_src.size() << "-d src and " << flattened_dst.size() << "-d dst")
92+
.str();
93+
}
94+
PrimExpr src_index = flattened_src[0];
95+
PrimExpr dst_index = flattened_dst[0];
96+
97+
// First check, do the input/output form affine subsets of their
98+
// respective buffers?
99+
//
100+
// For example, should exclude the following, indices are not affine
101+
//
102+
// for i in T.serial(16):
103+
// B[i] = A[T.abs(i-8)]
104+
105+
auto src_iter_map = arith::DetectIterMap({src_index}, loop_ranges, Bool(true),
106+
arith::IterMapLevel::Bijective, analyzer);
107+
if (src_iter_map->errors.size()) {
108+
return static_cast<const std::stringstream&>(std::stringstream()
109+
<< "arith::DetectIterMap(src) returned "
110+
<< src_iter_map->errors.size() << " errors: ["
111+
<< src_iter_map->errors << "]"
112+
<< " for src_index = " << src_index)
113+
.str();
114+
}
115+
auto dst_iter_map = arith::DetectIterMap({dst_index}, loop_ranges, Bool(true),
116+
arith::IterMapLevel::Bijective, analyzer);
117+
if (dst_iter_map->errors.size()) {
118+
return static_cast<const std::stringstream&>(std::stringstream()
119+
<< "arith::DetectIterMap(dst) returned "
120+
<< dst_iter_map->errors.size() << " errors: ["
121+
<< dst_iter_map->errors << "]"
122+
<< " for dst_index = " << dst_index)
123+
.str();
124+
}
125+
126+
// Second check, are those affine subsets contiguous? If so, then
127+
// the index expressions will visit every location between the min
128+
// and the max. This checks surjectivity over a linear region,
129+
// which may not be the same as DetectIterMap's check of
130+
// surjectivity over the affine subset.
131+
//
132+
// For example, should exclude the following, doesn't touch all
133+
// output locations within the output region touched.
134+
//
135+
// for i in T.serial(16):
136+
// B[2*i] = A[i]
137+
//
138+
// Similarly, should exclude the following, doesn't touch all
139+
// input locations within the input region touched.
140+
//
141+
// for i in T.serial(16):
142+
// B[i] = A[2*i]
143+
total_loop_iterations = analyzer->Simplify(total_loop_iterations);
144+
auto src_interval = analyzer->int_set(src_index, loop_intervals);
145+
auto dst_interval = analyzer->int_set(dst_index, loop_intervals);
146+
147+
if (!src_interval.HasLowerBound() || !src_interval.HasUpperBound()) {
148+
return static_cast<const std::stringstream&>(std::stringstream()
149+
<< "Expected known bounds for src, but found "
150+
<< src_interval << " for expression " << src_index)
151+
.str();
152+
}
153+
if (!dst_interval.HasLowerBound() || !dst_interval.HasUpperBound()) {
154+
return static_cast<const std::stringstream&>(std::stringstream()
155+
<< "Expected known bounds for dst, but found "
156+
<< dst_interval << " for expression " << dst_index)
157+
.str();
158+
}
159+
160+
{
161+
PrimExpr must_prove = total_loop_iterations == src_interval.max() - src_interval.min() + 1;
162+
PrimExpr simplified = analyzer->Simplify(must_prove);
163+
if (!analyzer->CanProve(simplified)) {
164+
return static_cast<const std::stringstream&>(
165+
std::stringstream()
166+
<< "Mismatch between loop iterations (" << total_loop_iterations
167+
<< ") and number of src indices touched (" << src_interval
168+
<< ". Equality to prove simplified to " << simplified)
169+
.str();
170+
}
171+
}
172+
{
173+
PrimExpr must_prove = total_loop_iterations == dst_interval.max() - dst_interval.min() + 1;
174+
PrimExpr simplified = analyzer->Simplify(must_prove);
175+
if (!analyzer->CanProve(simplified)) {
176+
return static_cast<const std::stringstream&>(
177+
std::stringstream()
178+
<< "Mismatch between loop iterations (" << total_loop_iterations
179+
<< ") and number of dst indices touched (" << dst_interval
180+
<< ". Equality to prove simplified to " << simplified)
181+
.str();
182+
}
183+
}
184+
185+
// Third check, is there a transformation applied between the input
186+
// and output iterators?
187+
//
188+
// For example, the following would pass all checks so far, but
189+
// converts between row-major and column-major layouts, and could
190+
// not be specified as a memcpy.
191+
//
192+
// for i,j in T.grid(4,4):
193+
// B[i,j] = A[j,i]
194+
195+
auto src_iter_sum = src_iter_map->indices[0];
196+
auto dst_iter_sum = dst_iter_map->indices[0];
197+
198+
if (src_iter_sum->args.size() != dst_iter_sum->args.size()) {
199+
return static_cast<const std::stringstream&>(
200+
std::stringstream()
201+
<< "IterMap for src/dst unpacked to different number of IterSplitExpr: "
202+
<< src_iter_sum->args.size() << " for src, " << dst_iter_sum->args.size()
203+
<< " for dst. "
204+
<< "IterMaps were detected as src = " << src_iter_sum << ", dst = " << dst_iter_sum)
205+
.str();
206+
}
207+
std::vector<arith::IterSplitExpr> src_iter_terms(src_iter_sum->args.begin(),
208+
src_iter_sum->args.end());
209+
std::vector<arith::IterSplitExpr> dst_iter_terms(dst_iter_sum->args.begin(),
210+
dst_iter_sum->args.end());
211+
212+
auto make_comparison_tuple = [](const arith::IterSplitExpr& expr) {
213+
auto as_int_or_zero = [](auto& val) -> int64_t {
214+
if (auto* as_int = val.template as<IntImmNode>()) {
215+
return as_int->value;
216+
} else {
217+
return 0;
218+
}
219+
};
220+
return std::tuple{
221+
static_cast<bool>(expr->scale.as<IntImmNode>()), as_int_or_zero(expr->scale),
222+
static_cast<bool>(expr->extent.as<IntImmNode>()), as_int_or_zero(expr->lower_factor),
223+
static_cast<bool>(expr->lower_factor.as<IntImmNode>()), as_int_or_zero(expr->lower_factor),
224+
};
225+
};
226+
auto sorting_function = [&make_comparison_tuple](const arith::IterSplitExpr& lhs,
227+
const arith::IterSplitExpr& rhs) -> bool {
228+
return make_comparison_tuple(lhs) < make_comparison_tuple(rhs);
229+
};
230+
std::sort(src_iter_terms.begin(), src_iter_terms.end(), sorting_function);
231+
std::sort(dst_iter_terms.begin(), dst_iter_terms.end(), sorting_function);
232+
233+
for (size_t i = 0; i < src_iter_terms.size(); i++) {
234+
const arith::IterSplitExpr& src_term = src_iter_terms[i];
235+
const arith::IterSplitExpr& dst_term = dst_iter_terms[i];
236+
237+
if (!analyzer->CanProve(
238+
arith::NormalizeIterMapToExpr(src_term->source->source == dst_term->source->source))) {
239+
return static_cast<const std::stringstream&>(
240+
std::stringstream()
241+
<< "Term " << i << " had different source, src_term->source = " << src_term->source
242+
<< ", dst_term->source = " << dst_term->source)
243+
.str();
244+
}
245+
if (!analyzer->CanProve(src_term->lower_factor == dst_term->lower_factor)) {
246+
return static_cast<const std::stringstream&>(
247+
std::stringstream()
248+
<< "Term " << i << " had different lower_factor, src_term->lower_factor = "
249+
<< src_term->lower_factor
250+
<< ", dst_term->lower_factor = " << dst_term->lower_factor)
251+
.str();
252+
}
253+
if (!analyzer->CanProve(src_term->extent == dst_term->extent)) {
254+
return static_cast<const std::stringstream&>(
255+
std::stringstream()
256+
<< "Term " << i << " had different extent, src_term->extent = " << src_term->extent
257+
<< ", dst_term->extent = " << dst_term->extent)
258+
.str();
259+
}
260+
if (!analyzer->CanProve(src_term->scale == dst_term->scale)) {
261+
return static_cast<const std::stringstream&>(
262+
std::stringstream()
263+
<< "Term " << i << " had different scale, src_term->scale = " << src_term->scale
264+
<< ", dst_term->scale = " << dst_term->scale)
265+
.str();
266+
}
267+
}
268+
269+
BufferRegion src_region(load->buffer, arith::DomainTouched(loop, load->buffer, true, true));
270+
BufferRegion dst_region(store->buffer, arith::DomainTouched(loop, store->buffer, true, true));
271+
272+
return MemCpyDetails{src_region, dst_region};
273+
}
274+
275+
std::optional<MemCpyDetails> IdentifyMemCpy(const For& loop, arith::Analyzer* analyzer) {
276+
auto result = IdentifyMemCpyImpl(loop, analyzer);
277+
if (auto* ptr = std::get_if<MemCpyDetails>(&result)) {
278+
return *ptr;
279+
} else {
280+
return std::nullopt;
281+
}
282+
}
283+
284+
// Expose the IdentifyMemCpy functionality to Python API for purpose of unit testing.
285+
TVM_REGISTER_GLOBAL("tir.analysis._identify_memcpy").set_body_typed([](const Stmt& stmt) {
286+
Array<ObjectRef> output;
287+
288+
struct Visitor : arith::IRVisitorWithAnalyzer {
289+
explicit Visitor(Array<ObjectRef>* output) : output(output) {}
290+
Array<ObjectRef>* output;
291+
292+
private:
293+
using IRVisitorWithAnalyzer::VisitStmt_;
294+
void VisitStmt_(const ForNode* op) override {
295+
For loop = GetRef<For>(op);
296+
auto result = IdentifyMemCpyImpl(loop, &analyzer_);
297+
if (auto* ptr = std::get_if<MemCpyDetails>(&result)) {
298+
output->push_back(Array{ptr->source, ptr->dest});
299+
} else if (auto* ptr = std::get_if<std::string>(&result)) {
300+
output->push_back(StringImm(*ptr));
301+
} else {
302+
LOG(FATAL) << "Internal error, unhandled std::variant type";
303+
}
304+
305+
IRVisitorWithAnalyzer::VisitStmt_(op);
306+
}
307+
};
308+
309+
Visitor visitor(&output);
310+
visitor(stmt);
311+
312+
return output;
313+
});
314+
315+
} // namespace tir
316+
} // namespace tvm

0 commit comments

Comments
 (0)