Skip to content

Commit a2b4a94

Browse files
zxybazhSiyuan FengspectrometerHBHjinhongyiiMasterJH5574
committed
Squashed commit
[Meta Schedule][M3c] Schedule Rules, Mutator & Postprocs (#485) [Meta Schedule][M3c] PostOrderApply (#486) Fix Post Order Apply (#490) [MetaSchedule] Relay Integration (#489) [M3c][Meta Schedule] Add Trace Correctness Test for PostOrderApply (#492) Fix replay trace. (#493) [M3c][Meta Schedule] Implement the Replay Func class. (#495) [PR] Test script for meta-schedule task extraction. Interface to load… (#494) [Meta Schedule Refactor] Get child blocks (#500) Read-at && Write-at (#497) [M3c][Meta Schedule] Measure Callbacks (#498) [Bug] Fix Infinite Loop Caused When Calling Methods Not Overrided In PyClass (#496) [MetaSchedule] Sample-Perfect-Tile (#501) Co-authored-by: Siyuan Feng <[email protected]> Co-authored-by: Bohan Hou <[email protected]> Co-authored-by: Hongyi Jin <[email protected]> Co-authored-by: Ruihang Lai <[email protected]> Co-authored-by: Junru Shao <[email protected]> Co-authored-by: Wuwei Lin <[email protected]> Co-authored-by: Sunghyun Park <[email protected]>
1 parent b8fb438 commit a2b4a94

File tree

84 files changed

+5343
-246
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

84 files changed

+5343
-246
lines changed

include/tvm/meta_schedule/builder.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,7 @@ class PyBuilderNode : public BuilderNode {
137137
}
138138

139139
Array<BuilderResult> Build(const Array<BuilderInput>& build_inputs) final {
140+
ICHECK(f_build != nullptr) << "PyBuilder's Build method not implemented!";
140141
return f_build(build_inputs);
141142
}
142143

include/tvm/meta_schedule/database.h

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -230,18 +230,29 @@ class PyDatabaseNode : public DatabaseNode {
230230
// `f_size` is not visited
231231
}
232232

233-
static constexpr const char* _type_key = "meta_schedule.PyDatabase";
234-
TVM_DECLARE_FINAL_OBJECT_INFO(PyDatabaseNode, DatabaseNode);
235-
236-
Workload CommitWorkload(const IRModule& mod) final { return f_commit_workload(mod); }
233+
Workload CommitWorkload(const IRModule& mod) final {
234+
ICHECK(f_commit_workload != nullptr) << "PyDatabase's CommitWorkload method not implemented!";
235+
return f_commit_workload(mod);
236+
}
237237

238-
void CommitTuningRecord(const TuningRecord& record) final { f_commit_tuning_record(record); }
238+
void CommitTuningRecord(const TuningRecord& record) final {
239+
ICHECK(f_commit_tuning_record != nullptr)
240+
<< "PyDatabase's CommitTuningRecord method not implemented!";
241+
f_commit_tuning_record(record);
242+
}
239243

240244
Array<TuningRecord> GetTopK(const Workload& workload, int top_k) final {
245+
ICHECK(f_get_top_k != nullptr) << "PyDatabase's GetTopK method not implemented!";
241246
return f_get_top_k(workload, top_k);
242247
}
243248

244-
int64_t Size() final { return f_size(); }
249+
int64_t Size() final {
250+
ICHECK(f_size != nullptr) << "PyDatabase's Size method not implemented!";
251+
return f_size();
252+
}
253+
254+
static constexpr const char* _type_key = "meta_schedule.PyDatabase";
255+
TVM_DECLARE_FINAL_OBJECT_INFO(PyDatabaseNode, DatabaseNode);
245256
};
246257

