@@ -336,7 +336,7 @@ def silu_and_mul_masked_post_quant_fwd(
336336 scale_k = ceil_div (k , quant_group_size )
337337 m_padded = align (m , alignment )
338338 scale_k_padded = align (scale_k , alignment )
339- output_scale = torch .zeros ((g , scale_k_padded // 4 , m_padded ),
339+ output_scale = torch .empty ((g , scale_k_padded // 4 , m_padded ),
340340 dtype = torch .int32 ,
341341 device = 'cuda' )
342342
@@ -458,6 +458,7 @@ def per_token_quant_and_transform(
458458 input : torch .Tensor ,
459459 quant_group_size : int = 128 ,
460460 scale_ue8m0 : bool = True ,
461+ swap_ab = False ,
461462):
462463 """
463464 input shape [g, m, k]
@@ -477,18 +478,21 @@ def per_token_quant_and_transform(
477478 fp8_min = - fp8_max
478479
479480 m , k = input .shape
481+ m_padded = m if not swap_ab else align (m , 8 )
480482
481483 # Create output
482- output = torch .empty ((m , k ), dtype = torch .float8_e4m3fn , device = "cuda" )
484+ output = torch .empty ((m_padded , k ),
485+ dtype = torch .float8_e4m3fn ,
486+ device = input .device )
483487
484488 # Create output scale
485489 alignment = 4
486490 scale_k = ceil_div (k , quant_group_size )
487- m_padded = align (m , alignment )
491+ m_aligned = align (m_padded , alignment )
488492 scale_k_padded = align (scale_k , alignment )
489- output_scale = torch .zeros ((scale_k_padded // 4 , m_padded ),
493+ output_scale = torch .empty ((scale_k_padded // 4 , m_aligned ),
490494 dtype = torch .int32 ,
491- device = 'cuda' )
495+ device = input . device )
492496
493497 # Get block/grid/stage/warp
494498 BLOCK_NUM_PER_EXPERT = 64
@@ -518,13 +522,56 @@ def per_token_quant_and_transform(
518522 num_warps = num_warps ,
519523 SCALE_UE8M0 = scale_ue8m0 ,
520524 )
521- output_scale = output_scale .transpose (0 , 1 )[:m , :]
525+ output_scale = output_scale .transpose (0 , 1 )[:m_padded , :]
522526 check_sf_layout (
523527 output_scale ,
524- m ,
528+ m_padded ,
525529 k ,
526530 (1 , 128 ),
527531 num_groups = None ,
528532 tma_stride_check = True ,
529533 )
530534 return output , output_scale
535+
536+
537+ @triton .jit
538+ def _transpose_kernel (input_ptr , output_ptr , M , N , stride_in_m , stride_in_n ,
539+ stride_out_m , stride_out_n , BLOCK_SIZE : tl .constexpr ):
540+ row_block = tl .program_id (0 )
541+ col_block = tl .program_id (1 )
542+
543+ row = row_block * BLOCK_SIZE + tl .arange (0 , BLOCK_SIZE )
544+ col = col_block * BLOCK_SIZE + tl .arange (0 , BLOCK_SIZE )
545+
546+ mask_row = row < M
547+ mask_col = col < N
548+ mask = mask_row [:, None ] & mask_col [None , :]
549+
550+ input_idx = row [:, None ] * stride_in_m + col [None , :] * stride_in_n
551+ data = tl .load (input_ptr + input_idx , mask = mask , other = 0 )
552+
553+ output_idx = row [:, None ] * stride_out_n + col [None , :] * stride_out_m
554+ tl .store (output_ptr + output_idx , data , mask = mask )
555+
556+
557+ def masked_transpose (input : torch .Tensor , n_available : int ) -> torch .Tensor :
558+ M , N = input .shape
559+ BLOCK_SIZE = 32
560+ output = torch .empty ((n_available , M ),
561+ dtype = input .dtype ,
562+ device = input .device )
563+
564+ grid = ((M + BLOCK_SIZE - 1 ) // BLOCK_SIZE ,
565+ (n_available + BLOCK_SIZE - 1 ) // BLOCK_SIZE )
566+ _transpose_kernel [grid ](
567+ input ,
568+ output ,
569+ M ,
570+ n_available ,
571+ input .stride (0 ),
572+ input .stride (1 ),
573+ output .stride (0 ),
574+ output .stride (1 ),
575+ BLOCK_SIZE = BLOCK_SIZE ,
576+ )
577+ return output
0 commit comments