@@ -28,20 +28,11 @@ using Halide::IR::FunctionRef;
2828 * \brief Tensor structure representing a possible input,
2929 * or intermediate computation result.
3030 */
31- class Tensor : public FunctionRef {
31+ class Tensor : public NodeRef {
3232 public:
3333 /* ! \brief default constructor, used internally */
3434 Tensor () {}
35- explicit Tensor (std::shared_ptr<Node> n) : FunctionRef(n) {}
36- /* !
37- * \brief constructor of input tensor
38- * \param shape Shape of the tensor.
39- * \param name optional name of the Tensor.
40- * \param dtype The data type of the input tensor.
41- */
42- explicit Tensor (Array<Expr> shape,
43- std::string name = " tensor" ,
44- Type dtype = Float(32 ));
35+ explicit Tensor (std::shared_ptr<Node> n) : NodeRef(n) {}
4536 /* !
4637 * \brief access the internal node container
4738 * \return the pointer to the internal node container
@@ -116,11 +107,11 @@ class Tensor : public FunctionRef {
116107};
117108
118109/* ! \brief Operation that produces tensors */
119- class Operation : public NodeRef {
110+ class Operation : public FunctionRef {
120111 public:
121112 /* ! \brief default constructor */
122113 Operation () {}
123- explicit Operation (std::shared_ptr<Node> n) : NodeRef (n) {}
114+ explicit Operation (std::shared_ptr<Node> n) : FunctionRef (n) {}
124115 /* !
125116 * \brief access the internal node container
126117 * \return the pointer to the internal node container
@@ -137,12 +128,10 @@ class Operation : public NodeRef {
137128};
138129
139130/* ! \brief Node to represent a tensor */
140- class TensorNode : public FunctionBaseNode {
131+ class TensorNode : public Node {
141132 public:
142133 /* ! \brief The shape of the tensor */
143134 Array<Expr> shape;
144- /* ! \brief optional name of the tensor */
145- std::string name;
146135 /* ! \brief data type in the content of the tensor */
147136 Type dtype;
148137 /* ! \brief the source operation, can be None */
@@ -154,19 +143,11 @@ class TensorNode : public FunctionBaseNode {
154143
155144 void VisitAttrs (AttrVisitor* v) final {
156145 v->Visit (" shape" , &shape);
157- v->Visit (" name" , &name);
158146 v->Visit (" dtype" , &dtype);
159147 v->Visit (" op" , &op);
160148 v->Visit (" value_index" , &value_index);
161149 }
162- const std::string& func_name () const final {
163- return name;
164- }
165- int outputs () const final {
166- return 1 ;
167- }
168150 static Tensor make (Array<Expr> shape,
169- std::string name,
170151 Type dtype,
171152 Operation op,
172153 int value_index);
@@ -178,16 +159,18 @@ class TensorNode : public FunctionBaseNode {
178159/* !
179160 * \brief base class of operation node.
180161 */
181- class OperationNode : public Node {
162+ class OperationNode : public FunctionBaseNode {
182163 public:
183164 /* ! \brief optional name of the operation */
184165 std::string name;
166+ /* ! \return name of the operation */
167+ const std::string& func_name () const final {
168+ return name;
169+ }
170+ /* ! \return number of outputs of this op */
171+ virtual int num_outputs () const = 0;
185172 /* ! \return the list of iteration variable at root */
186173 virtual Array<IterVar> root_iter_vars () const = 0;
187- /* ! \return number of outputs of this op */
188- virtual size_t num_outputs () const = 0;
189- /* ! \return name of i-th output */
190- virtual std::string output_name (size_t i) const = 0;
191174 /* ! \return type of i-th output */
192175 virtual Type output_dtype (size_t i) const = 0;
193176 /* ! \return shape of i-th output */
0 commit comments