Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions include/tvm/meta_schedule/schedule_rule.h
Original file line number Diff line number Diff line change
Expand Up @@ -137,19 +137,17 @@ class ScheduleRule : public runtime::ObjectRef {
* \param tile_binds For each level of tiles, which thread axis it is bound to. Recommended:
* - NullOpt on CPU
* - [blockIdx.x, vthread.x, threadIdx.x] on GPU
* \param use_tensor_core Whether to apply tensor core wmma intrinsic for the computation
* \param max_innermost_factor The maximum size of the innermost factor. NullOpt means no limit
* \param vector_load_max_len The length of vector lane in vectorized cooperative fetching.
* \param vector_load_lens The length of vector lane in vectorized cooperative fetching.
* NullOpt means disable vectorization
* \param reuse_read Data reuse configuration for reading. NullOpt means no reuse.
* \param reuse_write Data reuse configuration for writing. NullOpt means no reuse.
* \return The schedule rule created
*/
TVM_DLL static ScheduleRule MultiLevelTiling(String structure, //
Optional<Array<String>> tile_binds, //
bool use_tensor_core, //
Optional<Integer> max_innermost_factor, //
Optional<Integer> vector_load_max_len, //
Optional<Array<Integer>> vector_load_lens, //
Optional<Map<String, ObjectRef>> reuse_read, //
Optional<Map<String, ObjectRef>> reuse_write);
/*!
Expand Down
14 changes: 14 additions & 0 deletions include/tvm/tir/stmt.h
Original file line number Diff line number Diff line change
Expand Up @@ -1364,6 +1364,20 @@ constexpr const char* pragma_loop_partition_hint = "pragma_loop_partition_hint";
/*! \brief Mark the tiling structure of blocks that are applied by rule Multi-Level-Tiling */
constexpr const char* meta_schedule_tiling_structure = "meta_schedule.tiling_structure";

/*!
* \brief Mark that the loop should be further skip and bound to environment threads to enable
* cooperative fetching.
*/
constexpr const char* meta_schedule_cooperative_fetch = "meta_schedule.cooperative_fetch";

/*! \brief The allowed range of thread extent in thread bindings */
constexpr const char* meta_schedule_thread_extent_low_inclusive =
"meta_schedule.thread_extent_low_inclusive";

/*! \brief The allowed range of thread extent in thread bindings */
constexpr const char* meta_schedule_thread_extent_high_inclusive =
"meta_schedule.thread_extent_high_inclusive";

/*! \brief Mark the block whose producer needs to be applied by rule Random-Compute-Location */
constexpr const char* meta_schedule_random_compute_producer =
"meta_schedule.random_compute_producer";
Expand Down
1 change: 1 addition & 0 deletions python/tvm/meta_schedule/schedule_rule/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from .add_rfactor import AddRFactor
from .auto_inline import AutoInline
from .cross_thread_reduction import CrossThreadReduction
from .multi_level_tiling import MultiLevelTiling, ReuseType
from .parallel_vectorize_unroll import ParallelizeVectorizeUnroll
from .random_compute_location import RandomComputeLocation
from .schedule_rule import PyScheduleRule, ScheduleRule
84 changes: 84 additions & 0 deletions python/tvm/meta_schedule/schedule_rule/multi_level_tiling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Multi-level tiling with reuse."""
from typing import Any, Dict, List, NamedTuple, Optional

from tvm._ffi import register_object

from .. import _ffi_api
from .schedule_rule import ScheduleRule


class ReuseType(NamedTuple):
"""Reuse type."""

req: str
levels: List[int]
scope: str

def as_dict(self) -> Dict[str, Any]:
"""Return the dict representation of the reuse type."""
return {
"req": self.req,
"levels": self.levels,
"scope": self.scope,
}


@register_object("meta_schedule.MultiLevelTiling")
class MultiLevelTiling(ScheduleRule):
"""Multi-level tiling with reuse.

Parameters
----------
structure : str
The tiling structure. Recommended:
- 'SSRSRS' on CPU
- 'SSSRRSRS' on GPU
tile_bind : Optional[List[str]]
For each level of tiles, which thread axis it is bound to. Recommended:
- None on CPU
- [blockIdx.x, vthread.x, threadIdx.x] on GPU
max_innermost_factor : Optional[int]
The maximum size of the innermost factor. None means no limit
vector_load_lens : Optional[List[int]]
The length of vector lane in vectorized cooperative fetching.
None means disable vectorization
reuse_read : Optional[ReuseType]
Data reuse configuration for reading. None means no reuse.
reuse_write : Optional[ReuseType]
Data reuse configuration for writing. None means no reuse.
"""

def __init__(
self,
structure: str,
tile_binds: Optional[List[str]] = None,
max_innermost_factor: Optional[int] = None,
vector_load_lens: Optional[List[int]] = None,
reuse_read: Optional[ReuseType] = None,
reuse_write: Optional[ReuseType] = None,
) -> None:
self.__init_handle_by_constructor__(
_ffi_api.ScheduleRuleMultiLevelTiling, # type: ignore # pylint: disable=no-member
structure,
tile_binds,
max_innermost_factor,
vector_load_lens,
reuse_read.as_dict() if reuse_read is not None else None,
reuse_write.as_dict() if reuse_write is not None else None,
)
37 changes: 37 additions & 0 deletions python/tvm/meta_schedule/testing/schedule_rule.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,10 @@
AddRFactor,
AutoInline,
CrossThreadReduction,
MultiLevelTiling,
ParallelizeVectorizeUnroll,
RandomComputeLocation,
ReuseType,
ScheduleRule,
)
from tvm.target import Target
Expand Down Expand Up @@ -65,6 +67,41 @@ def cross_thread_reduction(target: Target) -> ScheduleRule:
raise NotImplementedError(f"{target.kind.name} is not supported")


def multi_level_tiling(target: Target) -> ScheduleRule:
"""Default schedule rules for with multi-level tiling and reuse"""
if target.kind.name == "llvm":
return MultiLevelTiling(
structure="SSRSRS",
tile_binds=None,
max_innermost_factor=64,
vector_load_lens=None,
reuse_read=None,
reuse_write=ReuseType(
req="may",
levels=[1, 2],
scope="global",
),
)
if target.kind.name == "cuda":
return MultiLevelTiling(
structure="SSSRRSRS",
tile_binds=["blockIdx.x", "vthread.x", "threadIdx.x"],
max_innermost_factor=64,
vector_load_lens=[1, 2, 3, 4],
reuse_read=ReuseType(
req="must",
levels=[4],
scope="shared",
),
reuse_write=ReuseType(
req="must",
levels=[3],
scope="local",
),
)
raise NotImplementedError(f"{target.kind.name} is not supported")


def random_compute_location(target: Target) -> ScheduleRule:
"""Default schedule rules for with random-compute-location"""
if target.kind.name == "llvm":
Expand Down
Loading