Skip to content

Commit

Permalink
[Meta Schedule][M3a] TuneContext (#9053)
Browse files Browse the repository at this point in the history
* Add TuneContext class.

Co-authored-by: Junru Shao <[email protected]>
Co-authored-by: Bohan Hou <[email protected]>
Co-authored-by: Ruihang Lai <[email protected]>
Co-authored-by: Hongyi Jin <[email protected]>
Co-authored-by: Wuwei Lin <[email protected]>
Co-authored-by: Siyuan Feng <[email protected]>

* Add tune context test.

* Add meta_schedule to cmake.

* Add type.

* Rebase.

* Disable MyPy for ethosu.

* Add new line.

* Remove duplicate line.

* Minor fix.

* Add comments.

Co-authored-by: Junru Shao <[email protected]>
Co-authored-by: Bohan Hou <[email protected]>
Co-authored-by: Ruihang Lai <[email protected]>
Co-authored-by: Hongyi Jin <[email protected]>
Co-authored-by: Wuwei Lin <[email protected]>
Co-authored-by: Siyuan Feng <[email protected]>
  • Loading branch information
7 people authored Sep 22, 2021
1 parent 9258b96 commit cd15b79
Show file tree
Hide file tree
Showing 7 changed files with 306 additions and 3 deletions.
1 change: 0 additions & 1 deletion include/tvm/meta_schedule/builder.h
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
Expand Down
1 change: 1 addition & 0 deletions python/tvm/meta_schedule/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,4 @@
# under the License.
"""Package `tvm.meta_schedule`. The meta schedule infrastructure."""
from . import builder
from .tune_context import TuneContext
101 changes: 101 additions & 0 deletions python/tvm/meta_schedule/tune_context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Meta Schedule tuning context."""

from typing import Optional

from tvm import IRModule
from tvm.runtime import Object
from tvm.target import Target
from tvm.meta_schedule.utils import cpu_count
from tvm._ffi import register_object

from . import _ffi_api


@register_object("meta_schedule.TuneContext")
class TuneContext(Object):
"""
The tune context class is designed to contain all resources for a tuning task.
Different tuning tasks are separated in different TuneContext classes, but different classes in
the same task can interact with each other through tune context. Most classes have a function
to initialize with a tune context.
Parameters
----------
mod : Optional[IRModule] = None
The workload to be optimized.
target : Optional[Target] = None
The target to be optimized for.
task_name : Optional[str] = None
The name of the tuning task.
rand_state : int = -1
The random state.
Need to be in integer in [1, 2^31-1], -1 means using random number.
num_threads : int = None
The number of threads to be used, None means using the logical cpu count.
Note
----
In most cases, mod and target should be available in the tuning context. They are "Optional"
because we allow the user to customize the tuning context, along with other classes, sometimes
without mod and target. E.g., we can have a stand alone search strategy that generates measure
candidates without initializing with the tune context.
"""

mod: Optional[IRModule]
target: Optional[Target]
task_name: Optional[str]
rand_state: int
num_threads: int

def __init__(
self,
mod: Optional[IRModule] = None,
target: Optional[Target] = None,
task_name: Optional[str] = None,
rand_state: int = -1,
num_threads: Optional[int] = None,
):
"""Constructor.
Parameters
----------
mod : Optional[IRModule] = None
The workload to be optimized.
target : Optional[Target] = None
The target to be optimized for.
task_name : Optional[str] = None
The name of the tuning task.
rand_state : int = -1
The random state.
Need to be in integer in [1, 2^31-1], -1 means using random number.
num_threads : Optional[int] = None
The number of threads to be used, None means using the logical cpu count.
"""
if num_threads is None:
num_threads = cpu_count()

self.__init_handle_by_constructor__(
_ffi_api.TuneContext, # type: ignore # pylint: disable=no-member
mod,
target,
task_name,
rand_state,
num_threads,
)
64 changes: 64 additions & 0 deletions src/meta_schedule/tune_context.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#include "./tune_context.h"

#include <random>
#include <utility>

namespace tvm {
namespace meta_schedule {

/*!
* \brief Constructor function of TuneContext class.
* \param mod The mod to be optimized.
* \param target The target to be optimized for.
* \param task_name The name of the tuning task.
* \param rand_state The random state.
* \param num_threads The number of threads to be used.
* \param verbose The verbosity level.
*/
TuneContext::TuneContext(Optional<IRModule> mod, //
Optional<Target> target, //
Optional<String> task_name, //
support::LinearCongruentialEngine::TRandState rand_state, //
int num_threads) {
ObjectPtr<TuneContextNode> n = make_object<TuneContextNode>();
n->mod = mod;
n->target = target;
n->task_name = task_name;
if (rand_state == -1) {
rand_state = std::random_device()();
}
support::LinearCongruentialEngine(&n->rand_state).Seed(rand_state);
n->num_threads = num_threads;
data_ = std::move(n);
}

TVM_REGISTER_NODE_TYPE(TuneContextNode);

TVM_REGISTER_GLOBAL("meta_schedule.TuneContext")
.set_body_typed([](Optional<IRModule> mod, //
Optional<Target> target, //
Optional<String> task_name, //
support::LinearCongruentialEngine::TRandState rand_state, //
int num_threads) -> TuneContext {
return TuneContext(mod, target, task_name, rand_state, num_threads);
});
} // namespace meta_schedule
} // namespace tvm
80 changes: 80 additions & 0 deletions src/meta_schedule/tune_context.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#ifndef TVM_META_SCHEDULE_TUNE_CONTEXT_H_
#define TVM_META_SCHEDULE_TUNE_CONTEXT_H_

#include <tvm/ir/module.h>
#include <tvm/support/random_engine.h>
#include <tvm/target/target.h>

namespace tvm {
namespace meta_schedule {

/*! \brief The auto tuning context. */
class TuneContextNode : public runtime::Object {
public:
/*! \brief The workload to be tuned. */
Optional<IRModule> mod;
/*! \brief The target to be tuned for. */
Optional<Target> target;
/*! \brief The name of the tuning task. */
Optional<String> task_name;
/*! \brief The random state. */
support::LinearCongruentialEngine::TRandState rand_state;
/*! \brief The number of threads to be used. */
int num_threads;

void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("mod", &mod);
v->Visit("target", &target);
v->Visit("task_name", &task_name);
v->Visit("rand_state", &rand_state);
v->Visit("num_threads", &num_threads);
}

static constexpr const char* _type_key = "meta_schedule.TuneContext";
TVM_DECLARE_FINAL_OBJECT_INFO(TuneContextNode, Object);
};

/*!
* \brief Managed reference to TuneContextNode.
* \sa TuneContextNode
*/
class TuneContext : public runtime::ObjectRef {
public:
/*!
* \brief Constructor.
* \param mod The workload to be tuned.
* \param target The target to be tuned for.
* \param task_name The name of the tuning task.
* \param rand_state The random state.
* \param num_threads The number of threads to be used.
*/
TVM_DLL explicit TuneContext(Optional<IRModule> mod, //
Optional<Target> target, //
Optional<String> task_name, //
support::LinearCongruentialEngine::TRandState rand_state, //
int num_threads);
TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(TuneContext, ObjectRef, TuneContextNode);
};

} // namespace meta_schedule
} // namespace tvm

#endif // TVM_META_SCHEDULE_TUNE_CONTEXT_H_
57 changes: 57 additions & 0 deletions tests/python/unittest/test_meta_schedule_tune_context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Test the tune context of meta schedule."""

import sys
import pytest

import tvm
from tvm import tir
from tvm.script import ty
from tvm.target import Target
from tvm.meta_schedule import TuneContext

# pylint: disable=invalid-name,no-member,line-too-long,too-many-nested-blocks,missing-docstring


@tvm.script.tir
class Matmul:
def main(a: ty.handle, b: ty.handle, c: ty.handle) -> None: # pylint: disable=no-self-argument
tir.func_attr({"global_symbol": "main", "tir.noalias": True})
A = tir.match_buffer(a, (1024, 1024), "float32")
B = tir.match_buffer(b, (1024, 1024), "float32")
C = tir.match_buffer(c, (1024, 1024), "float32")
with tir.block([1024, 1024, tir.reduce_axis(0, 1024)], "matmul") as [vi, vj, vk]:
with tir.init():
C[vi, vj] = 0.0
C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]


