Skip to content

Commit 29534b7

Browse files
authored
[SVE] Check for SVE target in VectorizeLoop (#16893)
Check that we are compiling for an SVE enabled target when the extent of a loop marked for vectorizing is a vscale dependent expression. The extent of a loop should be either a positive integer or an vscale dependent expression, in the latter case we'd expect the target to have `has_sve` feature.
1 parent 57316da commit 29534b7

File tree

5 files changed

+125
-62
lines changed

5 files changed

+125
-62
lines changed

src/arith/analyzer.cc

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -235,17 +235,14 @@ bool Analyzer::CanProve(const PrimExpr& expr, ProofStrength strength) {
235235
// SVE, we can make some assumptions about the value of vscale and iterate over a
236236
// space of pre-defined values to attempt to prove the expression.
237237
if (tir::CheckContains::ExprContains(expr, IsVScaleCall)) {
238-
Target curr_target = tvm::Target::Current();
239-
if (curr_target.defined() && curr_target->features.defined() &&
240-
(curr_target->features.find("has_sve") != curr_target->features.end()) &&
241-
curr_target->GetFeature<Bool>("has_sve").value_or(Bool(false)).operator bool()) {
238+
if (TargetHasSVE()) {
242239
return CanProveVscaleExpressionFromKnownValues(this, simplified, kAArch64VScaleValues);
243240
}
244241
LOG(WARNING)
245242
<< "The expression contains scalable values. An attempt to prove by substituting "
246243
"with known values of vscale was not performed. This proof currently only supports "
247244
"AArch64 SVE targets, but the target was "
248-
<< curr_target;
245+
<< Target::Current();
249246
}
250247
return false;
251248
}

src/arith/scalable_expression.cc

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,5 +88,14 @@ bool CanProveVscaleExpressionFromKnownValues(arith::Analyzer* analyzer, const Pr
8888
return can_prove_expr;
8989
}
9090

91+
bool TargetHasSVE() {
92+
Target current_target = Target::Current();
93+
bool has_sve{false};
94+
if (current_target.defined()) {
95+
has_sve = current_target->GetFeature<Bool>("has_sve").value_or(Bool(false));
96+
}
97+
return has_sve;
98+
}
99+
91100
} // namespace arith
92101
} // namespace tvm

src/arith/scalable_expression.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,12 @@ std::optional<int> ExtractVscaleFactor(const PrimExpr& lanes);
7171
bool CanProveVscaleExpressionFromKnownValues(arith::Analyzer* analyzer, const PrimExpr& expr,
7272
const std::vector<unsigned int>& vscale_values);
7373

74+
/*!
75+
* \brief Check whether the compilation target supports SVE
76+
* \return Whether SVE is supported
77+
*/
78+
bool TargetHasSVE();
79+
7480
} // namespace arith
7581
} // namespace tvm
7682

src/tir/transforms/vectorize_loop.cc

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,9 @@
3434
#include <unordered_map>
3535
#include <vector>
3636

