Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions src/tir/ir/expr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -449,16 +449,18 @@ Ramp::Ramp(PrimExpr base, PrimExpr stride, PrimExpr lanes, Span span) {
int lanes = static_cast<int>(lanes_as_int->value);
ICHECK_GT(lanes, 1);
node->dtype = base.dtype().with_lanes(lanes);
// Stick to int32 lanes for fixed length vectors
node->lanes = lanes;
} else { /* scalable vector */
std::optional<int> vscale_factor = arith::ExtractVscaleFactor(lanes);
ICHECK(vscale_factor) << "Invalid expression for scalable lanes " << lanes;

node->dtype = base.dtype().with_scalable_vscale_factor(vscale_factor.value());
lanes = Mul(Call(DataType::Int(32), tir::builtin::vscale(), {}), vscale_factor.value());
node->lanes = lanes;
}
node->base = base;
node->stride = stride;
node->lanes = lanes;
node->span = std::move(span);
data_ = std::move(node);
}
Expand All @@ -481,15 +483,17 @@ Broadcast::Broadcast(PrimExpr value, PrimExpr lanes, Span span) {
int lanes = static_cast<int>(lanes_int->value);
ICHECK_GT(lanes, 1);
node->dtype = value.dtype().with_lanes(lanes);
// Stick to int32 lanes for fixed length vectors
node->lanes = lanes;
} else { /* scalable vector */
std::optional<int> vscale_factor = arith::ExtractVscaleFactor(lanes);
ICHECK(vscale_factor) << "Invalid expression for scalable lanes " << lanes;

node->dtype = value.dtype().with_scalable_vscale_factor(vscale_factor.value());
lanes = Mul(Call(DataType::Int(32), tir::builtin::vscale(), {}), vscale_factor.value());
node->lanes = lanes;
}
node->value = std::move(value);
node->lanes = lanes;
node->span = std::move(span);
data_ = node;
}
Expand Down
15 changes: 15 additions & 0 deletions tests/python/arith/test_arith_rewrite_simplify.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ def test_simplify(self, test_case):

class TestVector(BaseCompare):
x, y, z = te.var("x"), te.var("y"), te.var("z")
x64 = te.var("x", dtype="int64")
vx = te.var("vx", dtype="int32x2")
vc = te.var("vc", dtype="uint1")
test_case = tvm.testing.parameter(
Expand All @@ -88,6 +89,20 @@ class TestVector(BaseCompare):
),
TestCase(y.astype("int32x2") + x.astype("int32x2"), (y + x).astype("int32x2")),
TestCase(tvm.tir.Broadcast(0, 4) + y, tvm.tir.Broadcast(y, 4)),
# int64 lanes
TestCase(
tvm.tir.Broadcast(x, 4) + tvm.tir.Ramp(0, 1, tvm.tir.IntImm(dtype="int64", value=4)),
tvm.tir.Ramp(x, 1, 4),
),
TestCase(
tvm.tir.Broadcast(x, tvm.tir.IntImm(dtype="int64", value=4)) + tvm.tir.Ramp(0, 1, 4),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I actually meant that the value being broadcasted and ramp base being i64. but lanes remains i32, this would be a more common case in our setting

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah okay, sorry for the confusion. I'll update the test cases to check that as well.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tqchen I've updated the test case, let me know if this looks good.

tvm.tir.Ramp(x, 1, 4),
),
# int64 iterators with int32 lanes
TestCase(
tvm.tir.Broadcast(x64, 4) + tvm.tir.Ramp(tvm.tir.IntImm(dtype="int64", value=0), 1, 4),
tvm.tir.Ramp(x64, 1, 4),
),
TestCase(
tvm.tir.Broadcast(0, tir.vscale() * 8) + y, tvm.tir.Broadcast(y, tir.vscale() * 8)
),
Expand Down
10 changes: 10 additions & 0 deletions tests/python/tir-base/test_tir_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,6 +409,16 @@ def _create_broadcast(lanes):
return tvm.tir.Broadcast(0, lanes)


@pytest.mark.parametrize("lanes", [(tvm.tir.IntImm(dtype="int64", value=11))])
@pytest.mark.parametrize("node_func", [_create_ramp, _create_broadcast])
def test_lane_types(lanes, node_func):
def _check_dtype(node):
assert node.lanes.dtype == "int32"
assert node.lanes == 11

_check_dtype(node_func(lanes))


@pytest.mark.parametrize("lanes", [(11 * tvm.tir.vscale()), (tvm.tir.vscale() * 11)])
@pytest.mark.parametrize("node_func", [_create_ramp, _create_broadcast])
def test_scalable_vec(lanes, node_func):
Expand Down