Skip to content

Commit b16a64d

Browse files
authored
[MetaSchedule] Refactor ScheduleRule Attributes (#13195)
1 parent ce777fd commit b16a64d

File tree

61 files changed

+1966
-1852
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

61 files changed

+1966
-1852
lines changed

include/tvm/meta_schedule/schedule/cpu/.gitignore

Whitespace-only changes.

src/meta_schedule/schedule_rule/auto_bind.h renamed to include/tvm/meta_schedule/schedule/cuda/thread_bind.h

Lines changed: 33 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -16,37 +16,53 @@
1616
* specific language governing permissions and limitations
1717
* under the License.
1818
*/
19-
#ifndef TVM_META_SCHEDULE_SCHEDULE_RULE_AUTO_BIND_H_
20-
#define TVM_META_SCHEDULE_SCHEDULE_RULE_AUTO_BIND_H_
19+
#ifndef TVM_META_SCHEDULE_SCHEDULE_CUDA_THREAD_BIND_H_
20+
#define TVM_META_SCHEDULE_SCHEDULE_CUDA_THREAD_BIND_H_
2121

22-
#include "../utils.h"
22+
#include <tvm/tir/schedule/schedule.h>
23+
24+
#include <algorithm>
25+
#include <limits>
26+
#include <utility>
2327

2428
namespace tvm {
2529
namespace meta_schedule {
2630

2731
/*!
28-
* \brief Bind the given block if it is not bound to blockIdx or threadIdx.
32+
* \brief Given candidates of thread_extents, make a sampler that use `sch->SampleCategorical`
33+
* to return a random thread extent.
34+
* \param sch The schedule
35+
* \param thread_extents The candidate thread extents.
36+
* \return A sampler that returns a random thread extent.
37+
*/
38+
std::function<tir::ExprRV(int64_t)> MakeFactorSampler(tir::Schedule sch,
39+
Array<Integer> thread_extents);
40+
41+
/*!
42+
* \brief Bind blockIdx.x and threadIdx.x to the given loop
2943
* \param sch The schedule.
30-
* \param block The block to be bound.
44+
* \param loop The loop to be bound.
3145
* \param max_threadblocks The maximum number of threadblocks allowed.
32-
* \param max_threads The maximum number of threads allowed.
46+
* \param max_threads_per_block The maximum number of threads allowed.
3347
* \param get_factor A function that returns the tiling factor.
3448
*/
35-
void BindBlockThreadIdx(const tir::Schedule& sch, const tir::BlockRV& block,
36-
int64_t max_threadblocks, int64_t max_threads_per_block,
37-
std::function<tir::ExprRV(int64_t max_extent)> get_factor);
49+
Array<tir::LoopRV> BindSpatialLoop(tir::Schedule sch, tir::LoopRV loop, //
50+
int64_t max_threadblocks, int64_t max_threads_per_block,
51+
std::function<tir::ExprRV(int64_t)> get_factor = nullptr);
3852

3953
/*!
40-
* \brief Given candidates of thread_extents, make a sampler that use `sch->SampleCategorical`
41-
* to return a random thread extent.
42-
* \param sch The schedule
43-
* \param thread_extents The candidate thread extents.
44-
* \return A sampler that returns a random thread extent.
54+
* \brief Bind the given block if it is not bound to blockIdx or threadIdx.
55+
* \param sch The schedule.
56+
* \param block The block to be bound.
57+
* \param max_threadblocks The maximum number of threadblocks allowed.
58+
* \param max_threads_per_block The maximum number of threads allowed.
59+
* \param get_factor A function that returns the tiling factor.
4560
*/
46-
std::function<tir::ExprRV(int64_t max_extent)> MakeFactorSampler(tir::Schedule sch,
47-
Array<Integer> thread_extents);
61+
void BindBlockThreadIdx(tir::Schedule sch, tir::BlockRV block, //
62+
int64_t max_threadblocks, int64_t max_threads_per_block,
63+
std::function<tir::ExprRV(int64_t max_extent)> get_factor = nullptr);
4864

4965
} // namespace meta_schedule
5066
} // namespace tvm
5167

52-
#endif // TVM_META_SCHEDULE_SCHEDULE_RULE_AUTO_BIND_H_
68+
#endif // TVM_META_SCHEDULE_SCHEDULE_CUDA_THREAD_BIND_H_
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
#ifndef TVM_META_SCHEDULE_SCHEDULE_GENERIC_WINOGRAD_H_
20+
#define TVM_META_SCHEDULE_SCHEDULE_GENERIC_WINOGRAD_H_
21+
22+
#include <tvm/tir/schedule/schedule.h>
23+
24+
namespace tvm {
25+
namespace meta_schedule {
26+
27+
/*!
28+
* \brief Get the producer block of a given block.
29+
* If there is a constant winograd transform matrix, inline it.
30+
* \return The only producer block.
31+
*/
32+
tir::BlockRV GetWinogradProducerAndInlineConst(tir::Schedule sch, tir::BlockRV block);
33+
34+
} // namespace meta_schedule
35+
} // namespace tvm
36+
37+
#endif // TVM_META_SCHEDULE_SCHEDULE_GENERIC_WINOGRAD_H_

include/tvm/meta_schedule/schedule/x86/.gitignore

Whitespace-only changes.

include/tvm/meta_schedule/schedule_rule.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,14 @@ class ScheduleRule : public runtime::ObjectRef {
9999
* \return The cloned schedule rule.
100100
*/
101101
using FClone = runtime::TypedPackedFunc<ScheduleRule()>;
102+
/*!
103+
* \brief Create a rule that applies customized rules registered using block attribute
104+
* `schedule_rule`. The rule will be dispatched according to target keys.
105+
* \return The created schedule rule.
106+
*/
107+
TVM_DLL static ScheduleRule ApplyCustomRule();
108+
/*! \brief Check if the rule is `ApplyCustomRule` */
109+
TVM_DLL static bool IsApplyCustomRule(const ScheduleRule& rule);
102110
/*!
103111
* \brief Create an auto-inline rule that inlines spatial blocks if it satisfies some conditions
104112
* \param into_producer If allows to inline a block into its producer

python/tvm/meta_schedule/__init__.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
postproc,
2727
relay_integration,
2828
runner,
29+
schedule,
2930
schedule_rule,
3031
search_strategy,
3132
space_generator,
@@ -41,10 +42,7 @@
4142
from .mutator import Mutator
4243
from .postproc import Postproc
4344
from .profiler import Profiler
44-
from .relay_integration import (
45-
is_meta_schedule_dispatch_enabled,
46-
is_meta_schedule_enabled,
47-
)
45+
from .relay_integration import is_meta_schedule_enabled
4846
from .runner import Runner
4947
from .schedule_rule import ScheduleRule
5048
from .search_strategy import MeasureCandidate, SearchStrategy

python/tvm/meta_schedule/relay_integration.py

Lines changed: 1 addition & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -377,7 +377,7 @@ def compile_relay(
377377
mod, target, params, pass_config, executor = _normalize_params(
378378
mod, target, params, pass_config, executor
379379
)
380-
pass_config.setdefault("relay.backend.use_meta_schedule_dispatch", target.kind.name != "cuda")
380+
pass_config.setdefault("relay.backend.use_meta_schedule_dispatch", True)
381381
with Profiler.timeit("PostTuningCompilation"):
382382
with target, _autotvm_silencer(), database:
383383
with transform.PassContext(
@@ -404,17 +404,3 @@ def is_meta_schedule_enabled() -> bool:
404404
"relay.backend.use_meta_schedule",
405405
False,
406406
)
407-
408-
409-
def is_meta_schedule_dispatch_enabled() -> bool:
410-
"""Return whether the meta-schedule dispatch is enabled.
411-
412-
Returns
413-
-------
414-
enabled: bool
415-
Whether the meta schedule is enabled
416-
"""
417-
return transform.PassContext.current().config.get(
418-
"relay.backend.use_meta_schedule_dispatch",
419-
False,
420-
)
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
"""Per-block schedule rules in MetaSchedule"""
18+
from . import cpu, cuda, generic, x86
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
"""Per-block schedule rules in MetaSchedule for target key 'cpu'"""
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
"""Per-block schedule rules in MetaSchedule for target key 'cuda'"""

0 commit comments

Comments
 (0)