Skip to content

Commit af0c038

Browse files
lhutton1ekaldaNeil Hickey
authored
[SVE] Add codegen support for scalable buffer accesses (#16696)
This commit adds support for generating code for scalable loads and stores. It also adds support for the creation of scalable broadcast operations. Co-authored-by: Elen Kalda <[email protected]> Co-authored-by: Neil Hickey <[email protected]>
1 parent 695f958 commit af0c038

File tree

10 files changed

+249
-39
lines changed

10 files changed

+249
-39
lines changed

include/tvm/runtime/data_type.h

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,8 @@ class DataType {
110110
}
111111
return -lanes_as_int;
112112
}
113+
/*! \return get vscale factor or lanes depending on scalability of the vector. */
114+
int get_lanes_or_vscale_factor() { return is_scalable_vector() ? vscale_factor() : lanes(); }
113115
/*! \return whether type is a scalar type. */
114116
bool is_scalar() const { return !is_scalable_vector() && lanes() == 1; }
115117
/*! \return whether type is a scalar type. */
@@ -211,10 +213,13 @@ class DataType {
211213
/*!
212214
* \brief Construct an uint type.
213215
* \param bits The number of bits in the type.
214-
* \param lanes The number of lanes
216+
* \param lanes The number of lanes.
217+
* \param is_scalable Whether the data type is scalable.
215218
* \return The constructed data type.
216219
*/
217-
static DataType UInt(int bits, int lanes = 1) { return DataType(kDLUInt, bits, lanes); }
220+
static DataType UInt(int bits, int lanes = 1, bool is_scalable = false) {
221+
return DataType(kDLUInt, bits, lanes, is_scalable);
222+
}
218223
/*!
219224
* \brief Construct an float type.
220225
* \param bits The number of bits in the type.
@@ -243,10 +248,13 @@ class DataType {
243248
static DataType NVFloat8E5M2(int lanes = 1) { return DataType(kE5M2Float, 8, lanes); }
244249
/*!
245250
* \brief Construct a bool type.
246-
* \param lanes The number of lanes
251+
* \param lanes The number of lanes.
252+
* \param is_scalable Whether the data type is scalable.
247253
* \return The constructed data type.
248254
*/
249-
static DataType Bool(int lanes = 1) { return DataType::UInt(1, lanes); }
255+
static DataType Bool(int lanes = 1, bool is_scalable = false) {
256+
return DataType::UInt(1, lanes, is_scalable);
257+
}
250258
/*!
251259
* \brief Construct a handle type.
252260
* \param bits The number of bits in the type.

python/tvm/testing/utils.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1045,6 +1045,13 @@ def _has_cpu_feat(features):
10451045
)
10461046

10471047

1048+
requires_aarch64_sve = Feature(
1049+
"arm_sve",
1050+
"AArch64 SVE",
1051+
run_time_check=lambda: _has_cpu_feat("sve"),
1052+
)
1053+
1054+
10481055
requires_x86_vnni = Feature(
10491056
"x86_vnni",
10501057
"x86 VNNI Extensions",

src/target/llvm/codegen_llvm.cc

Lines changed: 35 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -587,10 +587,17 @@ llvm::Type* CodeGenLLVM::DTypeToLLVMType(const DataType& dtype) const {
587587
LOG(FATAL) << "do not support " << dtype;
588588
}
589589
}
590-
if (dtype.lanes() != 1) {
590+
if (!dtype.is_scalar()) {
591591
#if TVM_LLVM_VERSION >= 110
592-
return llvm::FixedVectorType::get(etype, dtype.lanes());
592+
if (dtype.is_scalable_vector()) {
593+
return llvm::VectorType::get(etype, dtype.vscale_factor(), true);
594+
} else {
595+
return llvm::FixedVectorType::get(etype, dtype.lanes());
596+
}
593597
#else
598+
ICHECK(!dtype.is_scalable_vector())
599+
<< "Versions of LLVM < 11 do not support scalable vectors. Please upgrade to a later "
600+
"version.";
594601
return llvm::VectorType::get(etype, dtype.lanes());
595602
#endif
596603
} else {
@@ -749,26 +756,6 @@ std::unique_ptr<CodeGenLLVM::DebugInfo> CodeGenLLVM::CreateDebugInfo(llvm::Modul
749756
return debug_info;
750757
}
751758

752-
llvm::Value* CodeGenLLVM::CreateBroadcast(llvm::Value* value, int lanes) {
753-
#if TVM_LLVM_VERSION >= 110
754-
llvm::Type* type = llvm::FixedVectorType::get(value->getType(), lanes);
755-
#else
756-
llvm::Type* type = llvm::VectorType::get(value->getType(), lanes);
757-
#endif
758-
llvm::Constant* undef = llvm::UndefValue::get(type);
759-
llvm::Constant* zero = ConstInt32(0);
760-
value = builder_->CreateInsertElement(undef, value, zero);
761-
#if TVM_LLVM_VERSION >= 120
762-
llvm::Constant* mask = llvm::ConstantVector::getSplat(llvm::ElementCount::getFixed(lanes), zero);
763-
#elif TVM_LLVM_VERSION >= 110
764-
llvm::Constant* mask =
765-
llvm::ConstantVector::getSplat(llvm::ElementCount(lanes, /*Scalable=*/false), zero);
766-
#else
767-
llvm::Constant* mask = llvm::ConstantVector::getSplat(lanes, zero);
768-
#endif
769-
return builder_->CreateShuffleVector(value, undef, mask);
770-
}
771-
772759
llvm::Value* CodeGenLLVM::CreateVecSlice(llvm::Value* vec, int begin, int extent) {
773760
int num_elems = GetVectorNumElements(vec);
774761
if (extent == num_elems && begin == 0) return vec;
@@ -1693,7 +1680,8 @@ void CodeGenLLVM::BufferAccessHelper(
16931680
}
16941681

16951682
PrimExpr last_index = indices[indices.size() - 1];
1696-
ICHECK_EQ(value_dtype.lanes(), last_index.dtype().lanes() * buffer_element_dtype.lanes());
1683+
ICHECK_EQ(value_dtype.get_lanes_or_vscale_factor(),
1684+
last_index.dtype().get_lanes_or_vscale_factor() * buffer_element_dtype.lanes());
16971685

16981686
// Record index and elemtype in original form used for alias info
16991687
PrimExpr last_index_origin = last_index;
@@ -1736,8 +1724,6 @@ void CodeGenLLVM::BufferAccessHelper(
17361724
llvm::Value* last_index_value;
17371725
int subelement_i = i;
17381726
if (const RampNode* ramp = last_index.as<RampNode>()) {
1739-
// TODO(ekalda): P4 in https://github.com/apache/tvm/issues/16455
1740-
ICHECK(!last_index.dtype().is_scalable_vector());
17411727
PrimExpr offset = ramp->base + (ramp->stride * i);
17421728
last_index_value = MakeValue(offset);
17431729
} else if (last_index.dtype().lanes() > 1) {
@@ -1754,8 +1740,13 @@ void CodeGenLLVM::BufferAccessHelper(
17541740
all_index_values.push_back(last_index_value);
17551741

17561742
TypedPointer buffer_ptr =
1757-
CreateBufferPtr(MakeValue(buffer->data), buffer_element_dtype, all_index_values,
1758-
value_dtype.with_lanes(value_dtype.lanes() / last_index.dtype().lanes()));
1743+
value_dtype.is_scalable_vector()
1744+
? CreateBufferPtr(MakeValue(buffer->data), buffer_element_dtype, all_index_values,
1745+
value_dtype.with_scalable_vscale_factor(value_dtype.vscale_factor() /
1746+
last_index.dtype().lanes()))
1747+
: CreateBufferPtr(
1748+
MakeValue(buffer->data), buffer_element_dtype, all_index_values,
1749+
value_dtype.with_lanes(value_dtype.lanes() / last_index.dtype().lanes()));
17591750
auto instruction = make_instruction(buffer_ptr, subelement_i, alignment, is_volatile);
17601751
AddAliasInfo(instruction, buffer->data.get(), last_index_origin, buffer_element_dtype_origin);
17611752
}
@@ -1870,10 +1861,23 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const ShuffleNode* op) {
18701861
}
18711862

18721863
llvm::Value* CodeGenLLVM::VisitExpr_(const BroadcastNode* op) {
1873-
// TODO(ekalda): P4 in https://github.com/apache/tvm/issues/16455
1874-
ICHECK(!op->dtype.is_scalable_vector());
1875-
int lanes = op->dtype.lanes();
1876-
return CreateBroadcast(MakeValue(op->value), lanes);
1864+
DataType dtype = op->dtype;
1865+
llvm::Value* value = MakeValue(op->value);
1866+
llvm::Type* type = DTypeToLLVMType(dtype);
1867+
llvm::Constant* undef = llvm::UndefValue::get(type);
1868+
llvm::Constant* zero = ConstInt32(0);
1869+
value = builder_->CreateInsertElement(undef, value, zero);
1870+
#if TVM_LLVM_VERSION >= 110
1871+
llvm::ElementCount ec =
1872+
llvm::ElementCount::get(dtype.get_lanes_or_vscale_factor(), dtype.is_scalable_vector());
1873+
llvm::Constant* mask = llvm::ConstantVector::getSplat(ec, zero);
1874+
#else
1875+
ICHECK(!dtype.is_scalable_vector())
1876+
<< "Versions of LLVM < 11 do not support scalable vectors. Please upgrade to a later "
1877+
"version.";
1878+
llvm::Constant* mask = llvm::ConstantVector::getSplat(dtype.lanes(), zero);
1879+
#endif
1880+
return builder_->CreateShuffleVector(value, undef, mask);
18771881
}
18781882

18791883
void CodeGenLLVM::VisitStmt_(const BufferStoreNode* op) {

src/target/llvm/codegen_llvm.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -468,7 +468,6 @@ class CodeGenLLVM : public ExprFunctor<llvm::Value*(const PrimExpr&)>,
468468
llvm::Value* CreateAdd(DataType t, llvm::Value* a, llvm::Value* b);
469469
llvm::Value* CreateSub(DataType t, llvm::Value* a, llvm::Value* b);
470470
llvm::Value* CreateMul(DataType t, llvm::Value* a, llvm::Value* b);
471-
llvm::Value* CreateBroadcast(llvm::Value* value, int lanes);
472471
virtual TypedPointer CreateBufferPtr(llvm::Value* buffer_ptr, DataType buffer_element_dtype,
473472
llvm::ArrayRef<llvm::Value*> indices, DataType value_dtype);
474473
// Vector concatenation.

src/tir/ir/data_type_rewriter.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -451,7 +451,7 @@ Stmt IndexDataTypeRewriter::VisitStmt_(const BufferStoreNode* op) {
451451

452452
Buffer new_buffer = GetRemappedBuffer(op->buffer);
453453
auto value = this->VisitExpr(op->value);
454-
if (new_buffer->dtype != value->dtype && value->dtype.lanes() == 1) {
454+
if (new_buffer->dtype != value->dtype && value->dtype.is_scalar()) {
455455
value = cast(new_buffer->dtype, value);
456456
}
457457
auto indices = VisitIndices(op->indices);

src/tir/ir/expr.cc

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,9 @@ namespace tir {
5858
CHECK(a.dtype() == b.dtype()) << "TypeError: mismatched types. " << a.dtype() << " vs. " \
5959
<< b.dtype() << "\n"; \
6060
ObjectPtr<T> node = make_object<T>(); \
61-
node->dtype = DataType::Bool(a.dtype().lanes()); \
61+
DataType a_dtype = a.dtype(); \
62+
node->dtype = \
63+
DataType::Bool(a_dtype.get_lanes_or_vscale_factor(), a_dtype.is_scalable_vector()); \
6264
node->a = std::move(a); \
6365
node->b = std::move(b); \
6466
node->span = std::move(span); \
@@ -393,7 +395,8 @@ Not::Not(PrimExpr a, Span span) {
393395
ICHECK(a.dtype().is_bool());
394396

395397
ObjectPtr<NotNode> node = make_object<NotNode>();
396-
node->dtype = DataType::Bool(a.dtype().lanes());
398+
DataType a_dtype = a.dtype();
399+
node->dtype = DataType::Bool(a_dtype.get_lanes_or_vscale_factor(), a_dtype.is_scalable_vector());
397400
node->a = std::move(a);
398401
node->span = std::move(span);
399402
data_ = std::move(node);

src/tir/transforms/storage_rewrite.cc

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1275,6 +1275,13 @@ class VectorTypeAccessChecker : public StmtExprVisitor {
12751275
auto it = info_map_.find(buffer);
12761276
ICHECK(it != info_map_.end()) << "Load/Store of buffer " << buffer->name_hint << " (" << buffer
12771277
<< ") occurred before its declaration.";
1278+
1279+
if (value_dtype.is_scalable_vector()) {
1280+
// Scalable types are not currently supported in storage_rewrite. Scalable buffer
1281+
// accesses are not currently checked and therefore are not rewritten.
1282+
return;
1283+
}
1284+
12781285
BufferVarInfo& var_info = it->second;
12791286

12801287
if (value_dtype.element_of() == DataType::Bool()) {

tests/cpp/tir_scalable_datatype.cc

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,22 @@ TEST(ScalableDataType, TestScalableDataTypeInvalidLanesAccess) {
162162
tvm::InternalError);
163163
}
164164

165+
TEST(ScalableDataType, TestScalableBool) {
166+
tvm::DataType scalable_type = tvm::DataType::Bool(4, true);
167+
ASSERT_EQ(scalable_type.code(), kDLUInt);
168+
ASSERT_EQ(scalable_type.bits(), 1);
169+
ASSERT_EQ(scalable_type.vscale_factor(), 4);
170+
ASSERT_TRUE(scalable_type.is_scalable_vector());
171+
}
172+
173+
TEST(ScalableDataType, TestScalableUInt) {
174+
tvm::DataType scalable_type = tvm::DataType::UInt(1, 4, true);
175+
ASSERT_EQ(scalable_type.code(), kDLUInt);
176+
ASSERT_EQ(scalable_type.bits(), 1);
177+
ASSERT_EQ(scalable_type.vscale_factor(), 4);
178+
ASSERT_TRUE(scalable_type.is_scalable_vector());
179+
}
180+
165181
// -----------
166182
// Integration
167183
// -----------

tests/python/codegen/test_target_codegen_aarch64.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -492,5 +492,46 @@ def main(A: T.Buffer((5,), "int32")):
492492
assert re.findall(r"llvm.vscale.i32", llvm), "No vscale in generated LLVM."
493493

494494

495+
@pytest.mark.skipif(
496+
llvm_version_major() < 11, reason="Vscale is not supported in earlier versions of LLVM"
497+
)
498+
def test_scalable_buffer_load_store():
499+
target = "llvm -mtriple=aarch64-linux-gnu -mattr=+sve"
500+
501+
@T.prim_func
502+
def my_func(a: T.handle, b: T.handle):
503+
A = T.match_buffer(a, (128,), "float32")
504+
B = T.match_buffer(b, (128,), "float32")
505+
T.func_attr({"global_symbol": "my_module", "tir.noalias": True})
506+
B[T.ramp(0, 1, 4 * T.vscale())] = A[T.ramp(0, 1, 4 * T.vscale())]
507+
508+
mod = tvm.build(my_func, target=target)
509+
llvm = mod.get_source("ll")
510+
511+
assert re.findall(r"load <vscale x 4 x float>", llvm), "No scalable load in generated LLVM."
512+
assert re.findall(r" store <vscale x 4 x float>", llvm), "No scalable store in generated LLVM."
513+
514+
515+
@pytest.mark.skipif(
516+
llvm_version_major() < 11, reason="Vscale is not supported in earlier versions of LLVM"
517+
)
518+
def test_scalable_broadcast():
519+
target = "llvm -mtriple=aarch64-linux-gnu -mattr=+sve"
520+
521+
@T.prim_func
522+
def my_func(a: T.handle):
523+
A = T.match_buffer(a, (128,), "float32")
524+
T.func_attr({"global_symbol": "my_module", "tir.noalias": True})
525+
A[T.ramp(0, 1, 4 * T.vscale())] = T.broadcast(1, 4 * T.vscale())
526+
527+
mod = tvm.build(my_func, target=target)
528+
llvm = mod.get_source("ll")
529+
530+
assert re.findall(
531+
r"shufflevector \(<vscale x 4 x float> insertelement \(<vscale x 4 x float>", llvm
532+
), "No scalable broadcast in generated LLVM."
533+
assert re.findall(r" store <vscale x 4 x float>", llvm), "No scalable store in generated LLVM."
534+
535+
495536
if __name__ == "__main__":
496537
tvm.testing.main()

0 commit comments

Comments
 (0)