@@ -86,14 +86,45 @@ bool CanProveVscaleExpressionFromKnownValues(arith::Analyzer* analyzer, const Pr
8686 return can_prove_expr;
8787}
8888
89- bool TargetHasSVE (Optional<Target> target) {
89+ bool TargetHasVLA (Optional<Target> target) {
9090 if (!target.defined ()) {
9191 target = Target::Current ();
9292 }
93+ bool has_vla{false };
9394 if (target.defined ()) {
94- return Downcast<Target>(target)->GetFeature <Bool>(" has_sve" ).value_or (Bool (false ));
95+ // aarch64
96+ has_vla = Downcast<Target>(target)->GetFeature <Bool>(" has_sve" ).value_or (Bool (false ));
97+ // riscv{32,64}
98+ static const PackedFunc* target_has_feature_fn_ptr =
99+ runtime::Registry::Get (" target.target_has_feature" );
100+ ICHECK (target_has_feature_fn_ptr != nullptr )
101+ << " The `target.target_has_feature` func is not in tvm registry." ;
102+ has_vla |= static_cast <bool >((*target_has_feature_fn_ptr)(" v" , target));
95103 }
96- return false ;
104+ return has_vla;
105+ }
106+
107+ const std::vector<unsigned int > GetVScaleValues (Optional<Target> target) {
108+ unsigned int vector_width = 0 ;
109+ std::vector<unsigned int > kVScaleValues ;
110+ if (!target.defined ()) {
111+ target = Target::Current ();
112+ }
113+ if (target.defined ()) {
114+ static const PackedFunc* llvm_get_vector_width_fn_ptr =
115+ runtime::Registry::Get (" target.llvm_get_vector_width" );
116+ ICHECK (llvm_get_vector_width_fn_ptr != nullptr )
117+ << " The `target.llvm_get_vector_width` func is not in tvm registry." ;
118+ vector_width = static_cast <int >((*llvm_get_vector_width_fn_ptr)(target));
119+ }
120+ // scale list with powers of two
121+ for (unsigned int i = 0 ;; ++i) {
122+ auto power = static_cast <unsigned int >(std::pow (2 , i));
123+ if (power > (vector_width / 8 )) break ;
124+ kVScaleValues .push_back (power);
125+ }
126+
127+ return kVScaleValues ;
97128}
98129
99130} // namespace arith
0 commit comments