@@ -139,6 +139,7 @@ def __init__(
139139 accum_dtype : str = "float16" ,
140140 a_transposed : bool = False ,
141141 b_transposed : bool = False ,
142+ e_transposed : bool = False ,
142143 block_row_warps : int = 2 ,
143144 block_col_warps : int = 2 ,
144145 warp_row_tiles : int = 8 ,
@@ -155,6 +156,7 @@ def __init__(
155156 self .accum_dtype = accum_dtype
156157 self .a_transposed = a_transposed
157158 self .b_transposed = b_transposed
159+ self .e_transposed = e_transposed
158160 # Hint Information
159161 self .block_row_warps = block_row_warps
160162 self .block_col_warps = block_col_warps
@@ -362,6 +364,7 @@ def ldmatrix_e(self, E_local_buf: Buffer, E_shared_buf: Buffer, ki: PrimExpr, rk
362364 local_size_e = self .local_size_e
363365 a_dtype = self .a_dtype
364366 e_dtype = self .e_dtype
367+ trans = self .e_transposed
365368 # ldmatrix cannot be used for int8 + trans case.
366369 # include/cutlass/gemm/warp/mma_tensor_op_tile_iterator_sparse.h
367370 ldmatrix_available = False # TODO: use ldmatrix when possible
@@ -413,7 +416,7 @@ def _warp_ldmatrix_e(
413416 rk * warp_k + ki * micro_size_k ) // self .e_factor
414417 for j in T .serial (local_size_e ):
415418 mi , mk = mma_load_layout (tx , j )
416- E_local_buf [i * local_size_e + j ] = E_shared_buf [wi + mi , wk + mk ]
419+ E_local_buf [i * local_size_e + j ] = E_shared_buf [ wk + mk , wi + mi ] if trans else E_shared_buf [wi + mi , wk + mk ]
417420
418421 return _warp_ldmatrix_e (E_local_buf , E_shared_buf , ki , thread_binding , rk )
419422
0 commit comments