Skip to content

Commit d8795a0

Browse files
committed
Address comments
- Correct doc strings - Correct typo in error message - Add some additional checks for BufferLoad Change-Id: Ie25563d569c0ed729ac915a6ba3a724a9e191014
1 parent 9191ecd commit d8795a0

File tree

7 files changed

+45
-10
lines changed

7 files changed

+45
-10
lines changed

include/tvm/tir/buffer.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,7 @@ class Buffer : public ObjectRef {
210210
* \param begin The beginning index
211211
* \param dtype The data type to be loaded.
212212
* \param predicate A vector mask of boolean values indicating which lanes of a vector are to be
213-
* stored. The number lanes of the mask must be equal to the number of lanes in value.
213+
* loaded. The number lanes of the mask must be equal to the number of lanes in being loaded.
214214
*/
215215
TVM_DLL PrimExpr vload(Array<PrimExpr> begin, DataType dtype,
216216
Optional<PrimExpr> predicate = NullOpt) const;

python/tvm/tir/buffer.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -115,8 +115,7 @@ def vload(self, begin, dtype=None, predicate=None):
115115
116116
predicate : Optional[PrimExpr]
117117
A vector mask of boolean values indicating which lanes of a vector are to be
118-
stored. The number lanes of the mask must be equal to the number of lanes in
119-
value.
118+
loaded. The number lanes of the mask must be equal to the number of lanes being loaded.
120119
121120
Returns
122121
-------

python/tvm/tir/expr.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1100,8 +1100,7 @@ class BufferLoad(PrimExprWithOp):
11001100
11011101
predicate : Optional[PrimExpr]
11021102
A vector mask of boolean values indicating which lanes of a vector are to be
1103-
stored. The number lanes of the mask must be equal to the number of lanes in
1104-
value.
1103+
loaded. The number lanes of the mask must be equal to the number of lanes being loaded.
11051104
"""
11061105

11071106
buffer: Buffer

src/target/llvm/codegen_llvm.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -331,8 +331,8 @@ class CodeGenLLVM : public ExprFunctor<llvm::Value*(const PrimExpr&)>,
331331
* \param indices The indices at which the buffer is being accessed.
332332
*
333333
* \param predicate A vector mask of boolean values indicating which lanes of a
334-
* vector are to be stored. The number lanes of the mask must be equal to the
335-
* number of lanes in value.
334+
* vector are to be accessed. The number lanes of the mask must be equal to the
335+
* number of lanes being accessed.
336336
*
337337
* \param value_dtype The datatype to be read from (BufferLoad) or
338338
* written to (BufferStore) the buffer.

src/tir/ir/expr.cc

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -780,7 +780,21 @@ BufferLoad::BufferLoad(Buffer buffer, Array<PrimExpr> indices, Optional<PrimExpr
780780
<< "-dimensional indices provided.";
781781

782782
if (predicate.defined()) {
783-
DataType predicate_element_dtype = predicate.value().dtype().element_of();
783+
DataType predicate_dtype = predicate.value().dtype();
784+
785+
bool is_index_scalable = indices.empty() ? false : indices.back().dtype().is_scalable_vector();
786+
bool is_predicate_scalable = predicate_dtype.is_scalable_vector();
787+
ICHECK_EQ(is_index_scalable, is_predicate_scalable)
788+
<< "Predicate mask dtype and load indices must both be scalable.";
789+
790+
int index_lanes = indices.empty() ? 1 : indices.back().dtype().get_lanes_or_vscale_factor();
791+
int predicate_lanes = predicate_dtype.get_lanes_or_vscale_factor();
792+
ICHECK_EQ(index_lanes, predicate_lanes)
793+
<< "Got a predicate mask with " << predicate_lanes
794+
<< " lanes, but trying to load a vector with " << index_lanes
795+
<< " lanes. The number of lanes must match.";
796+
797+
DataType predicate_element_dtype = predicate_dtype.element_of();
784798
ICHECK(predicate_element_dtype.is_bool())
785799
<< "Predicate mask elements must be boolean values, but got " << predicate_element_dtype
786800
<< ".";

src/tir/transforms/inject_rolling_buffer.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -257,8 +257,8 @@ class RollingBufferInjector : public StmtExprMutator {
257257
indices.push_back(index);
258258
}
259259
}
260-
ICHECK(!op->predicate.defined())
261-
<< "Predicated buffer store is not current supported in the inject rolling buffer pass.";
260+
ICHECK(!op->predicate.defined()) << "Predicated buffer store is not currently supported in "
261+
"the inject rolling buffer pass.";
262262
Stmt buffer_store = BufferStore(op->buffer, op->value, indices, op->predicate, op->span);
263263
// Then wrap the BufferStores in some Ifs to avoid recomputing elements
264264
for (size_t i{0}; i < rolling_buffer_info.axis_iter_vars.size(); ++i) {

tests/python/tir-base/test_tir_nodes.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -514,6 +514,29 @@ def test_buffer_load_predicate_elements_invalid_type():
514514
tvm.tir.BufferLoad(b, [index], predicate)
515515

516516

517+
def test_buffer_store_predicate_invalid_scalability():
518+
b = tvm.tir.decl_buffer((24,), "int32")
519+
index = tvm.tir.expr.Ramp(0, 1, 4 * tvm.tir.vscale())
520+
predicate = tvm.tir.expr.Broadcast(tvm.tir.IntImm("int1", 1), 4)
521+
522+
err_msg = "Predicate mask dtype and load indices must both be scalable."
523+
with pytest.raises(tvm.TVMError, match=err_msg):
524+
tvm.tir.BufferLoad(b, [index], predicate)
525+
526+
527+
def test_buffer_store_predicate_invalid_lanes():
528+
b = tvm.tir.decl_buffer((24,), "int32")
529+
index = tvm.tir.expr.Ramp(0, 1, 4 * tvm.tir.vscale())
530+
predicate = tvm.tir.expr.Broadcast(tvm.tir.IntImm("int1", 1), 8 * tvm.tir.vscale())
531+
532+
err_msg = (
533+
"Got a predicate mask with 8 lanes, but trying to load a "
534+
"vector with 4 lanes. The number of lanes must match."
535+
)
536+
with pytest.raises(tvm.TVMError, match=err_msg):
537+
tvm.tir.BufferLoad(b, [index], predicate)
538+
539+
517540
def test_scalable_vec_cast():
518541
b = tvm.tir.decl_buffer((24,), "float32")
519542
value = tvm.tir.expr.Broadcast(1, 12 * tvm.tir.vscale()).astype("float32xvscalex12")

0 commit comments

Comments
 (0)