Skip to content

Commit 12839e6

Browse files
merrymercytqchen
authored andcommitted
[AUTOTVM] Decouple build and run in measurement (#1661)
1 parent 38203a8 commit 12839e6

File tree

16 files changed

+880
-760
lines changed

16 files changed

+880
-760
lines changed

docs/api/python/autotvm.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,11 @@ tvm.autotvm.measure
1616

1717
.. autofunction:: tvm.autotvm.measure.create_measure_batch
1818

19+
.. autoclass:: tvm.autotvm.measure.measure_methods.LocalBuilder
20+
21+
.. autoclass:: tvm.autotvm.measure.measure_methods.RPCRunner
22+
23+
.. autoclass:: tvm.autotvm.measure.measure_methods.LocalRunner
1924

2025
tvm.autotvm.tuner
2126
~~~~~~~~~~~~~~~~~

python/tvm/autotvm/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,8 @@
2222
from . import tophub
2323

2424
# some shortcuts
25-
from .measure import measure_option, MeasureInput, MeasureResult, MeasureErrorNo
25+
from .measure import measure_option, MeasureInput, MeasureResult, MeasureErrorNo, \
26+
LocalBuilder, LocalRunner, RPCRunner
2627
from .tuner import callback
2728
from .task import template, get_config, create, ConfigSpace, ConfigEntity, \
2829
register_topi_compute, register_topi_schedule, \
Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""Distributed executor infrastructure to scale up the tuning"""
22

3-
from .measure import MeasureInput, MeasureResult, MeasureErrorNo, measure_option
4-
from .measure_methods import request_remote, check_remote, create_measure_batch, rpc
5-
3+
from .measure import MeasureInput, MeasureResult, MeasureErrorNo, measure_option, \
4+
create_measure_batch
5+
from .measure_methods import LocalBuilder, LocalRunner, RPCRunner, request_remote
6+
from .executor import Executor
67
from .local_executor import LocalExecutor
7-
from .executor import Future, Executor

python/tvm/autotvm/measure/local_executor.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -37,25 +37,21 @@ def _execute_func(func, queue, args, kwargs):
3737
res = exc
3838
queue.put(res)
3939

40-
def timeout_monitor(queue, timeout, func, args, kwargs):
40+
41+
def call_with_timeout(queue, timeout, func, args, kwargs):
4142
"""A wrapper to support timeout of a function call"""
4243

4344
# start a new process for timeout (cannot use thread because we have c function)
4445
p = Process(target=_execute_func, args=(func, queue, args, kwargs))
4546
p.start()
4647
p.join(timeout=timeout)
4748

48-
alive = p.is_alive()
49+
queue.put(executor.TimeoutError())
50+
4951
kill_child_processes(p.pid)
5052
p.terminate()
5153
p.join()
5254

53-
if alive:
54-
queue.put(executor.TimeoutError())
55-
else:
56-
if queue.empty():
57-
queue.put(executor.ExecutionError("Fatal error in local executor"))
58-
5955

6056
class LocalFuture(executor.Future):
6157
"""Local wrapper for the future
@@ -134,7 +130,7 @@ def submit(self, func, *args, **kwargs):
134130
return LocalFutureNoFork(func(*args, **kwargs))
135131

136132
queue = Queue(2)
137-
process = Process(target=timeout_monitor,
133+
process = Process(target=call_with_timeout,
138134
args=(queue, self.timeout, func, args, kwargs))
139135
process.start()
140136
return LocalFuture(process, queue)
Lines changed: 173 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# pylint: disable=pointless-string-statement,consider-using-enumerate,invalid-name
22
"""User facing API for specifying how to measure the generated code"""
3+
import multiprocessing
34
from collections import namedtuple
45

56
class MeasureInput(namedtuple("MeasureInput", ["target", "task", "config"])):
@@ -16,15 +17,16 @@ class MeasureInput(namedtuple("MeasureInput", ["target", "task", "config"])):
1617
Specific configuration.
1718
"""
1819

20+
1921
class MeasureResult(namedtuple("MeasureResult", ["costs", "error_no", "all_cost", "timestamp"])):
2022
"""
2123
Stores all the results of a measurement
2224
2325
Parameters
2426
----------
2527
costs: Array of float or Array of Exception
26-
If no error occurs for this measurement, it is an array of measured running times.
27-
If some error occurs during the measurement, it is an array of the exception objections.
28+
If no error occurs during measurement, it is an array of measured running times.
29+
If an error occurs during measurement, it is an array of the exception objections.
2830
error_no: int
2931
Denote error type, defined by MeasureErrorNo
3032
all_cost: float
@@ -37,92 +39,185 @@ class MeasureResult(namedtuple("MeasureResult", ["costs", "error_no", "all_cost"
3739
class MeasureErrorNo(object):
3840
"""Error type for MeasureResult"""
3941
NO_ERROR = 0 # no error
40-
INSTANTIATION_ERROR = 1 # error when calling template function
42+
INSTANTIATION_ERROR = 1 # actively detected error in instantiating a template with a config
4143
COMPILE_HOST = 2 # error when compiling code on host (e.g. tvm.build)
42-
COMPILE_DEVICE = 3 # error when compiling code on device (e.g. opencl JIT on device)
44+
COMPILE_DEVICE = 3 # error when compiling code on device (e.g. OpenCL JIT on the device)
4345
RUNTIME_DEVICE = 4 # error when run program on device
4446
WRONG_ANSWER = 5 # answer is wrong when compared to a golden output
45-
FLEET_ERROR = 6 # error of measure infrastructure
47+
BUILD_TIMEOUT = 6 # timeout during compilation
48+
RUN_TIMEOUT = 7 # timeout during run
49+
UNKNOWN_ERROR = 8 # unknown error
50+
4651

52+
class Builder(object):
53+
"""Builder that builds programs in tuning
4754
48-
def measure_option(measure_func,
49-
number=1,
50-
repeat=1,
51-
timeout=60,
52-
n_parallel=1,
53-
do_fork=True,
54-
build_func='default',
55-
check_correctness=False,
56-
replay_db=None):
57-
"""Configure how to do measurement
55+
Parameters
56+
----------
57+
timeout: float, optional
58+
The timeout of a build task
59+
n_parallel: int, optional
60+
The number of tasks submitted in parallel
61+
By default it will use all cpu cores
62+
"""
63+
def __init__(self, timeout=10, n_parallel=None):
64+
self.timeout = timeout
65+
self.n_parallel = n_parallel or multiprocessing.cpu_count()
66+
self.build_kwargs = {}
67+
self.task = None
68+
69+
def set_task(self, task, build_kwargs=None):
70+
"""
71+
Initialize for a new tuning task
72+
73+
Parameters
74+
----------
75+
task: Task
76+
The tuning task
77+
build_kwargs: dict, optional
78+
The additional kwargs for build function
79+
"""
80+
self.task = task
81+
self.build_kwargs = build_kwargs
82+
83+
def build(self, measure_inputs):
84+
"""Build programs
85+
86+
Parameters
87+
----------
88+
measure_inputs: List of MeasureInput
89+
The measure input
90+
91+
Returns
92+
-------
93+
build_results: List of BuildResult
94+
The build result.
95+
"""
96+
raise NotImplementedError()
97+
98+
99+
class Runner(object):
100+
"""Runner that runs and measures the time cost of a generated program in tuning
58101
59102
Parameters
60103
----------
61-
measure_func: str or callable
62-
'local': use the local device for measurement. The tuner will start a tracker
63-
and a RPC server silently for the user.
64-
65-
callable: It is a callable function for measurement.
66-
See the return value of measure/measure_methods.py::rpc for example.
67-
number : int, optional
68-
Number of times to do the measurement for average
69-
repeat : int, optional
70-
Number of times to repeat the measurement.
71-
In total, the generated code will be run (1 + number x repeat) times,
72-
where the first one is warm up. The returned result contains `repeat` costs,
73-
each of which is the average of `number` test run.
74-
timeout: int, optional
75-
Timeout for a whole batch. TimeoutError will be returned as the result if a
76-
task timeouts.
104+
timeout: float, optional
105+
The timeout of a build task
77106
n_parallel: int, optional
78-
The number of measurement task that can run in parallel.
79-
Set this according to the number of cpu cores (for compilation) and
80-
the number of devices you have (for measuring generate code).
81-
do_fork: bool, optional
82-
Whether use multiprocessing (based on fork) for running measure jobs in parallel.
83-
Set this to False if you want to debug (see trackback) or using fork is not suitable.
84-
NOTE: If this is False, parallel and timeout do not work.
85-
build_func: str or callable, optional
86-
'default': call default builder. This works for normal target (llvm, cuda)
87-
88-
'ndk': use Android NDK to create shared library. Use this for android target.
89-
90-
callable: customized build function for other backends (e.g. VTA).
91-
See measure/measure_methods.py::default_build_func for example.
92-
check_correctness: bool, optional
93-
Whether check correctness after measurement. This will use llvm cpu target to generate
94-
reference output.
95-
replay_db : Database, optional
96-
The database that we retrieve saved MeasureResult from.
107+
The number of tasks submitted in parallel
108+
By default it will use all cpu cores
109+
"""
110+
def __init__(self, timeout=5, n_parallel=None):
111+
self.timeout = timeout
112+
self.n_parallel = n_parallel or multiprocessing.cpu_count()
113+
self.task = None
114+
115+
def set_task(self, task):
116+
"""
117+
Initialize for a new tuning task
118+
119+
Parameters
120+
----------
121+
task: Task
122+
The tuning task
123+
"""
124+
self.task = task
125+
126+
def get_build_kwargs(self):
127+
"""
128+
Get device specific build arguments (e.g. maximum shared memory size)
129+
130+
Returns
131+
----------
132+
kwargs: dict
133+
The additional keyword arguments
134+
"""
135+
raise NotImplementedError()
136+
137+
def run(self, measure_inputs, build_results):
138+
"""Run amd measure built programs
139+
140+
Parameters
141+
----------
142+
measure_inputs: List of MeasureInput
143+
The raw measure input
144+
build_results: List of BuildResults
145+
The build results
146+
147+
Returns
148+
-------
149+
measure_results: List of MeasureResult
150+
The final results of measurement
151+
"""
152+
raise NotImplementedError()
153+
154+
155+
def measure_option(builder, runner):
156+
"""
157+
Set options for measure. To measure a config, we will build it and run it.
158+
So we have to set options for these two steps.
159+
They have their own options on timeout, parallel, etc.
160+
161+
Parameters
162+
----------
163+
builder: Builder
164+
Specify how to build programs
165+
runner: Runner
166+
Specify how to run programs
167+
"""
168+
from .measure_methods import LocalBuilder, LocalRunner
169+
170+
if isinstance(builder, str):
171+
if builder == 'local':
172+
builder = LocalBuilder()
173+
else:
174+
raise ValueError("Invalid builder: " + builder)
175+
176+
if isinstance(runner, str):
177+
if runner == 'local':
178+
runner = LocalRunner()
179+
else:
180+
raise ValueError("Invalid runner: " + runner)
181+
182+
opt = {
183+
'builder': builder,
184+
'runner': runner,
185+
}
186+
187+
return opt
188+
189+
190+
def create_measure_batch(task, option):
191+
"""Get a standard measure_batch function.
192+
193+
Parameters
194+
----------
195+
task: tvm.autotvm.task.Task
196+
The tuning task
197+
option: dict
198+
The option for measuring generated code.
199+
You should use the return value of function :any:`measure_option` for this argument.
97200
98201
Returns
99202
-------
100-
options: dict
101-
A dict to store all options
102-
103-
Note
104-
----
105-
To support customized measure, you can pass callable `measure_func` or
106-
`build_func` in. The `measure_func` will call `build_func` to build binary library
107-
and handle the logic of measurement.
108-
109-
Signature:
110-
* measure_func (see the return value of measure/measure_methods.py::rpc for example)
111-
def measure_func(input_pack, build_func, build_kwargs, number, repeat, ref_input, ref_output):
112-
return measure_results
113-
114-
* build_func (see measure/measure_methods.py::default_build_func for example)
115-
def build_func(inp, tmp_dir, **kwargs):
116-
return func, args, filename
203+
measure_batch: callable
204+
a callback function to measure a batch of configs
117205
"""
118-
return {
119-
'measure_func': measure_func,
120-
'number': number,
121-
'repeat': repeat,
122-
'timeout': timeout,
123-
'n_parallel': n_parallel,
124-
'do_fork': do_fork,
125-
'build_func': build_func,
126-
'check_correctness': check_correctness,
127-
'replay_db': replay_db,
128-
}
206+
builder = option['builder']
207+
runner = option['runner']
208+
209+
attach_objects = runner.set_task(task)
210+
211+
# feed device related information from runner to builder
212+
# (e.g. max shared memory for validity checking)
213+
build_kwargs = runner.get_build_kwargs()
214+
builder.set_task(task, build_kwargs)
215+
216+
def measure_batch(measure_inputs):
217+
build_results = builder.build(measure_inputs)
218+
results = runner.run(measure_inputs, build_results)
219+
return results
220+
221+
measure_batch.n_parallel = builder.n_parallel
222+
measure_batch.attach_objects = attach_objects
223+
return measure_batch

0 commit comments

Comments
 (0)