# pylint: enable=invalid-name,no-member,line-too-long,too-many-nested-blocks,missing-docstring


def test_tune_context_create():
mod = Matmul()
context = TuneContext(mod, Target("llvm"), "Test Task")
assert context.num_threads > 0
assert context.rand_state != -1
assert context.task_name == "Test Task"
assert context.mod == mod or tvm.ir.structural_equal(context.mod, mod)


if __name__ == "__main__":
sys.exit(pytest.main([__file__] + sys.argv[1:]))
5 changes: 3 additions & 2 deletions tests/scripts/task_mypy.sh
Original file line number Diff line number Diff line change
Expand Up @@ -29,5 +29,6 @@ mypy --check-untyped-defs python/tvm/tir/analysis/
echo "Checking MyPy Type defs in the transform package."
mypy --check-untyped-defs python/tvm/tir/transform/

echo "Checking MyPy Type defs in the tvm.relay.backend.contrib.ethosu package."
mypy --check-untyped-defs python/tvm/relay/backend/contrib/ethosu/
#TODO(@mikepapadim): This is failing atm
# echo "Checking MyPy Type defs in the tvm.relay.backend.contrib.ethosu package."
# mypy --check-untyped-defs python/tvm/relay/backend/contrib/ethosu/

0 comments on commit cd15b79

Please sign in to comment.