247258
/*!
Lines changed: 214 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,214 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
#ifndef TVM_META_SCHEDULE_INTEGRATION_H_
20+
#define TVM_META_SCHEDULE_INTEGRATION_H_
21+
22+
#include <tvm/meta_schedule/database.h>
23+
#include <tvm/support/with.h>
24+
25+
#include <unordered_set>
26+
27+
namespace tvm {
28+
namespace meta_schedule {
29+
30+
/**************** ExtractedTask ****************/
31+
32+
/*!
33+
* \brief A tuning task extracted from the high-level IR
34+
*/
35+
class ExtractedTaskNode : public runtime::Object {
36+
public:
37+
/*! \brief The name of the task extracted */
38+
String task_name;
39+
/*! \brief The high-level IR */
40+
IRModule mod;
41+
/*! \brief A list of low-level IRs that the high-level IR could potentially dispatch to */
42+
Array<IRModule> dispatched;
43+
44+
void VisitAttrs(AttrVisitor* v) {
45+
v->Visit("task_name", &task_name);
46+
v->Visit("mod", &mod);
47+
v->Visit("dispatched", &dispatched);
48+
}
49+
50+
static constexpr const char* _type_key = "meta_schedule.ExtractedTask";
51+
TVM_DECLARE_FINAL_OBJECT_INFO(ExtractedTaskNode, runtime::Object);
52+
};
53+
54+
/*!
55+
* \brief Managed reference to ExtractedTaskNode
56+
* \sa ExtractedTaskNode
57+
*/
58+
class ExtractedTask : public runtime::ObjectRef {
59+
public:
60+
/*!
61+
* \brief Constructor. The name of the task extracted
62+
* \brief The high-level IR
63+
* \brief A list of low-level IRs that the high-level IR could potentially dispatch to
64+
*/
65+
explicit ExtractedTask(String task_name, IRModule mod, Array<IRModule> dispatched);
66+
TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(ExtractedTask, runtime::ObjectRef, ExtractedTaskNode);
67+
};
68+
69+
/**************** IntegrationContext ****************/
70+
71+
/*!
72+
* \brief A context manager interface for the integration
73+
*/
74+
class IntegrationContextNode : public runtime::Object {
75+
public:
76+
/*! \brief Default destructor */
77+
virtual ~IntegrationContextNode() = default;
78+
/*!
79+
* \brief The entry point of the integration
80+
* \param task_name The name of the task
81+
* \param mod The high-level IR
82+
* \param dispatched A list of low-level IRs that the high-level IR could potentially dispatch to.
83+
* NullOpt means the dispatch needs to be done in the context.
84+
* \return There are different types of the output
85+
* 1) NullOpt if there is no feedback hint
86+
* 2) tir::PrimFunc if `mod` should be lowered to a PrimFunc
87+
* 3) relay::Function if `mod` should be dispatched to BYOC workflow
88+
* 4) IRModule for unified dispatch
89+
*/
90+
virtual Optional<ObjectRef> Query(runtime::String task_name, IRModule mod,
91+
Optional<Array<IRModule>> dispatched) = 0;
92+
93+
static constexpr const char* _type_key = "meta_schedule.IntegrationContext";
94+
TVM_DECLARE_BASE_OBJECT_INFO(IntegrationContextNode, runtime::Object);
95+
};
96+
97+
/*!
98+
* \brief Managed reference to IntegrationContextNode
99+
* \sa IntegrationContextNode
100+
*/
101+
class IntegrationContext : public runtime::ObjectRef {
102+
friend class IntegrationContextInternal;
103+
friend class With<IntegrationContext>;
104+
105+
public:
106+
/*! \brief Default destructor */
107+
virtual ~IntegrationContext() = default;
108+
/*!
109+
* \brief The context manager in the current scope
110+
* \return The IntegrationContext in the current scope. NullOpt if it's currently not under any
111+
* IntegrationContext.
112+
*/
113+
static Optional<IntegrationContext> Current();
114+
/*!
115+
* \brief The entry point of the integration workflow. The compilation process of the high-level
116+
* IR should call this method for task extraction and for feedback hints
117+
* \param task_name The name of the task
118+
* \param mod The high-level IR
119+
* \param dispatched A list of low-level IRs that the high-level IR could potentially dispatch to
120+
* \return There are different types of the output
121+
* 1) NullOpt if there is no feedback hint
122+
* 2) tir::PrimFunc if `mod` should be lowered to a PrimFunc
123+
* 3) relay::Function if `mod` should be dispatched to BYOC workflow
124+
* 4) IRModule for unified dispatch
125+
*/
126+
static Optional<ObjectRef> EntryPoint(runtime::String task_name, IRModule mod,
127+
Optional<Array<IRModule>> dispatched);
128+
129+
TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(IntegrationContext, runtime::ObjectRef,
130+
IntegrationContextNode);
131+
132+
protected:
133+
/*! \brief Default constructor */
134+
IntegrationContext() = default;
135+
/*! \brief Entering the scope of the context manager */
136+
void EnterWithScope();
137+
/*! \brief Exiting the scope of the context manager */
138+
void ExitWithScope();
139+
};
140+
141+
/**************** TaskExtraction ****************/
142+
143+
/*!
144+
* \brief An integration context for task extraction
145+
*/
146+
class TaskExtractionNode : public IntegrationContextNode {
147+
public:
148+
/*! \brief The extracted tasks */
149+
Array<ExtractedTask> tasks{nullptr};
150+
151+
void VisitAttrs(AttrVisitor* v) { v->Visit("tasks", &tasks); }
152+
153+
// Inherited from base class
154+
Optional<ObjectRef> Query(runtime::String task_name, IRModule mod,
155+
Optional<Array<IRModule>> dispatched) final;
156+
157+
static constexpr const char* _type_key = "meta_schedule.TaskExtraction";
158+
TVM_DECLARE_FINAL_OBJECT_INFO(TaskExtractionNode, IntegrationContextNode);
159+
};
160+
161+
/*!
162+
* \brief Managed reference to TaskExtractionNode
163+
* \sa TaskExtractionNode
164+
*/
165+
class TaskExtraction : public IntegrationContext {
166+
public:
167+
/*! \brief The path to a cache file storing extracted tasks */
168+
TaskExtraction();
169+
TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(TaskExtraction, IntegrationContext,
170+
TaskExtractionNode);
171+
};
172+
173+
/**************** ApplyHistoryBest ****************/
174+
175+
/*!
176+
* \brief An integration context that allows application of historically best records from a
177+
* database
178+
*/
179+
class ApplyHistoryBestNode : public IntegrationContextNode {
180+
public:
181+
/*! \brief The database to be queried from */
182+
Database database{nullptr};
183+
184+
void VisitAttrs(AttrVisitor* v) {
185+
v->Visit("database", &database); //
186+
}
187+
188+
// Inherited from base class
189+
Optional<ObjectRef> Query(runtime::String task_name, IRModule mod,
190+
Optional<Array<IRModule>> dispatched) final;
191+
192+
static constexpr const char* _type_key = "meta_schedule.ApplyHistoryBest";
193+
TVM_DECLARE_FINAL_OBJECT_INFO(ApplyHistoryBestNode, IntegrationContextNode);
194+
};
195+
196+
/*!
197+
* \brief Managed reference to ApplyHistoryBestNode
198+
* \sa ApplyHistoryBestNode
199+
*/
200+
class ApplyHistoryBest : public IntegrationContext {
201+
public:
202+
/*!
203+
* \brief Constructor
204+
* \param database The database to be queried from
205+
*/
206+
explicit ApplyHistoryBest(Database database);
207+
TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(ApplyHistoryBest, IntegrationContext,
208+
ApplyHistoryBestNode);
209+
};
210+
211+
} // namespace meta_schedule
212+
} // namespace tvm
213+
214+
#endif // TVM_META_SCHEDULE_INTEGRATION_H_
Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
#ifndef TVM_META_SCHEDULE_MEASURE_CALLBACK_H_
21+
#define TVM_META_SCHEDULE_MEASURE_CALLBACK_H_
22+
23+
#include <tvm/meta_schedule/builder.h>
24+
#include <tvm/meta_schedule/runner.h>
25+
#include <tvm/meta_schedule/search_strategy.h>
26+
#include <tvm/meta_schedule/tune_context.h>
27+
28+
namespace tvm {
29+
namespace meta_schedule {
30+
31+
class TaskScheduler;
32+
33+
/*! \brief Rules to apply after measure results is available. */
34+
class MeasureCallbackNode : public runtime::Object {
35+
public:
36+
/*! \brief Virtual destructor. */
37+
virtual ~MeasureCallbackNode() = default;
38+
39+
void VisitAttrs(tvm::AttrVisitor* v) {}
40+
41+
/*!
42+
* \brief Apply a measure callback rule with given arguments.
43+
* \param task_scheduler The task scheduler.
44+
* \param tasks The list of tune context to process.
45+
* \param measure_candidates The measure candidates.
46+
* \param builds The builder results by building the measure candidates.
47+
* \param results The runner results by running the built measure candidates.
48+
* \return Whether the measure callback was successfully applied.
49+
*/
50+
virtual bool Apply(const TaskScheduler& task_scheduler, //
51+
const Array<TuneContext> tasks, //
52+
const Array<MeasureCandidate>& measure_candidates, //
53+
const Array<BuilderResult>& builds, //
54+
const Array<RunnerResult>& results) = 0;
55+
56+
static constexpr const char* _type_key = "meta_schedule.MeasureCallback";
57+
TVM_DECLARE_BASE_OBJECT_INFO(MeasureCallbackNode, Object);
58+
};
59+
60+
/*! \brief The measure callback with customized methods on the python-side. */
61+
class PyMeasureCallbackNode : public MeasureCallbackNode {
62+
public:
63+
/*!
64+
* \brief Apply a measure callback to the given schedule.
65+
* \param task_scheduler The task scheduler.
66+
* \param tasks The list of tune context to process.
67+
* \param measure_candidates The measure candidates.
68+
* \param builds The builder results by building the measure candidates.
69+
* \param results The runner results by running the built measure candidates.
70+
* \return Whether the measure callback was successfully applied.
71+
*/
72+
using FApply =
73+
runtime::TypedPackedFunc<bool(const TaskScheduler& task_scheduler, //
74+
const Array<TuneContext> tasks, //
75+
const Array<MeasureCandidate>& measure_candidates, //
76+
const Array<BuilderResult>& builds, //
77+
const Array<RunnerResult>& results)>;
78+
/*!
79+
* \brief Get the measure callback function as string with name.
80+
* \return The string of the measure callback function.
81+
*/
82+
using FAsString = runtime::TypedPackedFunc<String()>;
83+
84+
/*! \brief The packed function to the `Apply` funcion. */
85+
FApply f_apply;
86+
/*! \brief The packed function to the `AsString` funcion. */
87+
FAsString f_as_string;
88+
89+
void VisitAttrs(tvm::AttrVisitor* v) {
90+
// `f_apply` is not visited
91+
// `f_as_string` is not visited
92+
}
93+
94+
bool Apply(const TaskScheduler& task_scheduler, //
95+
const Array<TuneContext> tasks, //
96+
const Array<MeasureCandidate>& measure_candidates, //
97+
const Array<BuilderResult>& builds, //
98+
const Array<RunnerResult>& results) final {
99+
ICHECK(f_apply != nullptr) << "PyMeasureCallback's Apply method not implemented!";
100+
return this->f_apply(task_scheduler, tasks, measure_candidates, builds, results);
101+
}
102+
103+
static constexpr const char* _type_key = "meta_schedule.PyMeasureCallback";
104+
TVM_DECLARE_FINAL_OBJECT_INFO(PyMeasureCallbackNode, MeasureCallbackNode);
105+
};
106+
107+
/*!
108+
* \brief Managed reference to MeasureCallbackNode
109+
* \sa MeasureCallbackNode
110+
*/
111+
class MeasureCallback : public runtime::ObjectRef {
112+
public:
113+
/*!
114+
* \brief Create a measure callback with customized methods on the python-side.
115+
* \param f_apply The packed function of `Apply`.
116+
* \return The measure callback created.
117+
*/
118+
TVM_DLL static MeasureCallback PyMeasureCallback(PyMeasureCallbackNode::FApply f_apply, //
119+
PyMeasureCallbackNode::FAsString f_as_string);
120+
TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(MeasureCallback, ObjectRef, MeasureCallbackNode);
121+
};
122+
123+
} // namespace meta_schedule
124+
} // namespace tvm
125+
126+
#endif // TVM_META_SCHEDULE_MEASURE_CALLBACK_H_

0 commit comments

Comments
 (0)