Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion include/tvm/relax/tir_pattern.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ class MatchResult : public ObjectRef {
TVM_DEFINE_OBJECT_REF_METHODS(MatchResult, ObjectRef, MatchResultNode);
};

using FCodegen = ffi::TypedFunction<Array<ObjectRef>(Array<MatchResult> match_results)>;
using FCodegen = ffi::TypedFunction<Array<ffi::Any>(Array<MatchResult> match_results)>;
} // namespace relax
} // namespace tvm
#endif // TVM_RELAX_TIR_PATTERN_H_
8 changes: 4 additions & 4 deletions include/tvm/runtime/profiling.h
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,7 @@ class MetricCollectorNode : public Object {
* \returns A set of metric names and the associated values. Values must be
* one of DurationNode, PercentNode, CountNode, or StringObj.
*/
virtual Map<String, ffi::Any> Stop(ObjectRef obj) = 0;
virtual Map<String, ffi::Any> Stop(ffi::ObjectRef obj) = 0;

virtual ~MetricCollectorNode() {}

Expand All @@ -340,7 +340,7 @@ struct CallFrame {
/*! Runtime of the function or op */
Timer timer;
/*! Extra performance metrics */
std::unordered_map<std::string, ObjectRef> extra_metrics;
std::unordered_map<std::string, ffi::Any> extra_metrics;
/*! User defined metric collectors. Each pair is the MetricCollector and its
* associated data (returned from MetricCollector.Start).
*/
Expand Down Expand Up @@ -404,12 +404,12 @@ class Profiler {
* `StartCall` and `StopCall` must be nested properly.
*/
void StartCall(String name, Device dev,
std::unordered_map<std::string, ObjectRef> extra_metrics = {});
std::unordered_map<std::string, ffi::Any> extra_metrics = {});
/*! \brief Stop the last `StartCall`.
* \param extra_metrics Optional additional profiling information to add to
* the frame (input sizes, allocations).
*/
void StopCall(std::unordered_map<std::string, ObjectRef> extra_metrics = {});
void StopCall(std::unordered_map<std::string, ffi::Any> extra_metrics = {});
/*! \brief A report of total runtime between `Start` and `Stop` as
* well as individual statistics for each `StartCall`-`StopCall` pair.
* \returns A `Report` that can either be formatted as CSV (with `.AsCSV`)
Expand Down
8 changes: 3 additions & 5 deletions include/tvm/script/printer/doc.h
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ class LiteralDocNode : public ExprDocNode {
* - String
* - null
*/
ObjectRef value;
ffi::Any value;

static void RegisterReflection() {
namespace refl = tvm::ffi::reflection;
Expand All @@ -265,16 +265,14 @@ class LiteralDocNode : public ExprDocNode {
*/
class LiteralDoc : public ExprDoc {
protected:
explicit LiteralDoc(ObjectRef value, const Optional<ObjectPath>& object_path);
explicit LiteralDoc(ffi::Any value, const Optional<ObjectPath>& object_path);

public:
/*!
* \brief Create a LiteralDoc to represent None/null/empty value.
* \param p The object path
*/
static LiteralDoc None(const Optional<ObjectPath>& p) {
return LiteralDoc(ObjectRef(nullptr), p);
}
static LiteralDoc None(const Optional<ObjectPath>& p) { return LiteralDoc(ffi::Any(nullptr), p); }
/*!
* \brief Create a LiteralDoc to represent integer.
* \param v The integer value.
Expand Down
20 changes: 14 additions & 6 deletions include/tvm/script/printer/ir_docsifier.h
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ class IRDocsifierNode : public Object {
/*! \brief Mapping from a var to its info */
std::unordered_map<ObjectRef, VariableInfo, ObjectPtrHash, ObjectPtrEqual> obj2info;
/*! \brief Metadata printing */
std::unordered_map<String, Array<ObjectRef>> metadata;
std::unordered_map<String, Array<ffi::Any>> metadata;
/*! \brief GlobalInfo printing */
std::unordered_map<String, Array<GlobalInfo>> global_infos;
/*! \brief The variable names used already */
Expand Down Expand Up @@ -206,7 +206,7 @@ class IRDocsifierNode : public Object {
*/
Optional<ExprDoc> GetVarDoc(const ObjectRef& obj) const;
/*! \brief Add a TVM object to the metadata section*/
ExprDoc AddMetadata(const ObjectRef& obj);
ExprDoc AddMetadata(const ffi::Any& obj);
/*! \brief Add a GlobalInfo to the global_infos map.
* \param name The name of key of global_infos.
* \param ginfo The GlobalInfo to be added.
Expand Down Expand Up @@ -275,7 +275,7 @@ inline static void AddDocDecoration(const Doc& d, const ObjectRef& obj, const Ob
const PrinterConfig& cfg) {
if (cfg->obj_to_annotate.count(obj)) {
if (const auto* stmt = d.as<StmtDocNode>()) {
if (stmt->comment.defined()) {
if (stmt->comment.has_value()) {
stmt->comment = stmt->comment.value() + "\n" + cfg->obj_to_annotate.at(obj);
} else {
stmt->comment = cfg->obj_to_annotate.at(obj);
Expand All @@ -295,7 +295,7 @@ inline static void AddDocDecoration(const Doc& d, const ObjectRef& obj, const Ob
String attn = pair.second;
if (p->IsPrefixOf(path) && path->IsPrefixOf(p)) {
if (const auto* stmt = d.as<StmtDocNode>()) {
if (stmt->comment.defined()) {
if (stmt->comment.has_value()) {
stmt->comment = stmt->comment.value() + "\n" + attn;
} else {
stmt->comment = attn;
Expand All @@ -319,8 +319,16 @@ inline TDoc IRDocsifierNode::AsDoc(const Any& value, const ObjectPath& path) con
return Downcast<TDoc>(LiteralDoc::Int(value.as<int64_t>().value(), path));
case ffi::TypeIndex::kTVMFFIFloat:
return Downcast<TDoc>(LiteralDoc::Float(value.as<double>().value(), path));
case ffi::TypeIndex::kTVMFFIStr:
return Downcast<TDoc>(LiteralDoc::Str(value.as<String>().value(), path));
case ffi::TypeIndex::kTVMFFIStr: {
std::string string_value = value.cast<std::string>();
bool has_multiple_lines = string_value.find_first_of('\n') != std::string::npos;
if (has_multiple_lines) {
Doc d = const_cast<IRDocsifierNode*>(this)->AddMetadata(string_value);
// TODO(tqchen): cross check AddDocDecoration
return Downcast<TDoc>(d);
}
return Downcast<TDoc>(LiteralDoc::Str(string_value, path));
}
case ffi::TypeIndex::kTVMFFIDataType:
return Downcast<TDoc>(LiteralDoc::DataType(value.as<runtime::DataType>().value(), path));
case ffi::TypeIndex::kTVMFFIDevice:
Expand Down
4 changes: 2 additions & 2 deletions include/tvm/tir/stmt.h
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ class LetStmt : public Stmt {
class AttrStmtNode : public StmtNode {
public:
/*! \brief this is attribute about certain node */
ObjectRef node;
ffi::Any node;
/*! \brief the type key of the attribute */
String attr_key;
/*! \brief The attribute value, value is well defined at current scope. */
Expand All @@ -142,7 +142,7 @@ class AttrStmtNode : public StmtNode {
*/
class AttrStmt : public Stmt {
public:
TVM_DLL AttrStmt(ObjectRef node, String attr_key, PrimExpr value, Stmt body, Span span = Span());
TVM_DLL AttrStmt(ffi::Any node, String attr_key, PrimExpr value, Stmt body, Span span = Span());

TVM_DEFINE_OBJECT_REF_METHODS(AttrStmt, Stmt, AttrStmtNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(AttrStmtNode);
Expand Down
14 changes: 7 additions & 7 deletions src/contrib/msc/core/ir/graph_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -342,7 +342,7 @@ const MSCJoint GraphBuilder::AddNode(const Expr& expr, const Optional<Expr>& bin
if (const auto* v_node = call_node->op.as<GlobalVarNode>()) {
const auto& func = Downcast<Function>(ref_module_->Lookup(v_node->name_hint));
const auto& name_opt = func->GetAttr<String>(relax::attr::kComposite);
if (name_opt.defined()) {
if (name_opt.has_value()) {
attrs = FuncAttrGetter().GetAttrs(func);
}
} else if (call_node->op->IsInstance<VarNode>()) {
Expand Down Expand Up @@ -760,7 +760,7 @@ void GraphBuilder::VisitBinding_(const VarBindingNode* binding, const DataflowVa

void GraphBuilder::VisitBinding_(const VarBindingNode* binding, const FunctionNode* val) {
const auto& name_opt = val->GetAttr<String>(relax::attr::kComposite);
ICHECK(name_opt.defined()) << "Unexpected target func without composite";
ICHECK(name_opt.has_value()) << "Unexpected target func without composite";
ICHECK(config_.target.size() > 0 && StringUtils::StartsWith(name_opt.value(), config_.target))
<< "Target should be given for target function";
target_funcs_.Set(binding->var, GetRef<Function>(val));
Expand All @@ -770,26 +770,26 @@ const std::tuple<String, String, String> GraphBuilder::ParseFunc(const Function&
String node_name, optype, layout;
const auto& name_opt = func->GetAttr<String>(msc_attr::kUnique);
// get node_name
if (name_opt.defined()) {
if (name_opt.has_value()) {
node_name = name_opt.value();
}
// get optype
const auto& codegen_opt = func->GetAttr<String>(relax::attr::kCodegen);
const auto& optype_opt = func->GetAttr<String>(msc_attr::kOptype);
const auto& composite_opt = func->GetAttr<String>(relax::attr::kComposite);
if (codegen_opt.defined()) {
if (codegen_opt.has_value()) {
optype = codegen_opt.value();
} else if (optype_opt.defined()) {
} else if (optype_opt.has_value()) {
optype = optype_opt.value();
} else if (composite_opt.defined()) {
} else if (composite_opt.has_value()) {
optype = composite_opt.value();
if (config_.target.size() > 0) {
optype = StringUtils::Replace(composite_opt.value(), config_.target + ".", "");
}
}
// get layout
const auto& layout_opt = func->GetAttr<String>(msc_attr::kLayout);
if (layout_opt.defined()) {
if (layout_opt.has_value()) {
layout = layout_opt.value();
}
return std::make_tuple(node_name, optype, layout);
Expand Down
6 changes: 3 additions & 3 deletions src/contrib/msc/core/printer/cpp_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,9 @@ namespace contrib {
namespace msc {

void CppPrinter::PrintTypedDoc(const LiteralDoc& doc) {
const ObjectRef& value = doc->value;
const ffi::Any& value = doc->value;
bool defined = false;
if (!value.defined()) {
if (value == nullptr) {
output_ << "nullptr";
defined = true;
} else if (const auto* int_imm = value.as<IntImmNode>()) {
Expand Down Expand Up @@ -217,7 +217,7 @@ void CppPrinter::PrintTypedDoc(const ClassDoc& doc) {
}

void CppPrinter::PrintTypedDoc(const CommentDoc& doc) {
if (doc->comment.defined()) {
if (doc->comment.has_value()) {
output_ << "// " << doc->comment.value();
}
}
Expand Down
2 changes: 1 addition & 1 deletion src/contrib/msc/core/printer/msc_base_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ void MSCBasePrinter::PrintTypedDoc(const ExprStmtDoc& doc) {
}

void MSCBasePrinter::MaybePrintComment(const StmtDoc& stmt, bool multi_lines) {
if (stmt->comment.defined()) {
if (stmt->comment.has_value()) {
if (multi_lines) {
for (const auto& l : StringUtils::Split(stmt->comment.value(), "\n")) {
PrintDoc(CommentDoc(l));
Expand Down
6 changes: 3 additions & 3 deletions src/contrib/msc/core/printer/prototxt_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ namespace tvm {
namespace contrib {
namespace msc {

LiteralDoc PrototxtPrinter::ToLiteralDoc(const ObjectRef& obj) {
LiteralDoc PrototxtPrinter::ToLiteralDoc(const ffi::Any& obj) {
if (obj.as<ffi::StringObj>()) {
return LiteralDoc::Str(Downcast<String>(obj), std::nullopt);
} else if (obj.as<IntImmNode>()) {
Expand All @@ -51,7 +51,7 @@ DictDoc PrototxtPrinter::ToDictDoc(const Map<String, ffi::Any>& dict) {
if (pair.second.as<DictDocNode>()) {
values.push_back(Downcast<DictDoc>(pair.second));
} else {
values.push_back(ToLiteralDoc(pair.second.cast<ObjectRef>()));
values.push_back(ToLiteralDoc(pair.second));
}
}
return DictDoc(keys, values);
Expand All @@ -65,7 +65,7 @@ DictDoc PrototxtPrinter::ToDictDoc(const std::vector<std::pair<String, Any>>& di
if (pair.second.as<DictDocNode>()) {
values.push_back(Downcast<DictDoc>(pair.second));
} else {
values.push_back(ToLiteralDoc(pair.second.cast<ObjectRef>()));
values.push_back(ToLiteralDoc(pair.second));
}
}
return DictDoc(keys, values);
Expand Down
2 changes: 1 addition & 1 deletion src/contrib/msc/core/printer/prototxt_printer.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ class PrototxtPrinter : public MSCBasePrinter {
explicit PrototxtPrinter(const std::string& options = "") : MSCBasePrinter(options) {}

/*! \brief Change object to LiteralDoc*/
static LiteralDoc ToLiteralDoc(const ObjectRef& obj);
static LiteralDoc ToLiteralDoc(const ffi::Any& obj);

/*! \brief Change map to DictDoc*/
static DictDoc ToDictDoc(const Map<String, ffi::Any>& dict);
Expand Down
10 changes: 5 additions & 5 deletions src/contrib/msc/core/printer/python_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@ namespace contrib {
namespace msc {

void PythonPrinter::PrintTypedDoc(const LiteralDoc& doc) {
const ObjectRef& value = doc->value;
const ffi::Any& value = doc->value;
bool defined = false;
if (!value.defined()) {
if (value == nullptr) {
output_ << "None";
defined = true;
} else if (const auto* int_imm = value.as<IntImmNode>()) {
Expand Down Expand Up @@ -176,7 +176,7 @@ void PythonPrinter::PrintTypedDoc(const FunctionDoc& doc) {

output_ << ":";

if (doc->comment.defined()) {
if (doc->comment.has_value()) {
IncreaseIndent();
MaybePrintComment(doc, true);
DecreaseIndent();
Expand All @@ -197,7 +197,7 @@ void PythonPrinter::PrintTypedDoc(const ClassDoc& doc) {
}

void PythonPrinter::PrintTypedDoc(const CommentDoc& doc) {
if (doc->comment.defined()) {
if (doc->comment.has_value()) {
output_ << "# " << doc->comment.value();
}
}
Expand Down Expand Up @@ -234,7 +234,7 @@ void PythonPrinter::PrintTypedDoc(const SwitchDoc& doc) {
}

void PythonPrinter::MaybePrintComment(const StmtDoc& stmt, bool multi_lines) {
if (stmt->comment.defined() && multi_lines) {
if (stmt->comment.has_value() && multi_lines) {
NewLine();
output_ << "\"\"\"";
for (const auto& l : StringUtils::Split(stmt->comment.value(), "\n")) {
Expand Down
11 changes: 5 additions & 6 deletions src/contrib/msc/core/transform/bind_named_params.cc
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ std::tuple<Map<Var, Expr>, Map<tir::Var, PrimExpr>> NormalizeNamedBindings(

Map<relax::Var, relax::Expr> relax_var_remap;

auto normalize_key = [&](ObjectRef obj) -> relax::Var {
auto normalize_key = [&](ffi::Any obj) -> relax::Var {
if (auto opt_str = obj.as<String>()) {
std::string str = opt_str.value();
auto it = string_lookup.find(str);
Expand Down Expand Up @@ -77,18 +77,17 @@ std::tuple<Map<Var, Expr>, Map<tir::Var, PrimExpr>> NormalizeNamedBindings(
LOG(FATAL)
<< "Expected bound parameter to be a relax::Var, "
<< " or a string that uniquely identifies a relax::Var param within the function. "
<< "However, received object " << obj << " of type " << obj->GetTypeKey();
<< "However, received object " << obj << " of type " << obj.GetTypeKey();
}
};
auto normalize_value = [&](Var key, ObjectRef obj) -> relax::Expr {
auto normalize_value = [&](Var key, ffi::Any obj) -> relax::Expr {
if (auto opt = obj.as<relax::Expr>()) {
return opt.value();
} else if (auto opt = obj.as<runtime::NDArray>()) {
const auto& span = SpanUtils::CreateWithAttr(msc_attr::kName, key->name_hint());
return Constant(opt.value(), StructInfo(), span);
} else {
LOG(FATAL) << "Cannot coerce object of type " << obj->GetTypeKey()
<< " into relax expression";
LOG(FATAL) << "Cannot coerce object of type " << obj.GetTypeKey() << " into relax expression";
}
};

Expand Down Expand Up @@ -130,7 +129,7 @@ IRModule BindNamedParam(IRModule m, String func_name, Map<ObjectRef, ObjectRef>
if (relax_f->GetLinkageType() == LinkageType::kExternal) {
// Use global_symbol if it's external linkage
Optional<String> gsymbol = relax_f->GetAttr<String>(tvm::attr::kGlobalSymbol);
if (gsymbol.defined() && gsymbol.value() == func_name) {
if (gsymbol.has_value() && gsymbol.value() == func_name) {
Function f_after_bind = FunctionBindNamedParams(GetRef<Function>(relax_f), bind_params);
new_module->Update(func_pr.first, f_after_bind);
}
Expand Down
4 changes: 2 additions & 2 deletions src/contrib/msc/core/transform/fuse_tuple.cc
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ class TupleFuser : public ExprMutator {
main_var = gv;
} else {
const auto& name_opt = func->GetAttr<String>(attr::kComposite);
if (name_opt.defined() && StringUtils::StartsWith(name_opt.value(), target_)) {
if (name_opt.has_value() && StringUtils::StartsWith(name_opt.value(), target_)) {
target_funcs_.Set(gv, Downcast<Function>(func));
}
}
Expand All @@ -76,7 +76,7 @@ class TupleFuser : public ExprMutator {
if (arg->IsInstance<TupleNode>()) {
String tuple_name;
const auto& name_opt = target_funcs_[val->op]->GetAttr<String>(msc_attr::kUnique);
if (name_opt.defined()) {
if (name_opt.has_value()) {
if (val->args.size() == 1) {
tuple_name = name_opt.value() + "_input";
} else {
Expand Down
2 changes: 1 addition & 1 deletion src/contrib/msc/core/transform/inline_params.cc
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ class ParamsInliner : public ExprMutator {
}
if (struct_info->IsInstance<FuncStructInfoNode>()) {
const auto& optype_opt = func->GetAttr<String>(msc_attr::kOptype);
ICHECK(optype_opt.defined())
ICHECK(optype_opt.has_value())
<< "Can not find attr " << msc_attr::kOptype << " form extern func";
extern_types_.Set(p, optype_opt.value());
continue;
Expand Down
4 changes: 2 additions & 2 deletions src/contrib/msc/core/transform/set_byoc_attrs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ class ByocNameSetter : public ExprMutator {
continue;
}
const auto& name_opt = func->GetAttr<String>(attr::kCodegen);
if (name_opt.defined() && name_opt.value() == target_) {
if (name_opt.has_value() && name_opt.value() == target_) {
const String& func_name = target_ + "_" + std::to_string(func_cnt);
const auto& new_func = Downcast<Function>(VisitExpr(func));
builder_->UpdateFunction(gv, WithAttr(new_func, msc_attr::kUnique, func_name));
Expand All @@ -75,7 +75,7 @@ class ByocNameSetter : public ExprMutator {
if (val->op->IsInstance<relax::VarNode>()) {
ICHECK(local_funcs_.count(val->op)) << "Can not find local func " << val->op;
const auto& name_opt = local_funcs_[val->op]->GetAttr<String>(msc_attr::kUnique);
if (name_opt.defined()) {
if (name_opt.has_value()) {
val->span = SpanUtils::SetAttr(val->span, "name", name_opt.value());
}
}
Expand Down
Loading
Loading