Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MetaSchedule][M3a] SpaceGenerator #9079

Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 122 additions & 0 deletions include/tvm/meta_schedule/space_generator.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#ifndef TVM_META_SCHEDULE_SPACE_GENERATOR_H_
#define TVM_META_SCHEDULE_SPACE_GENERATOR_H_

#include <tvm/ir/module.h>
#include <tvm/tir/schedule/schedule.h>

namespace tvm {
namespace meta_schedule {

// Forward declaration
class TuneContext;

/*! \brief The abstract class for design space generation. */
class SpaceGeneratorNode : public Object {
public:
/*! \brief Default destructor */
virtual ~SpaceGeneratorNode() = default;

/*!
* \brief Initialize the design space generator with tuning context.
* \param tune_context The tuning context for initialization.
*/
virtual void InitializeWithTuneContext(const TuneContext& tune_context) = 0;

/*!
* \brief Generate design spaces given a module.
* \param mod The module used for design space generation.
* \return The generated design spaces, i.e., schedules.
*/
virtual Array<tir::Schedule> GenerateDesignSpace(const IRModule& mod) = 0;

static constexpr const char* _type_key = "meta_schedule.SpaceGenerator";
TVM_DECLARE_BASE_OBJECT_INFO(SpaceGeneratorNode, Object);
};

/*! \brief The design space generator with customized methods on the python-side. */
class PySpaceGeneratorNode : public SpaceGeneratorNode {
public:
/*!
* \brief The function type of `InitializeWithTuneContext` method.
* \param tune_context The tuning context for initialization.
*/
using FInitializeWithTuneContext = runtime::TypedPackedFunc<void(const TuneContext&)>;
/*!
* \brief The function type of `GenerateDesignSpace` method.
* \param mod The module used for design space generation.
* \return The generated design spaces, i.e., schedules.
*/
using FGenerateDesignSpace = runtime::TypedPackedFunc<Array<tir::Schedule>(const IRModule&)>;

/*! \brief The packed function to the `InitializeWithTuneContext` funcion. */
FInitializeWithTuneContext f_initialize_with_tune_context;
/*! \brief The packed function to the `GenerateDesignSpace` function. */
FGenerateDesignSpace f_generate_design_space;

void VisitAttrs(tvm::AttrVisitor* v) {
// `f_initialize_with_tune_context` is not visited
// `f_generate_design_space` is not visited
}

void InitializeWithTuneContext(const TuneContext& tune_context) final {
f_initialize_with_tune_context(tune_context);
}

Array<tir::Schedule> GenerateDesignSpace(const IRModule& mod) final {
return f_generate_design_space(mod);
}

static constexpr const char* _type_key = "meta_schedule.PySpaceGenerator";
TVM_DECLARE_FINAL_OBJECT_INFO(PySpaceGeneratorNode, SpaceGeneratorNode);
};

/*!
* \brief Managed reference to SpaceGeneratorNode.
* \sa SpaceGeneratorNode
*/
class SpaceGenerator : public ObjectRef {
protected:
SpaceGenerator() = default;

public:
/*!
* \brief Create a design space generator with customized methods on the python-side.
* \param initialize_with_tune_context_func The packed function of `InitializeWithTuneContext`.
* \param generate_design_space_func The packed function of `GenerateDesignSpace`.
* \return The design space generator created.
*/
TVM_DLL static SpaceGenerator PySpaceGenerator(
PySpaceGeneratorNode::FInitializeWithTuneContext initialize_with_tune_context_func,
PySpaceGeneratorNode::FGenerateDesignSpace generate_design_space_func);

/*!
* \brief Create a design space generator that is union of multiple design space generators.
* \param space_generators An array of design space generators to be unioned.
* \return The design space generator created.
*/
TVM_DLL static SpaceGenerator SpaceGeneratorUnion(Array<SpaceGenerator, void> space_generators);
TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(SpaceGenerator, ObjectRef, SpaceGeneratorNode);
};

} // namespace meta_schedule
} // namespace tvm

#endif // TVM_META_SCHEDULE_SPACE_GENERATOR_H_
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#define TVM_META_SCHEDULE_TUNE_CONTEXT_H_

#include <tvm/ir/module.h>
#include <tvm/meta_schedule/space_generator.h>
#include <tvm/support/random_engine.h>
#include <tvm/target/target.h>

