From a2325ec7e991ea4cb86c1cde9a74d74109ff8983 Mon Sep 17 00:00:00 2001 From: Joey Tsai Date: Tue, 11 Apr 2023 12:14:27 +0800 Subject: [PATCH 1/2] [IR][SIBuilder] - Add SIBuilder to handle the span propagation between passes - Add SequentialSpan for multiple source expressions conversion between passes - Add test cases for SIBuilder and SequentialSpan --- include/tvm/ir/si_builder.h | 103 +++++++++ include/tvm/ir/source_map.h | 46 +++- python/tvm/ir/__init__.py | 1 + python/tvm/ir/base.py | 14 ++ python/tvm/relay/__init__.py | 1 + python/tvm/relay/base.py | 2 +- src/ir/si_builder.cc | 341 ++++++++++++++++++++++++++++++ src/ir/source_map.cc | 60 ++++++ tests/cpp/si_builder_test.cc | 399 +++++++++++++++++++++++++++++++++++ 9 files changed, 965 insertions(+), 2 deletions(-) create mode 100644 include/tvm/ir/si_builder.h create mode 100644 src/ir/si_builder.cc create mode 100644 tests/cpp/si_builder_test.cc diff --git a/include/tvm/ir/si_builder.h b/include/tvm/ir/si_builder.h new file mode 100644 index 000000000000..57ce4563d719 --- /dev/null +++ b/include/tvm/ir/si_builder.h @@ -0,0 +1,103 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/ir/si_builder.h + * \brief build a source info during rewriting expressions. + */ +#ifndef TVM_IR_SI_BUILDER_H_ +#define TVM_IR_SI_BUILDER_H_ + +#include +#include +#include + +#include +#include + +namespace tvm { + +/*! + * \brief SIBuilder provides helper APIs for filling spans, + * particularly useful for one-to-many, many-to-one and many-to-many pass transformations. + */ +class SIBuilder { + public: + /*! + * \brief Create SIBuilder from a given span + */ + explicit SIBuilder(const Span& span = Span()); + + /*! + * \brief Create SIBuilder from a given span sequence + */ + explicit SIBuilder(const Array& spans = Array()); + explicit SIBuilder(const std::initializer_list& init); + + /*! + * \brief Create SIBuilder via a subgraph, + * Will construct span based on the exprs in the subgraph. Including the inputs exprs. + * + * \param entry Entry expr for subgraph + * \param inputs End exprs for subgraph + */ + template ::value>> + explicit SIBuilder(const T& entry, const tvm::Array& inputs = {}); + explicit SIBuilder(const tir::Stmt& entry, const tvm::Array& inputs = {}); + explicit SIBuilder(const tir::Stmt& entry, const tvm::Array& inputs = {}); + + ~SIBuilder(); + + SIBuilder(const SIBuilder&) = delete; + SIBuilder& operator=(const SIBuilder&) = delete; + + /*! + * \brief create new source info based on the given span or subgraph. + * + * \return The given span, or reconstructed span from subgraph. + */ + Span CreateSpan() const; + + /*! + * \brief Recursively fill all span of exprs in subgraph from entry until inputs. + * + * \param entry Entry expr for subgraph. + * \param inputs End exprs for subgraph, will not be filled with new span. + */ + template ::value>> + void RecursivelyFillSpan( + const T& entry, const std::unordered_set& inputs) const; + + void RecursivelyFillSpan( + const tir::Stmt& entry, + const std::unordered_set& inputs) const; + void RecursivelyFillSpan( + const tir::Stmt& entry, + const std::unordered_set& inputs) const; + + private: + struct Impl; + std::unique_ptr impl_; + + std::unique_ptr CreateImpl(const Span& span); +}; + +} // namespace tvm + +#endif // TVM_IR_SI_BUILDER_H_ diff --git a/include/tvm/ir/source_map.h b/include/tvm/ir/source_map.h index 536099f3114b..9b3041f3c000 100644 --- a/include/tvm/ir/source_map.h +++ b/include/tvm/ir/source_map.h @@ -114,7 +114,7 @@ class SpanNode : public Object { } static constexpr const char* _type_key = "Span"; - TVM_DECLARE_FINAL_OBJECT_INFO(SpanNode, Object); + TVM_DECLARE_BASE_OBJECT_INFO(SpanNode, Object); }; class Span : public ObjectRef { @@ -127,6 +127,50 @@ class Span : public ObjectRef { TVM_DEFINE_OBJECT_REF_METHODS(Span, ObjectRef, SpanNode); }; +/*! + * \brief Store a list of spans for an expr generated from mulitple source exprs + */ +class SequentialSpanNode : public SpanNode { + public: + /*! \brief The original source list of spans to construct a sequential span. */ + Array spans; + + // override attr visitor + void VisitAttrs(AttrVisitor* v) { + SpanNode::VisitAttrs(v); + v->Visit("spans", &spans); + } + + static constexpr const char* _type_key = "SequentialSpan"; + TVM_DECLARE_FINAL_OBJECT_INFO(SequentialSpanNode, SpanNode); + + bool SEqualReduce(const SequentialSpanNode* other, SEqualReducer equal) const { + if (spans.size() != other->spans.size()) { + return false; + } + + for (size_t i = 0, e = spans.size(); i != e; ++i) { + if (!StructuralEqual()(spans[i], other->spans[i])) { + return false; + } + } + return true; + } +}; + +/*! + * \brief Reference class of SequentialSpanNode. + * \sa SequentialSpanNode + */ +class SequentialSpan : public Span { + public: + TVM_DLL SequentialSpan(Array spans); + + TVM_DLL SequentialSpan(std::initializer_list init); + + TVM_DEFINE_OBJECT_REF_METHODS(SequentialSpan, Span, SequentialSpanNode); +}; + /*! \brief A program source in any language. * * Could represent the source from an ML framework or a source diff --git a/python/tvm/ir/__init__.py b/python/tvm/ir/__init__.py index 4f63cbecd9d1..5875f4bad831 100644 --- a/python/tvm/ir/__init__.py +++ b/python/tvm/ir/__init__.py @@ -25,6 +25,7 @@ Node, SourceName, Span, + SequentialSpan, assert_structural_equal, load_json, save_json, diff --git a/python/tvm/ir/base.py b/python/tvm/ir/base.py index 5f3a679591d1..f52eb97704a3 100644 --- a/python/tvm/ir/base.py +++ b/python/tvm/ir/base.py @@ -69,6 +69,20 @@ def __init__(self, source_name, line, end_line, column, end_column): ) +@register_object("SequentialSpan") +class SequentialSpan(Object): + """Specifies a location in a source program. + + Parameters + ---------- + spans : Array + The array of spans. + """ + + def __init__(self, spans): + self.__init_handle_by_constructor__(_ffi_api.SequentialSpan, spans) + + @register_object class EnvFunc(Object): """Environment function. diff --git a/python/tvm/relay/__init__.py b/python/tvm/relay/__init__.py index 02eec18d3013..ef2b515c3be2 100644 --- a/python/tvm/relay/__init__.py +++ b/python/tvm/relay/__init__.py @@ -73,6 +73,7 @@ # Span Span = base.Span +SequentialSpan = base.SequentialSpan SourceName = base.SourceName # Type diff --git a/python/tvm/relay/base.py b/python/tvm/relay/base.py index 8667bfb1dfdc..460746f94f1f 100644 --- a/python/tvm/relay/base.py +++ b/python/tvm/relay/base.py @@ -20,7 +20,7 @@ import tvm._ffi from tvm.ir import Node as RelayNode -from tvm.ir import SourceName, Span +from tvm.ir import SourceName, Span, SequentialSpan from tvm.runtime import Object from . import _ffi_api diff --git a/src/ir/si_builder.cc b/src/ir/si_builder.cc new file mode 100644 index 000000000000..e149c2900128 --- /dev/null +++ b/src/ir/si_builder.cc @@ -0,0 +1,341 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/ir/si_builder.cc + * \brief Implementation for building a source info during rewriting expressions. + */ +#include +#include +#include + +#include + +namespace tvm { + +using RelayExprSet = std::unordered_set; +using PrimExprSet = std::unordered_set; +using StmtSet = std::unordered_set; + +class RelayCollapse : public relay::ExprVisitor { + public: + explicit RelayCollapse(const RelayExprSet& inputs = {}) : inputs_(inputs) {} + + Span Collapse(const relay::Expr& entry); + + void VisitExpr(const relay::Expr& expr) final; + + private: + Array spans_; + const RelayExprSet& inputs_; +}; + +void RelayCollapse::VisitExpr(const relay::Expr& expr) { + if (visit_counter_.count(expr.get())) { + return; + } + if (expr->span.defined()) { + spans_.push_back(expr->span); + } + if (inputs_.find(expr) != inputs_.end()) { + // becuase it returns directly, it should be recorded as visted manually. + visit_counter_.insert({expr.get(), 1}); + return; + } + relay::ExprVisitor::VisitExpr(expr); +} + +Span RelayCollapse::Collapse(const relay::Expr& entry) { + VisitExpr(entry); + return SequentialSpan(spans_); +} + +class RelayRecursivelyFill : public relay::ExprMutator { + public: + explicit RelayRecursivelyFill(const Span& span, const RelayExprSet& inputs = {}) + : span_(span), inputs_(inputs) {} + + void Fill(const relay::Expr& entry); + + relay::Expr VisitExpr(const relay::Expr& expr) final; + + private: + const Span& span_; + const RelayExprSet& inputs_; +}; + +relay::Expr RelayRecursivelyFill::VisitExpr(const relay::Expr& expr) { + if (inputs_.find(expr) != inputs_.end()) { + return expr; + } + // Skip op node. Align with python frontend + if (!expr.as()) { + expr->span = span_; + } + + return relay::ExprMutator::VisitExpr(expr); +} + +void RelayRecursivelyFill::Fill(const relay::Expr& entry) { Mutate(entry); } + +class TirCollapse : public tir::StmtExprVisitor { + public: + explicit TirCollapse(const PrimExprSet& expr_inputs = {}, const StmtSet& stmt_inputs = {}) + : expr_inputs_(expr_inputs), stmt_inputs_(stmt_inputs) {} + + void VisitExpr(const PrimExpr& expr) final; + void VisitStmt(const tir::Stmt& stmt) final; + + bool IsInput(const PrimExpr& expr); + bool IsInput(const tir::Stmt& stmt); + + Span Collapse(const PrimExpr& expr); + Span Collapse(const tir::Stmt& stmt); + + private: + Array spans_; + std::unordered_map visit_counter_; + const PrimExprSet& expr_inputs_; + const StmtSet& stmt_inputs_; +}; + +Span TirCollapse::Collapse(const PrimExpr& expr) { + operator()(expr); + return SequentialSpan(spans_); +} + +Span TirCollapse::Collapse(const tir::Stmt& stmt) { + operator()(stmt); + return SequentialSpan(spans_); +} + +bool TirCollapse::IsInput(const PrimExpr& expr) { + return expr_inputs_.find(expr) != expr_inputs_.end(); +} + +bool TirCollapse::IsInput(const tir::Stmt& stmt) { + return stmt_inputs_.find(stmt) != stmt_inputs_.end(); +} + +void TirCollapse::VisitExpr(const PrimExpr& expr) { + if (visit_counter_.count(expr.get())) { + return; + } + if (expr->span.defined()) { + spans_.push_back(expr->span); + } + if (IsInput(expr)) { + // becuase it returns directly, it should be recorded as visted manually. + visit_counter_.insert({expr.get(), 1}); + return; + } + StmtExprVisitor::VisitExpr(expr); +} + +void TirCollapse::VisitStmt(const tir::Stmt& stmt) { + if (visit_counter_.count(stmt.get())) { + return; + } + if (stmt->span.defined()) { + spans_.push_back(stmt->span); + } + if (IsInput(stmt)) { + // becuase it returns directly, it should be recorded as visted manually. + visit_counter_.insert({stmt.get(), 1}); + return; + } + StmtExprVisitor::VisitStmt(stmt); +} + +class TirRecursivelyFill : public tir::StmtExprMutator { + public: + TirRecursivelyFill(const Span& span, const PrimExprSet& expr_inputs = {}, + const StmtSet& stmt_inputs = {}) + : span_(span), expr_inputs_(expr_inputs), stmt_inputs_(stmt_inputs) {} + + tir::Stmt Fill(const tir::Stmt& s) { return operator()(s); } + PrimExpr Fill(const PrimExpr& e) { return operator()(e); } + + bool IsInput(const PrimExpr& expr); + bool IsInput(const tir::Stmt& stmt); + + PrimExpr VisitExpr(const PrimExpr& expr) final; + tir::Stmt VisitStmt(const tir::Stmt& stmt) final; + + private: + const Span& span_; + const PrimExprSet& expr_inputs_; + const StmtSet& stmt_inputs_; +}; + +tir::Stmt TirRecursivelyFill::VisitStmt(const tir::Stmt& stmt) { + if (IsInput(stmt)) { + return stmt; + } + stmt->span = span_; + return StmtExprMutator::VisitStmt(stmt); +} + +bool TirRecursivelyFill::IsInput(const PrimExpr& expr) { + return expr_inputs_.find(expr) != expr_inputs_.end(); +} + +bool TirRecursivelyFill::IsInput(const tir::Stmt& stmt) { + return stmt_inputs_.find(stmt) != stmt_inputs_.end(); +} + +PrimExpr TirRecursivelyFill::VisitExpr(const PrimExpr& expr) { + if (IsInput(expr)) { + return expr; + } + expr->span = span_; + return StmtExprMutator::VisitExpr(expr); +} + +struct SIBuilder::Impl { + virtual Span CreateSpan() const = 0; + virtual void RecursivelyFillSpan(const relay::Expr& entry, const RelayExprSet& inputs) const = 0; + virtual void RecursivelyFillSpan(const PrimExpr& entry, const PrimExprSet& inputs) const = 0; + virtual void RecursivelyFillSpan(const tir::Stmt& entry, const PrimExprSet& inputs) const = 0; + virtual void RecursivelyFillSpan(const tir::Stmt& entry, const StmtSet& inputs) const = 0; + virtual void CollapseSpan(const relay::Expr& entry, const RelayExprSet& inputs) = 0; + virtual void CollapseSpan(const PrimExpr& entry, const PrimExprSet& inputs) = 0; + virtual void CollapseSpan(const tir::Stmt& entry, const PrimExprSet& inputs) = 0; + virtual void CollapseSpan(const tir::Stmt& entry, const StmtSet& inputs) = 0; +}; + +SIBuilder::~SIBuilder() = default; + +Span SIBuilder::CreateSpan() const { return impl_->CreateSpan(); } + +template <> +void SIBuilder::RecursivelyFillSpan(const relay::Expr& entry, const RelayExprSet& inputs) const { + impl_->RecursivelyFillSpan(entry, inputs); +} + +template <> +void SIBuilder::RecursivelyFillSpan(const PrimExpr& entry, const PrimExprSet& inputs) const { + impl_->RecursivelyFillSpan(entry, inputs); +} + +void SIBuilder::RecursivelyFillSpan(const tir::Stmt& entry, const PrimExprSet& inputs) const { + impl_->RecursivelyFillSpan(entry, inputs); +} + +void SIBuilder::RecursivelyFillSpan(const tir::Stmt& entry, const StmtSet& inputs) const { + impl_->RecursivelyFillSpan(entry, inputs); +} + +std::unique_ptr SIBuilder::CreateImpl(const Span& span) { + struct NullImpl : public SIBuilder::Impl { + Span CreateSpan() const final { return Span(); } + + void RecursivelyFillSpan(const relay::Expr& entry, const RelayExprSet& inputs) const final{}; + void RecursivelyFillSpan(const PrimExpr& entry, const PrimExprSet& inputs) const final{}; + void RecursivelyFillSpan(const tir::Stmt& entry, const PrimExprSet& inputs) const final{}; + void RecursivelyFillSpan(const tir::Stmt& entry, const StmtSet& inputs) const final{}; + void CollapseSpan(const relay::Expr& entry, const RelayExprSet& inputs) final{}; + void CollapseSpan(const PrimExpr& entry, const PrimExprSet& inputs) final{}; + void CollapseSpan(const tir::Stmt& entry, const PrimExprSet& inputs) final{}; + void CollapseSpan(const tir::Stmt& entry, const StmtSet& inputs) final{}; + }; + + struct Impl : public SIBuilder::Impl { + explicit Impl(const Span& span) : span_(span) {} + + Span CreateSpan() const final { return span_; } + + void RecursivelyFillSpan(const relay::Expr& entry, const RelayExprSet& inputs) const final { + RelayRecursivelyFill(CreateSpan(), inputs).Fill(entry); + } + + void RecursivelyFillSpan(const PrimExpr& entry, const PrimExprSet& inputs) const final { + TirRecursivelyFill(CreateSpan(), inputs).Fill(entry); + } + + void RecursivelyFillSpan(const tir::Stmt& entry, const PrimExprSet& inputs) const final { + TirRecursivelyFill(CreateSpan(), inputs).Fill(entry); + } + + void RecursivelyFillSpan(const tir::Stmt& entry, const StmtSet& inputs) const final { + TirRecursivelyFill(CreateSpan(), {}, inputs).Fill(entry); + } + + void CollapseSpan(const relay::Expr& entry, const RelayExprSet& inputs) final { + span_ = RelayCollapse(inputs).Collapse(entry); + } + + void CollapseSpan(const PrimExpr& entry, const PrimExprSet& inputs) final { + span_ = TirCollapse(inputs).Collapse(entry); + } + + void CollapseSpan(const tir::Stmt& entry, const PrimExprSet& inputs) final { + span_ = TirCollapse(inputs).Collapse(entry); + } + + void CollapseSpan(const tir::Stmt& entry, const StmtSet& inputs) final { + span_ = TirCollapse({}, inputs).Collapse(entry); + } + + private: + Span span_; + }; + + const bool enable_si_builder = transform::PassContext::Current() + ->GetConfig("ir.enable_si_builder", Bool(false)) + .value(); + + if (enable_si_builder) { + return std::make_unique(span); + } + + return std::make_unique(); +} + +SIBuilder::SIBuilder(const Span& span) : impl_(CreateImpl(span)) {} +SIBuilder::SIBuilder(const Array& spans) : impl_(CreateImpl(SequentialSpan(spans))) {} +SIBuilder::SIBuilder(const std::initializer_list& init) + : impl_(CreateImpl(SequentialSpan(Array(init)))) {} + +template <> +SIBuilder::SIBuilder(const relay::Expr& expr, const Array& inputs) + : impl_(CreateImpl(Span())) { + impl_->CollapseSpan(expr, RelayExprSet(inputs.begin(), inputs.end())); +} + +template <> +SIBuilder::SIBuilder(const PrimExpr& expr, const Array& inputs) + : impl_(CreateImpl(Span())) { + impl_->CollapseSpan(expr, PrimExprSet(inputs.begin(), inputs.end())); +} + +SIBuilder::SIBuilder(const tir::Stmt& s, const Array& inputs) + : impl_(CreateImpl(Span())) { + impl_->CollapseSpan(s, PrimExprSet(inputs.begin(), inputs.end())); +} + +SIBuilder::SIBuilder(const tir::Stmt& s, const Array& inputs) + : impl_(CreateImpl(Span())) { + impl_->CollapseSpan(s, StmtSet(inputs.begin(), inputs.end())); +} + +// Register build pipeline related options +TVM_REGISTER_PASS_CONFIG_OPTION("ir.enable_si_builder", Bool); + +} // namespace tvm diff --git a/src/ir/source_map.cc b/src/ir/source_map.cc index 8b913906ea42..721a30affa3f 100644 --- a/src/ir/source_map.cc +++ b/src/ir/source_map.cc @@ -88,11 +88,58 @@ Span Span::Merge(const Span& other) const { TVM_REGISTER_NODE_TYPE(SpanNode); +SequentialSpan::SequentialSpan(tvm::Array spans) { + auto n = make_object(); + tvm::Array tmp_spans; + for (const Span& s : spans) { + if (const SequentialSpanNode* seq_s = s.as()) { + tmp_spans.insert(tmp_spans.end(), seq_s->spans.begin(), seq_s->spans.end()); + } else { + tmp_spans.push_back(s); + } + } + n->spans = std::move(tmp_spans); + + n->line = 0; + n->end_line = 0; + n->column = 0; + n->end_column = 0; + + data_ = std::move(n); +} + +SequentialSpan::SequentialSpan(std::initializer_list init) { + auto n = make_object(); + tvm::Array spans = tvm::Array(init); + tvm::Array tmp_spans; + for (const Span& s : spans) { + if (const SequentialSpanNode* seq_s = s.as()) { + tmp_spans.insert(tmp_spans.end(), seq_s->spans.begin(), seq_s->spans.end()); + } else { + tmp_spans.push_back(s); + } + } + n->spans = std::move(tmp_spans); + + n->line = 0; + n->end_line = 0; + n->column = 0; + n->end_column = 0; + + data_ = std::move(n); +} + +TVM_REGISTER_NODE_TYPE(SequentialSpanNode); + TVM_REGISTER_GLOBAL("ir.Span").set_body_typed([](SourceName source_name, int line, int end_line, int column, int end_column) { return Span(source_name, line, end_line, column, end_column); }); +TVM_REGISTER_GLOBAL("ir.SequentialSpan").set_body_typed([](tvm::Array spans) { + return SequentialSpan(spans); +}); + TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { auto* node = static_cast(ref.get()); @@ -100,6 +147,19 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) << ", " << node->column << ", " << node->end_column << ")"; }); +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + + p->stream << "SequentailSpan([ "; + int index = 0; + const int last = node->spans.size() - 1; + while (index < last) { + p->stream << node->spans[index++] << ", "; + } + p->stream << node->spans[last] << " ])"; + }); + TVM_REGISTER_NODE_TYPE(SourceNode); /*! \brief Construct a source from a string. */ diff --git a/tests/cpp/si_builder_test.cc b/tests/cpp/si_builder_test.cc new file mode 100644 index 000000000000..4bbd1acd8393 --- /dev/null +++ b/tests/cpp/si_builder_test.cc @@ -0,0 +1,399 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +tvm::Span _CreateSpan(std::string text) { + return tvm::Span(tvm::SourceName::Get(text), 0, 0, 0, 0); +} + +class RelayCheckSpan : public tvm::relay::ExprVisitor { + public: + std::vector tmp_result_; + std::vector lhs_spans_; + std::vector rhs_spans_; + + std::vector CollectSpan(tvm::relay::Expr expr) { + tmp_result_.clear(); + VisitExpr(expr); + return tmp_result_; + } + + void Check(tvm::relay::Expr lhs, tvm::relay::Expr rhs) { + tvm::relay::Function lhs_f = + tvm::relay::Function(tvm::relay::FreeVars(lhs), lhs, tvm::relay::Type(), {}); + tvm::relay::Function rhs_f = + tvm::relay::Function(tvm::relay::FreeVars(rhs), rhs, tvm::relay::Type(), {}); + EXPECT_TRUE(tvm::StructuralEqual()(lhs_f, rhs_f)); + lhs_spans_ = CollectSpan(lhs); + rhs_spans_ = CollectSpan(rhs); + + EXPECT_EQ(lhs_spans_.size(), rhs_spans_.size()); + for (std::size_t i = 0; i != lhs_spans_.size(); i++) { + EXPECT_TRUE(tvm::StructuralEqual()(lhs_spans_[i], rhs_spans_[i])); + } + } + + void VisitExpr(const tvm::relay::Expr& expr) { + if (expr->span.defined()) { + tmp_result_.push_back(expr->span); + } + using TParent = ExprFunctor; + TParent::VisitExpr(expr); + visit_counter_.emplace(expr.get(), 1); + } +}; + +TEST(SIBuilder, SequentialSpan) { + using namespace tvm; + Array ingredients = {_CreateSpan("first"), _CreateSpan("second"), _CreateSpan("third")}; + + SequentialSpan seq_span_1{ingredients[0], ingredients[1]}; + EXPECT_EQ(seq_span_1->spans.size(), 2); + for (std::size_t i = 0; i != seq_span_1->spans.size(); i++) { + EXPECT_EQ(seq_span_1->spans[i], ingredients[i]); + } + + // nested SequentialSpan test + SequentialSpan seq_span_2{seq_span_1, ingredients[2]}; + EXPECT_EQ(seq_span_2->spans.size(), 3); + for (std::size_t i = 0; i != seq_span_2->spans.size(); i++) { + EXPECT_EQ(seq_span_2->spans[i], ingredients[i]); + } + + // Array constructor test + Array tvm_array(ingredients); + SequentialSpan seq_span_3(tvm_array); + EXPECT_EQ(seq_span_3->spans.size(), 3); + for (std::size_t i = 0; i != seq_span_3->spans.size(); i++) { + EXPECT_EQ(seq_span_3->spans[i], ingredients[i]); + } +} + +TEST(SIBuilder, CreateSapn) { + using namespace tvm; + auto pass_ctx = transform::PassContext::Create(); + pass_ctx->config.Set("ir.enable_si_builder", Bool(true)); + tvm::With ctx_scope(pass_ctx); + Span span_1 = _CreateSpan("first"); + { + SIBuilder si_builder(span_1); + EXPECT_EQ(span_1, si_builder.CreateSpan()); + } + + Span span_2 = _CreateSpan("second"); + Array ingredients = {span_1, span_2}; + SequentialSpan seq_span_1{ingredients[0], ingredients[1]}; + { + SIBuilder si_builder_1(seq_span_1); + SIBuilder si_builder_2({span_1, span_2}); + SIBuilder si_builder_3{span_1, span_2}; + + Span created_span_1 = si_builder_1.CreateSpan(); + Span created_span_2 = si_builder_2.CreateSpan(); + Span created_span_3 = si_builder_3.CreateSpan(); + + auto created_seq_span_1 = created_span_1.as(); + auto created_seq_span_2 = created_span_2.as(); + auto created_seq_span_3 = created_span_3.as(); + EXPECT_EQ(created_seq_span_1->spans.size(), 2); + EXPECT_EQ(created_seq_span_2->spans.size(), 2); + EXPECT_EQ(created_seq_span_3->spans.size(), 2); + for (std::size_t i = 0; i != 2; i++) { + EXPECT_EQ(created_seq_span_1->spans[i], ingredients[i]); + EXPECT_EQ(created_seq_span_2->spans[i], ingredients[i]); + EXPECT_EQ(created_seq_span_3->spans[i], ingredients[i]); + } + } +} + +TEST(SIBuilder, DisableSIBuilder) { + using namespace tvm; + auto pass_ctx = transform::PassContext::Create(); + pass_ctx->config.Set("ir.enable_si_builder", Bool(false)); + tvm::With ctx_scope(pass_ctx); + Span span_1 = _CreateSpan("first"); + { + SIBuilder si_builder(span_1); + EXPECT_NE(span_1, si_builder.CreateSpan()); + } +} + +TEST(SIBuilder, RelayRecursivelyFill) { + using namespace tvm; + auto pass_ctx = transform::PassContext::Create(); + pass_ctx->config.Set("ir.enable_si_builder", Bool(true)); + tvm::With ctx_scope(pass_ctx); + Span test_span = _CreateSpan("test_span"); + Span a_node_span = _CreateSpan("a_node"); + + auto tensor_type = relay::TensorType({2, 3}, tvm::DataType::Float(32)); + relay::Expr add_op = relay::Op::Get("add"); + relay::Expr relu_op = relay::Op::Get("nn.relu"); + relay::Expr leaky_relu_op = relay::Op::Get("nn.leaky_relu"); + // Reset span of OpNode. Because a relay Op Node is a static reference, any change on it will + // be assigned the original object. + add_op->span = Span(); + relu_op->span = Span(); + leaky_relu_op->span = Span(); + + relay::Expr a = relay::Var("a", tensor_type, a_node_span); + relay::Expr x = relay::Call(relu_op, {a}, tvm::Attrs(), {}); + relay::Expr y = relay::Call(leaky_relu_op, {x}, tvm::Attrs(), {}); + relay::Expr z = relay::Call(add_op, {y, x}, tvm::Attrs(), {}); + + relay::Expr expected_a = relay::Var("a", tensor_type, a_node_span); + relay::Expr expected_x = relay::Call(relu_op, {expected_a}, tvm::Attrs(), {}, test_span); + relay::Expr expected_y = relay::Call(leaky_relu_op, {expected_x}, tvm::Attrs(), {}, test_span); + relay::Expr expected_z = + relay::Call(add_op, {expected_y, expected_x}, tvm::Attrs(), {}, test_span); + + SIBuilder si_builder(test_span); + si_builder.RecursivelyFillSpan(z, {a}); + RelayCheckSpan checker; + checker.Check(z, expected_z); +} + +TEST(SIBuilder, RelayCollapse) { + using namespace tvm; + auto pass_ctx = transform::PassContext::Create(); + pass_ctx->config.Set("ir.enable_si_builder", Bool(true)); + tvm::With ctx_scope(pass_ctx); + Span a_node_span = _CreateSpan("a_node"); + Span x_node_span = _CreateSpan("x_node"); + Span y_node_span = _CreateSpan("y_node"); + Span z_node_span = _CreateSpan("z_node"); + std::vector target = {z_node_span, y_node_span, x_node_span, a_node_span}; + + auto tensor_type = relay::TensorType({2, 3}, tvm::DataType::Float(32)); + relay::Expr add_op = relay::Op::Get("add"); + relay::Expr relu_op = relay::Op::Get("nn.relu"); + relay::Expr leaky_relu_op = relay::Op::Get("nn.leaky_relu"); + // Reset span of OpNode. Because a relay Op Node is a static reference, any change on it will + // be assigned the original object. + add_op->span = Span(); + relu_op->span = Span(); + leaky_relu_op->span = Span(); + + relay::Expr a = relay::Var("a", tensor_type, a_node_span); + relay::Expr x = relay::Call(relu_op, {a}, tvm::Attrs(), {}, x_node_span); + relay::Expr y = relay::Call(leaky_relu_op, {x}, tvm::Attrs(), {}, y_node_span); + relay::Expr z = relay::Call(add_op, {y, x}, tvm::Attrs(), {}, z_node_span); + + SIBuilder si_builder(z, {a}); + Span created_span = si_builder.CreateSpan(); + auto created_seq_span = created_span.as(); + EXPECT_EQ(created_seq_span->spans.size(), 4); + for (std::size_t i = 0; i != created_seq_span->spans.size(); i++) { + EXPECT_TRUE(StructuralEqual()(created_seq_span->spans[i], target[i])); + } +} + +TEST(SIBuilder, TirCollapsePrimExpr) { + using namespace tvm; + auto pass_ctx = transform::PassContext::Create(); + pass_ctx->config.Set("ir.enable_si_builder", Bool(true)); + tvm::With ctx_scope(pass_ctx); + Span a_node_span = _CreateSpan("a_node"); + Span b_node_span = _CreateSpan("b_node"); + Span x_node_span = _CreateSpan("x_node"); + Span add_1_node_span = _CreateSpan("add_1_node"); + Span add_2_node_span = _CreateSpan("add_2_node"); + Span z_node_span = _CreateSpan("z_node"); + std::vector target = {z_node_span, add_2_node_span, add_1_node_span, x_node_span, + a_node_span}; + tir::Var a("a"); + tir::Var b("b"); + auto x = a + b; + auto add_1 = x + 1; + auto add_2 = add_1 + 2; + auto z = max(add_2, 100); + x->span = x_node_span; + a->span = a_node_span; + b->span = b_node_span; + add_1->span = add_1_node_span; + add_2->span = add_2_node_span; + z->span = z_node_span; + + SIBuilder si_builder(z, {x}); + Span created_span = si_builder.CreateSpan(); + auto created_seq_span = created_span.as(); + + EXPECT_EQ(created_seq_span->spans.size(), 4); + for (std::size_t i = 0; i != created_seq_span->spans.size(); i++) { + EXPECT_TRUE(StructuralEqual()(created_seq_span->spans[i], target[i])); + } +} + +TEST(SIBuilder, TirCollapseStmtWithPrimInput) { + using namespace tvm; + auto pass_ctx = transform::PassContext::Create(); + pass_ctx->config.Set("ir.enable_si_builder", Bool(true)); + tvm::With ctx_scope(pass_ctx); + Span a_node_span = _CreateSpan("a_node"); + Span b_node_span = _CreateSpan("b_node"); + Span x_node_span = _CreateSpan("x_node"); + Span z_node_span = _CreateSpan("z_plus_1"); + Span stmt_node_span = _CreateSpan("stmt_node"); + std::vector target = {stmt_node_span, z_node_span, x_node_span}; + tir::Var a("a"); + tir::Var b("b"); + auto x = a + b; + x->span = x_node_span; + auto fmaketest = [&]() { + auto z = x + 1; + z->span = z_node_span; + tir::Stmt ret = te::Evaluate(z); + return ret; + }; + auto stmt = fmaketest(); + stmt->span = stmt_node_span; + SIBuilder si_builder(stmt, {x}); + Span created_span = si_builder.CreateSpan(); + auto created_seq_span = created_span.as(); + + EXPECT_EQ(created_seq_span->spans.size(), 3); + for (std::size_t i = 0; i != created_seq_span->spans.size(); i++) { + EXPECT_TRUE(StructuralEqual()(created_seq_span->spans[i], target[i])); + } +} + +TEST(SIBuilder, TirCollapseStmtWithStmtInput) { + using namespace tvm; + auto pass_ctx = transform::PassContext::Create(); + pass_ctx->config.Set("ir.enable_si_builder", Bool(true)); + tvm::With ctx_scope(pass_ctx); + Span zero_node_span = _CreateSpan("zero_node"); + Span body_node_span = _CreateSpan("body_node"); + Span init_node_span = _CreateSpan("init_node"); + Span block_node_span = _CreateSpan("block_node"); + std::vector target = {block_node_span, init_node_span, body_node_span}; + + tir::Stmt zero = tir::Evaluate(Integer(0), zero_node_span); + tir::Stmt body = tir::Evaluate(Integer(1), body_node_span); + tir::Stmt init = tir::IfThenElse(tir::const_true(), zero, zero, init_node_span); + tir::Block block({}, {}, {}, "block", body, init, Array(), + Array(), Map(), block_node_span); + SIBuilder si_builder(block, {init}); + Span created_span = si_builder.CreateSpan(); + auto created_seq_span = created_span.as(); + + EXPECT_EQ(created_seq_span->spans.size(), 3); + for (std::size_t i = 0; i != created_seq_span->spans.size(); i++) { + EXPECT_TRUE(StructuralEqual()(created_seq_span->spans[i], target[i])); + } +} + +TEST(SIBuilder, TirRecursivelyFillPrimExpr) { + using namespace tvm; + auto pass_ctx = transform::PassContext::Create(); + pass_ctx->config.Set("ir.enable_si_builder", Bool(true)); + tvm::With ctx_scope(pass_ctx); + Span test_span = _CreateSpan("test_span"); + tir::Var a("a"); + tir::Var b("b"); + auto x = a + b; + auto add_1 = x + 1; + auto add_2 = add_1 + 2; + auto z = max(add_2, 100); + + SIBuilder si_builder(test_span); + si_builder.RecursivelyFillSpan(z, {a, b}); + EXPECT_TRUE(!a->span.defined()); + EXPECT_TRUE(!b->span.defined()); + EXPECT_TRUE(StructuralEqual()(x->span, test_span)); + EXPECT_TRUE(StructuralEqual()(add_1->span, test_span)); + EXPECT_TRUE(StructuralEqual()(add_2->span, test_span)); + EXPECT_TRUE(StructuralEqual()(z->span, test_span)); + + ObjectRef tmp = z; + PrimExpr zz = Downcast(tmp); + std::ostringstream os; + os << z; + EXPECT_TRUE(zz.same_as(z)); + EXPECT_EQ(os.str(), "T.max(a + b + 1 + 2, 100)"); +} + +TEST(SIBuilder, TirRecursivelyFillStmtWithPrimInput) { + using namespace tvm; + auto pass_ctx = transform::PassContext::Create(); + pass_ctx->config.Set("ir.enable_si_builder", Bool(true)); + tvm::With ctx_scope(pass_ctx); + Span test_span = _CreateSpan("test_span"); + tir::Var a("a"); + tir::Var b("b"); + auto x = a + b; + auto z = x + 1; + tir::Stmt stmt = te::Evaluate(z); + SIBuilder si_builder(test_span); + const std::unordered_set inputs = {a, b}; + si_builder.RecursivelyFillSpan(stmt, inputs); + + EXPECT_TRUE(!a->span.defined()); + EXPECT_TRUE(!b->span.defined()); + EXPECT_TRUE(StructuralEqual()(x->span, test_span)); + EXPECT_TRUE(StructuralEqual()(z->span, test_span)); + EXPECT_TRUE(StructuralEqual()(stmt->span, test_span)); + + ObjectRef tmp = z; + PrimExpr zz = Downcast(tmp); + std::ostringstream os; + os << z; + EXPECT_TRUE(zz.same_as(z)); + EXPECT_EQ(os.str(), "a + b + 1"); +} + +TEST(SIBuilder, TirRecursivelyFillStmtWithStmtInput) { + using namespace tvm; + auto pass_ctx = transform::PassContext::Create(); + pass_ctx->config.Set("ir.enable_si_builder", Bool(true)); + tvm::With ctx_scope(pass_ctx); + tir::Stmt zero = tir::Evaluate(Integer(0)); + tir::Stmt init = tir::IfThenElse(tir::const_true(), zero, zero); + tir::Stmt body = tir::Evaluate(Integer(1)); + tir::Block block(/*iter_vars=*/{}, /*reads=*/{}, + /*writes=*/{}, /*name_hint=*/"block", /*body=*/body, + /*init=*/init); + + Span test_span = _CreateSpan("test_span"); + const std::unordered_set inputs = {init}; + SIBuilder si_builder(test_span); + si_builder.RecursivelyFillSpan(block, {init}); + EXPECT_TRUE(!zero->span.defined()); + EXPECT_TRUE(!init->span.defined()); + EXPECT_TRUE(StructuralEqual()(body->span, test_span)); + EXPECT_TRUE(StructuralEqual()(block->span, test_span)); + + tir::Stmt expected_zero = tir::Evaluate(Integer(0)); + tir::Stmt expected_init = tir::IfThenElse(tir::const_true(), zero, zero); + tir::Stmt expected_body = tir::Evaluate(Integer(1)); + tir::Block expected_block(/*iter_vars=*/{}, /*reads=*/{}, + /*writes=*/{}, /*name_hint=*/"block", /*body=*/expected_body, + /*init=*/expected_init); + EXPECT_TRUE(tvm::StructuralEqual()(block, expected_block)); +} From 13a3bc70ffb2ca865725c2141616e8de0ba2a73e Mon Sep 17 00:00:00 2001 From: Joey Tsai Date: Fri, 19 May 2023 09:39:55 +0800 Subject: [PATCH 2/2] [IR][SIBuilder] - Make null implementation as base class - Add comments and change naming based on reviewing --- include/tvm/ir/si_builder.h | 10 +-- python/tvm/ir/base.py | 5 +- src/ir/si_builder.cc | 114 +++++++++++++++-------------------- tests/cpp/si_builder_test.cc | 26 ++++---- 4 files changed, 71 insertions(+), 84 deletions(-) diff --git a/include/tvm/ir/si_builder.h b/include/tvm/ir/si_builder.h index 57ce4563d719..ab5f2d450fe4 100644 --- a/include/tvm/ir/si_builder.h +++ b/include/tvm/ir/si_builder.h @@ -34,8 +34,8 @@ namespace tvm { /*! - * \brief SIBuilder provides helper APIs for filling spans, - * particularly useful for one-to-many, many-to-one and many-to-many pass transformations. + * \brief Source Information Builder, SIBuilder provides helper APIs for filling spans, + * particularly useful for one-to-many, many-to-one and many-to-many IR transformations. */ class SIBuilder { public: @@ -68,11 +68,11 @@ class SIBuilder { SIBuilder& operator=(const SIBuilder&) = delete; /*! - * \brief create new source info based on the given span or subgraph. + * \brief build a span of source information, which is based on the given span or subgraph. * - * \return The given span, or reconstructed span from subgraph. + * \return the built span */ - Span CreateSpan() const; + Span Build() const; /*! * \brief Recursively fill all span of exprs in subgraph from entry until inputs. diff --git a/python/tvm/ir/base.py b/python/tvm/ir/base.py index f52eb97704a3..21a5ed657675 100644 --- a/python/tvm/ir/base.py +++ b/python/tvm/ir/base.py @@ -71,7 +71,10 @@ def __init__(self, source_name, line, end_line, column, end_column): @register_object("SequentialSpan") class SequentialSpan(Object): - """Specifies a location in a source program. + """A sequence of source spans + + This span is specific for an expression, which is from multiple expressions + after an IR transform. Parameters ---------- diff --git a/src/ir/si_builder.cc b/src/ir/si_builder.cc index e149c2900128..c3375436f137 100644 --- a/src/ir/si_builder.cc +++ b/src/ir/si_builder.cc @@ -33,11 +33,12 @@ using RelayExprSet = std::unordered_set; using StmtSet = std::unordered_set; -class RelayCollapse : public relay::ExprVisitor { +class RelayCollectSpans : public relay::ExprVisitor { public: - explicit RelayCollapse(const RelayExprSet& inputs = {}) : inputs_(inputs) {} + explicit RelayCollectSpans(const RelayExprSet& inputs = {}) : inputs_(inputs) {} - Span Collapse(const relay::Expr& entry); + // From entry to inputs, recursively collect spans. The spans of inputs are included. + Span CollectSpans(const relay::Expr& entry); void VisitExpr(const relay::Expr& expr) final; @@ -46,7 +47,7 @@ class RelayCollapse : public relay::ExprVisitor { const RelayExprSet& inputs_; }; -void RelayCollapse::VisitExpr(const relay::Expr& expr) { +void RelayCollectSpans::VisitExpr(const relay::Expr& expr) { if (visit_counter_.count(expr.get())) { return; } @@ -61,7 +62,7 @@ void RelayCollapse::VisitExpr(const relay::Expr& expr) { relay::ExprVisitor::VisitExpr(expr); } -Span RelayCollapse::Collapse(const relay::Expr& entry) { +Span RelayCollectSpans::CollectSpans(const relay::Expr& entry) { VisitExpr(entry); return SequentialSpan(spans_); } @@ -71,6 +72,7 @@ class RelayRecursivelyFill : public relay::ExprMutator { explicit RelayRecursivelyFill(const Span& span, const RelayExprSet& inputs = {}) : span_(span), inputs_(inputs) {} + // From entry until inputs, recursively fill spans into expressions. Inputs are not filled. void Fill(const relay::Expr& entry); relay::Expr VisitExpr(const relay::Expr& expr) final; @@ -94,9 +96,9 @@ relay::Expr RelayRecursivelyFill::VisitExpr(const relay::Expr& expr) { void RelayRecursivelyFill::Fill(const relay::Expr& entry) { Mutate(entry); } -class TirCollapse : public tir::StmtExprVisitor { +class TirCollectSpans : public tir::StmtExprVisitor { public: - explicit TirCollapse(const PrimExprSet& expr_inputs = {}, const StmtSet& stmt_inputs = {}) + explicit TirCollectSpans(const PrimExprSet& expr_inputs = {}, const StmtSet& stmt_inputs = {}) : expr_inputs_(expr_inputs), stmt_inputs_(stmt_inputs) {} void VisitExpr(const PrimExpr& expr) final; @@ -105,8 +107,10 @@ class TirCollapse : public tir::StmtExprVisitor { bool IsInput(const PrimExpr& expr); bool IsInput(const tir::Stmt& stmt); - Span Collapse(const PrimExpr& expr); - Span Collapse(const tir::Stmt& stmt); + // From entry to inputs, recursively collect spans. The spans of inputs are included. + Span CollectSpans(const PrimExpr& expr); + // From entry to inputs, recursively collect spans. The spans of inputs are included. + Span CollectSpans(const tir::Stmt& stmt); private: Array spans_; @@ -115,25 +119,25 @@ class TirCollapse : public tir::StmtExprVisitor { const StmtSet& stmt_inputs_; }; -Span TirCollapse::Collapse(const PrimExpr& expr) { +Span TirCollectSpans::CollectSpans(const PrimExpr& expr) { operator()(expr); return SequentialSpan(spans_); } -Span TirCollapse::Collapse(const tir::Stmt& stmt) { +Span TirCollectSpans::CollectSpans(const tir::Stmt& stmt) { operator()(stmt); return SequentialSpan(spans_); } -bool TirCollapse::IsInput(const PrimExpr& expr) { +bool TirCollectSpans::IsInput(const PrimExpr& expr) { return expr_inputs_.find(expr) != expr_inputs_.end(); } -bool TirCollapse::IsInput(const tir::Stmt& stmt) { +bool TirCollectSpans::IsInput(const tir::Stmt& stmt) { return stmt_inputs_.find(stmt) != stmt_inputs_.end(); } -void TirCollapse::VisitExpr(const PrimExpr& expr) { +void TirCollectSpans::VisitExpr(const PrimExpr& expr) { if (visit_counter_.count(expr.get())) { return; } @@ -148,7 +152,7 @@ void TirCollapse::VisitExpr(const PrimExpr& expr) { StmtExprVisitor::VisitExpr(expr); } -void TirCollapse::VisitStmt(const tir::Stmt& stmt) { +void TirCollectSpans::VisitStmt(const tir::Stmt& stmt) { if (visit_counter_.count(stmt.get())) { return; } @@ -169,7 +173,9 @@ class TirRecursivelyFill : public tir::StmtExprMutator { const StmtSet& stmt_inputs = {}) : span_(span), expr_inputs_(expr_inputs), stmt_inputs_(stmt_inputs) {} + // From entry until inputs, recursively fill spans into expressions. Inputs are not filled. tir::Stmt Fill(const tir::Stmt& s) { return operator()(s); } + // From entry until inputs, recursively fill spans into expressions. Inputs are not filled. PrimExpr Fill(const PrimExpr& e) { return operator()(e); } bool IsInput(const PrimExpr& expr); @@ -209,20 +215,20 @@ PrimExpr TirRecursivelyFill::VisitExpr(const PrimExpr& expr) { } struct SIBuilder::Impl { - virtual Span CreateSpan() const = 0; - virtual void RecursivelyFillSpan(const relay::Expr& entry, const RelayExprSet& inputs) const = 0; - virtual void RecursivelyFillSpan(const PrimExpr& entry, const PrimExprSet& inputs) const = 0; - virtual void RecursivelyFillSpan(const tir::Stmt& entry, const PrimExprSet& inputs) const = 0; - virtual void RecursivelyFillSpan(const tir::Stmt& entry, const StmtSet& inputs) const = 0; - virtual void CollapseSpan(const relay::Expr& entry, const RelayExprSet& inputs) = 0; - virtual void CollapseSpan(const PrimExpr& entry, const PrimExprSet& inputs) = 0; - virtual void CollapseSpan(const tir::Stmt& entry, const PrimExprSet& inputs) = 0; - virtual void CollapseSpan(const tir::Stmt& entry, const StmtSet& inputs) = 0; + virtual Span Build() const { return Span(); } + virtual void RecursivelyFillSpan(const relay::Expr& entry, const RelayExprSet& inputs) const {} + virtual void RecursivelyFillSpan(const PrimExpr& entry, const PrimExprSet& inputs) const {} + virtual void RecursivelyFillSpan(const tir::Stmt& entry, const PrimExprSet& inputs) const {} + virtual void RecursivelyFillSpan(const tir::Stmt& entry, const StmtSet& inputs) const {} + virtual void CollectSpansSpan(const relay::Expr& entry, const RelayExprSet& inputs) {} + virtual void CollectSpansSpan(const PrimExpr& entry, const PrimExprSet& inputs) {} + virtual void CollectSpansSpan(const tir::Stmt& entry, const PrimExprSet& inputs) {} + virtual void CollectSpansSpan(const tir::Stmt& entry, const StmtSet& inputs) {} }; SIBuilder::~SIBuilder() = default; -Span SIBuilder::CreateSpan() const { return impl_->CreateSpan(); } +Span SIBuilder::Build() const { return impl_->Build(); } template <> void SIBuilder::RecursivelyFillSpan(const relay::Expr& entry, const RelayExprSet& inputs) const { @@ -243,54 +249,32 @@ void SIBuilder::RecursivelyFillSpan(const tir::Stmt& entry, const StmtSet& input } std::unique_ptr SIBuilder::CreateImpl(const Span& span) { - struct NullImpl : public SIBuilder::Impl { - Span CreateSpan() const final { return Span(); } - - void RecursivelyFillSpan(const relay::Expr& entry, const RelayExprSet& inputs) const final{}; - void RecursivelyFillSpan(const PrimExpr& entry, const PrimExprSet& inputs) const final{}; - void RecursivelyFillSpan(const tir::Stmt& entry, const PrimExprSet& inputs) const final{}; - void RecursivelyFillSpan(const tir::Stmt& entry, const StmtSet& inputs) const final{}; - void CollapseSpan(const relay::Expr& entry, const RelayExprSet& inputs) final{}; - void CollapseSpan(const PrimExpr& entry, const PrimExprSet& inputs) final{}; - void CollapseSpan(const tir::Stmt& entry, const PrimExprSet& inputs) final{}; - void CollapseSpan(const tir::Stmt& entry, const StmtSet& inputs) final{}; - }; - struct Impl : public SIBuilder::Impl { explicit Impl(const Span& span) : span_(span) {} - - Span CreateSpan() const final { return span_; } - + Span Build() const final { return span_; } void RecursivelyFillSpan(const relay::Expr& entry, const RelayExprSet& inputs) const final { - RelayRecursivelyFill(CreateSpan(), inputs).Fill(entry); + RelayRecursivelyFill(Build(), inputs).Fill(entry); } - void RecursivelyFillSpan(const PrimExpr& entry, const PrimExprSet& inputs) const final { - TirRecursivelyFill(CreateSpan(), inputs).Fill(entry); + TirRecursivelyFill(Build(), inputs).Fill(entry); } - void RecursivelyFillSpan(const tir::Stmt& entry, const PrimExprSet& inputs) const final { - TirRecursivelyFill(CreateSpan(), inputs).Fill(entry); + TirRecursivelyFill(Build(), inputs).Fill(entry); } - void RecursivelyFillSpan(const tir::Stmt& entry, const StmtSet& inputs) const final { - TirRecursivelyFill(CreateSpan(), {}, inputs).Fill(entry); + TirRecursivelyFill(Build(), {}, inputs).Fill(entry); } - - void CollapseSpan(const relay::Expr& entry, const RelayExprSet& inputs) final { - span_ = RelayCollapse(inputs).Collapse(entry); + void CollectSpansSpan(const relay::Expr& entry, const RelayExprSet& inputs) final { + span_ = RelayCollectSpans(inputs).CollectSpans(entry); } - - void CollapseSpan(const PrimExpr& entry, const PrimExprSet& inputs) final { - span_ = TirCollapse(inputs).Collapse(entry); + void CollectSpansSpan(const PrimExpr& entry, const PrimExprSet& inputs) final { + span_ = TirCollectSpans(inputs).CollectSpans(entry); } - - void CollapseSpan(const tir::Stmt& entry, const PrimExprSet& inputs) final { - span_ = TirCollapse(inputs).Collapse(entry); + void CollectSpansSpan(const tir::Stmt& entry, const PrimExprSet& inputs) final { + span_ = TirCollectSpans(inputs).CollectSpans(entry); } - - void CollapseSpan(const tir::Stmt& entry, const StmtSet& inputs) final { - span_ = TirCollapse({}, inputs).Collapse(entry); + void CollectSpansSpan(const tir::Stmt& entry, const StmtSet& inputs) final { + span_ = TirCollectSpans({}, inputs).CollectSpans(entry); } private: @@ -305,7 +289,7 @@ std::unique_ptr SIBuilder::CreateImpl(const Span& span) { return std::make_unique(span); } - return std::make_unique(); + return std::make_unique(); } SIBuilder::SIBuilder(const Span& span) : impl_(CreateImpl(span)) {} @@ -316,23 +300,23 @@ SIBuilder::SIBuilder(const std::initializer_list& init) template <> SIBuilder::SIBuilder(const relay::Expr& expr, const Array& inputs) : impl_(CreateImpl(Span())) { - impl_->CollapseSpan(expr, RelayExprSet(inputs.begin(), inputs.end())); + impl_->CollectSpansSpan(expr, RelayExprSet(inputs.begin(), inputs.end())); } template <> SIBuilder::SIBuilder(const PrimExpr& expr, const Array& inputs) : impl_(CreateImpl(Span())) { - impl_->CollapseSpan(expr, PrimExprSet(inputs.begin(), inputs.end())); + impl_->CollectSpansSpan(expr, PrimExprSet(inputs.begin(), inputs.end())); } SIBuilder::SIBuilder(const tir::Stmt& s, const Array& inputs) : impl_(CreateImpl(Span())) { - impl_->CollapseSpan(s, PrimExprSet(inputs.begin(), inputs.end())); + impl_->CollectSpansSpan(s, PrimExprSet(inputs.begin(), inputs.end())); } SIBuilder::SIBuilder(const tir::Stmt& s, const Array& inputs) : impl_(CreateImpl(Span())) { - impl_->CollapseSpan(s, StmtSet(inputs.begin(), inputs.end())); + impl_->CollectSpansSpan(s, StmtSet(inputs.begin(), inputs.end())); } // Register build pipeline related options diff --git a/tests/cpp/si_builder_test.cc b/tests/cpp/si_builder_test.cc index 4bbd1acd8393..f65debaa6b17 100644 --- a/tests/cpp/si_builder_test.cc +++ b/tests/cpp/si_builder_test.cc @@ -103,7 +103,7 @@ TEST(SIBuilder, CreateSapn) { Span span_1 = _CreateSpan("first"); { SIBuilder si_builder(span_1); - EXPECT_EQ(span_1, si_builder.CreateSpan()); + EXPECT_EQ(span_1, si_builder.Build()); } Span span_2 = _CreateSpan("second"); @@ -114,9 +114,9 @@ TEST(SIBuilder, CreateSapn) { SIBuilder si_builder_2({span_1, span_2}); SIBuilder si_builder_3{span_1, span_2}; - Span created_span_1 = si_builder_1.CreateSpan(); - Span created_span_2 = si_builder_2.CreateSpan(); - Span created_span_3 = si_builder_3.CreateSpan(); + Span created_span_1 = si_builder_1.Build(); + Span created_span_2 = si_builder_2.Build(); + Span created_span_3 = si_builder_3.Build(); auto created_seq_span_1 = created_span_1.as(); auto created_seq_span_2 = created_span_2.as(); @@ -140,7 +140,7 @@ TEST(SIBuilder, DisableSIBuilder) { Span span_1 = _CreateSpan("first"); { SIBuilder si_builder(span_1); - EXPECT_NE(span_1, si_builder.CreateSpan()); + EXPECT_NE(span_1, si_builder.Build()); } } @@ -179,7 +179,7 @@ TEST(SIBuilder, RelayRecursivelyFill) { checker.Check(z, expected_z); } -TEST(SIBuilder, RelayCollapse) { +TEST(SIBuilder, RelayCollectSpans) { using namespace tvm; auto pass_ctx = transform::PassContext::Create(); pass_ctx->config.Set("ir.enable_si_builder", Bool(true)); @@ -206,7 +206,7 @@ TEST(SIBuilder, RelayCollapse) { relay::Expr z = relay::Call(add_op, {y, x}, tvm::Attrs(), {}, z_node_span); SIBuilder si_builder(z, {a}); - Span created_span = si_builder.CreateSpan(); + Span created_span = si_builder.Build(); auto created_seq_span = created_span.as(); EXPECT_EQ(created_seq_span->spans.size(), 4); for (std::size_t i = 0; i != created_seq_span->spans.size(); i++) { @@ -214,7 +214,7 @@ TEST(SIBuilder, RelayCollapse) { } } -TEST(SIBuilder, TirCollapsePrimExpr) { +TEST(SIBuilder, TirCollectSpansPrimExpr) { using namespace tvm; auto pass_ctx = transform::PassContext::Create(); pass_ctx->config.Set("ir.enable_si_builder", Bool(true)); @@ -241,7 +241,7 @@ TEST(SIBuilder, TirCollapsePrimExpr) { z->span = z_node_span; SIBuilder si_builder(z, {x}); - Span created_span = si_builder.CreateSpan(); + Span created_span = si_builder.Build(); auto created_seq_span = created_span.as(); EXPECT_EQ(created_seq_span->spans.size(), 4); @@ -250,7 +250,7 @@ TEST(SIBuilder, TirCollapsePrimExpr) { } } -TEST(SIBuilder, TirCollapseStmtWithPrimInput) { +TEST(SIBuilder, TirCollectSpansStmtWithPrimInput) { using namespace tvm; auto pass_ctx = transform::PassContext::Create(); pass_ctx->config.Set("ir.enable_si_builder", Bool(true)); @@ -274,7 +274,7 @@ TEST(SIBuilder, TirCollapseStmtWithPrimInput) { auto stmt = fmaketest(); stmt->span = stmt_node_span; SIBuilder si_builder(stmt, {x}); - Span created_span = si_builder.CreateSpan(); + Span created_span = si_builder.Build(); auto created_seq_span = created_span.as(); EXPECT_EQ(created_seq_span->spans.size(), 3); @@ -283,7 +283,7 @@ TEST(SIBuilder, TirCollapseStmtWithPrimInput) { } } -TEST(SIBuilder, TirCollapseStmtWithStmtInput) { +TEST(SIBuilder, TirCollectSpansStmtWithStmtInput) { using namespace tvm; auto pass_ctx = transform::PassContext::Create(); pass_ctx->config.Set("ir.enable_si_builder", Bool(true)); @@ -300,7 +300,7 @@ TEST(SIBuilder, TirCollapseStmtWithStmtInput) { tir::Block block({}, {}, {}, "block", body, init, Array(), Array(), Map(), block_node_span); SIBuilder si_builder(block, {init}); - Span created_span = si_builder.CreateSpan(); + Span created_span = si_builder.Build(); auto created_seq_span = created_span.as(); EXPECT_EQ(created_seq_span->spans.size(), 3);