37+
#include "../../src/arith/scalable_expression.h"
38+
#include "../../tir/analysis/check_contains.h"
39+
3740
namespace tvm {
3841
namespace tir {
3942

@@ -727,6 +730,14 @@ class LoopVectorizer : public StmtMutator {
727730
public:
728731
Stmt VisitStmt_(const ForNode* op) final {
729732
if (op->kind == ForKind::kVectorized) {
733+
auto* extent_as_int = op->extent.as<IntImmNode>();
734+
735+
if (!extent_as_int || extent_as_int->value < 1) {
736+
bool is_scalable_expr = CheckContains::ExprContains(op->extent, arith::IsVScaleCall);
737+
ICHECK(is_scalable_expr && arith::TargetHasSVE())
738+
<< "Failed to vectorize loop with extent " << op->extent << " for target "
739+
<< Target::Current();
740+
}
730741
ICHECK(is_zero(op->min));
731742
return Vectorizer(op->loop_var, op->extent)(op->body);
732743
} else {
@@ -735,8 +746,6 @@ class LoopVectorizer : public StmtMutator {
735746
}
736747
};
737748

738-
Stmt VectorizeLoop(Stmt stmt) { return LoopVectorizer()(std::move(stmt)); }
739-
740749
class VectorizeSkipper : public StmtMutator {
741750
public:
742751
Stmt VisitStmt_(const ForNode* op) final {

tests/python/tir-transform/test_tir_transform_vectorize.py

Lines changed: 97 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,12 @@
2222
import pytest
2323

2424

25-
@pytest.mark.parametrize("extent", (4, T.vscale() * 4))
26-
def test_vectorize_loop(extent):
25+
simple_target = tvm.target.Target("llvm -mtriple=x86_64-linux-gnu")
26+
sve_target = tvm.target.Target("llvm -device=arm_cpu -mtriple=aarch64-linux-gnu -mattr=+v8.2a,+sve")
27+
28+
29+
@pytest.mark.parametrize("extent, target", [(4, simple_target), (T.vscale() * 4, sve_target)])
30+
def test_vectorize_loop(extent, target):
2731
@I.ir_module
2832
class Before:
2933
@T.prim_func
@@ -37,8 +41,9 @@ class After:
3741
def main(A: T.Buffer((16,), "float32")):
3842
A[T.Ramp(0, 1, extent)] = T.Broadcast(1, extent)
3943

40-
mod = tvm.tir.transform.VectorizeLoop()(Before)
41-
tvm.ir.assert_structural_equal(mod, After)
44+
with tvm.target.Target(target):
45+
mod = tvm.tir.transform.VectorizeLoop()(Before)
46+
tvm.ir.assert_structural_equal(mod, After)
4247

4348

4449
def test_vectorize_vector():
@@ -70,8 +75,9 @@ def main(A: T.Buffer((25,), "float32")):
7075
A[j * 4 : j * 4 + 4] = T.Broadcast(T.float32(1), 4)
7176

7277
error_msg = f"Creating scalable vectors from existing vectors is not supported."
73-
with pytest.raises(tvm.error.InternalError, match=error_msg):
74-
tvm.tir.transform.VectorizeLoop()(Module)
78+
with tvm.target.Target(sve_target):
79+
with pytest.raises(tvm.error.InternalError, match=error_msg):
80+
tvm.tir.transform.VectorizeLoop()(Module)
7581

7682

7783
def test_vectorize_vector_scalable_error2():
@@ -99,7 +105,8 @@ def main(A: T.Buffer((25,), "float32")):
99105

100106
error_msg = f"Vectorizing over existing scalable vectors is not supported."
101107
with pytest.raises(tvm.error.InternalError, match=error_msg):
102-
tvm.tir.transform.VectorizeLoop()(Module)
108+
with tvm.target.Target(sve_target):
109+
tvm.tir.transform.VectorizeLoop()(Module)
103110

104111

105112
def test_vectorize_vector_scalable_error4():
@@ -114,11 +121,12 @@ def main(A: T.Buffer((25,), "float32")):
114121

115122
error_msg = f"Creating scalable vectors from existing vectors is not supported."
116123
with pytest.raises(tvm.error.InternalError, match=error_msg):
117-
tvm.tir.transform.VectorizeLoop()(Module)
124+
with tvm.target.Target(sve_target):
125+
tvm.tir.transform.VectorizeLoop()(Module)
118126

119127

120-
@pytest.mark.parametrize("extent", (4, T.vscale() * 4))
121-
def test_vectorize_with_if(extent):
128+
@pytest.mark.parametrize("extent, target", [(4, simple_target), (T.vscale() * 4, sve_target)])
129+
def test_vectorize_with_if(extent, target):
122130
@I.ir_module
123131
class Before:
124132
@T.prim_func
@@ -143,8 +151,9 @@ def main(A: T.Buffer((25,), "float32"), n: T.int32, x: T.int32):
143151
if i_s < n:
144152
A[i_s] = T.float32(2)
145153

146-
mod = tvm.tir.transform.VectorizeLoop()(Before)
147-
tvm.ir.assert_structural_equal(mod, After)
154+
with tvm.target.Target(target):
155+
mod = tvm.tir.transform.VectorizeLoop()(Before)
156+
tvm.ir.assert_structural_equal(mod, After)
148157

149158

150159
def test_vectorize_with_if_cond_int64():
@@ -157,8 +166,8 @@ def test_vectorize_with_if_cond_int64():
157166
f = tvm.build(s, [A, B], "llvm")
158167

159168

160-
@pytest.mark.parametrize("extent", (4, T.vscale() * 4))
161-
def test_vectorize_let(extent):
169+
@pytest.mark.parametrize("extent, target", [(4, simple_target), (T.vscale() * 4, sve_target)])
170+
def test_vectorize_let(extent, target):
162171
@I.ir_module
163172
class Before:
164173
@T.prim_func
@@ -174,12 +183,13 @@ def main(A: T.Buffer((25,), "float32")):
174183
v = A[T.Ramp(0, 1, extent)] + T.Broadcast(T.float32(1), extent)
175184
A[T.Ramp(0, 1, extent)] = v + T.Broadcast(T.float32(2), extent)
176185

177-
mod = tvm.tir.transform.VectorizeLoop()(Before)
178-
tvm.ir.assert_structural_equal(mod, After)
186+
with tvm.target.Target(target):
187+
mod = tvm.tir.transform.VectorizeLoop()(Before)
188+
tvm.ir.assert_structural_equal(mod, After)
179189

180190

181-
@pytest.mark.parametrize("extent", (4, tvm.tir.vscale() * 4))
182-
def test_vectorize_with_le_cond(extent):
191+
@pytest.mark.parametrize("extent, target", [(4, simple_target), (tvm.tir.vscale() * 4, sve_target)])
192+
def test_vectorize_with_le_cond(extent, target):
183193
n = te.var("n")
184194
ib = tvm.tir.ir_builder.create()
185195
A = ib.pointer("float32", name="A")
@@ -189,14 +199,16 @@ def test_vectorize_with_le_cond(extent):
189199
stmt = ib.get()
190200

191201
mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([A, n], stmt))
192-
stmt = tvm.tir.transform.VectorizeLoop()(mod)["main"].body
193202

194-
# Check that the loop was't vectorised
195-
assert isinstance(stmt, tvm.tir.For)
203+
with tvm.target.Target(target):
204+
stmt = tvm.tir.transform.VectorizeLoop()(mod)["main"].body
205+
206+
# Check that the loop was't vectorised
207+
assert isinstance(stmt, tvm.tir.For)
196208

197209

198-
@pytest.mark.parametrize("extent", (4, tvm.tir.vscale() * 4))
199-
def test_vectorize_with_ge_cond(extent):
210+
@pytest.mark.parametrize("extent, target", [(4, simple_target), (tvm.tir.vscale() * 4, sve_target)])
211+
def test_vectorize_with_ge_cond(extent, target):
200212
n = te.var("n")
201213
ib = tvm.tir.ir_builder.create()
202214
A = ib.pointer("float32", name="A")
@@ -206,14 +218,16 @@ def test_vectorize_with_ge_cond(extent):
206218
stmt = ib.get()
207219

208220
mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([A, n], stmt))
209-
stmt = tvm.tir.transform.VectorizeLoop()(mod)["main"].body
210221

211-
# Check that the loop wasn't vectorised
212-
assert isinstance(stmt, tvm.tir.For)
222+
with tvm.target.Target(target):
223+
stmt = tvm.tir.transform.VectorizeLoop()(mod)["main"].body
213224

225+
# Check that the loop wasn't vectorised
226+
assert isinstance(stmt, tvm.tir.For)
214227

215-
@pytest.mark.parametrize("extent", (4, T.vscale() * 4))
216-
def test_vectorize_if_then_else_scalarize(extent):
228+
229+
@pytest.mark.parametrize("extent, target", [(4, simple_target), (T.vscale() * 4, sve_target)])
230+
def test_vectorize_if_then_else_scalarize(extent, target):
217231
@I.ir_module
218232
class Before:
219233
@T.prim_func
@@ -228,12 +242,13 @@ def main(A: T.Buffer((25,), "float32")):
228242
for i_s in range(extent):
229243
A[i_s] = T.if_then_else(i_s > 0, A[i_s] + T.float32(1), A[i_s])
230244

231-
mod = tvm.tir.transform.VectorizeLoop()(Before)
232-
tvm.ir.assert_structural_equal(mod, After)
245+
with tvm.target.Target(target):
246+
mod = tvm.tir.transform.VectorizeLoop()(Before)
247+
tvm.ir.assert_structural_equal(mod, After)
233248

234249

235-
@pytest.mark.parametrize("extent", (4, T.vscale() * 4))
236-
def test_vectorize_if_then_else_vector(extent):
250+
@pytest.mark.parametrize("extent, target", [(4, simple_target), (T.vscale() * 4, sve_target)])
251+
def test_vectorize_if_then_else_vector(extent, target):
237252
@I.ir_module
238253
class Before:
239254
@T.prim_func
@@ -251,8 +266,9 @@ def main(A: T.Buffer((25,), "float32"), n: T.int32):
251266
i > 0, A[T.Ramp(i * extent, 1, extent)], T.Broadcast(0, extent)
252267
)
253268

254-
mod = tvm.tir.transform.VectorizeLoop()(Before)
255-
tvm.ir.assert_structural_equal(mod, After)
269+
with tvm.target.Target(target):
270+
mod = tvm.tir.transform.VectorizeLoop()(Before)
271+
tvm.ir.assert_structural_equal(mod, After)
256272

257273

258274
def test_vectorize_while_fail():
@@ -311,9 +327,10 @@ def test_vectorize_dtype_mismatch():
311327

312328

313329
@pytest.mark.parametrize(
314-
"extent, vec_str", [(16, "float32x16"), (T.vscale() * 8, "float32xvscalex8")]
330+
"extent, vec_str, target",
331+
[(16, "float32x16", simple_target), (T.vscale() * 8, "float32xvscalex8", sve_target)],
315332
)
316-
def test_vectorize_with_reinterpret(extent, vec_str):
333+
def test_vectorize_with_reinterpret(extent, vec_str, target):
317334
@I.ir_module
318335
class Before:
319336
@T.prim_func
@@ -327,11 +344,12 @@ class After:
327344
def main(A: T.Buffer((16,), "int32"), B: T.Buffer((16,), "float32")):
328345
B[T.Ramp(0, 1, extent)] = T.reinterpret(vec_str, A[T.Ramp(0, 1, extent)])
329346

330-
mod = tvm.tir.transform.VectorizeLoop()(Before)
331-
tvm.ir.assert_structural_equal(mod, After)
347+
with tvm.target.Target(target):
348+
mod = tvm.tir.transform.VectorizeLoop()(Before)
349+
tvm.ir.assert_structural_equal(mod, After)
332350

333351

334-
@pytest.mark.parametrize("extent", (4, T.vscale() * 4))
352+
@pytest.mark.parametrize("extent, target", [(4, simple_target), (T.vscale() * 4, sve_target)])
335353
@pytest.mark.parametrize(
336354
"op",
337355
(
@@ -352,7 +370,7 @@ def main(A: T.Buffer((16,), "int32"), B: T.Buffer((16,), "float32")):
352370
T.NE,
353371
),
354372
)
355-
def test_vectorize_binary(op, extent):
373+
def test_vectorize_binary(op, extent, target):
356374
@I.ir_module
357375
class Before:
358376
@T.prim_func
@@ -366,13 +384,14 @@ class After:
366384
def main(A: T.Buffer((25,), "float32"), B: T.Buffer((25,), "float32")):
367385
A[T.Ramp(0, 1, extent)] = op(T.Broadcast(T.float32(3), extent), B[T.Ramp(0, 1, extent)])
368386

369-
mod = tvm.tir.transform.VectorizeLoop()(Before)
370-
tvm.ir.assert_structural_equal(mod, After)
387+
with tvm.target.Target(target):
388+
mod = tvm.tir.transform.VectorizeLoop()(Before)
389+
tvm.ir.assert_structural_equal(mod, After)
371390

372391

373-
@pytest.mark.parametrize("extent", (4, T.vscale() * 4))
392+
@pytest.mark.parametrize("extent, target", [(4, simple_target), (T.vscale() * 4, sve_target)])
374393
@pytest.mark.parametrize("op", (T.And, T.Or))
375-
def test_vectorize_logical(op, extent):
394+
def test_vectorize_logical(op, extent, target):
376395
@I.ir_module
377396
class Before:
378397
@T.prim_func
@@ -386,12 +405,13 @@ class After:
386405
def main(A: T.Buffer((25,), "bool"), B: T.Buffer((25,), "bool")):
387406
A[T.Ramp(0, 1, extent)] = op(T.Broadcast(T.bool(1), extent), B[T.Ramp(0, 1, extent)])
388407

389-
mod = tvm.tir.transform.VectorizeLoop()(Before)
390-
tvm.ir.assert_structural_equal(mod, After)
408+
with tvm.target.Target(target):
409+
mod = tvm.tir.transform.VectorizeLoop()(Before)
410+
tvm.ir.assert_structural_equal(mod, After)
391411

392412

393-
@pytest.mark.parametrize("extent", (4, T.vscale() * 4))
394-
def test_vectorize_select(extent):
413+
@pytest.mark.parametrize("extent, target", [(4, simple_target), (T.vscale() * 4, sve_target)])
414+
def test_vectorize_select(extent, target):
395415
@I.ir_module
396416
class Before:
397417
@T.prim_func
@@ -409,12 +429,16 @@ def main(A: T.Buffer((25,), "float32"), B: T.Buffer((25,), "float32")):
409429
B[T.Ramp(0, 1, extent)],
410430
)
411431

412-
mod = tvm.tir.transform.VectorizeLoop()(Before)
413-
tvm.ir.assert_structural_equal(mod, After)
432+
with tvm.target.Target(target):
433+
mod = tvm.tir.transform.VectorizeLoop()(Before)
434+
tvm.ir.assert_structural_equal(mod, After)
414435

415436

416-
@pytest.mark.parametrize("extent, vec_str", [(4, "int32x4"), (T.vscale() * 4, "int32xvscalex4")])
417-
def test_vectorize_cast(extent, vec_str):
437+
@pytest.mark.parametrize(
438+
"extent, vec_str, target",
439+
[(4, "int32x4", simple_target), (T.vscale() * 4, "int32xvscalex4", sve_target)],
440+
)
441+
def test_vectorize_cast(extent, vec_str, target):
418442
@I.ir_module
419443
class Before:
420444
@T.prim_func
@@ -428,8 +452,9 @@ class After:
428452
def main(A: T.Buffer((25,), "int32"), B: T.Buffer((25,), "float32")):
429453
A[T.Ramp(0, 1, extent)] = T.Cast(vec_str, B[T.Ramp(0, 1, extent)])
430454

431-
mod = tvm.tir.transform.VectorizeLoop()(Before)
432-
tvm.ir.assert_structural_equal(mod, After)
455+
with tvm.target.Target(target):
456+
mod = tvm.tir.transform.VectorizeLoop()(Before)
457+
tvm.ir.assert_structural_equal(mod, After)
433458

434459

435460
def test_illegal_extent():
@@ -441,10 +466,27 @@ def main(A: T.Buffer((25,), "int32")):
441466
for j in T.vectorized(n):
442467
A[j] = 3
443468

444-
error_msg = f"Invalid expression for scalable lanes n"
469+
error_msg = f"Failed to vectorize loop with extent n for target \\(nullptr\\)"
445470
with pytest.raises(tvm.error.InternalError, match=error_msg):
446471
tvm.tir.transform.VectorizeLoop()(Mod)
447472

448473

474+
def test_illegal_vscale_in_non_sve_compilation():
475+
@I.ir_module
476+
class Mod:
477+
@T.prim_func
478+
def main(A: T.Buffer((16,), "float32")):
479+
for j in T.vectorized(0, 4 * T.vscale()):
480+
A[j] = 13
481+
482+
msg = (
483+
f"Failed to vectorize loop with extent T.vscale\\(\\) \\* 4 for target "
484+
f"llvm -keys=cpu -mtriple=x86_64-linux-gnu"
485+
)
486+
with tvm.target.Target(simple_target):
487+
with pytest.raises(tvm.error.InternalError, match=msg):
488+
tvm.tir.transform.VectorizeLoop()(Mod)
489+
490+
449491
if __name__ == "__main__":
450492
tvm.testing.main()

0 commit comments

Comments
 (0)