11# pylint: disable=pointless-string-statement,consider-using-enumerate,invalid-name
22"""User facing API for specifying how to measure the generated code"""
3+ import multiprocessing
34from collections import namedtuple
45
56class MeasureInput (namedtuple ("MeasureInput" , ["target" , "task" , "config" ])):
@@ -16,15 +17,16 @@ class MeasureInput(namedtuple("MeasureInput", ["target", "task", "config"])):
1617 Specific configuration.
1718 """
1819
20+
1921class MeasureResult (namedtuple ("MeasureResult" , ["costs" , "error_no" , "all_cost" , "timestamp" ])):
2022 """
2123 Stores all the results of a measurement
2224
2325 Parameters
2426 ----------
2527 costs: Array of float or Array of Exception
26- If no error occurs for this measurement, it is an array of measured running times.
27- If some error occurs during the measurement, it is an array of the exception objections.
28+ If no error occurs during measurement, it is an array of measured running times.
29+ If an error occurs during measurement, it is an array of the exception objections.
2830 error_no: int
2931 Denote error type, defined by MeasureErrorNo
3032 all_cost: float
@@ -37,92 +39,185 @@ class MeasureResult(namedtuple("MeasureResult", ["costs", "error_no", "all_cost"
3739class MeasureErrorNo (object ):
3840 """Error type for MeasureResult"""
3941 NO_ERROR = 0 # no error
40- INSTANTIATION_ERROR = 1 # error when calling template function
42+ INSTANTIATION_ERROR = 1 # actively detected error in instantiating a template with a config
4143 COMPILE_HOST = 2 # error when compiling code on host (e.g. tvm.build)
42- COMPILE_DEVICE = 3 # error when compiling code on device (e.g. opencl JIT on device)
44+ COMPILE_DEVICE = 3 # error when compiling code on device (e.g. OpenCL JIT on the device)
4345 RUNTIME_DEVICE = 4 # error when run program on device
4446 WRONG_ANSWER = 5 # answer is wrong when compared to a golden output
45- FLEET_ERROR = 6 # error of measure infrastructure
47+ BUILD_TIMEOUT = 6 # timeout during compilation
48+ RUN_TIMEOUT = 7 # timeout during run
49+ UNKNOWN_ERROR = 8 # unknown error
50+
4651
52+ class Builder (object ):
53+ """Builder that builds programs in tuning
4754
48- def measure_option (measure_func ,
49- number = 1 ,
50- repeat = 1 ,
51- timeout = 60 ,
52- n_parallel = 1 ,
53- do_fork = True ,
54- build_func = 'default' ,
55- check_correctness = False ,
56- replay_db = None ):
57- """Configure how to do measurement
55+ Parameters
56+ ----------
57+ timeout: float, optional
58+ The timeout of a build task
59+ n_parallel: int, optional
60+ The number of tasks submitted in parallel
61+ By default it will use all cpu cores
62+ """
63+ def __init__ (self , timeout = 10 , n_parallel = None ):
64+ self .timeout = timeout
65+ self .n_parallel = n_parallel or multiprocessing .cpu_count ()
66+ self .build_kwargs = {}
67+ self .task = None
68+
69+ def set_task (self , task , build_kwargs = None ):
70+ """
71+ Initialize for a new tuning task
72+
73+ Parameters
74+ ----------
75+ task: Task
76+ The tuning task
77+ build_kwargs: dict, optional
78+ The additional kwargs for build function
79+ """
80+ self .task = task
81+ self .build_kwargs = build_kwargs
82+
83+ def build (self , measure_inputs ):
84+ """Build programs
85+
86+ Parameters
87+ ----------
88+ measure_inputs: List of MeasureInput
89+ The measure input
90+
91+ Returns
92+ -------
93+ build_results: List of BuildResult
94+ The build result.
95+ """
96+ raise NotImplementedError ()
97+
98+
99+ class Runner (object ):
100+ """Runner that runs and measures the time cost of a generated program in tuning
58101
59102 Parameters
60103 ----------
61- measure_func: str or callable
62- 'local': use the local device for measurement. The tuner will start a tracker
63- and a RPC server silently for the user.
64-
65- callable: It is a callable function for measurement.
66- See the return value of measure/measure_methods.py::rpc for example.
67- number : int, optional
68- Number of times to do the measurement for average
69- repeat : int, optional
70- Number of times to repeat the measurement.
71- In total, the generated code will be run (1 + number x repeat) times,
72- where the first one is warm up. The returned result contains `repeat` costs,
73- each of which is the average of `number` test run.
74- timeout: int, optional
75- Timeout for a whole batch. TimeoutError will be returned as the result if a
76- task timeouts.
104+ timeout: float, optional
105+ The timeout of a build task
77106 n_parallel: int, optional
78- The number of measurement task that can run in parallel.
79- Set this according to the number of cpu cores (for compilation) and
80- the number of devices you have (for measuring generate code).
81- do_fork: bool, optional
82- Whether use multiprocessing (based on fork) for running measure jobs in parallel.
83- Set this to False if you want to debug (see trackback) or using fork is not suitable.
84- NOTE: If this is False, parallel and timeout do not work.
85- build_func: str or callable, optional
86- 'default': call default builder. This works for normal target (llvm, cuda)
87-
88- 'ndk': use Android NDK to create shared library. Use this for android target.
89-
90- callable: customized build function for other backends (e.g. VTA).
91- See measure/measure_methods.py::default_build_func for example.
92- check_correctness: bool, optional
93- Whether check correctness after measurement. This will use llvm cpu target to generate
94- reference output.
95- replay_db : Database, optional
96- The database that we retrieve saved MeasureResult from.
107+ The number of tasks submitted in parallel
108+ By default it will use all cpu cores
109+ """
110+ def __init__ (self , timeout = 5 , n_parallel = None ):
111+ self .timeout = timeout
112+ self .n_parallel = n_parallel or multiprocessing .cpu_count ()
113+ self .task = None
114+
115+ def set_task (self , task ):
116+ """
117+ Initialize for a new tuning task
118+
119+ Parameters
120+ ----------
121+ task: Task
122+ The tuning task
123+ """
124+ self .task = task
125+
126+ def get_build_kwargs (self ):
127+ """
128+ Get device specific build arguments (e.g. maximum shared memory size)
129+
130+ Returns
131+ ----------
132+ kwargs: dict
133+ The additional keyword arguments
134+ """
135+ raise NotImplementedError ()
136+
137+ def run (self , measure_inputs , build_results ):
138+ """Run amd measure built programs
139+
140+ Parameters
141+ ----------
142+ measure_inputs: List of MeasureInput
143+ The raw measure input
144+ build_results: List of BuildResults
145+ The build results
146+
147+ Returns
148+ -------
149+ measure_results: List of MeasureResult
150+ The final results of measurement
151+ """
152+ raise NotImplementedError ()
153+
154+
155+ def measure_option (builder , runner ):
156+ """
157+ Set options for measure. To measure a config, we will build it and run it.
158+ So we have to set options for these two steps.
159+ They have their own options on timeout, parallel, etc.
160+
161+ Parameters
162+ ----------
163+ builder: Builder
164+ Specify how to build programs
165+ runner: Runner
166+ Specify how to run programs
167+ """
168+ from .measure_methods import LocalBuilder , LocalRunner
169+
170+ if isinstance (builder , str ):
171+ if builder == 'local' :
172+ builder = LocalBuilder ()
173+ else :
174+ raise ValueError ("Invalid builder: " + builder )
175+
176+ if isinstance (runner , str ):
177+ if runner == 'local' :
178+ runner = LocalRunner ()
179+ else :
180+ raise ValueError ("Invalid runner: " + runner )
181+
182+ opt = {
183+ 'builder' : builder ,
184+ 'runner' : runner ,
185+ }
186+
187+ return opt
188+
189+
190+ def create_measure_batch (task , option ):
191+ """Get a standard measure_batch function.
192+
193+ Parameters
194+ ----------
195+ task: tvm.autotvm.task.Task
196+ The tuning task
197+ option: dict
198+ The option for measuring generated code.
199+ You should use the return value of function :any:`measure_option` for this argument.
97200
98201 Returns
99202 -------
100- options: dict
101- A dict to store all options
102-
103- Note
104- ----
105- To support customized measure, you can pass callable `measure_func` or
106- `build_func` in. The `measure_func` will call `build_func` to build binary library
107- and handle the logic of measurement.
108-
109- Signature:
110- * measure_func (see the return value of measure/measure_methods.py::rpc for example)
111- def measure_func(input_pack, build_func, build_kwargs, number, repeat, ref_input, ref_output):
112- return measure_results
113-
114- * build_func (see measure/measure_methods.py::default_build_func for example)
115- def build_func(inp, tmp_dir, **kwargs):
116- return func, args, filename
203+ measure_batch: callable
204+ a callback function to measure a batch of configs
117205 """
118- return {
119- 'measure_func' : measure_func ,
120- 'number' : number ,
121- 'repeat' : repeat ,
122- 'timeout' : timeout ,
123- 'n_parallel' : n_parallel ,
124- 'do_fork' : do_fork ,
125- 'build_func' : build_func ,
126- 'check_correctness' : check_correctness ,
127- 'replay_db' : replay_db ,
128- }
206+ builder = option ['builder' ]
207+ runner = option ['runner' ]
208+
209+ attach_objects = runner .set_task (task )
210+
211+ # feed device related information from runner to builder
212+ # (e.g. max shared memory for validity checking)
213+ build_kwargs = runner .get_build_kwargs ()
214+ builder .set_task (task , build_kwargs )
215+
216+ def measure_batch (measure_inputs ):
217+ build_results = builder .build (measure_inputs )
218+ results = runner .run (measure_inputs , build_results )
219+ return results
220+
221+ measure_batch .n_parallel = builder .n_parallel
222+ measure_batch .attach_objects = attach_objects
223+ return measure_batch
0 commit comments