Skip to content

Commit 4d4f050

Browse files
ekaldalhutton1Neil Hickey
authored
[SVE] Support scalable vectors in LoopVectorizer (#16782)
This patch add support for turning loops marked for vectorizing into scalable vectors if the extent of the loop is a vscale dependent expression in a correct form. The testing for both scalable and fixed length vectors in test_tir_transform.py has been extended and most of the tests have been converted to TVMScript based testing against expected output. Co-authored-by: Luke Hutton <[email protected]> Co-authored-by: Neil Hickey <[email protected]>
1 parent a309b6b commit 4d4f050

File tree

5 files changed

+428
-148
lines changed

5 files changed

+428
-148
lines changed

include/tvm/runtime/data_type.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,9 @@ class DataType {
111111
return -lanes_as_int;
112112
}
113113
/*! \return get vscale factor or lanes depending on scalability of the vector. */
114-
int get_lanes_or_vscale_factor() { return is_scalable_vector() ? vscale_factor() : lanes(); }
114+
int get_lanes_or_vscale_factor() const {
115+
return is_scalable_vector() ? vscale_factor() : lanes();
116+
}
115117
/*! \return whether type is a scalar type. */
116118
bool is_scalar() const { return !is_scalable_vector() && lanes() == 1; }
117119
/*! \return whether type is a scalar type. */

include/tvm/tir/op.h

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
#include <tvm/ir/expr.h>
3232
#include <tvm/ir/op.h>
3333
#include <tvm/ir/type.h>
34+
#include <tvm/tir/builtin.h>
3435
#include <tvm/tir/expr.h>
3536
#include <tvm/tir/stmt.h>
3637

@@ -959,10 +960,16 @@ inline PrimExpr MakeConstScalar(DataType t, bool value, Span span) {
959960

960961
template <typename ValueType, typename>
961962
inline PrimExpr make_const(DataType t, ValueType value, Span span) {
962-
if (t.lanes() == 1) {
963+
if (t.is_scalar()) {
963964
return MakeConstScalar(t, value, span);
964965
} else {
965-
return tir::Broadcast(MakeConstScalar(t.element_of(), value, span), t.lanes(), span);
966+
if (t.is_fixed_length_vector()) {
967+
return tir::Broadcast(MakeConstScalar(t.element_of(), value, span), t.lanes(), span);
968+
} else {
969+
PrimExpr lanes =
970+
tir::Mul(tir::Call(DataType::Int(32), tir::builtin::vscale(), {}), t.vscale_factor());
971+
return tir::Broadcast(MakeConstScalar(t.element_of(), value, span), lanes, span);
972+
}
966973
}
967974
}
968975

src/tir/ir/expr.cc

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,8 @@ TVM_REGISTER_NODE_TYPE(StringImmNode);
196196
// Cast
197197
Cast::Cast(DataType t, PrimExpr value, Span span) {
198198
ICHECK(value.defined());
199-
ICHECK_EQ(t.lanes(), value.dtype().lanes());
199+
ICHECK_EQ(t.get_lanes_or_vscale_factor(), value.dtype().get_lanes_or_vscale_factor());
200+
ICHECK(t.is_scalable_vector() == value.dtype().is_scalable_vector());
200201
ObjectPtr<CastNode> node = make_object<CastNode>();
201202
node->dtype = t;
202203
node->value = std::move(value);
@@ -354,7 +355,8 @@ And::And(PrimExpr a, PrimExpr b, Span span) {
354355
ICHECK(a.dtype() == b.dtype()) << "TypeError: mismatched types";
355356

356357
ObjectPtr<AndNode> node = make_object<AndNode>();
357-
node->dtype = DataType::Bool(a.dtype().lanes());
358+
node->dtype =
359+
DataType::Bool(a.dtype().get_lanes_or_vscale_factor(), a.dtype().is_scalable_vector());
358360
node->a = std::move(a);
359361
node->b = std::move(b);
360362
node->span = std::move(span);
@@ -376,7 +378,8 @@ Or::Or(PrimExpr a, PrimExpr b, Span span) {
376378
ICHECK(a.dtype() == b.dtype()) << "TypeError: mismatched types";
377379

378380
ObjectPtr<OrNode> node = make_object<OrNode>();
379-
node->dtype = DataType::Bool(a.dtype().lanes());
381+
node->dtype =
382+
DataType::Bool(a.dtype().get_lanes_or_vscale_factor(), a.dtype().is_scalable_vector());
380383
node->a = std::move(a);
381384
node->b = std::move(b);
382385
node->span = std::move(span);
@@ -412,7 +415,9 @@ Select::Select(PrimExpr condition, PrimExpr true_value, PrimExpr false_value, Sp
412415
ICHECK(true_value.defined()) << "ValueError: true_value is undefined";
413416
ICHECK(false_value.defined()) << "ValueError: true_value is undefined";
414417
ICHECK(condition.dtype().is_bool());
415-
ICHECK(condition.dtype().lanes() == true_value.dtype().lanes() || condition.dtype().lanes() == 1);
418+
ICHECK(condition.dtype().get_lanes_or_vscale_factor() ==
419+
true_value.dtype().get_lanes_or_vscale_factor() ||
420+
condition.dtype().is_scalar());
416421
ICHECK(false_value.dtype() == true_value.dtype())
417422
<< "TypeError: mismatched types. "
418423
<< "False type: " << false_value.dtype() << "; True type: " << true_value.dtype();

0 commit comments

Comments
 (0)