diff --git a/include/tvm/auto_scheduler/feature.h b/include/tvm/auto_scheduler/feature.h old mode 100755 new mode 100644 index a1782f1871d0..71d00f249210 --- a/include/tvm/auto_scheduler/feature.h +++ b/include/tvm/auto_scheduler/feature.h @@ -33,6 +33,7 @@ #include #include +#include #include #include @@ -41,14 +42,15 @@ namespace tvm { namespace auto_scheduler { /*! - * \brief Get per-store feature from a TIR Stmt - * \param stmt The input lowered TIR statement + * \brief Get per-store features from a TIR PrimFunc + * \param func The input lowered TIR PrimFunc * \param cache_line_size The size of cache line in bytes * \param max_n_bufs The maximum number of extracted buffers for one statement * \param ret The returned feature vector + * \param log_scale Should the outputs be scaled by log2(1+x). */ -void GetPerStoreFeature(const Stmt& stmt, int cache_line_size, int max_n_bufs, - std::vector* ret); +void GetPerStoreFeature(const PrimFunc& func, int cache_line_size, int max_n_bufs, + std::vector* ret, bool log_scale = true); /* * \brief Get the names of elements in the feature vector. Use this for debug and inspection. diff --git a/python/tvm/auto_scheduler/feature.py b/python/tvm/auto_scheduler/feature.py index ec7cf6334f98..09d54a92fd64 100644 --- a/python/tvm/auto_scheduler/feature.py +++ b/python/tvm/auto_scheduler/feature.py @@ -26,7 +26,7 @@ The feature specification is defined by `src/auto_scheduler/feature.cc::FeatureSet` """ -from typing import List, Tuple, Union, Optional +from typing import List, Tuple, Union, Optional, Dict import struct import numpy as np @@ -34,6 +34,7 @@ from .loop_state import State, StateObject from .measure import MeasureInput, MeasureResult from . import _ffi_api +from ..tir import PrimFunc # The maximum number of extracted buffers for one statement DEFAULT_MAX_N_BUFS = 5 @@ -252,3 +253,78 @@ def get_per_store_feature_names(max_n_bufs: Optional[int] = None) -> List[str]: The names of elements in the flatten feature vector """ return _ffi_api.GetPerStoreFeatureNames(max_n_bufs or DEFAULT_MAX_N_BUFS) + + +def features_from_primfunc( + func: PrimFunc, + cache_line_bytes: int = 64, + max_n_bufs: Optional[int] = None, + log_scale: bool = False, +) -> np.ndarray: + """Extract performance features from a PrimFunc. + + Parameters + ---------- + func: PrimFunc + PrimFunc from which features will be extracted. Each store operation to + a unique buffer in the function will result in one row of features in + the output. + + cache_line_bytes: int, optional + Size of a cache line in bytes. Defaults to 64 which is the size for + most x86 processors. + + max_n_bufs: int, optional + Maximum number of buffers in generated features. This determines the + length of the resulting feature vector. + + log_scale: bool + Should entries in the feature vector be scaled by log2(x + 1). Defaults + to False. Use True if using features with a cost model. + + Returns + ------- + np.ndarray + Output features, one row per store into a unique buffer statement in `func`. + """ + return _ffi_api.FeaturesFromPrimFunc( + func, cache_line_bytes, max_n_bufs or DEFAULT_MAX_N_BUFS, log_scale + ).numpy() + + +def named_features_from_primfunc( + func: PrimFunc, + cache_line_bytes: int = 64, + max_n_bufs: Optional[int] = None, + log_scale: bool = False, +) -> Dict[str, np.ndarray]: + """Extract performance features and associated names from a PrimFunc. + + Parameters + ---------- + func: PrimFunc + PrimFunc from which features will be extracted. Each store operation to + a unique buffer in the function will result in one row of features in + the output. + + cache_line_bytes: int, optional + Size of a cache line in bytes. Defaults to 64 which is the size for + most x86 processors. + + max_n_bufs: int, optional + Maximum number of buffers in generated features. This determines the + length of the resulting feature vector. + + log_scale: bool + Should entries in the feature vector be scaled by log2(x + 1). Defaults + to False. Use True if using features with a cost model. + + Returns + ------- + Dict[str, np.ndarray] + Mapping from feature name to features. One element per store into a + unique buffer statement in `func`. + """ + features = features_from_primfunc(func, cache_line_bytes, max_n_bufs, log_scale) + names = get_per_store_feature_names(max_n_bufs) + return {name: features[:, i] for i, name in enumerate(names)} diff --git a/src/auto_scheduler/feature.cc b/src/auto_scheduler/feature.cc old mode 100755 new mode 100644 index 5809888543c6..1beb1ced6345 --- a/src/auto_scheduler/feature.cc +++ b/src/auto_scheduler/feature.cc @@ -53,7 +53,7 @@ using arith::Analyzer; using arith::ConstIntBound; template -using BufferMap = std::unordered_map; +using BufferMap = std::unordered_map; // The number of samples to extract for arithmetic intensity curves static const int ARITH_INTENSITY_CURVE_SAMPLE_N = 10; @@ -249,9 +249,9 @@ class MathOpCounter : public StmtExprVisitor { #define VisitBinary(Type, float_ct, int_ct) \ void VisitExpr_(const Type* op) final { \ if (op->a.dtype().is_float() || op->a.dtype().is_bfloat16()) { \ - float_ct++; \ + float_ct += op->a.dtype().lanes(); \ } else { \ - int_ct++; \ + int_ct += op->a.dtype().lanes(); \ } \ StmtExprVisitor::VisitExpr_(op); \ } @@ -340,14 +340,19 @@ class BufferAccessExtractor : public StmtExprVisitor { public: void ExtractReads(const PrimExpr& expr) { this->VisitExpr(expr); } - void InsertAccess(const Buffer& buf, BufferAccessType acc_type, const Array& indices) { + void InsertAccess(const Var& buf, BufferAccessType acc_type, const Array& indices) { BufferAccess& acc = buf_accesses[buf]; acc.acc_type = acc_type; acc.indices.push_back(std::vector(indices.begin(), indices.end())); } void VisitExpr_(const BufferLoadNode* op) final { - BufferAccess& acc = buf_accesses[op->buffer]; + AddAccess(op->buffer->data, op->indices); + StmtExprVisitor::VisitExpr_(op); + } + + void AddAccess(const Var& buffer, const Array& indices) { + BufferAccess& acc = buf_accesses[buffer]; switch (acc.acc_type) { case BufferAccessType::kRead: break; @@ -366,10 +371,8 @@ class BufferAccessExtractor : public StmtExprVisitor { // If a buffer is both read and written, in the tvm DSL, it must be a update, // so the indices should be the same. Then we can skip appending indices for it. // Otherwise we do the following. - buf_accesses[op->buffer].indices.push_back( - std::vector(op->indices.begin(), op->indices.end())); + buf_accesses[buffer].indices.push_back(std::vector(indices.begin(), indices.end())); } - StmtExprVisitor::VisitExpr_(op); } BufferMap buf_accesses; @@ -492,7 +495,7 @@ void ComputeRegion(const std::vector>& indices, arith::Ana // Compute reuse distance and reuse ratio for accesses to a buffer // return values: reuse_type, reuse_dis_iter, reuse_dis_bytes, reuse_ct std::tuple ComputeReuse( - const Buffer& buf, const std::vector>& indices, + const Var& buf, const std::vector>& indices, const std::vector& for_loop_stack, const std::unordered_map>>>& @@ -572,7 +575,17 @@ std::tuple ComputeReuse( // Extract features for every BufferStore statement class PerStoreFeatureExtractor : public StmtExprVisitor { public: - explicit PerStoreFeatureExtractor(int cache_line_size) : cache_line_size_(cache_line_size) {} + explicit PerStoreFeatureExtractor(int cache_line_size, const Map& existing_buffers) + : cache_line_size_(cache_line_size) { + for (const auto& buffer : existing_buffers) { + buffer_shapes[buffer.first] = buffer.second->shape; + buffer_dtypes[buffer.first] = buffer.second->dtype; + // Also need to add a reference from the buffers internal variable. This + // is usually how buffers are referenced within the body of a PrimFunc + buffer_shapes[buffer.second->data] = buffer.second->shape; + buffer_dtypes[buffer.second->data] = buffer.second->dtype; + } + } void VisitStmt_(const AttrStmtNode* node) final { if (node->attr_key == tir::attr::thread_extent || node->attr_key == tir::attr::virtual_thread) { @@ -659,7 +672,18 @@ class PerStoreFeatureExtractor : public StmtExprVisitor { } } + void VisitExpr_(const BufferLoadNode* node) final { + // Store buffer shape/dtype. It may already be stored. + buffer_shapes[node->buffer->data] = node->buffer->shape; + buffer_dtypes[node->buffer->data] = node->buffer->dtype; + StmtExprVisitor::VisitExpr_(node); + } + void VisitStmt_(const BufferStoreNode* node) final { + // Store buffer shape/dtype. It may already be stored. + buffer_shapes[node->buffer->data] = node->buffer->shape; + buffer_dtypes[node->buffer->data] = node->buffer->dtype; + MathOpCounter math_op_counter; math_op_counter(node->value); std::vector mem_bytes_list; @@ -667,20 +691,33 @@ class PerStoreFeatureExtractor : public StmtExprVisitor { double cur_compute_ops; // Group 1: Computation related features - ExtractComputationFeature(node, math_op_counter); + ExtractComputationFeature(node->buffer->data, node->indices, math_op_counter); // Group 2: Buffer access related features (per buffer) - ExtractBufferAccessFeature(node, math_op_counter, &cur_compute_ops, &compute_ops_list, - &mem_bytes_list); + ExtractBufferAccessFeature(node->buffer->data, node->indices, node->value, math_op_counter, + &cur_compute_ops, &compute_ops_list, &mem_bytes_list); // Group 3: Arithmetic intensity related features - ExtractArithmeticIntensityFeature(node, cur_compute_ops, compute_ops_list, mem_bytes_list); + ExtractArithmeticIntensityFeature(node->buffer->data, cur_compute_ops, compute_ops_list, + mem_bytes_list); // Group 4: Allocation related features - ExtractOuterScopeFeature(node); + ExtractOuterScopeFeature(node->buffer->data); } void VisitStmt_(const BufferRealizeNode* node) final { + // Store buffer shape/dtype. It may already be stored. + buffer_shapes[node->buffer->data] = node->buffer->shape; + buffer_dtypes[node->buffer->data] = node->buffer->dtype; + StmtExprVisitor::VisitStmt_(node); + + // Group 5: Outer scope related features + ExtractAllocationFeature(node); + } + + void VisitStmt_(const AllocateNode* node) final { + buffer_dtypes[node->buffer_var] = node->dtype; + buffer_shapes[node->buffer_var] = node->extents; StmtExprVisitor::VisitStmt_(node); // Group 5: Outer scope related features @@ -688,9 +725,9 @@ class PerStoreFeatureExtractor : public StmtExprVisitor { } // Extract computation related features (group 1) - void ExtractComputationFeature(const BufferStoreNode* node, + void ExtractComputationFeature(const Var& buffer, const Array& indices, const MathOpCounter& math_op_counter) { - FeatureSet& fea = buffer_features[node->buffer]; + FeatureSet& fea = buffer_features[buffer]; // Computation related features fea.float_mad = outer_loop_prod_ * math_op_counter.float_mad; @@ -762,16 +799,17 @@ class PerStoreFeatureExtractor : public StmtExprVisitor { } // Extract buffer access related features (group 2) - void ExtractBufferAccessFeature(const BufferStoreNode* node, const MathOpCounter& math_op_counter, + void ExtractBufferAccessFeature(const Var& buffer, const Array& indices, + const PrimExpr& value, const MathOpCounter& math_op_counter, double* cur_compute_ops, std::vector* compute_ops_list, std::vector* mem_bytes_list) { - FeatureSet& fea = buffer_features[node->buffer]; + FeatureSet& fea = buffer_features[buffer]; // Extract all buffer accesses std::vector acc_feas; BufferAccessExtractor buf_extractor; - buf_extractor.InsertAccess(node->buffer, BufferAccessType::kWrite, node->indices); - buf_extractor.ExtractReads(node->value); + buf_extractor.InsertAccess(buffer, BufferAccessType::kWrite, indices); + buf_extractor.ExtractReads(value); // Compute touched region for all outer loops for (auto x : for_loop_stack_) { @@ -801,14 +839,14 @@ class PerStoreFeatureExtractor : public StmtExprVisitor { int64_t mem_bytes = 0; for (const auto& x : buf_extractor.buf_accesses) { - const Buffer& t = x.first; + const Var& t = x.first; const BufferAccess& acc = x.second; ComputeRegion(acc.indices, &ana_, &tmp_region); int64_t touched_size = ElementProduct(tmp_region); buffer_regions_map[t].push_back( - std::make_tuple(acc.acc_type, touched_size, t->dtype.bytes())); - mem_bytes += touched_size * t->dtype.bytes(); + std::make_tuple(acc.acc_type, touched_size, buffer_dtypes.at(t).bytes())); + mem_bytes += touched_size * buffer_dtypes.at(t).bytes(); } mem_bytes_list->push_back(std::log2(mem_bytes)); @@ -818,15 +856,15 @@ class PerStoreFeatureExtractor : public StmtExprVisitor { // Buffer access related features (per buffer) for (const auto& x : buf_extractor.buf_accesses) { - const Buffer& t = x.first; + const Var& t = x.first; const BufferAccess& acc = x.second; std::vector int_shape; - for (const auto& dim : t->shape) { + for (const auto& dim : buffer_shapes.at(t)) { int_shape.push_back(GetIntImm(dim)); } - size_t ele_bytes = t->dtype.bytes(); + size_t ele_bytes = buffer_dtypes.at(t).bytes(); // calculate bytes float bytes = outer_loop_prod_ * ele_bytes; @@ -886,7 +924,8 @@ class PerStoreFeatureExtractor : public StmtExprVisitor { acc_feas.emplace_back(); BufferAccessFeature& acc_fea = acc_feas.back(); - acc_fea.buffer_name = t->name; + // TODO(tkonolige): save buffer names and use those instead? + acc_fea.buffer_name = t->name_hint; acc_fea.acc_type = acc.acc_type; acc_fea.stride = stride; acc_fea.bytes = bytes; @@ -915,10 +954,10 @@ class PerStoreFeatureExtractor : public StmtExprVisitor { } // Extract arithmetic intensity related feature (group 3) - void ExtractArithmeticIntensityFeature(const BufferStoreNode* node, double cur_compute_ops, + void ExtractArithmeticIntensityFeature(const Var& buffer, double cur_compute_ops, const std::vector& compute_ops_list, const std::vector& mem_bytes_list) { - FeatureSet& fea = buffer_features[node->buffer]; + FeatureSet& fea = buffer_features[buffer]; // Compute arithmetic intensity curve (y axis : arithmetic intensity, x axis : flops). // We use piecewise linear interpolation to fit this curve. @@ -951,7 +990,7 @@ class PerStoreFeatureExtractor : public StmtExprVisitor { // Extract allocation related features (group 4) void ExtractAllocationFeature(const BufferRealizeNode* node) { - FeatureSet& fea = buffer_features[node->buffer]; + FeatureSet& fea = buffer_features[node->buffer->data]; float allocation_size = 1.0f; for (const auto& x : node->bounds) { @@ -964,9 +1003,24 @@ class PerStoreFeatureExtractor : public StmtExprVisitor { fea.alloc_inner_prod = fea.outer_prod / outer_loop_prod_; } + void ExtractAllocationFeature(const AllocateNode* node) { + FeatureSet& fea = buffer_features[node->buffer_var]; + + float allocation_size = 1.0f; + for (const auto& x : node->extents) { + // TODO(tkonolige): will not handle dynamic shape + allocation_size *= GetIntImm(x); + } + // allocation feature + fea.alloc_size = allocation_size * node->dtype.bytes(); + fea.alloc_prod = allocation_size * outer_loop_prod_; + fea.alloc_outer_prod = outer_loop_prod_; + fea.alloc_inner_prod = fea.outer_prod / outer_loop_prod_; + } + // Extract outer scope related features (group 5) - void ExtractOuterScopeFeature(const BufferStoreNode* node) { - FeatureSet& fea = buffer_features[node->buffer]; + void ExtractOuterScopeFeature(const Var& buffer) { + FeatureSet& fea = buffer_features[buffer]; fea.outer_prod = outer_loop_prod_; fea.num_loops = for_loop_stack_.size(); @@ -1009,15 +1063,22 @@ class PerStoreFeatureExtractor : public StmtExprVisitor { // The default cache line size in bytes const int cache_line_size_ = 64; + + // Storage of buffer shape and dtype information. Needed because Load/Store + // nodes only do not contain this information. + BufferMap> buffer_shapes; + BufferMap buffer_dtypes; }; -// shifted log to incorporate the property that slog(0) = 0 -inline float slog(float x) { return x < 0 ? -std::log2(-x + 1) : std::log2(x + 1); } +// shifted log to incorporate the property that log2p(0) = 0 +inline float log2p(float x) { return x < 0 ? -std::log2(-x + 1) : std::log2(x + 1); } -void GetPerStoreFeature(const Stmt& stmt, int cache_line_size, int max_n_bufs, - std::vector* ret) { - PerStoreFeatureExtractor extractor(cache_line_size); - extractor(stmt); +void GetPerStoreFeature(const PrimFunc& func, int cache_line_size, int max_n_bufs, + std::vector* ret, bool log_scale) { + PerStoreFeatureExtractor extractor(cache_line_size, func->buffer_map); + extractor(func->body); + + auto slog = log_scale ? log2p : [](float x) { return x; }; ret->push_back(extractor.buffer_features.size()); @@ -1308,8 +1369,7 @@ void GetPerStoreFeaturesWorkerFunc(const SearchTask& task, const State& state, i tir::transform::Sequential(Array{tir::transform::Simplify()}); mod = optimize(std::move(mod)); PrimFunc prim_func = Downcast(mod->Lookup(name)); - GetPerStoreFeature(prim_func->body, task->hardware_params->cache_line_bytes, max_n_bufs, - feature); + GetPerStoreFeature(prim_func, task->hardware_params->cache_line_bytes, max_n_bufs, feature); } catch (Error& e) { (*error_ct)++; } @@ -1636,5 +1696,18 @@ TVM_REGISTER_GLOBAL("auto_scheduler.GetPerStoreFeatureNames") *ret = arr; }); +TVM_REGISTER_GLOBAL("auto_scheduler.FeaturesFromPrimFunc") + .set_body_typed([](const PrimFunc& func, int cache_line_size, int max_n_bufs, bool log_scale) { + std::vector vec; + GetPerStoreFeature(func, cache_line_size, max_n_bufs, &vec, log_scale); + int64_t num_feature_rows = vec[0]; // first element is number of rows + int64_t row_length = (vec.size() - 1) / num_feature_rows; + auto ary = + runtime::NDArray::Empty({num_feature_rows, row_length}, {kDLFloat, 32, 1}, {kDLCPU, 0}); + // NDArray is row major by default + ary.CopyFromBytes(vec.data() + 1, sizeof(float) * num_feature_rows * row_length); + return ary; + }); + } // namespace auto_scheduler } // namespace tvm diff --git a/tests/python/unittest/test_auto_scheduler_feature.py b/tests/python/unittest/test_auto_scheduler_feature.py index 96090e328328..a092afe28b93 100644 --- a/tests/python/unittest/test_auto_scheduler_feature.py +++ b/tests/python/unittest/test_auto_scheduler_feature.py @@ -22,6 +22,7 @@ import tvm from tvm import te, auto_scheduler +from tvm.script import tir as T from tvm.testing.auto_scheduler import matmul_auto_scheduler_test @@ -200,6 +201,33 @@ def test_gpu_feature(): assert fequal(fea_dicts[0]["is_gpu"], 1.0) +@T.prim_func +def tir_matmul( + A: T.Buffer[(16384,), "float32"], + B: T.Buffer[(16384,), "float32"], + C: T.Buffer[(16384,), "float32"], +) -> None: + # function attr dict + T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) + T.preflattened_buffer(A, [128, 128], dtype="float32", data=A.data) + T.preflattened_buffer(B, [128, 128], dtype="float32", data=B.data) + T.preflattened_buffer(C, [128, 128], dtype="float32", data=C.data) + # body + for x, y in T.grid(128, 128): + C[x * 128 + y] = T.float32(0) + for k in T.serial(128): + C[x * 128 + y] = C[x * 128 + y] + A[x * 128 + k] * B[y * 128 + k] + + +def test_primfunc(): + features = auto_scheduler.feature.named_features_from_primfunc(tir_matmul) + assert features["float_mad"].shape == (1,) + # featurization does not handle multiple-add right now, so they are split out + assert abs(features["float_addsub"][0] - 128 * 128 * 128) < 10 + assert abs(features["float_mul"][0] - 128 * 128 * 128) < 10 + assert abs(features["B0.unique_bytes"][0] - 128 * 128 * 4) < 10 # 4 bytes per float32 + + if __name__ == "__main__": test_cpu_matmul() test_cpu_fusion()