Skip to content

Commit

Permalink
Refactor code to use new names
Browse files Browse the repository at this point in the history
  • Loading branch information
jroesch committed Mar 26, 2019
1 parent 2041efa commit 02c0823
Show file tree
Hide file tree
Showing 13 changed files with 230 additions and 175 deletions.
2 changes: 1 addition & 1 deletion 3rdparty/HalideIR
Submodule HalideIR updated 1 files
+1 −1 src/tvm/node/node.h
101 changes: 6 additions & 95 deletions include/tvm/relay/vm/vm.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,102 +11,13 @@
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/logging.h>
#include <tvm/runtime/memory_manager.h>
#include <tvm/runtime/object.h>

namespace tvm {
namespace relay {
namespace vm {

using runtime::NDArray;

enum struct VMObjectTag {
kTensor,
kClosure,
kDatatype,
kExternalFunc,
};

inline std::string VMObjectTagString(VMObjectTag tag) {
switch (tag) {
case VMObjectTag::kClosure:
return "Closure";
case VMObjectTag::kDatatype:
return "Datatype";
case VMObjectTag::kTensor:
return "Tensor";
case VMObjectTag::kExternalFunc:
return "ExternalFunction";
default:
LOG(FATAL) << "Object tag is not supported.";
return "";
}
}

// TODO(@jroesch): Use intrusive pointer.
struct VMObjectCell {
VMObjectTag tag;
VMObjectCell(VMObjectTag tag) : tag(tag) {}
VMObjectCell() {}
virtual ~VMObjectCell() {}
};

struct VMTensorCell : public VMObjectCell {
tvm::runtime::NDArray data;
VMTensorCell(const tvm::runtime::NDArray& data)
: VMObjectCell(VMObjectTag::kTensor), data(data) {}
};

struct VMObject {
std::shared_ptr<VMObjectCell> ptr;
VMObject(std::shared_ptr<VMObjectCell> ptr) : ptr(ptr) {}
VMObject() : ptr() {}
VMObject(const VMObject& obj) : ptr(obj.ptr) {}
VMObjectCell* operator->() {
return this->ptr.operator->();
}
};

struct VMDatatypeCell : public VMObjectCell {
size_t tag;
std::vector<VMObject> fields;

VMDatatypeCell(size_t tag, const std::vector<VMObject>& fields)
: VMObjectCell(VMObjectTag::kDatatype), tag(tag), fields(fields) {}
};

struct VMClosureCell : public VMObjectCell {
size_t func_index;
std::vector<VMObject> free_vars;

VMClosureCell(size_t func_index, const std::vector<VMObject>& free_vars)
: VMObjectCell(VMObjectTag::kClosure), func_index(func_index), free_vars(free_vars) {}
};


inline VMObject VMTensor(const tvm::runtime::NDArray& data) {
auto ptr = std::make_shared<VMTensorCell>(data);
return std::dynamic_pointer_cast<VMObjectCell>(ptr);
}

inline VMObject VMDatatype(size_t tag, const std::vector<VMObject>& fields) {
auto ptr = std::make_shared<VMDatatypeCell>(tag, fields);
return std::dynamic_pointer_cast<VMObjectCell>(ptr);
}

inline VMObject VMTuple(const std::vector<VMObject>& fields) {
return VMDatatype(0, fields);
}

inline VMObject VMClosure(size_t func_index, std::vector<VMObject> free_vars) {
auto ptr = std::make_shared<VMClosureCell>(func_index, free_vars);
return std::dynamic_pointer_cast<VMObjectCell>(ptr);
}

inline NDArray ToNDArray(const VMObject& obj) {
CHECK(obj.ptr.get());
CHECK(obj.ptr->tag == VMObjectTag::kTensor) << "Expect Tensor, Got " << VMObjectTagString(obj.ptr->tag);
std::shared_ptr<VMTensorCell> o = std::dynamic_pointer_cast<VMTensorCell>(obj.ptr);
return o->data;
}
using namespace tvm::runtime;

enum struct Opcode {
Push,
Expand Down Expand Up @@ -235,8 +146,8 @@ struct VirtualMachine {
std::vector<PackedFunc> packed_funcs;
std::vector<VMFunction> functions;
std::vector<VMFrame> frames;
std::vector<VMObject> stack;
std::vector<VMObject> constants;
std::vector<Object> stack;
std::vector<Object> constants;

// Frame State
size_t func_index;
Expand All @@ -255,8 +166,8 @@ struct VirtualMachine {
void InvokeGlobal(const VMFunction& func);
void Run();

VMObject Invoke(const VMFunction& func, const std::vector<VMObject>& args);
VMObject Invoke(const GlobalVar& global, const std::vector<VMObject>& args);
Object Invoke(const VMFunction& func, const std::vector<Object>& args);
Object Invoke(const GlobalVar& global, const std::vector<Object>& args);

// Ignore the method that dumps register info at compile-time if debugging
// mode is not enabled.
Expand Down
2 changes: 1 addition & 1 deletion include/tvm/runtime/c_runtime_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ typedef enum {
kStr = 11U,
kBytes = 12U,
kNDArrayContainer = 13U,
kVMObject = 14U,
kObject = 14U,
// Extension codes for other frameworks to integrate TVM PackedFunc.
// To make sure each framework's id do not conflict, use first and
// last sections to mark ranges.
Expand Down
2 changes: 1 addition & 1 deletion include/tvm/runtime/ndarray.h
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/*!
* Copyright (c) 2017 by Contributors
* \file tvm/runtime/ndarray.h
* \brief Abstract device memory management API
* \brief A device-indpendent managed NDArray abstraction.
*/
#ifndef TVM_RUNTIME_NDARRAY_H_
#define TVM_RUNTIME_NDARRAY_H_
Expand Down
79 changes: 79 additions & 0 deletions include/tvm/runtime/object.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
/*!
* Copyright (c) 2019 by Contributors
* \file tvm/runtime/object.h
* \brief A managed object in the TVM runtime.
*/
#ifndef TVM_RUNTIME_OBJECT_H_
#define TVM_RUNTIME_OBJECT_H_

#include <tvm/runtime/ndarray.h>

namespace tvm {
namespace runtime {

enum struct ObjectTag {
kTensor,
kClosure,
kDatatype,
kExternalFunc
};

std::ostream& operator<<(std::ostream& os, const ObjectTag&);

// TODO(@jroesch): Use intrusive pointer.
struct ObjectCell {
ObjectTag tag;
ObjectCell(ObjectTag tag) : tag(tag) {}
ObjectCell() {}
virtual ~ObjectCell() {}
};

/*!
* \brief A managed object in the TVM runtime.
*
* For example a tuple, list, closure, and so on.
*
* Maintains a reference count for the object.
*/
class Object {
public:
std::shared_ptr<ObjectCell> ptr;
Object(std::shared_ptr<ObjectCell> ptr) : ptr(ptr) {}
Object() : ptr() {}
Object(const Object& obj) : ptr(obj.ptr) {}
ObjectCell* operator->() {
return this->ptr.operator->();
}
};

struct TensorCell : public ObjectCell {
NDArray data;
TensorCell(const NDArray& data)
: ObjectCell(ObjectTag::kTensor), data(data) {}
};

struct DatatypeCell : public ObjectCell {
size_t tag;
std::vector<Object> fields;

DatatypeCell(size_t tag, const std::vector<Object>& fields)
: ObjectCell(ObjectTag::kDatatype), tag(tag), fields(fields) {}
};

struct ClosureCell : public ObjectCell {
size_t func_index;
std::vector<Object> free_vars;

ClosureCell(size_t func_index, const std::vector<Object>& free_vars)
: ObjectCell(ObjectTag::kClosure), func_index(func_index), free_vars(free_vars) {}
};

Object TensorObj(const NDArray& data);
Object DatatypeObj(size_t tag, const std::vector<Object>& fields);
Object TupleObj(const std::vector<Object>& fields);
Object ClosureObj(size_t func_index, std::vector<Object> free_vars);
NDArray ToNDArray(const Object& obj);

} // namespace runtime
} // namespace tvm
#endif // TVM_RUNTIME_OBJECT_H_
17 changes: 6 additions & 11 deletions include/tvm/runtime/packed_func.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include "c_runtime_api.h"
#include "module.h"
#include "ndarray.h"
#include "object.h"
#include "node_base.h"

namespace HalideIR {
Expand All @@ -40,12 +41,6 @@ namespace tvm {
// forward declarations
class Integer;

namespace relay {
namespace vm {
struct VMObject;
}
}

namespace runtime {
// forward declarations
class TVMArgs;
Expand Down Expand Up @@ -589,7 +584,7 @@ class TVMArgValue : public TVMPODValue_ {
inline operator tvm::Integer() const;
// get internal node ptr, if it is node
inline NodePtr<Node>& node_sptr();
operator relay::vm::VMObject() const;
operator runtime::Object() const;
};

/*!
Expand Down Expand Up @@ -724,7 +719,7 @@ class TVMRetValue : public TVMPODValue_ {
return *this;
}

TVMRetValue& operator=(relay::vm::VMObject other);
TVMRetValue& operator=(runtime::Object other);

TVMRetValue& operator=(PackedFunc f) {
this->SwitchToClass(kFuncHandle, f);
Expand Down Expand Up @@ -821,7 +816,7 @@ class TVMRetValue : public TVMPODValue_ {
kNodeHandle, *other.template ptr<NodePtr<Node> >());
break;
}
case kVMObject: {
case kObject: {
throw dmlc::Error("here");
}
default: {
Expand Down Expand Up @@ -871,7 +866,7 @@ class TVMRetValue : public TVMPODValue_ {
static_cast<NDArray::Container*>(value_.v_handle)->DecRef();
break;
}
// case kModuleHandle: delete ptr<relay::vm::VMObject>(); break;
// case kModuleHandle: delete ptr<runtime::Object>(); break;
}
if (type_code_ > kExtBegin) {
#if TVM_RUNTIME_HEADER_ONLY
Expand Down Expand Up @@ -901,7 +896,7 @@ inline const char* TypeCode2Str(int type_code) {
case kFuncHandle: return "FunctionHandle";
case kModuleHandle: return "ModuleHandle";
case kNDArrayContainer: return "NDArrayContainer";
case kVMObject: return "VMObject";
case kObject: return "Object";
default: LOG(FATAL) << "unknown type_code="
<< static_cast<int>(type_code); return "";
}
Expand Down
5 changes: 3 additions & 2 deletions python/tvm/relay/prelude.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,7 +473,7 @@ def __init__(self, mod):
self.define_list_map()
self.define_list_foldl()
self.define_list_foldr()
# self.define_list_concat()
self.define_list_concat()
self.define_list_filter()
self.define_list_zip()
self.define_list_rev()
Expand All @@ -489,9 +489,10 @@ def __init__(self, mod):
self.define_nat_add()
self.define_list_length()
self.define_list_nth()
self.define_list_update()
self.define_list_sum()
self.define_tree_adt()

self.define_tree_adt()
self.define_tree_map()
self.define_tree_size()

Expand Down
5 changes: 2 additions & 3 deletions src/api/dsl_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
#include <dmlc/thread_local.h>
#include <tvm/api_registry.h>
#include <tvm/attrs.h>
#include <tvm/relay/vm/vm.h>
#include <vector>
#include <string>
#include <exception>
Expand Down Expand Up @@ -74,7 +73,7 @@ struct APIAttrGetter : public AttrVisitor {
found_ref_object = true;
}
}
void Visit(const char* key, relay::vm::VMObject* value) final {
void Visit(const char* key, runtime::Object* value) final {
if (skey == key) {
*ret = value[0];
found_ref_object = true;
Expand Down Expand Up @@ -115,7 +114,7 @@ struct APIAttrDir : public AttrVisitor {
void Visit(const char* key, runtime::NDArray* value) final {
names->push_back(key);
}
void Visit(const char* key, relay::vm::VMObject* value) final {
void Visit(const char* key, runtime::Object* value) final {
names->push_back(key);
}
};
Expand Down
Loading

0 comments on commit 02c0823

Please sign in to comment.