Skip to content

Commit 9966735

Browse files
committed
Fix location emission
1 parent 0cf4cb8 commit 9966735

File tree

10 files changed

+139
-52
lines changed

10 files changed

+139
-52
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -271,3 +271,6 @@ gallery/how_to/work_with_microtvm/micro_tvmc.py
271271

272272
# Printed TIR code on disk
273273
*.tir
274+
275+
# GDB history file
276+
.gdb_history

gallery/tutorial/debug_tir.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
"""
18+
.. _tutorial-topi:
19+
20+
Debugging TIR
21+
=============
22+
23+
"""
24+
25+
# sphinx_gallery_start_ignore
26+
from tvm import testing
27+
28+
testing.utils.install_request_hook(depth=3)
29+
# sphinx_gallery_end_ignore
30+
31+
import tvm
32+
import tvm.testing
33+
import numpy as np
34+
from tvm.script import tir as T
35+
36+
# Installing dependencies
37+
#
38+
# .. code-block:: bash
39+
#
40+
# pip install -q tensorflow
41+
# apt-get -qq install curl
42+
43+
44+
@tvm.script.ir_module
45+
class MyModule:
46+
@T.prim_func
47+
def main(a: T.handle, b: T.handle):
48+
# We exchange data between function by handles, which are similar to pointer.
49+
T.func_attr({"global_symbol": "main", "tir.noalias": True})
50+
# Create buffer from handles.
51+
A = T.match_buffer(a, (8,), dtype="float32")
52+
B = T.match_buffer(b, (8,), dtype="float32")
53+
for i in range(8):
54+
# A block is an abstraction for computation.
55+
with T.block("B"):
56+
# Define a spatial block iterator and bind it to value i.
57+
vi = T.axis.spatial(8, i)
58+
assert 1 == 0, "Some numbers"
59+
B[vi] = A[vi] + 1.0
60+
61+
62+
print("Actually starting ------")
63+
with tvm.transform.PassContext(opt_level=3, config={"tir.enable_debug": True}):
64+
runtime_module = tvm.build(MyModule, target="llvm")
65+
66+
# print(runtime_module.get_source())
67+
print(type(runtime_module))
68+
69+
a = tvm.nd.array(np.arange(8).astype("float32"))
70+
b = tvm.nd.array(np.zeros((8,)).astype("float32"))
71+
print("EXECUTING ------")
72+
runtime_module(a, b)
73+
print(a)
74+
print(b)

src/printer/tir_text_printer_debug.cc

Lines changed: 44 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,25 +25,65 @@
2525

2626
#include "tir_text_printer_debug.h"
2727

28+
#include <optional>
2829
#include <string>
2930

3031
#include "text_printer.h"
3132

