Skip to content

Commit 7e49f53

Browse files
author
Tristan Konolige
authored
[AUTO_SCHEDULER] Add feature extraction directly from PrimFunc (#10455)
* [AUTO_SCHEDULER] Add feature extraction directly from PrimFunc Allow users to directly extract features from a PrimFunc. Extracted features can be used to get an estimate of flops, memory load size, or arithmetic intensity from a PrimFunc. Also fix feature extraction to correctly measure the number of arithmetic operations width vector datatypes. * fix param name * log scale in cc instead of python * rename functions, remove load/store * forgot rename in tests * forgot to commit rename
1 parent e2211a2 commit 7e49f53

File tree

4 files changed

+225
-46
lines changed

4 files changed

+225
-46
lines changed

include/tvm/auto_scheduler/feature.h

100755100644
Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333

3434
#include <tvm/auto_scheduler/compute_dag.h>
3535
#include <tvm/auto_scheduler/measure.h>
36+
#include <tvm/tir/function.h>
3637

3738
#include <string>
3839
#include <vector>
@@ -41,14 +42,15 @@ namespace tvm {
4142
namespace auto_scheduler {
4243

4344
/*!
44-
* \brief Get per-store feature from a TIR Stmt
45-
* \param stmt The input lowered TIR statement
45+
* \brief Get per-store features from a TIR PrimFunc
46+
* \param func The input lowered TIR PrimFunc
4647
* \param cache_line_size The size of cache line in bytes
4748
* \param max_n_bufs The maximum number of extracted buffers for one statement
4849
* \param ret The returned feature vector
50+
* \param log_scale Should the outputs be scaled by log2(1+x).
4951
*/
50-
void GetPerStoreFeature(const Stmt& stmt, int cache_line_size, int max_n_bufs,
51-
std::vector<float>* ret);
52+
void GetPerStoreFeature(const PrimFunc& func, int cache_line_size, int max_n_bufs,
53+
std::vector<float>* ret, bool log_scale = true);
5254

5355
/*
5456
* \brief Get the names of elements in the feature vector. Use this for debug and inspection.

python/tvm/auto_scheduler/feature.py

Lines changed: 77 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,14 +26,15 @@
2626
The feature specification is defined by `src/auto_scheduler/feature.cc::FeatureSet`
2727
"""
2828

29-
from typing import List, Tuple, Union, Optional
29+
from typing import List, Tuple, Union, Optional, Dict
3030
import struct
3131

3232
import numpy as np
3333

3434
from .loop_state import State, StateObject
3535
from .measure import MeasureInput, MeasureResult
3636
from . import _ffi_api
37+
from ..tir import PrimFunc
3738

3839
# The maximum number of extracted buffers for one statement
3940
DEFAULT_MAX_N_BUFS = 5
@@ -252,3 +253,78 @@ def get_per_store_feature_names(max_n_bufs: Optional[int] = None) -> List[str]:
252253
The names of elements in the flatten feature vector
253254
"""
254255
return _ffi_api.GetPerStoreFeatureNames(max_n_bufs or DEFAULT_MAX_N_BUFS)
256+
257+
258+
def features_from_primfunc(
259+
func: PrimFunc,
260+
cache_line_bytes: int = 64,
261+
max_n_bufs: Optional[int] = None,
262+
log_scale: bool = False,
263+
) -> np.ndarray:
264+
"""Extract performance features from a PrimFunc.
265+
266+
Parameters
267+
----------
268+
func: PrimFunc
269+
PrimFunc from which features will be extracted. Each store operation to
270+
a unique buffer in the function will result in one row of features in
271+
the output.
272+
273+
cache_line_bytes: int, optional
274+
Size of a cache line in bytes. Defaults to 64 which is the size for
275+
most x86 processors.
276+
277+
max_n_bufs: int, optional
278+
Maximum number of buffers in generated features. This determines the
279+
length of the resulting feature vector.
280+
281+
log_scale: bool
282+
Should entries in the feature vector be scaled by log2(x + 1). Defaults
283+
to False. Use True if using features with a cost model.
284+
285+
Returns
286+
-------
287+
np.ndarray
288+
Output features, one row per store into a unique buffer statement in `func`.
289+
"""
290+
return _ffi_api.FeaturesFromPrimFunc(
291+
func, cache_line_bytes, max_n_bufs or DEFAULT_MAX_N_BUFS, log_scale
292+
).numpy()
293+
294+
295+
def named_features_from_primfunc(
296+
func: PrimFunc,
297+
cache_line_bytes: int = 64,
298+
max_n_bufs: Optional[int] = None,
299+
log_scale: bool = False,
300+
) -> Dict[str, np.ndarray]:
301+
"""Extract performance features and associated names from a PrimFunc.
302+
303+
Parameters
304+
----------
305+
func: PrimFunc
306+
PrimFunc from which features will be extracted. Each store operation to
307+
a unique buffer in the function will result in one row of features in
308+
the output.
309+
310+
cache_line_bytes: int, optional
311+
Size of a cache line in bytes. Defaults to 64 which is the size for
312+
most x86 processors.
313+
314+
max_n_bufs: int, optional
315+
Maximum number of buffers in generated features. This determines the
316+
length of the resulting feature vector.
317+
318+
log_scale: bool
319+
Should entries in the feature vector be scaled by log2(x + 1). Defaults
320+
to False. Use True if using features with a cost model.
321+
322+
Returns
323+
-------
324+
Dict[str, np.ndarray]
325+
Mapping from feature name to features. One element per store into a
326+
unique buffer statement in `func`.
327+
"""
328+
features = features_from_primfunc(func, cache_line_bytes, max_n_bufs, log_scale)
329+
names = get_per_store_feature_names(max_n_bufs)
330+
return {name: features[:, i] for i, name in enumerate(names)}

0 commit comments

Comments
 (0)