Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions include/tvm/auto_scheduler/feature.h
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@

#include <tvm/auto_scheduler/compute_dag.h>
#include <tvm/auto_scheduler/measure.h>
#include <tvm/tir/function.h>

#include <string>
#include <vector>
Expand All @@ -41,14 +42,15 @@ namespace tvm {
namespace auto_scheduler {

/*!
* \brief Get per-store feature from a TIR Stmt
* \param stmt The input lowered TIR statement
* \brief Get per-store features from a TIR PrimFunc
* \param func The input lowered TIR PrimFunc
* \param cache_line_size The size of cache line in bytes
* \param max_n_bufs The maximum number of extracted buffers for one statement
* \param ret The returned feature vector
* \param log_scale Should the outputs be scaled by log2(1+x).
*/
void GetPerStoreFeature(const Stmt& stmt, int cache_line_size, int max_n_bufs,
std::vector<float>* ret);
void GetPerStoreFeature(const PrimFunc& func, int cache_line_size, int max_n_bufs,
std::vector<float>* ret, bool log_scale = true);

/*
* \brief Get the names of elements in the feature vector. Use this for debug and inspection.
Expand Down
78 changes: 77 additions & 1 deletion python/tvm/auto_scheduler/feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,15 @@
The feature specification is defined by `src/auto_scheduler/feature.cc::FeatureSet`
"""

from typing import List, Tuple, Union, Optional
from typing import List, Tuple, Union, Optional, Dict
import struct

import numpy as np

from .loop_state import State, StateObject
from .measure import MeasureInput, MeasureResult
from . import _ffi_api
from ..tir import PrimFunc

# The maximum number of extracted buffers for one statement
DEFAULT_MAX_N_BUFS = 5
Expand Down Expand Up @@ -252,3 +253,78 @@ def get_per_store_feature_names(max_n_bufs: Optional[int] = None) -> List[str]:
The names of elements in the flatten feature vector
"""
return _ffi_api.GetPerStoreFeatureNames(max_n_bufs or DEFAULT_MAX_N_BUFS)


def primfunc_features(
func: PrimFunc,
cache_line_bytes: int = 64,
max_n_bufs: Optional[int] = None,
log_scale: bool = False,
) -> np.ndarray:
"""Extract performance features from a PrimFunc.

Parameters
----------
func: PrimFunc
PrimFunc from which features will be extracted. Each store operation to
a unique buffer in the function will result in one row of features in
the output.

cache_line_bytes: int, optional
Size of a cache line in bytes. Defaults to 64 which is the size for
most x86 processors.

max_n_bufs: int, optional
Maximum number of buffers in generated features. This determines the
length of the resulting feature vector.

log_scale: bool
Should entries in the feature vector be scaled by log2(x + 1). Defaults
to False. Use True if using features with a cost model.

Returns
-------
np.ndarray
Output features, one row per store into a unique buffer statement in `func`.
"""
return _ffi_api.FeaturesFromPrimFunc(
func, cache_line_bytes, max_n_bufs or DEFAULT_MAX_N_BUFS, log_scale
).numpy()


def named_primfunc_features(
func: PrimFunc,
cache_line_bytes: int = 64,
max_n_bufs: Optional[int] = None,
log_scale: bool = False,
) -> Dict[str, np.ndarray]:
"""Extract performance features and associated names from a PrimFunc.

Parameters
----------
func: PrimFunc
PrimFunc from which features will be extracted. Each store operation to
a unique buffer in the function will result in one row of features in
the output.

cache_line_bytes: int, optional
Size of a cache line in bytes. Defaults to 64 which is the size for
most x86 processors.

max_n_bufs: int, optional
Maximum number of buffers in generated features. This determines the
length of the resulting feature vector.

log_scale: bool
Should entries in the feature vector be scaled by log2(x + 1). Defaults
to False. Use True if using features with a cost model.

Returns
-------
Dict[str, np.ndarray]
Mapping from feature name to features. One element per store into a
unique buffer statement in `func`.
"""
features = primfunc_features(func, cache_line_bytes, max_n_bufs, log_scale)
names = get_per_store_feature_names(max_n_bufs)
return {name: features[:, i] for i, name in enumerate(names)}
Loading