Skip to content

Commit 33082e0

Browse files
authored
[runtime] Add Metadata classes for AOTExecutor (#10282)
* Add new Metadata classes and base implementation. * These were autogenerated in the original PR, but checking them in as plain code until we can revisit the auto-generator approach. * address masa comments * Add documentation per Manupa's comments, and move kMetadataVersion namespace. * remove get_name function, used for debugging * clang-format
1 parent 91b2e91 commit 33082e0

File tree

7 files changed

+973
-0
lines changed

7 files changed

+973
-0
lines changed

include/tvm/runtime/metadata.h

Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
/*!
21+
* \file tvm/runtime/metadata.h
22+
* \brief Defines types which can be used in Metadata.
23+
*/
24+
#ifndef TVM_RUNTIME_METADATA_H_
25+
#define TVM_RUNTIME_METADATA_H_
26+
27+
#include <inttypes.h>
28+
#ifdef __cplusplus
29+
#include <memory>
30+
#include <string>
31+
#include <vector>
32+
#endif
33+
#include <tvm/runtime/c_runtime_api.h>
34+
#ifdef __cplusplus
35+
#include <tvm/runtime/metadata_base.h>
36+
#endif
37+
#include <tvm/support/span.h>
38+
39+
// Version number recorded in emitted artifacts for runtime checking.
40+
#define TVM_METADATA_VERSION 1
41+
42+
namespace tvm {
43+
namespace runtime {
44+
namespace metadata {
45+
/*!
46+
* \brief Version of metadata emitted and understood by this compiler/runtime.
47+
* Should be populated into the `version` field of all TVMMetadata.
48+
*/
49+
static const constexpr int64_t kMetadataVersion = TVM_METADATA_VERSION;
50+
} // namespace metadata
51+
} // namespace runtime
52+
} // namespace tvm
53+
54+
#ifdef __cplusplus
55+
extern "C" {
56+
#endif
57+
58+
/*!
59+
* \brief Top-level metadata structure. Holds all other metadata types.
60+
*/
61+
struct TVMMetadata {
62+
/*! \brief Version identifier for this metadata. */
63+
int64_t version;
64+
/*! \brief Inputs to the AOT run_model function.
65+
* The order of the elements is the same as in the arguments to run_model. That is to say,
66+
* this array specifies the first `num_inputs` arguments to run_model.
67+
*/
68+
const struct TVMTensorInfo* inputs;
69+
/*! \brief Number of elements in `inputs` array. */
70+
int64_t num_inputs;
71+
/*! \brief Outputs of the AOT run_model function.
72+
* The order of the elements is the same as in the arguments to run_model. That is to say,
73+
* this array specifies the last `num_outputs` arguments to run_model.
74+
*/
75+
const struct TVMTensorInfo* outputs;
76+
/*! \brief Number of elements in `outputs` array. */
77+
int64_t num_outputs;
78+
/*! \brief Name of the model, as passed to tvm.relay.build. */
79+
const char* mod_name;
80+
};
81+
82+
/*!
83+
* \brief Describes one tensor argument to `run_model`.
84+
* NOTE: while TIR allows for other types of arguments, such as scalars, the AOT run_model
85+
* function does not currently accept these. Therefore it's not possible to express those
86+
* in this metadata. A future patch may modify this.
87+
*/
88+
struct TVMTensorInfo {
89+
/*! \brief Name of the tensor, as specified in the Relay program. */
90+
const char* name;
91+
/*! \brief Shape of the tensor. */
92+
const int64_t* shape;
93+
/*! \brief Rank of this tensor. */
94+
int64_t num_shape;
95+
/*! \brief Data type of one element of this tensor. */
96+
DLDataType dtype;
97+
};
98+
#ifdef __cplusplus
99+
} // extern "C"
100+
#include <tvm/runtime/object.h>
101+
namespace tvm {
102+
namespace runtime {
103+
namespace metadata {
104+
105+
class Metadata;
106+
class TensorInfo;
107+
108+
class MetadataNode : public MetadataBaseNode {
109+
public:
110+
explicit MetadataNode(const struct ::TVMMetadata* data) : data_{data} {}
111+
static constexpr const char* _type_key = "metadata.MetadataNode";
112+
inline int64_t version() const { return int64_t(data_->version); }
113+
inline int64_t num_inputs() const { return data_->num_inputs; }
114+
ArrayAccessor<struct TVMTensorInfo, TensorInfo> inputs();
115+
inline int64_t num_outputs() const { return data_->num_outputs; }
116+
ArrayAccessor<struct TVMTensorInfo, TensorInfo> outputs();
117+
inline ::tvm::runtime::String mod_name() const { return ::tvm::runtime::String(data_->mod_name); }
118+
const struct ::TVMMetadata* data() const { return data_; }
119+
TVM_DECLARE_FINAL_OBJECT_INFO(MetadataNode, MetadataBaseNode);
120+
121+
private:
122+
const struct ::TVMMetadata* data_;
123+
};
124+
125+
class Metadata : public MetadataBase {
126+
public:
127+
explicit Metadata(const struct ::TVMMetadata* data);
128+
TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(Metadata, MetadataBase, MetadataNode);
129+
};
130+
131+
class TensorInfoNode : public MetadataBaseNode {
132+
public:
133+
explicit TensorInfoNode(const struct ::TVMTensorInfo* data) : data_{data} {}
134+
static constexpr const char* _type_key = "metadata.TensorInfoNode";
135+
inline ::tvm::runtime::String name() const { return ::tvm::runtime::String(data_->name); }
136+
inline int64_t num_shape() const { return data_->num_shape; }
137+
inline ::tvm::support::Span<const int64_t, int64_t> shape() const {
138+
return ::tvm::support::Span<const int64_t, int64_t>(data_->shape,
139+
data_->shape + data_->num_shape);
140+
}
141+
inline ::tvm::runtime::DataType dtype() const { return ::tvm::runtime::DataType(data_->dtype); }
142+
const struct ::TVMTensorInfo* data() const { return data_; }
143+
TVM_DECLARE_FINAL_OBJECT_INFO(TensorInfoNode, MetadataBaseNode);
144+
145+
private:
146+
const struct ::TVMTensorInfo* data_;
147+
};
148+
149+
class TensorInfo : public MetadataBase {
150+
public:
151+
explicit TensorInfo(const struct ::TVMTensorInfo* data);
152+
TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(TensorInfo, MetadataBase, TensorInfoNode);
153+
};
154+
155+
} // namespace metadata
156+
} // namespace runtime
157+
} // namespace tvm
158+
#endif // defined(__cplusplus)
159+
160+
#endif // TVM_RUNTIME_METADATA_H_
Lines changed: 198 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,198 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
/*!
21+
* \file tvm/runtime/metadata_base.h
22+
* \brief Defines types which can be used in Metadata.
23+
*/
24+
#ifndef TVM_RUNTIME_METADATA_BASE_H_
25+
#define TVM_RUNTIME_METADATA_BASE_H_
26+
27+
#include <tvm/ir/expr.h>
28+
#include <tvm/runtime/object.h>
29+
30+
#include <memory>
31+
#include <string>
32+
#include <utility>
33+
#include <vector>
34+
35+
namespace tvm {
36+
namespace runtime {
37+
namespace metadata {
38+
39+
/*!
40+
* \brief Common base class for all Metadata.
41+
*
42+
* This class is used in the visitor classes as a internal check to ensure that verify that all
43+
* parts of the Metadata struct used in codegen are Metadata objects.
44+
*/
45+
class MetadataBaseNode : public ::tvm::runtime::Object {
46+
public:
47+
static constexpr const char* _type_key = "metadata.MetadataBaseNode";
48+
TVM_DECLARE_BASE_OBJECT_INFO(MetadataBaseNode, ::tvm::runtime::Object);
49+
};
50+
51+
/*! \brief Reference class for the common MetadataBaseNode class. */
52+
class MetadataBase : public ::tvm::runtime::ObjectRef {
53+
public:
54+
TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(MetadataBase, ::tvm::runtime::ObjectRef, MetadataBaseNode);
55+
};
56+
57+
template <typename C, class Ref>
58+
class ArrayAccessor;
59+
60+
/*! \brief An iterator implementation that lazily instantiates the C++ wrapping Metadata class. */
61+
template <typename C, class Ref>
62+
class ArrayIterator {
63+
public:
64+
ArrayIterator(size_t index, const ArrayAccessor<C, Ref>* parent)
65+
: index_{index}, parent_{parent} {}
66+
67+
inline Ref operator*() { return (*parent_)[index_]; }
68+
69+
inline ArrayIterator<C, Ref>& operator++() {
70+
if (index_ < parent_->size()) {
71+
index_++;
72+
}
73+
74+
return *this;
75+
}
76+
77+
inline bool operator==(const ArrayIterator<C, Ref>& other) const {
78+
return parent_ == other.parent_ && index_ == other.index_;
79+
}
80+
81+
inline bool operator!=(const ArrayIterator<C, Ref>& other) const { return !operator==(other); }
82+
83+
private:
84+
size_t index_;
85+
const ArrayAccessor<C, Ref>* parent_;
86+
};
87+
88+
/*! \brief A span-like class which permits access to Array fields with complex elements.
89+
* These array fields should be accessed from C++ using the Metadata wrapper classes. This class
90+
* lazily instantiates those wrappers as they are accessed.
91+
*/
92+
template <typename C, class Ref>
93+
class ArrayAccessor {
94+
public:
95+
using value_type = Ref;
96+
using iterator = ArrayIterator<C, Ref>;
97+
using const_iterator = iterator;
98+
99+
template <typename T = typename std::enable_if<std::is_base_of<ObjectRef, Ref>::value>::type>
100+
ArrayAccessor(const C* data, size_t num_data) : data_{data}, num_data_{num_data} {}
101+
102+
inline size_t size() const { return num_data_; }
103+
104+
inline Ref operator[](size_t index) const {
105+
if (index >= num_data_) {
106+
throw std::runtime_error("Index out of range");
107+
}
108+
109+
return Ref(&data_[index]);
110+
}
111+
112+
inline ArrayIterator<C, Ref> begin() const { return ArrayIterator<C, Ref>{0, this}; }
113+
114+
inline ArrayIterator<C, Ref> end() const { return ArrayIterator<C, Ref>{num_data_, this}; }
115+
116+
private:
117+
const C* data_;
118+
size_t num_data_;
119+
};
120+
121+
/*! \brief A specialization of ArrayAccessor for String.
122+
* This class is needed because the String constructor signature is different from the typical
123+
* Metadata subclass.
124+
*/
125+
template <>
126+
class ArrayAccessor<const char*, ::tvm::runtime::String> {
127+
public:
128+
using value_type = ::tvm::runtime::String;
129+
using iterator = ArrayIterator<const char*, ::tvm::runtime::String>;
130+
using const_iterator = iterator;
131+
132+
ArrayAccessor(const char** data, size_t num_data) : data_{data}, num_data_{num_data} {}
133+
134+
inline size_t size() const { return num_data_; }
135+
136+
inline ::tvm::runtime::String operator[](size_t index) const {
137+
if (index >= num_data_) {
138+
throw std::runtime_error("Index out of range");
139+
}
140+
return ::tvm::runtime::String(data_[index]);
141+
}
142+
143+
inline ArrayIterator<const char*, ::tvm::runtime::String> begin() const {
144+
return ArrayIterator<const char*, ::tvm::runtime::String>{0, this};
145+
}
146+
147+
inline ArrayIterator<const char*, ::tvm::runtime::String> end() const {
148+
return ArrayIterator<const char*, ::tvm::runtime::String>{num_data_, this};
149+
}
150+
151+
private:
152+
const char** data_;
153+
size_t num_data_;
154+
};
155+
156+
/*! \brief Enumerates the primitive types which can be part of a Metadata instance.
157+
*
158+
* These are separate from TIR DataType because TIR does not model structs.
159+
*/
160+
enum MetadataTypeIndex : uint8_t {
161+
kUint64 = 0,
162+
kInt64 = 1,
163+
kBool = 2,
164+
kString = 3,
165+
kHandle = 4,
166+
kMetadata = 5,
167+
};
168+
169+
/*! \brief Container for arrays in the metadata.
170+
*
171+
* Type information is needed when emitting arrays. This container augments the data field with
172+
* the necessary typing information.
173+
*/
174+
class MetadataArrayNode : public MetadataBaseNode {
175+
public:
176+
MetadataArrayNode(Array<ObjectRef> array, MetadataTypeIndex type_index, const char* struct_name)
177+
: array(::std::move(array)), type_index{type_index}, struct_name{struct_name} {}
178+
179+
Array<ObjectRef> array;
180+
MetadataTypeIndex type_index;
181+
const char* struct_name;
182+
static constexpr const char* _type_key = "metadata.MetadataArrayNode";
183+
TVM_DECLARE_BASE_OBJECT_INFO(MetadataArrayNode, MetadataBaseNode);
184+
};
185+
186+
/*! \brief Reference class for MetadataArray. */
187+
class MetadataArray : public MetadataBase {
188+
public:
189+
MetadataArray(Array<ObjectRef> array, MetadataTypeIndex type_index, const char* struct_name);
190+
191+
TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(MetadataArray, MetadataBase, MetadataArrayNode);
192+
};
193+
194+
} // namespace metadata
195+
} // namespace runtime
196+
} // namespace tvm
197+
198+
#endif // TVM_RUNTIME_METADATA_BASE_H_

0 commit comments

Comments
 (0)