Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MetaSchedule][M3a] TuneContext #9053

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion include/tvm/meta_schedule/builder.h
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
Expand Down
1 change: 1 addition & 0 deletions python/tvm/meta_schedule/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,4 @@
# under the License.
"""Package `tvm.meta_schedule`. The meta schedule infrastructure."""
from . import builder
from .tune_context import TuneContext
101 changes: 101 additions & 0 deletions python/tvm/meta_schedule/tune_context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Meta Schedule tuning context."""

from typing import Optional

from tvm import IRModule
from tvm.runtime import Object
from tvm.target import Target
from tvm.meta_schedule.utils import cpu_count
from tvm._ffi import register_object

from . import _ffi_api


@register_object("meta_schedule.TuneContext")
class TuneContext(Object):
"""
The tune context class is designed to contain all resources for a tuning task.

Different tuning tasks are separated in different TuneContext classes, but different classes in
the same task can interact with each other through tune context. Most classes have a function
to initialize with a tune context.

Parameters
----------
mod : Optional[IRModule] = None
The workload to be optimized.
target : Optional[Target] = None
The target to be optimized for.
task_name : Optional[str] = None
The name of the tuning task.
rand_state : int = -1
The random state.
Need to be in integer in [1, 2^31-1], -1 means using random number.
num_threads : int = None
The number of threads to be used, None means using the logical cpu count.

Note
----
In most cases, mod and target should be available in the tuning context. They are "Optional"
because we allow the user to customize the tuning context, along with other classes, sometimes
without mod and target. E.g., we can have a stand alone search strategy that generates measure
candidates without initializing with the tune context.
"""

mod: Optional[IRModule]
zxybazh marked this conversation as resolved.
Show resolved Hide resolved
target: Optional[Target]
task_name: Optional[str]
rand_state: int
num_threads: int

def __init__(
self,
mod: Optional[IRModule] = None,
target: Optional[Target] = None,
task_name: Optional[str] = None,
rand_state: int = -1,
num_threads: Optional[int] = None,
):
"""Constructor.

Parameters
----------
mod : Optional[IRModule] = None
The workload to be optimized.
target : Optional[Target] = None
The target to be optimized for.
task_name : Optional[str] = None
The name of the tuning task.
rand_state : int = -1
The random state.
Need to be in integer in [1, 2^31-1], -1 means using random number.
num_threads : Optional[int] = None
The number of threads to be used, None means using the logical cpu count.
"""
if num_threads is None:
num_threads = cpu_count()

self.__init_handle_by_constructor__(
_ffi_api.TuneContext, # type: ignore # pylint: disable=no-member
mod,
target,
task_name,
rand_state,
num_threads,
)
64 changes: 64 additions & 0 deletions src/meta_schedule/tune_context.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#include "./tune_context.h"

#include <random>
#include <utility>

namespace tvm {
namespace meta_schedule {

/*!
* \brief Constructor function of TuneContext class.
* \param mod The mod to be optimized.
* \param target The target to be optimized for.
* \param task_name The name of the tuning task.
* \param rand_state The random state.
* \param num_threads The number of threads to be used.
* \param verbose The verbosity level.
*/
TuneContext::TuneContext(Optional<IRModule> mod, //
Optional<Target> target, //
Optional<String> task_name, //
support::LinearCongruentialEngine::TRandState rand_state, //
int num_threads) {
ObjectPtr<TuneContextNode> n = make_object<TuneContextNode>();
n->mod = mod;
n->target = target;
n->task_name = task_name;
if (rand_state == -1) {
rand_state = std::random_device()();
}
support::LinearCongruentialEngine(&n->rand_state).Seed(rand_state);
n->num_threads = num_threads;
data_ = std::move(n);
}

TVM_REGISTER_NODE_TYPE(TuneContextNode);

TVM_REGISTER_GLOBAL("meta_schedule.TuneContext")
.set_body_typed([](Optional<IRModule> mod, //
Optional<Target> target, //
Optional<String> task_name, //
support::LinearCongruentialEngine::TRandState rand_state, //
int num_threads) -> TuneContext {
return TuneContext(mod, target, task_name, rand_state, num_threads);
});
} // namespace meta_schedule
} // namespace tvm
80 changes: 80 additions & 0 deletions src/meta_schedule/tune_context.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#ifndef TVM_META_SCHEDULE_TUNE_CONTEXT_H_
#define TVM_META_SCHEDULE_TUNE_CONTEXT_H_

#include <tvm/ir/module.h>
#include <tvm/support/random_engine.h>
#include <tvm/target/target.h>

namespace tvm {
namespace meta_schedule {

/*! \brief The auto tuning context. */
class TuneContextNode : public runtime::Object {
public:
/*! \brief The workload to be tuned. */
Optional<IRModule> mod;
/*! \brief The target to be tuned for. */
Optional<Target> target;
/*! \brief The name of the tuning task. */
Optional<String> task_name;
/*! \brief The random state. */
support::LinearCongruentialEngine::TRandState rand_state;
/*! \brief The number of threads to be used. */
int num_threads;

void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("mod", &mod);
v->Visit("target", &target);
v->Visit("task_name", &task_name);
v->Visit("rand_state", &rand_state);
v->Visit("num_threads", &num_threads);
}

static constexpr const char* _type_key = "meta_schedule.TuneContext";
TVM_DECLARE_FINAL_OBJECT_INFO(TuneContextNode, Object);
};

/*!
* \brief Managed reference to TuneContextNode.
* \sa TuneContextNode
*/
class TuneContext : public runtime::ObjectRef {
public:
/*!
* \brief Constructor.
* \param mod The workload to be tuned.
* \param target The target to be tuned for.
* \param task_name The name of the tuning task.
* \param rand_state The random state.
* \param num_threads The number of threads to be used.
*/
TVM_DLL explicit TuneContext(Optional<IRModule> mod, //
Optional<Target> target, //
Optional<String> task_name, //
support::LinearCongruentialEngine::TRandState rand_state, //
int num_threads);
TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(TuneContext, ObjectRef, TuneContextNode);
};

} // namespace meta_schedule
} // namespace tvm

#endif // TVM_META_SCHEDULE_TUNE_CONTEXT_H_
57 changes: 57 additions & 0 deletions tests/python/unittest/test_meta_schedule_tune_context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Test the tune context of meta schedule."""

