Skip to content

Commit c56ba7c

Browse files
quic-sanirudhthaisacs
authored andcommitted
[TIR] Ramp and Broadcast lanes fixed to int32 dtype (apache#16795)
* [TIR] Ramp and Broadcast lanes fixed to int32 dtype When Ramp and Broadcast nodes are created with fixed length lanes, they're fixed to int32 dtype since DLDataType always supports only uint16 lanes. * Add test cases for int64 type lanes * Update test case with int64 iterators
1 parent 5920dca commit c56ba7c

File tree

3 files changed

+31
-2
lines changed

3 files changed

+31
-2
lines changed

src/tir/ir/expr.cc

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -449,16 +449,18 @@ Ramp::Ramp(PrimExpr base, PrimExpr stride, PrimExpr lanes, Span span) {
449449
int lanes = static_cast<int>(lanes_as_int->value);
450450
ICHECK_GT(lanes, 1);
451451
node->dtype = base.dtype().with_lanes(lanes);
452+
// Stick to int32 lanes for fixed length vectors
453+
node->lanes = lanes;
452454
} else { /* scalable vector */
453455
std::optional<int> vscale_factor = arith::ExtractVscaleFactor(lanes);
454456
ICHECK(vscale_factor) << "Invalid expression for scalable lanes " << lanes;
455457

456458
node->dtype = base.dtype().with_scalable_vscale_factor(vscale_factor.value());
457459
lanes = Mul(Call(DataType::Int(32), tir::builtin::vscale(), {}), vscale_factor.value());
460+
node->lanes = lanes;
458461
}
459462
node->base = base;
460463
node->stride = stride;
461-
node->lanes = lanes;
462464
node->span = std::move(span);
463465
data_ = std::move(node);
464466
}
@@ -481,15 +483,17 @@ Broadcast::Broadcast(PrimExpr value, PrimExpr lanes, Span span) {
481483
int lanes = static_cast<int>(lanes_int->value);
482484
ICHECK_GT(lanes, 1);
483485
node->dtype = value.dtype().with_lanes(lanes);
486+
// Stick to int32 lanes for fixed length vectors
487+
node->lanes = lanes;
484488
} else { /* scalable vector */
485489
std::optional<int> vscale_factor = arith::ExtractVscaleFactor(lanes);
486490
ICHECK(vscale_factor) << "Invalid expression for scalable lanes " << lanes;
487491

488492
node->dtype = value.dtype().with_scalable_vscale_factor(vscale_factor.value());
489493
lanes = Mul(Call(DataType::Int(32), tir::builtin::vscale(), {}), vscale_factor.value());
494+
node->lanes = lanes;
490495
}
491496
node->value = std::move(value);
492-
node->lanes = lanes;
493497
node->span = std::move(span);
494498
data_ = node;
495499
}

tests/python/arith/test_arith_rewrite_simplify.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ def test_simplify(self, test_case):
7575

7676
class TestVector(BaseCompare):
7777
x, y, z = te.var("x"), te.var("y"), te.var("z")
78+
x64 = te.var("x", dtype="int64")
7879
vx = te.var("vx", dtype="int32x2")
7980
vc = te.var("vc", dtype="uint1")
8081
test_case = tvm.testing.parameter(
@@ -88,6 +89,20 @@ class TestVector(BaseCompare):
8889
),
8990
TestCase(y.astype("int32x2") + x.astype("int32x2"), (y + x).astype("int32x2")),
9091
TestCase(tvm.tir.Broadcast(0, 4) + y, tvm.tir.Broadcast(y, 4)),
92+
# int64 lanes
93+
TestCase(
94+
tvm.tir.Broadcast(x, 4) + tvm.tir.Ramp(0, 1, tvm.tir.IntImm(dtype="int64", value=4)),
95+
tvm.tir.Ramp(x, 1, 4),
96+
),
97+
TestCase(
98+
tvm.tir.Broadcast(x, tvm.tir.IntImm(dtype="int64", value=4)) + tvm.tir.Ramp(0, 1, 4),
99+
tvm.tir.Ramp(x, 1, 4),
100+
),
101+
# int64 iterators with int32 lanes
102+
TestCase(
103+
tvm.tir.Broadcast(x64, 4) + tvm.tir.Ramp(tvm.tir.IntImm(dtype="int64", value=0), 1, 4),
104+
tvm.tir.Ramp(x64, 1, 4),
105+
),
91106
TestCase(
92107
tvm.tir.Broadcast(0, tir.vscale() * 8) + y, tvm.tir.Broadcast(y, tir.vscale() * 8)
93108
),

tests/python/tir-base/test_tir_nodes.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -409,6 +409,16 @@ def _create_broadcast(lanes):
409409
return tvm.tir.Broadcast(0, lanes)
410410

411411

412+
@pytest.mark.parametrize("lanes", [(tvm.tir.IntImm(dtype="int64", value=11))])
413+
@pytest.mark.parametrize("node_func", [_create_ramp, _create_broadcast])
414+
def test_lane_types(lanes, node_func):
415+
def _check_dtype(node):
416+
assert node.lanes.dtype == "int32"
417+
assert node.lanes == 11
418+
419+
_check_dtype(node_func(lanes))
420+
421+
412422
@pytest.mark.parametrize("lanes", [(11 * tvm.tir.vscale()), (tvm.tir.vscale() * 11)])
413423
@pytest.mark.parametrize("node_func", [_create_ramp, _create_broadcast])
414424
def test_scalable_vec(lanes, node_func):

0 commit comments

Comments
 (0)