3233
namespace tvm {
3334
namespace tir {
3435

35-
std::string span_text(const Span& span) {
36+
std::optional<std::string> span_text(const Span& span) {
3637
if (!span.defined()) {
37-
return "missing";
38+
return std::nullopt;
3839
}
39-
std::string source("file");
40+
41+
std::string source("main.tir");
42+
// TODO(driazati): This segfaults even with a guard around source_name, so the
43+
// filename always defaults to main.tir (llvm ignores this filename anyways)
44+
// if (span->source_name.defined()) {
45+
// source = span->source_name->name;
46+
// }
4047
return source + ":" + std::to_string(span->line) + ":" + std::to_string(span->column);
4148
}
4249

50+
template <typename ObjectPtr>
51+
void add_all_relevant_lines(const std::vector<std::tuple<const ObjectPtr*, size_t>>& data,
52+
size_t current_line, Doc* output) {
53+
for (const auto& item : data) {
54+
if (std::get<1>(item) != current_line - 1) {
55+
// Item is not relevant for this line, skip it
56+
continue;
57+
}
58+
59+
// Print out the item's span info if present
60+
auto text = span_text(std::get<0>(item)->span);
61+
if (text.has_value()) {
62+
output << *text;
63+
} else {
64+
output << "missing";
65+
}
66+
output << ", ";
67+
}
68+
}
69+
4370
Doc TIRTextPrinterDebug::NewLine() {
4471
current_line_ += 1;
4572

46-
return TIRTextPrinter::NewLine();
73+
if (!show_spans_) {
74+
return TIRTextPrinter::NewLine();
75+
}
76+
77+
Doc output;
78+
79+
output << " [";
80+
81+
add_all_relevant_lines(exprs_by_line_, current_line_, &output);
82+
add_all_relevant_lines(stmts_by_line_, current_line_, &output);
83+
84+
output << "]" << TIRTextPrinter::NewLine();
85+
86+
return output;
4787
}
4888

4989
#define X(TypeName) \

src/printer/tir_text_printer_debug.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,8 @@ namespace tir {
3737

3838
class TIRTextPrinterDebug : public TIRTextPrinter {
3939
public:
40-
TIRTextPrinterDebug() : TIRTextPrinter(false, &meta_), current_line_(1) {}
40+
explicit TIRTextPrinterDebug(bool show_spans)
41+
: TIRTextPrinter(false, &meta_), current_line_(1), show_spans_(show_spans) {}
4142

4243
std::vector<std::tuple<const PrimExprNode*, size_t>> GetExprsByLine() const {
4344
return exprs_by_line_;
@@ -61,6 +62,9 @@ class TIRTextPrinterDebug : public TIRTextPrinter {
6162
// Line that the printer is currently printing
6263
size_t current_line_;
6364

65+
// Whether to include spans relevant to each line before a newline or not
66+
bool show_spans_;
67+
6468
// Record of all stmts and exprs and their corresponding line
6569
std::vector<std::tuple<const StmtNode*, size_t>> stmts_by_line_;
6670
std::vector<std::tuple<const PrimExprNode*, size_t>> exprs_by_line_;

src/target/llvm/codegen_cpu.cc

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -952,7 +952,6 @@ llvm::Value* CodeGenCPU::CreateCallPacked(const CallNode* op, bool use_string_lo
952952
}
953953

954954
llvm::Value* CodeGenCPU::CreateCallTracePacked(const CallNode* op) {
955-
EmitDebugLocation(op);
956955
ICHECK_EQ(op->args.size(), 6U);
957956
PackedCall pc = MakeCallPackedLowered(op->args, op->dtype, op->args[3].as<IntImmNode>()->value,
958957
op->args[4].as<IntImmNode>()->value, true);
@@ -1388,7 +1387,6 @@ void CodeGenCPU::AddStartupFunction() {
13881387
}
13891388

13901389
llvm::Value* CodeGenCPU::CreateIntrinsic(const CallNode* op) {
1391-
EmitDebugLocation(op);
13921390
if (op->op.same_as(builtin::tvm_call_packed_lowered())) {
13931391
return CreateCallPacked(op, true /* use_string_lookup */);
13941392
} else if (op->op.same_as(builtin::tvm_call_trace_packed_lowered())) {

src/target/llvm/codegen_llvm.cc

Lines changed: 2 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1189,7 +1189,6 @@ void CodeGenLLVM::EmitFloat16ConversionBuiltins(bool use_float16_abi) {
11891189
}
11901190

11911191
llvm::Value* CodeGenLLVM::CreateIntrinsic(const CallNode* op) {
1192-
EmitDebugLocation(op);
11931192
if (op->op.same_as(builtin_call_llvm_intrin_) || op->op.same_as(builtin_call_llvm_pure_intrin_)) {
11941193
ICHECK_GE(op->args.size(), 2U);
11951194
llvm::Intrinsic::ID id = static_cast<llvm::Intrinsic::ID>(Downcast<IntImm>(op->args[0])->value);
@@ -1226,7 +1225,6 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const CallNode* op) {
12261225
} else if (op->op.same_as(builtin::bitwise_not())) {
12271226
return builder_->CreateNot(MakeValue(op->args[0]));
12281227
} else if (op->op.same_as(builtin::bitwise_xor())) {
1229-
EmitDebugLocation(op);
12301228
return builder_->CreateXor(MakeValue(op->args[0]), MakeValue(op->args[1]));
12311229
} else if (op->op.same_as(builtin::shift_left())) {
12321230
return builder_->CreateShl(MakeValue(op->args[0]), MakeValue(op->args[1]));
@@ -1353,29 +1351,20 @@ void CodeGenLLVM::Scalarize(const PrimExpr& e, std::function<void(int i, llvm::V
13531351
}
13541352

13551353
// Visitors
1356-
llvm::Value* CodeGenLLVM::VisitExpr_(const VarNode* op) {
1357-
EmitDebugLocation(op);
1358-
return GetVarValue(op);
1359-
}
1354+
llvm::Value* CodeGenLLVM::VisitExpr_(const VarNode* op) { return GetVarValue(op); }
13601355

13611356
llvm::Value* CodeGenLLVM::VisitExpr_(const CastNode* op) {
1362-
EmitDebugLocation(op);
13631357
return CreateCast(op->value.dtype(), op->dtype, MakeValue(op->value));
13641358
}
13651359
llvm::Value* CodeGenLLVM::VisitExpr_(const IntImmNode* op) {
1366-
EmitDebugLocation(op);
13671360
return llvm::ConstantInt::getSigned(DTypeToLLVMType(op->dtype), op->value);
13681361
}
13691362

13701363
llvm::Value* CodeGenLLVM::VisitExpr_(const FloatImmNode* op) {
1371-
EmitDebugLocation(op);
13721364
return llvm::ConstantFP::get(DTypeToLLVMType(op->dtype), op->value);
13731365
}
13741366

1375-
llvm::Value* CodeGenLLVM::VisitExpr_(const StringImmNode* op) {
1376-
EmitDebugLocation(op);
1377-
return GetConstString(op->value);
1378-
}
1367+
llvm::Value* CodeGenLLVM::VisitExpr_(const StringImmNode* op) { return GetConstString(op->value); }
13791368

13801369
#define DEFINE_CODEGEN_BINARY_OP(Op) \
13811370
llvm::Value* CodeGenLLVM::Create##Op(DataType t, llvm::Value* a, llvm::Value* b) { \
@@ -1397,7 +1386,6 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const StringImmNode* op) {
13971386
} \
13981387
} \
13991388
llvm::Value* CodeGenLLVM::VisitExpr_(const Op##Node* op) { \
1400-
EmitDebugLocation(op); \
14011389
return Create##Op(op->dtype, MakeValue(op->a), MakeValue(op->b)); \
14021390
}
14031391

@@ -1417,7 +1405,6 @@ DEFINE_CODEGEN_BINARY_OP(Mul);
14171405
} \
14181406
} \
14191407
llvm::Value* CodeGenLLVM::VisitExpr_(const Op##Node* op) { \
1420-
EmitDebugLocation(op); \
14211408
return Create##Op(op->a.dtype(), MakeValue(op->a), MakeValue(op->b)); \
14221409
}
14231410

@@ -1427,7 +1414,6 @@ DEFINE_CODEGEN_CMP_OP(GT);
14271414
DEFINE_CODEGEN_CMP_OP(GE);
14281415

14291416
llvm::Value* CodeGenLLVM::VisitExpr_(const DivNode* op) {
1430-
EmitDebugLocation(op);
14311417
llvm::Value* a = MakeValue(op->a);
14321418
llvm::Value* b = MakeValue(op->b);
14331419
if (op->dtype.is_int()) {
@@ -1441,7 +1427,6 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const DivNode* op) {
14411427
}
14421428

14431429
llvm::Value* CodeGenLLVM::VisitExpr_(const ModNode* op) {
1444-
EmitDebugLocation(op);
14451430
llvm::Value* a = MakeValue(op->a);
14461431
llvm::Value* b = MakeValue(op->b);
14471432
if (op->dtype.is_int()) {
@@ -1455,21 +1440,18 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const ModNode* op) {
14551440
}
14561441

14571442
llvm::Value* CodeGenLLVM::VisitExpr_(const MinNode* op) {
1458-
EmitDebugLocation(op);
14591443
llvm::Value* a = MakeValue(op->a);
14601444
llvm::Value* b = MakeValue(op->b);
14611445
return builder_->CreateSelect(CreateLT(op->a.dtype(), a, b), a, b);
14621446
}
14631447

14641448
llvm::Value* CodeGenLLVM::VisitExpr_(const MaxNode* op) {
1465-
EmitDebugLocation(op);
14661449
llvm::Value* a = MakeValue(op->a);
14671450
llvm::Value* b = MakeValue(op->b);
14681451
return builder_->CreateSelect(CreateGT(op->a.dtype(), a, b), a, b);
14691452
}
14701453

14711454
llvm::Value* CodeGenLLVM::VisitExpr_(const EQNode* op) {
1472-
EmitDebugLocation(op);
14731455
llvm::Value* a = MakeValue(op->a);
14741456
llvm::Value* b = MakeValue(op->b);
14751457
if (op->a.dtype().is_int() || op->a.dtype().is_uint()) {
@@ -1480,7 +1462,6 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const EQNode* op) {
14801462
}
14811463

14821464
llvm::Value* CodeGenLLVM::VisitExpr_(const NENode* op) {
1483-
EmitDebugLocation(op);
14841465
llvm::Value* a = MakeValue(op->a);
14851466
llvm::Value* b = MakeValue(op->b);
14861467
if (op->a.dtype().is_int() || op->a.dtype().is_uint()) {
@@ -1491,28 +1472,23 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const NENode* op) {
14911472
}
14921473

14931474
llvm::Value* CodeGenLLVM::VisitExpr_(const AndNode* op) {
1494-
EmitDebugLocation(op);
14951475
return builder_->CreateAnd(MakeValue(op->a), MakeValue(op->b));
14961476
}
14971477

14981478
llvm::Value* CodeGenLLVM::VisitExpr_(const OrNode* op) {
1499-
EmitDebugLocation(op);
15001479
return builder_->CreateOr(MakeValue(op->a), MakeValue(op->b));
15011480
}
15021481

15031482
llvm::Value* CodeGenLLVM::VisitExpr_(const NotNode* op) {
1504-
EmitDebugLocation(op);
15051483
return builder_->CreateNot(MakeValue(op->a));
15061484
}
15071485

15081486
llvm::Value* CodeGenLLVM::VisitExpr_(const SelectNode* op) {
1509-
EmitDebugLocation(op);
15101487
return builder_->CreateSelect(MakeValue(op->condition), MakeValue(op->true_value),
15111488
MakeValue(op->false_value));
15121489
}
15131490

15141491
llvm::Value* CodeGenLLVM::VisitExpr_(const LetNode* op) {
1515-
EmitDebugLocation(op);
15161492
auto it = let_binding_.find(op->var);
15171493
if (it != let_binding_.end()) {
15181494
ICHECK(deep_equal_(it->second->value, op->value))
@@ -1630,7 +1606,6 @@ void CodeGenLLVM::BufferAccessHelper(
16301606
}
16311607

16321608
llvm::Value* CodeGenLLVM::VisitExpr_(const BufferLoadNode* op) {
1633-
EmitDebugLocation(op);
16341609
DataType value_dtype = op->dtype;
16351610

16361611
std::vector<llvm::Value*> loads;
@@ -1668,7 +1643,6 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const BufferLoadNode* op) {
16681643
}
16691644

16701645
llvm::Value* CodeGenLLVM::VisitExpr_(const CallNode* op) {
1671-
EmitDebugLocation(op);
16721646
if (auto* ptr_op = op->op.as<OpNode>()) {
16731647
auto call_op = GetRef<Op>(ptr_op);
16741648
if (op->op.same_as(builtin_call_extern_) || op->op.same_as(builtin_call_pure_extern_)) {
@@ -1695,7 +1669,6 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const CallNode* op) {
16951669
}
16961670

16971671
llvm::Value* CodeGenLLVM::VisitExpr_(const RampNode* op) {
1698-
EmitDebugLocation(op);
16991672
llvm::Value* vec = llvm::UndefValue::get(DTypeToLLVMType(op->dtype));
17001673
for (int i = 0; i < op->lanes; ++i) {
17011674
vec = builder_->CreateInsertElement(
@@ -1705,7 +1678,6 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const RampNode* op) {
17051678
}
17061679

17071680
llvm::Value* CodeGenLLVM::VisitExpr_(const ShuffleNode* op) {
1708-
EmitDebugLocation(op);
17091681
std::vector<llvm::Value*> vecs(op->vectors.size());
17101682
int total_lanes = 0;
17111683
for (int i = 0, e = op->vectors.size(); i < e; ++i) {
@@ -1730,7 +1702,6 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const ShuffleNode* op) {
17301702
}
17311703

17321704
llvm::Value* CodeGenLLVM::VisitExpr_(const BroadcastNode* op) {
1733-
EmitDebugLocation(op);
17341705
return CreateBroadcast(MakeValue(op->value), op->lanes);
17351706
}
17361707

0 commit comments

Comments
 (0)