@@ -74,15 +74,13 @@ def neon_4x4_i8i8i32_impl(
7474
7575 multiply_low = T .call_llvm_pure_intrin (
7676 T .llvm_lookup_intrinsic_id ("llvm.aarch64.neon.smull.v8i16" ),
77- T .uint32 (2 ),
7877 vec_a ,
7978 vec_b_low ,
8079 dtype = "int16x8" ,
8180 )
8281
8382 pairwise_reduction_low = T .call_llvm_pure_intrin (
8483 T .llvm_lookup_intrinsic_id ("llvm.aarch64.neon.saddlp.v4i32.v8i16" ),
85- T .uint32 (1 ),
8684 multiply_low ,
8785 dtype = "int32x4" ,
8886 )
@@ -91,22 +89,19 @@ def neon_4x4_i8i8i32_impl(
9189
9290 multiply_high = T .call_llvm_pure_intrin (
9391 T .llvm_lookup_intrinsic_id ("llvm.aarch64.neon.smull.v8i16" ),
94- T .uint32 (2 ),
9592 vec_a ,
9693 vec_b_high ,
9794 dtype = "int16x8" ,
9895 )
9996
10097 pairwise_reduction_high = T .call_llvm_pure_intrin (
10198 T .llvm_lookup_intrinsic_id ("llvm.aarch64.neon.saddlp.v4i32.v8i16" ),
102- T .uint32 (1 ),
10399 multiply_high ,
104100 dtype = "int32x4" ,
105101 )
106102
107103 C [T .ramp (T .int32 (0 ), 1 , 4 )] += T .call_llvm_pure_intrin (
108104 T .llvm_lookup_intrinsic_id ("llvm.aarch64.neon.addp.v4i32" ),
109- T .uint32 (2 ),
110105 pairwise_reduction_low ,
111106 pairwise_reduction_high ,
112107 dtype = "int32x4" ,
@@ -159,7 +154,6 @@ def dot_prod_impl(a: T.handle, b: T.handle, c: T.handle) -> None:
159154
160155 C [T .ramp (T .int32 (0 ), 1 , 4 )] = T .call_llvm_pure_intrin (
161156 T .llvm_lookup_intrinsic_id (f"llvm.aarch64.neon.{ instr } " ),
162- T .uint32 (3 ),
163157 vec_c ,
164158 vec_a ,
165159 vec_b ,
@@ -311,7 +305,6 @@ def impl():
311305 T .call_llvm_intrin (
312306 "void" ,
313307 "llvm.aarch64.sme.ld1w.horiz" ,
314- T .uint32 (4 ),
315308 predicate ,
316309 input_ptr ,
317310 sub_tile ,
@@ -335,7 +328,6 @@ def impl():
335328 T .call_llvm_intrin (
336329 "void" ,
337330 "llvm.aarch64.sme.st1w.vert" ,
338- T .uint32 (4 ),
339331 predicate ,
340332 output_ptr ,
341333 sub_tile ,
@@ -438,7 +430,6 @@ def impl():
438430 T .call_llvm_intrin (
439431 "void" ,
440432 "llvm.aarch64.sme.ld1h.horiz" ,
441- T .uint32 (4 ),
442433 ptrue_fp16 ,
443434 input_ptr ,
444435 sub_tile_idx ,
@@ -450,7 +441,6 @@ def impl():
450441 T .call_llvm_intrin (
451442 "void" ,
452443 "llvm.aarch64.sme.ld1h.horiz" ,
453- T .uint32 (4 ),
454444 ptrue_fp16 ,
455445 input_ptr ,
456446 sub_tile_idx ,
@@ -467,7 +457,6 @@ def impl():
467457 T .call_llvm_intrin (
468458 "void" ,
469459 "llvm.aarch64.sme.st1w.vert" ,
470- T .uint32 (4 ),
471460 ptrue_fp32 ,
472461 output_ptr ,
473462 sub_tile_idx ,
@@ -479,7 +468,6 @@ def impl():
479468 T .call_llvm_intrin (
480469 "void" ,
481470 "llvm.aarch64.sme.st1w.vert" ,
482- T .uint32 (4 ),
483471 ptrue_fp32 ,
484472 output_ptr ,
485473 sub_tile_idx + 2 ,
@@ -692,7 +680,6 @@ def impl():
692680 T .call_llvm_intrin (
693681 "void" ,
694682 fmopa_intrin ,
695- T .uint32 (5 ),
696683 sub_tile ,
697684 input_1 [1 ],
698685 input_2 [1 ],
@@ -713,7 +700,6 @@ def impl():
713700 T .call_llvm_intrin (
714701 "void" ,
715702 "llvm.aarch64.sme.st1w.horiz" ,
716- T .uint32 (4 ),
717703 _create_active_lane_mask (
718704 C , (vert_offset + slice_idx , horiz_offset ), M
719705 ),
@@ -752,9 +738,7 @@ def impl(c: T.handle) -> None:
752738 T .reads ()
753739 T .writes (C [0 :SVF2 , 0 :SVF2 ])
754740 clear_all_tiles = T .int32 (255 )
755- T .evaluate (
756- T .call_llvm_intrin ("void" , "llvm.aarch64.sme.zero" , T .uint32 (1 ), clear_all_tiles )
757- )
741+ T .evaluate (T .call_llvm_intrin ("void" , "llvm.aarch64.sme.zero" , clear_all_tiles ))
758742
759743 return desc , impl
760744
0 commit comments