@@ -136,6 +136,7 @@ def get_tile_descriptions(math_inst):
136136 TileDescription ([128 , 256 , 32 ], 3 , [2 , 4 , 1 ], math_inst , min_cc , max_cc ),
137137 TileDescription ([256 , 64 , 32 ], 4 , [4 , 1 , 1 ], math_inst , min_cc , max_cc ),
138138 TileDescription ([64 , 256 , 32 ], 4 , [1 , 4 , 1 ], math_inst , min_cc , max_cc ),
139+ TileDescription ([128 , 128 , 32 ], 2 , [2 , 2 , 1 ], math_inst , min_cc , max_cc ),
139140 TileDescription ([128 , 128 , 32 ], 3 , [2 , 2 , 1 ], math_inst , min_cc , max_cc ),
140141 TileDescription ([128 , 128 , 32 ], 4 , [2 , 2 , 1 ], math_inst , min_cc , max_cc ),
141142 TileDescription ([128 , 128 , 32 ], 5 , [2 , 2 , 1 ], math_inst , min_cc , max_cc ),
@@ -152,9 +153,11 @@ def get_tile_descriptions(math_inst):
152153 TileDescription ([64 , 64 , 64 ], 5 , [2 , 2 , 1 ], math_inst , min_cc , max_cc ),
153154 ]
154155
155- return generate_tensor_op_common (
156+ sm75_kernels = generate_sm75_tensor_op_1688 (out_dtype , op_creator )
157+ sm80_kernels = generate_tensor_op_common (
156158 math_instructions , alignment_constraints , get_tile_descriptions , op_creator
157159 )
160+ return sm75_kernels + sm80_kernels
158161
159162
160163class ProfilerEngine :
0 commit comments