Expand All @@ -33,6 +34,8 @@ class TuneContextNode : public runtime::Object {
Optional<IRModule> mod;
/*! \brief The target to be tuned for. */
Optional<Target> target;
/*! \brief The design space generator. */
Optional<SpaceGenerator> space_generator;
/*! \brief The name of the tuning task. */
Optional<String> task_name;
/*! \brief The random state. */
Expand All @@ -43,6 +46,7 @@ class TuneContextNode : public runtime::Object {
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("mod", &mod);
v->Visit("target", &target);
v->Visit("space_generator", &space_generator);
v->Visit("task_name", &task_name);
v->Visit("rand_state", &rand_state);
v->Visit("num_threads", &num_threads);
Expand All @@ -62,12 +66,14 @@ class TuneContext : public runtime::ObjectRef {
* \brief Constructor.
* \param mod The workload to be tuned.
* \param target The target to be tuned for.
* \param space_generator The design space generator.
* \param task_name The name of the tuning task.
* \param rand_state The random state.
* \param num_threads The number of threads to be used.
*/
TVM_DLL explicit TuneContext(Optional<IRModule> mod, //
Optional<Target> target, //
Optional<SpaceGenerator> space_generator, //
Optional<String> task_name, //
support::LinearCongruentialEngine::TRandState rand_state, //
int num_threads);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,10 @@

#include <tvm/meta_schedule/arg_info.h>
zxybazh marked this conversation as resolved.
Show resolved Hide resolved
#include <tvm/meta_schedule/builder.h>
#include <tvm/meta_schedule/space_generator.h>
#include <tvm/meta_schedule/tune_context.h>

#include "../src/support/array.h"
#include "../../../src/support/array.h"

namespace tvm {
namespace meta_schedule {} // namespace meta_schedule
Expand Down
1 change: 1 addition & 0 deletions python/tvm/meta_schedule/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,5 @@
"""Package `tvm.meta_schedule`. The meta schedule infrastructure."""
from . import builder
from . import arg_info
from . import space_generator
from .tune_context import TuneContext
25 changes: 25 additions & 0 deletions python/tvm/meta_schedule/space_generator/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
The tvm.meta_schedule.space_generator package.
Meta Schedule design space generators that generates design
space for generation of measure candidates.
"""

from .space_generator import SpaceGenerator, PySpaceGenerator
from .space_generator_union import SpaceGeneratorUnion
from .schedule_fn import ScheduleFn
90 changes: 90 additions & 0 deletions python/tvm/meta_schedule/space_generator/schedule_fn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
Meta schedule design space generators that generates design
space via a schedule function.
"""
from typing import TYPE_CHECKING, Callable, List, Union

from tvm.ir import IRModule
from tvm.ir.container import Array
from tvm.tir.schedule import Schedule

from .space_generator import PySpaceGenerator

if TYPE_CHECKING:
from ..tune_context import TuneContext


class ScheduleFn(PySpaceGenerator):
"""A design space generator with design spaces specified by a schedule function."""

# Multiple cases of schedule functions supported
SCH_FN_TYPE = Union[
Callable[[IRModule], None], # No output
Callable[[IRModule], Schedule], # Single output
Callable[[IRModule], List[Schedule]], # Multiple outputs
]

def __init__(self, sch_fn: SCH_FN_TYPE):
"""Constructor.

Parameters
----------
sch_fn : SCH_FN_TYPE
The schedule function.
"""
super().__init__()
self.sch_fn = sch_fn

def initialize_with_tune_context(self, tune_context: "TuneContext") -> None:
"""Initialize the design space generator with tuning context.

Parameters
----------
tune_context : TuneContext
The tuning context for initializing the design space generator.
"""

def generate_design_space(self, mod: IRModule) -> List[Schedule]:
"""Generate design spaces given a module.

Parameters
----------
mod : IRModule
The module used for design space generation.

Returns
-------
design_spaces : List[Schedule]
The generated design spaces, i.e., schedules.
"""
sch = Schedule(mod) # Make sure the schedule is traced
result = self.sch_fn(sch) # Call the schedule function
if result is None: # Case 1. No output
return [sch]
if isinstance(result, Schedule): # Case 2. Single output
return [result]
if isinstance(result, (list, tuple, Array)): # Case 3. Multiple outputs
for ret in result: # enumerate the outputs
if not isinstance(ret, Schedule):
raise TypeError(
"Wrong type of element in the list, expected Schedule got "
+ f"'{type(ret)}': {ret}"
)
return result
raise TypeError(f"Unexpected return type {type(result)}: {result}")
93 changes: 93 additions & 0 deletions python/tvm/meta_schedule/space_generator/space_generator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
Meta Schedule design space generators that generates design
space for generation of measure candidates.
"""

from typing import TYPE_CHECKING, List

from tvm._ffi import register_object
from tvm.ir import IRModule
from tvm.runtime import Object
from tvm.tir.schedule import Schedule

from .. import _ffi_api

if TYPE_CHECKING:
from ..tune_context import TuneContext


@register_object("meta_schedule.SpaceGenerator")
class SpaceGenerator(Object):
"""The abstract design space generator interface."""

def initialize_with_tune_context(
self,
tune_context: "TuneContext",
) -> None:
"""Initialize the design space generator with tuning context.

Parameters
----------
tune_context : TuneContext
The tuning context for initializing the design space generator.
"""
_ffi_api.SpaceGeneratorInitializeWithTuneContext( # type: ignore # pylint: disable=no-member
self, tune_context
)

def generate_design_space(self, mod: IRModule) -> List[Schedule]:
"""Generate design spaces given a module.

Parameters
----------
mod : IRModule
The module used for design space generation.

Returns
-------
design_spaces : List[Schedule]
The generated design spaces, i.e., schedules.
"""
return _ffi_api.SpaceGeneratorGenerateDesignSpace(self, mod) # type: ignore # pylint: disable=no-member


@register_object("meta_schedule.PySpaceGenerator")
class PySpaceGenerator(SpaceGenerator):
"""An abstract design space generator with customized methods on the python-side."""

def __init__(self):
"""Constructor."""

def f_initialize_with_tune_context(tune_context: "TuneContext") -> None:
self.initialize_with_tune_context(tune_context)

def f_generate_design_space(mod: IRModule) -> List[Schedule]:
return self.generate_design_space(mod)

self.__init_handle_by_constructor__(
_ffi_api.SpaceGeneratorPySpaceGenerator, # type: ignore # pylint: disable=no-member
f_initialize_with_tune_context,
f_generate_design_space,
)

def initialize_with_tune_context(self, tune_context: "TuneContext") -> None:
raise NotImplementedError

def generate_design_space(self, mod: IRModule) -> List[Schedule]:
raise NotImplementedError
Loading