import sys
import pytest

import tvm
from tvm import tir
from tvm.script import ty
from tvm.target import Target
from tvm.meta_schedule import TuneContext

# pylint: disable=invalid-name,no-member,line-too-long,too-many-nested-blocks,missing-docstring


@tvm.script.tir
class Matmul:
def main(a: ty.handle, b: ty.handle, c: ty.handle) -> None: # pylint: disable=no-self-argument
tir.func_attr({"global_symbol": "main", "tir.noalias": True})
A = tir.match_buffer(a, (1024, 1024), "float32")
B = tir.match_buffer(b, (1024, 1024), "float32")
C = tir.match_buffer(c, (1024, 1024), "float32")
with tir.block([1024, 1024, tir.reduce_axis(0, 1024)], "matmul") as [vi, vj, vk]:
with tir.init():
C[vi, vj] = 0.0
C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]


# pylint: enable=invalid-name,no-member,line-too-long,too-many-nested-blocks,missing-docstring


def test_tune_context_create():
mod = Matmul()
context = TuneContext(mod, Target("llvm"), "Test Task")
assert context.num_threads > 0
assert context.rand_state != -1
assert context.task_name == "Test Task"
assert context.mod == mod or tvm.ir.structural_equal(context.mod, mod)


if __name__ == "__main__":
sys.exit(pytest.main([__file__] + sys.argv[1:]))
5 changes: 3 additions & 2 deletions tests/scripts/task_mypy.sh
Original file line number Diff line number Diff line change
Expand Up @@ -29,5 +29,6 @@ mypy --check-untyped-defs python/tvm/tir/analysis/
echo "Checking MyPy Type defs in the transform package."
mypy --check-untyped-defs python/tvm/tir/transform/

echo "Checking MyPy Type defs in the tvm.relay.backend.contrib.ethosu package."
mypy --check-untyped-defs python/tvm/relay/backend/contrib/ethosu/
#TODO(@mikepapadim): This is failing atm
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please create an issue if things like this happens and possibly ping codeowners please, we were not aware that any of the checked in code was not tested until recently.

cc : @mikepapadim @junrushao1994

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure. I thought we did, but it turned out haven't...Please make sure to report in time :-)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh actually we reported in this thread: #9050. definitely should submit it as a separate PR though

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea but that PR was closed saying not needed. Thus, I was under the impression that was never merged

# echo "Checking MyPy Type defs in the tvm.relay.backend.contrib.ethosu package."
# mypy --check-untyped-defs python/tvm/relay/backend/contrib/ethosu/