@@ -514,6 +514,29 @@ def test_buffer_load_predicate_elements_invalid_type():
514514 tvm .tir .BufferLoad (b , [index ], predicate )
515515
516516
517+ def test_buffer_store_predicate_invalid_scalability ():
518+ b = tvm .tir .decl_buffer ((24 ,), "int32" )
519+ index = tvm .tir .expr .Ramp (0 , 1 , 4 * tvm .tir .vscale ())
520+ predicate = tvm .tir .expr .Broadcast (tvm .tir .IntImm ("int1" , 1 ), 4 )
521+
522+ err_msg = "Predicate mask dtype and load indices must both be scalable."
523+ with pytest .raises (tvm .TVMError , match = err_msg ):
524+ tvm .tir .BufferLoad (b , [index ], predicate )
525+
526+
527+ def test_buffer_store_predicate_invalid_lanes ():
528+ b = tvm .tir .decl_buffer ((24 ,), "int32" )
529+ index = tvm .tir .expr .Ramp (0 , 1 , 4 * tvm .tir .vscale ())
530+ predicate = tvm .tir .expr .Broadcast (tvm .tir .IntImm ("int1" , 1 ), 8 * tvm .tir .vscale ())
531+
532+ err_msg = (
533+ "Got a predicate mask with 8 lanes, but trying to load a "
534+ "vector with 4 lanes. The number of lanes must match."
535+ )
536+ with pytest .raises (tvm .TVMError , match = err_msg ):
537+ tvm .tir .BufferLoad (b , [index ], predicate )
538+
539+
517540def test_scalable_vec_cast ():
518541 b = tvm .tir .decl_buffer ((24 ,), "float32" )
519542 value = tvm .tir .expr .Broadcast (1 , 12 * tvm .tir .vscale ()).astype ("float32xvscalex12" )
0 commit comments