diff --git a/3rdparty/HalideIR b/3rdparty/HalideIR index b133e55db025..cea4958f3223 160000 --- a/3rdparty/HalideIR +++ b/3rdparty/HalideIR @@ -1 +1 @@ -Subproject commit b133e55db025c962e30b045a6a3b937e9c03ca14 +Subproject commit cea4958f3223e6c14e9058eea85b6fa4ef9c4009 diff --git a/include/tvm/relay/vm/vm.h b/include/tvm/relay/vm/vm.h index 4f179689cfff..250aabde4c5f 100644 --- a/include/tvm/relay/vm/vm.h +++ b/include/tvm/relay/vm/vm.h @@ -11,102 +11,13 @@ #include #include #include +#include namespace tvm { namespace relay { namespace vm { -using runtime::NDArray; - -enum struct VMObjectTag { - kTensor, - kClosure, - kDatatype, - kExternalFunc, -}; - -inline std::string VMObjectTagString(VMObjectTag tag) { - switch (tag) { - case VMObjectTag::kClosure: - return "Closure"; - case VMObjectTag::kDatatype: - return "Datatype"; - case VMObjectTag::kTensor: - return "Tensor"; - case VMObjectTag::kExternalFunc: - return "ExternalFunction"; - default: - LOG(FATAL) << "Object tag is not supported."; - return ""; - } -} - -// TODO(@jroesch): Use intrusive pointer. -struct VMObjectCell { - VMObjectTag tag; - VMObjectCell(VMObjectTag tag) : tag(tag) {} - VMObjectCell() {} - virtual ~VMObjectCell() {} -}; - -struct VMTensorCell : public VMObjectCell { - tvm::runtime::NDArray data; - VMTensorCell(const tvm::runtime::NDArray& data) - : VMObjectCell(VMObjectTag::kTensor), data(data) {} -}; - -struct VMObject { - std::shared_ptr ptr; - VMObject(std::shared_ptr ptr) : ptr(ptr) {} - VMObject() : ptr() {} - VMObject(const VMObject& obj) : ptr(obj.ptr) {} - VMObjectCell* operator->() { - return this->ptr.operator->(); - } -}; - -struct VMDatatypeCell : public VMObjectCell { - size_t tag; - std::vector fields; - - VMDatatypeCell(size_t tag, const std::vector& fields) - : VMObjectCell(VMObjectTag::kDatatype), tag(tag), fields(fields) {} -}; - -struct VMClosureCell : public VMObjectCell { - size_t func_index; - std::vector free_vars; - - VMClosureCell(size_t func_index, const std::vector& free_vars) - : VMObjectCell(VMObjectTag::kClosure), func_index(func_index), free_vars(free_vars) {} -}; - - -inline VMObject VMTensor(const tvm::runtime::NDArray& data) { - auto ptr = std::make_shared(data); - return std::dynamic_pointer_cast(ptr); -} - -inline VMObject VMDatatype(size_t tag, const std::vector& fields) { - auto ptr = std::make_shared(tag, fields); - return std::dynamic_pointer_cast(ptr); -} - -inline VMObject VMTuple(const std::vector& fields) { - return VMDatatype(0, fields); -} - -inline VMObject VMClosure(size_t func_index, std::vector free_vars) { - auto ptr = std::make_shared(func_index, free_vars); - return std::dynamic_pointer_cast(ptr); -} - -inline NDArray ToNDArray(const VMObject& obj) { - CHECK(obj.ptr.get()); - CHECK(obj.ptr->tag == VMObjectTag::kTensor) << "Expect Tensor, Got " << VMObjectTagString(obj.ptr->tag); - std::shared_ptr o = std::dynamic_pointer_cast(obj.ptr); - return o->data; -} +using namespace tvm::runtime; enum struct Opcode { Push, @@ -235,8 +146,8 @@ struct VirtualMachine { std::vector packed_funcs; std::vector functions; std::vector frames; - std::vector stack; - std::vector constants; + std::vector stack; + std::vector constants; // Frame State size_t func_index; @@ -255,8 +166,8 @@ struct VirtualMachine { void InvokeGlobal(const VMFunction& func); void Run(); - VMObject Invoke(const VMFunction& func, const std::vector& args); - VMObject Invoke(const GlobalVar& global, const std::vector& args); + Object Invoke(const VMFunction& func, const std::vector& args); + Object Invoke(const GlobalVar& global, const std::vector& args); // Ignore the method that dumps register info at compile-time if debugging // mode is not enabled. diff --git a/include/tvm/runtime/c_runtime_api.h b/include/tvm/runtime/c_runtime_api.h index 1e608497d4dd..a511d00ff374 100644 --- a/include/tvm/runtime/c_runtime_api.h +++ b/include/tvm/runtime/c_runtime_api.h @@ -85,7 +85,7 @@ typedef enum { kStr = 11U, kBytes = 12U, kNDArrayContainer = 13U, - kVMObject = 14U, + kObject = 14U, // Extension codes for other frameworks to integrate TVM PackedFunc. // To make sure each framework's id do not conflict, use first and // last sections to mark ranges. diff --git a/include/tvm/runtime/ndarray.h b/include/tvm/runtime/ndarray.h index 03b21238870b..94f0e25da81c 100644 --- a/include/tvm/runtime/ndarray.h +++ b/include/tvm/runtime/ndarray.h @@ -1,7 +1,7 @@ /*! * Copyright (c) 2017 by Contributors * \file tvm/runtime/ndarray.h - * \brief Abstract device memory management API + * \brief A device-indpendent managed NDArray abstraction. */ #ifndef TVM_RUNTIME_NDARRAY_H_ #define TVM_RUNTIME_NDARRAY_H_ diff --git a/include/tvm/runtime/object.h b/include/tvm/runtime/object.h new file mode 100644 index 000000000000..7743e8d419ce --- /dev/null +++ b/include/tvm/runtime/object.h @@ -0,0 +1,79 @@ +/*! + * Copyright (c) 2019 by Contributors + * \file tvm/runtime/object.h + * \brief A managed object in the TVM runtime. + */ +#ifndef TVM_RUNTIME_OBJECT_H_ +#define TVM_RUNTIME_OBJECT_H_ + +#include + +namespace tvm { +namespace runtime { + +enum struct ObjectTag { + kTensor, + kClosure, + kDatatype, + kExternalFunc +}; + +std::ostream& operator<<(std::ostream& os, const ObjectTag&); + +// TODO(@jroesch): Use intrusive pointer. +struct ObjectCell { + ObjectTag tag; + ObjectCell(ObjectTag tag) : tag(tag) {} + ObjectCell() {} + virtual ~ObjectCell() {} +}; + +/*! + * \brief A managed object in the TVM runtime. + * + * For example a tuple, list, closure, and so on. + * + * Maintains a reference count for the object. + */ +class Object { +public: + std::shared_ptr ptr; + Object(std::shared_ptr ptr) : ptr(ptr) {} + Object() : ptr() {} + Object(const Object& obj) : ptr(obj.ptr) {} + ObjectCell* operator->() { + return this->ptr.operator->(); + } +}; + +struct TensorCell : public ObjectCell { + NDArray data; + TensorCell(const NDArray& data) + : ObjectCell(ObjectTag::kTensor), data(data) {} +}; + +struct DatatypeCell : public ObjectCell { + size_t tag; + std::vector fields; + + DatatypeCell(size_t tag, const std::vector& fields) + : ObjectCell(ObjectTag::kDatatype), tag(tag), fields(fields) {} +}; + +struct ClosureCell : public ObjectCell { + size_t func_index; + std::vector free_vars; + + ClosureCell(size_t func_index, const std::vector& free_vars) + : ObjectCell(ObjectTag::kClosure), func_index(func_index), free_vars(free_vars) {} +}; + +Object TensorObj(const NDArray& data); +Object DatatypeObj(size_t tag, const std::vector& fields); +Object TupleObj(const std::vector& fields); +Object ClosureObj(size_t func_index, std::vector free_vars); +NDArray ToNDArray(const Object& obj); + +} // namespace runtime +} // namespace tvm +#endif // TVM_RUNTIME_OBJECT_H_ diff --git a/include/tvm/runtime/packed_func.h b/include/tvm/runtime/packed_func.h index 6ed9c6dd7c66..a35dd0d5d5b7 100644 --- a/include/tvm/runtime/packed_func.h +++ b/include/tvm/runtime/packed_func.h @@ -21,6 +21,7 @@ #include "c_runtime_api.h" #include "module.h" #include "ndarray.h" +#include "object.h" #include "node_base.h" namespace HalideIR { @@ -40,12 +41,6 @@ namespace tvm { // forward declarations class Integer; -namespace relay { -namespace vm { - struct VMObject; -} -} - namespace runtime { // forward declarations class TVMArgs; @@ -589,7 +584,7 @@ class TVMArgValue : public TVMPODValue_ { inline operator tvm::Integer() const; // get internal node ptr, if it is node inline NodePtr& node_sptr(); - operator relay::vm::VMObject() const; + operator runtime::Object() const; }; /*! @@ -724,7 +719,7 @@ class TVMRetValue : public TVMPODValue_ { return *this; } - TVMRetValue& operator=(relay::vm::VMObject other); + TVMRetValue& operator=(runtime::Object other); TVMRetValue& operator=(PackedFunc f) { this->SwitchToClass(kFuncHandle, f); @@ -821,7 +816,7 @@ class TVMRetValue : public TVMPODValue_ { kNodeHandle, *other.template ptr >()); break; } - case kVMObject: { + case kObject: { throw dmlc::Error("here"); } default: { @@ -871,7 +866,7 @@ class TVMRetValue : public TVMPODValue_ { static_cast(value_.v_handle)->DecRef(); break; } - // case kModuleHandle: delete ptr(); break; + // case kModuleHandle: delete ptr(); break; } if (type_code_ > kExtBegin) { #if TVM_RUNTIME_HEADER_ONLY @@ -901,7 +896,7 @@ inline const char* TypeCode2Str(int type_code) { case kFuncHandle: return "FunctionHandle"; case kModuleHandle: return "ModuleHandle"; case kNDArrayContainer: return "NDArrayContainer"; - case kVMObject: return "VMObject"; + case kObject: return "Object"; default: LOG(FATAL) << "unknown type_code=" << static_cast(type_code); return ""; } diff --git a/python/tvm/relay/prelude.py b/python/tvm/relay/prelude.py index 4801db7d3f3c..fb8c58bf431e 100644 --- a/python/tvm/relay/prelude.py +++ b/python/tvm/relay/prelude.py @@ -473,7 +473,7 @@ def __init__(self, mod): self.define_list_map() self.define_list_foldl() self.define_list_foldr() - # self.define_list_concat() + self.define_list_concat() self.define_list_filter() self.define_list_zip() self.define_list_rev() @@ -489,9 +489,10 @@ def __init__(self, mod): self.define_nat_add() self.define_list_length() self.define_list_nth() + self.define_list_update() self.define_list_sum() - self.define_tree_adt() + self.define_tree_adt() self.define_tree_map() self.define_tree_size() diff --git a/src/api/dsl_api.cc b/src/api/dsl_api.cc index 55770d116596..6643ef0906fc 100644 --- a/src/api/dsl_api.cc +++ b/src/api/dsl_api.cc @@ -8,7 +8,6 @@ #include #include #include -#include #include #include #include @@ -74,7 +73,7 @@ struct APIAttrGetter : public AttrVisitor { found_ref_object = true; } } - void Visit(const char* key, relay::vm::VMObject* value) final { + void Visit(const char* key, runtime::Object* value) final { if (skey == key) { *ret = value[0]; found_ref_object = true; @@ -115,7 +114,7 @@ struct APIAttrDir : public AttrVisitor { void Visit(const char* key, runtime::NDArray* value) final { names->push_back(key); } - void Visit(const char* key, relay::vm::VMObject* value) final { + void Visit(const char* key, runtime::Object* value) final { names->push_back(key); } }; diff --git a/src/lang/reflection.cc b/src/lang/reflection.cc index aba5fdacdfdc..3e640a19abfb 100644 --- a/src/lang/reflection.cc +++ b/src/lang/reflection.cc @@ -10,7 +10,6 @@ #include #include #include -#include #include #include #include @@ -35,8 +34,8 @@ inline Type String2Type(std::string s) { return TVMType2Type(runtime::String2TVMType(s)); } -using relay::vm::VMObject; -using relay::vm::VMObjectCell; +using runtime::Object; +using runtime::ObjectCell; // indexer to index all the ndoes class NodeIndexer : public AttrVisitor { @@ -45,8 +44,8 @@ class NodeIndexer : public AttrVisitor { std::vector node_list{nullptr}; std::unordered_map tensor_index; std::vector tensor_list; - std::unordered_map vm_obj_index; - std::vector vm_obj_list; + std::unordered_map vm_obj_index; + std::vector vm_obj_list; void Visit(const char* key, double* value) final {} void Visit(const char* key, int64_t* value) final {} @@ -68,8 +67,8 @@ class NodeIndexer : public AttrVisitor { tensor_list.push_back(ptr); } - void Visit(const char* key, VMObject* value) final { - VMObjectCell* ptr = value->ptr.get(); + void Visit(const char* key, Object* value) final { + ObjectCell* ptr = value->ptr.get(); if (vm_obj_index.count(ptr)) return; CHECK_EQ(vm_obj_index.size(), vm_obj_list.size()); vm_obj_index[ptr] = vm_obj_list.size(); @@ -159,7 +158,7 @@ class JSONAttrGetter : public AttrVisitor { public: const std::unordered_map* node_index_; const std::unordered_map* tensor_index_; - const std::unordered_map* vm_obj_index_; + const std::unordered_map* vm_obj_index_; JSONNode* node_; void Visit(const char* key, double* value) final { @@ -194,7 +193,7 @@ class JSONAttrGetter : public AttrVisitor { node_->attrs[key] = std::to_string( tensor_index_->at(const_cast((*value).operator->()))); } - void Visit(const char* key, VMObject* value) final { + void Visit(const char* key, Object* value) final { node_->attrs[key] = std::to_string( vm_obj_index_->at(value->ptr.get())); } @@ -251,7 +250,7 @@ class JSONAttrSetter : public AttrVisitor { public: const std::vector >* node_list_; const std::vector* tensor_list_; - const std::vector* vm_obj_list_; + const std::vector* vm_obj_list_; JSONNode* node_; @@ -307,7 +306,7 @@ class JSONAttrSetter : public AttrVisitor { CHECK_LE(index, tensor_list_->size()); *value = tensor_list_->at(index); } - void Visit(const char* key, VMObject* value) final { + void Visit(const char* key, Object* value) final { size_t index; ParseValue(key, &index); CHECK_LE(index, vm_obj_list_->size()); @@ -490,8 +489,8 @@ class NodeAttrSetter : public AttrVisitor { void Visit(const char* key, runtime::NDArray* value) final { *value = GetAttr(key).operator runtime::NDArray(); } - void Visit(const char* key, VMObject* value) final { - *value = GetAttr(key).operator VMObject(); + void Visit(const char* key, Object* value) final { + *value = GetAttr(key).operator Object(); } private: diff --git a/src/relay/ir/pretty_printer.cc b/src/relay/ir/pretty_printer.cc index 40fcee68edc7..a38e7e3ff48f 100644 --- a/src/relay/ir/pretty_printer.cc +++ b/src/relay/ir/pretty_printer.cc @@ -660,6 +660,9 @@ class PrettyPrinter::AttrPrinter : public AttrVisitor { void Visit(const char* key, runtime::NDArray* value) final { LOG(FATAL) << "do not allow NDarray as argument"; } + void Visit(const char* key, runtime::Object* obj) final { + LOG(FATAL) << "do not allow Object as argument"; + } private: Doc& doc_; diff --git a/src/relay/vm/compiler.cc b/src/relay/vm/compiler.cc index beec0d55fd0d..19aaa60bf956 100644 --- a/src/relay/vm/compiler.cc +++ b/src/relay/vm/compiler.cc @@ -527,7 +527,7 @@ VirtualMachine CompileModule(const Module& mod_ref) { vm.constants.resize(context.const_map.size()); for (auto pair : context.const_map) { - vm.constants[pair.second] = VMTensor(pair.first->data); + vm.constants[pair.second] = TensorObj(pair.first->data); } for (auto named_func : mod->functions) { diff --git a/src/relay/vm/vm.cc b/src/relay/vm/vm.cc index 2a00e39e7e38..8c4f327345d2 100644 --- a/src/relay/vm/vm.cc +++ b/src/relay/vm/vm.cc @@ -21,15 +21,15 @@ using namespace tvm::runtime; namespace tvm { // Packed Function extensions. -TVMRetValue& runtime::TVMRetValue::operator=(relay::vm::VMObject other) { - this->SwitchToClass(kVMObject, other); +TVMRetValue& runtime::TVMRetValue::operator=(relay::vm::Object other) { + this->SwitchToClass(kObject, other); return *this; } -runtime::TVMArgValue::operator relay::vm::VMObject() const { - if (type_code_ == kNull) return relay::vm::VMObject(nullptr); - TVM_CHECK_TYPE_CODE(type_code_, kVMObject); - return *ptr(); +runtime::TVMArgValue::operator relay::vm::Object() const { + if (type_code_ == kNull) return relay::vm::Object(nullptr); + TVM_CHECK_TYPE_CODE(type_code_, kObject); + return *ptr(); } namespace relay { @@ -333,7 +333,7 @@ size_t VirtualMachine::PopFrame() { CHECK(0 <= stack_size - fr.sp); // Copy return value to the position past last function's frame - VMObject return_value = stack.back(); + Object return_value = stack.back(); // stack[fr.sp] = stack[stack_size - 1]; // Resize value stack. @@ -374,7 +374,7 @@ void VirtualMachine::InvokeGlobal(const VMFunction& func) { bp = stack.size() - func.params ; } -VMObject VirtualMachine::Invoke(const VMFunction& func, const std::vector& args) { +Object VirtualMachine::Invoke(const VMFunction& func, const std::vector& args) { RELAY_LOG(INFO) << "Executing function " << func.name << " bp " << bp << std::endl; for (auto arg : args) { @@ -390,8 +390,8 @@ VMObject VirtualMachine::Invoke(const VMFunction& func, const std::vector& args) { +Object VirtualMachine::Invoke(const GlobalVar& global, + const std::vector& args) { auto func_index = this->global_map[global]; RELAY_LOG(INFO) << "Invoke Global " << global << " at index " << func_index << std::endl; @@ -399,7 +399,7 @@ VMObject VirtualMachine::Invoke(const GlobalVar& global, } void InvokePacked(const PackedFunc& func, size_t arg_count, size_t output_size, - std::vector& stack) { + std::vector& stack) { auto stack_end = stack.size() - 1; RELAY_LOG(INFO) << "arg_count: " << arg_count << " output_size: " << output_size; CHECK(arg_count <= stack.size()); @@ -443,10 +443,10 @@ template typename std::enable_if::type VirtualMachine::DumpStack() { RELAY_LOG(INFO) << "DumpStack---\n"; for (size_t i = bp; i < stack.size(); ++i) { - RELAY_LOG(INFO) << i << " " << VMObjectTagString(stack[i]->tag) << " "; + RELAY_LOG(INFO) << i << " " << stack[i]->tag << " "; switch (stack[i]->tag) { - case VMObjectTag::kTensor: { - VMTensorCell* tensor = (VMTensorCell*)stack[i].operator->(); + case ObjectTag::kTensor: { + TensorCell* tensor = (TensorCell*)stack[i].operator->(); RELAY_LOG(INFO) << "dimensions=" << tensor->data->ndim; if (tensor->data->ndim == 0) { RELAY_LOG(INFO) << " " << TensorValueNode::make(tensor->data); @@ -454,8 +454,8 @@ typename std::enable_if::type VirtualMachine::DumpStack() { RELAY_LOG(INFO) << " \n"; break; } - case VMObjectTag::kDatatype: { - VMDatatypeCell* datatype = (VMDatatypeCell*)stack[i].operator->(); + case ObjectTag::kDatatype: { + DatatypeCell* datatype = (DatatypeCell*)stack[i].operator->(); RELAY_LOG(INFO) << "fields: " << datatype->fields.size(); RELAY_LOG(INFO) << "\n"; break; @@ -512,8 +512,8 @@ void VirtualMachine::Run() { case Opcode::InvokeClosure: { auto object = stack.back(); stack.pop_back(); - CHECK(object->tag == VMObjectTag::kClosure); - const std::shared_ptr& closure = std::dynamic_pointer_cast(object.ptr); + CHECK(object->tag == ObjectTag::kClosure); + const std::shared_ptr& closure = std::dynamic_pointer_cast(object.ptr); for (auto free_var : closure->free_vars) { stack.push_back(free_var); } @@ -523,8 +523,8 @@ void VirtualMachine::Run() { case Opcode::GetField: { auto object = stack[bp + instr.object_offset]; DumpStack(); - CHECK(object->tag == VMObjectTag::kDatatype) << "Object is not data type object " << bp << " " << instr.object_offset << " " << (int)object->tag; - const std::shared_ptr& tuple = std::dynamic_pointer_cast(object.ptr); + CHECK(object->tag == ObjectTag::kDatatype) << "Object is not data type object " << bp << " " << instr.object_offset << " " << (int)object->tag; + const std::shared_ptr& tuple = std::dynamic_pointer_cast(object.ptr); auto field = tuple->fields[instr.field_index]; stack.push_back(field); pc++; @@ -563,22 +563,22 @@ void VirtualMachine::Run() { shape.assign(ti.shape, ti.shape + ti.ndim); auto allocator = MemoryManager::Global()->GetAllocator(ctxs[0]); auto data = NDArray::Empty(shape, ti.dtype, ctxs[0], allocator); - stack.push_back(VMTensor(data)); + stack.push_back(TensorObj(data)); pc++; goto main_loop; } case Opcode::AllocDatatype: { - std::vector fields; + std::vector fields; size_t stack_size = stack.size(); for (size_t i = 0; i < instr.num_fields; ++i) { fields.push_back(stack[stack_size - instr.num_fields + i]); } - stack.push_back(VMDatatype(instr.constructor_tag, fields)); + stack.push_back(DatatypeObj(instr.constructor_tag, fields)); pc++; goto main_loop; } case Opcode::AllocClosure: { - std::vector free_vars; + std::vector free_vars; auto field_start = stack.size() - instr.num_freevar; // Optimize this code. for (size_t i = 0; i < instr.num_freevar; i++) { @@ -587,7 +587,7 @@ void VirtualMachine::Run() { for (size_t i = 0; i < instr.num_freevar; i++) { stack.pop_back(); } - stack.push_back(VMClosure(instr.func_index, free_vars)); + stack.push_back(ClosureObj(instr.func_index, free_vars)); DumpStack(); pc++; goto main_loop; @@ -649,16 +649,16 @@ VirtualMachine VirtualMachine::FromModule(const Module& module, /*! \brief Convert from an array of relay.Value into VM compatible objects. */ -void ConvertArgsToVM(tvm::Array args, std::vector& out) { +void ConvertArgsToVM(tvm::Array args, std::vector& out) { for (auto arg : args) { if (auto tensor = arg.as()) { - out.push_back(VMTensor(tensor->data)); + out.push_back(TensorObj(tensor->data)); } else if (auto tuple = arg.as()) { - std::vector fields; + std::vector fields; for (auto field : tuple->fields) { ConvertArgsToVM({field}, fields); } - out.push_back(VMDatatype(0, fields)); + out.push_back(DatatypeObj(0, fields)); } else { LOG(FATAL) << "unknown case: " << arg; } @@ -667,8 +667,8 @@ void ConvertArgsToVM(tvm::Array args, std::vector& out) { /*! \brief Convert from an array of relay.Value into VM compatible objects. */ -VMObject ValueToVM(Value value) { - std::vector out; +Object ValueToVM(Value value) { + std::vector out; ConvertArgsToVM({value}, out); CHECK_LT(out.size(), 2); return out[0]; @@ -676,13 +676,13 @@ VMObject ValueToVM(Value value) { using TagNameMap = std::unordered_map; -Value VMToValue(TagNameMap& tag_index_map, VMObject obj) { +Value VMToValue(TagNameMap& tag_index_map, Object obj) { switch (obj->tag) { - case VMObjectTag::kTensor: { + case ObjectTag::kTensor: { return TensorValueNode::make(ToNDArray(obj)); } - case VMObjectTag::kDatatype: { - auto data_type = std::dynamic_pointer_cast(obj.ptr); + case ObjectTag::kDatatype: { + auto data_type = std::dynamic_pointer_cast(obj.ptr); tvm::Array fields; for (size_t i = 0; i < data_type->fields.size(); ++i) { @@ -697,9 +697,9 @@ Value VMToValue(TagNameMap& tag_index_map, VMObject obj) { } } -std::tuple +std::tuple EvaluateModule(const Module& module, const std::vector ctxs, - const std::vector& vm_args) { + const std::vector& vm_args) { VirtualMachine vm = VirtualMachine::FromModule(module, ctxs); //TODO(zhiics) This measurement is for temporary usage. Remove it later. We //need to introduce a better profiling method. @@ -707,7 +707,7 @@ EvaluateModule(const Module& module, const std::vector ctxs, RELAY_LOG(INFO) << "Entry function is " << module->entry_func << std::endl; auto start = std::chrono::high_resolution_clock::now(); #endif // ENABLE_PROFILING - std::tuple res = + std::tuple res = std::make_tuple(vm.Invoke(module->entry_func, vm_args), vm.tag_index_map); #if ENABLE_PROFILING auto end = std::chrono::high_resolution_clock::now(); @@ -732,27 +732,34 @@ TVM_REGISTER_API("relay._vm._VMToValue") TVM_REGISTER_API("relay._vm._Tensor") .set_body([](TVMArgs args, TVMRetValue* ret) { - *ret = VMTensor(args[0]); + *ret = TensorObj(args[0]); }); TVM_REGISTER_API("relay._vm._Tuple") .set_body([](TVMArgs args, TVMRetValue* ret) { - std::vector fields; + std::vector fields; for (auto i = 0; i < args.size(); i++) { fields.push_back(args[i]); } - *ret = VMTuple(fields); + *ret = TupleObj(fields); }); -TVM_REGISTER_API("relay._vm._VMObjectTag") +template +std::string ToString(const T& t) { + std::stringstream s; + s << t; + return s.str(); +} + +TVM_REGISTER_API("relay._vm._ObjectTag") .set_body([](TVMArgs args, TVMRetValue* ret) { - VMObject obj = args[0]; - *ret = VMObjectTagString(obj->tag); + Object obj = args[0]; + *ret = ToString(obj->tag); }); // TVM_REGISTER_API("relay._vm._Datatype") // .set_body([](TVMArgs args, TVMRetValue* ret) { -// *ret = VMDatatype(args[0]); +// *ret = DatatypeObj(args[0]); // }); TVM_REGISTER_API("relay._vm._evaluate_vm") @@ -773,9 +780,9 @@ TVM_REGISTER_API("relay._vm._evaluate_vm") LOG(FATAL) << "expected function or module"; } - std::vector vm_args; + std::vector vm_args; for (auto i = 3; i < args.size(); i++) { - VMObject obj = args[i]; + Object obj = args[i]; vm_args.push_back(obj); } diff --git a/src/runtime/object.cc b/src/runtime/object.cc new file mode 100644 index 000000000000..ab7719c4a195 --- /dev/null +++ b/src/runtime/object.cc @@ -0,0 +1,61 @@ +/*! + * Copyright (c) 2017 by Contributors + * \file object.cc + * \brief TVM runtime object used by VM. + */ + +#include +#include +#include + +namespace tvm { +namespace runtime { + +std::ostream& operator<<(std::ostream& os, const ObjectTag& tag) { + switch (tag) { + case ObjectTag::kClosure: + os << "Closure"; + break; + case ObjectTag::kDatatype: + os << "Datatype"; + break; + case ObjectTag::kTensor: + os << "Tensor"; + break; + case ObjectTag::kExternalFunc: + os << "ExternalFunction"; + break; + default: + LOG(FATAL) << "Invalid object tag: found " << static_cast(tag); + } + return os; +} + +Object TensorObj(const NDArray& data) { + auto ptr = std::make_shared(data); + return std::dynamic_pointer_cast(ptr); +} + +Object DatatypeObj(size_t tag, const std::vector& fields) { + auto ptr = std::make_shared(tag, fields); + return std::dynamic_pointer_cast(ptr); +} + +Object TupleObj(const std::vector& fields) { + return DatatypeObj(0, fields); +} + +Object ClosureObj(size_t func_index, std::vector free_vars) { + auto ptr = std::make_shared(func_index, free_vars); + return std::dynamic_pointer_cast(ptr); +} + +NDArray ToNDArray(const Object& obj) { + CHECK(obj.ptr.get()); + CHECK(obj.ptr->tag == ObjectTag::kTensor) << "Expected tensor, found " << obj.ptr->tag; + std::shared_ptr o = std::dynamic_pointer_cast(obj.ptr); + return o->data; +} + +} // namespace runtime +} // namespace tvm