diff --git a/flashinfer/cute_dsl/blockscaled_gemm.py b/flashinfer/cute_dsl/blockscaled_gemm.py index 663b494a91..8387a1aabe 100644 --- a/flashinfer/cute_dsl/blockscaled_gemm.py +++ b/flashinfer/cute_dsl/blockscaled_gemm.py @@ -44,21 +44,127 @@ from cutlass.cutlass_dsl import ( Int32, + Int64, + Uint32, + Uint8, + Uint64, + T, Integer, dsl_user_op, extract_mlir_values, new_from_mlir_values, ) +from cutlass._mlir.dialects import llvm from flashinfer.utils import get_compute_capability from cutlass.utils.static_persistent_tile_scheduler import WorkTileInfo +from flashinfer.utils import ceil_div from .utils import get_cutlass_dtype, cutlass_to_torch_dtype, get_num_sm, make_ptr from typing import Callable, List +# DEBUG_EXIT_AFTER_FIRST_WAIT = True +DEBUG_EXIT_AFTER_FIRST_WAIT = False + + +sizeof_i32 = 4 +sizeof_u64 = 8 + +@dsl_user_op +def with_byte(obj: Uint64, index: Int32, value: Uint8, *, loc=None, ip=None) -> Uint64: + # assert index >= 0 and index < sizeof_u64 + obj &= ~(0xff << (index * 8)) + obj |= value << (index * 8) + assert isinstance(obj, Uint64), f"{obj=}" + return obj + + +@dsl_user_op +def read_byte(obj: Uint64, index: Int32, *, loc=None, ip=None) -> Uint8: + # assert index >= 0 and index < sizeof_u64 + return ((obj >> (index * 8)) & 0xFF).to(Uint8) + + +# TODO unify i32 (here) and the signal buffer (u32) +@dsl_user_op +def atomic_add_release_global(addr: Int64, value: Int32, *, loc=None, ip=None) -> Int32: + return Int32( + llvm.inline_asm( + T.i32(), + [ + addr.ir_value(loc=loc, ip=ip), + Int32(value).ir_value(loc=loc, ip=ip), + ], + "atom.add.release.gpu.global.s32 $0, [$1], $2;", + "=r,l,r", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + ) + +@dsl_user_op +def write_signal( + tile_sched_params: "MaskedSchedulerParams", + group_idx_and_m_block_idx, + *, loc=None, ip=None, +): + dst_signals = tile_sched_params.dst_signals + group_idx, m_block_idx = group_idx_and_m_block_idx + shape_m = tile_sched_params.c.shape[0] + block_m = tile_sched_params.c_tiler[0] + + offset = group_idx * ceil_div(shape_m, block_m) + m_block_idx + + atomic_add_release_global(dst_signals.toint() + sizeof_i32 * offset, value=1) + +# TODO unify i32 or u32 +# TODO only wait once per warp? +@cute.jit +def wait_signal(addr: Int64, expect_value: int, *, loc=None, ip=None): + # # TODO disable this time check + # repeat_count = Int64(0) + + ready = Int32(0) + + # early exiting / early return is not supported in cute dsl + while ready != expect_value: + ready = Int32( + llvm.inline_asm( + T.i32(), + [addr.ir_value(loc=loc, ip=ip)], + # TODO how to add `:"memory"` clobber? + "ld.acquire.gpu.global.s32 $0, [$1];", + "=r,l", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + ) + + llvm.inline_asm( + None, + [], + "nanosleep.u32 20;", + "", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + + # repeat_count += 1 + # if repeat_count % 1_000_000_000 == 0: + # tidx, _, _ = cute.arch.thread_idx() + # if tidx % 32 == 0: + # cute.printf("wait_signal STUCK addr={} tidx={} actual_value={}", addr, tidx, ready) + + class MaskedSchedulerParams: def __init__( self, masked_m: cute.Tensor, + src_signals: Optional[cute.Pointer], + src_signal_expect_value: int, + dst_signals: Optional[cute.Pointer], c: cute.Tensor, c_tiler: Tuple[int, int], cluster_shape_mnk: cute.Shape, @@ -72,6 +178,9 @@ def __init__( gc = cute.zipped_divide(c, tiler=c_tiler) problem_shape_ntile_mnl = gc[(0, (None, None, None))].shape self.masked_m = masked_m + self.src_signals = src_signals + self.src_signal_expect_value = src_signal_expect_value + self.dst_signals = dst_signals self.c = c self.c_tiler = c_tiler self.problem_shape_ntile_mnl = problem_shape_ntile_mnl @@ -92,6 +201,9 @@ def __extract_mlir_values__(self): values, self._values_pos = [], [] for obj in [ self.masked_m, + self.src_signals, + self.src_signal_expect_value, + self.dst_signals, self.c, self.c_tiler, self._cluster_shape_mnk, @@ -104,7 +216,7 @@ def __extract_mlir_values__(self): def __new_from_mlir_values__(self, values): obj_list = [] for obj, n_items in zip( - [self.masked_m, self.c, self.c_tiler, self._cluster_shape_mnk], + [self.masked_m, self.src_signals, self.src_signal_expect_value, self.dst_signals, self.c, self.c_tiler, self._cluster_shape_mnk], self._values_pos, ): obj_list.append(new_from_mlir_values(obj, values[:n_items])) @@ -236,7 +348,10 @@ def get_grid_shape( def _get_current_work_for_linear_idx( self, current_work_linear_idx: Int32, - ) -> WorkTileInfo: + # dsm_pending_packed: Optional[Uint64], + # dsm_counter: Optional[Uint8], + # num_c_stage: Optional[int] = None, + ) -> WorkTileInfo: # Tuple[WorkTileInfo, Optional[Uint64]]: # is_valid = current_work_linear_idx < cute.size( # self.params.problem_layout_ncluster_mnl, loc=loc, ip=ip # ) @@ -253,6 +368,16 @@ def _get_current_work_for_linear_idx( <= current_work_linear_idx and batch_idx < self.params.masked_m.shape[0] ): + # if cutlass.const_expr((dsm_pending_packed is not None) and (self.params.dst_signals is not None)): + # # TODO check off by one + # dsm_pending_packed = with_byte(dsm_pending_packed, index=batch_idx, value=dsm_counter + (num_c_stage - 1)) + if cutlass.const_expr(self.params.src_signals is not None): + if batch_idx < self.params.masked_m.shape[0] - 1: + wait_signal( + self.params.src_signals.toint() + sizeof_i32 * (batch_idx + 1), + expect_value=self.params.src_signal_expect_value, + ) + accum_tile_m += cute.ceil_div( self.params.masked_m[batch_idx], self.params.c_tiler[0] ) @@ -290,16 +415,27 @@ def _get_current_work_for_linear_idx( ) ) - return WorkTileInfo(cur_tile_coord, is_valid) + return WorkTileInfo(cur_tile_coord, is_valid) #, dsm_pending_packed @dsl_user_op - def get_current_work(self, *, loc=None, ip=None) -> WorkTileInfo: + def get_current_work( + self, + # dsm_pending_packed: Optional[Uint64] = None, + # dsm_counter: Optional[Uint8] = None, + # num_c_stage: Optional[int] = None, + *, loc=None, ip=None, + ) -> WorkTileInfo: # Tuple[WorkTileInfo, Optional[Uint64]]: return self._get_current_work_for_linear_idx( self._current_work_linear_idx, + # dsm_pending_packed=dsm_pending_packed, + # dsm_counter=dsm_counter, + # num_c_stage=num_c_stage, ) @dsl_user_op def initial_work_tile_info(self, *, loc=None, ip=None) -> WorkTileInfo: + # tile_info, _ = self.get_current_work(loc=loc, ip=ip) + # return tile_info return self.get_current_work(loc=loc, ip=ip) @dsl_user_op @@ -434,6 +570,7 @@ def __init__( mma_tiler_mn: Tuple[int, int], cluster_shape_mn: Tuple[int, int], sm_version: str, + src_signal_expect_value: int, ): """Initializes the configuration for a Blackwell dense GEMM kernel. @@ -464,6 +601,7 @@ def __init__( self.cluster_shape_mn = cluster_shape_mn # K dimension is deferred in _setup_attributes self.mma_tiler = (*mma_tiler_mn, 1) + self.src_signal_expect_value = src_signal_expect_value self.cta_group = ( tcgen05.CtaGroup.TWO if self.use_2cta_instrs else tcgen05.CtaGroup.ONE @@ -641,6 +779,8 @@ def __call__( sfb_tensor: cute.Tensor, c_tensor: cute.Tensor, masked_m_tensor: cute.Tensor, + src_signals: Optional[cute.Pointer], + dst_signals: Optional[cute.Pointer], alpha_tensor: Optional[cute.Tensor], max_active_clusters: cutlass.Constexpr, stream: cuda.CUstream, @@ -804,6 +944,9 @@ def __call__( # Compute grid size self.tile_sched_params, grid = self._compute_grid( masked_m_tensor, # add masked layout + src_signals, + self.src_signal_expect_value, + dst_signals, c_tensor, self.cta_tile_shape_mnk, self.cluster_shape_mn, @@ -1177,6 +1320,16 @@ def kernel( barrier_id=self.cta_sync_bar_id, number_of_threads=self.threads_per_cta ) + # TODO may optimize if too slow + if cutlass.const_expr(tile_sched_params.src_signals is not None): + wait_signal( + tile_sched_params.src_signals.toint() + sizeof_i32 * 0, + expect_value=tile_sched_params.src_signal_expect_value, + ) + + if cutlass.const_expr(DEBUG_EXIT_AFTER_FIRST_WAIT): + return + # # Specialized TMA load warp # @@ -1281,6 +1434,7 @@ def kernel( # Advance to next tile # tile_sched.advance_to_next_work() + # work_tile, _ = tile_sched.get_current_work() work_tile = tile_sched.get_current_work() # @@ -1485,6 +1639,7 @@ def kernel( # Advance to next tile # tile_sched.advance_to_next_work() + # work_tile, _ = tile_sched.get_current_work() work_tile = tile_sched.get_current_work() # @@ -1568,6 +1723,15 @@ def kernel( producer_group=c_producer_group, ) + # if cutlass.const_expr(tile_sched_params.dst_signals is not None): + # assert self.num_c_stage < 256, "must be representable in 1 byte" + # num_experts = tile_sched_params.masked_m.shape[0] + # assert num_experts <= 8, "need to be packable into a u64" + # dsm_pending_packed = Uint64(0) + # dsm_pending_idx = Int32(0) + # dsm_counter = Uint8(0) + dsm_pending_group_idx_and_m_block_idx = (Int32(-1), Int32(-1)) + while work_tile.is_valid_tile: # Get tile coord from tile scheduler cur_tile_coord = work_tile.tile_idx @@ -1646,6 +1810,10 @@ def kernel( number_of_threads=epilog_threads, ) + if cutlass.const_expr(tile_sched_params.dst_signals is not None): + assert subtile_cnt >= self.num_c_stage - 1 + dsm_will_write_signals = (subtile_idx == self.num_c_stage - 2) and (dsm_pending_group_idx_and_m_block_idx[1] != -1) + # # TMA store C to global memory # @@ -1655,14 +1823,51 @@ def kernel( bSG_sC[(None, c_buffer)], bSG_gC[(None, subtile_idx)], ) + # Fence and barrier to make sure shared memory store is visible to TMA store c_pipeline.producer_commit() - c_pipeline.producer_acquire() + + if cutlass.const_expr(tile_sched_params.dst_signals is not None): + # dsm_counter = (dsm_counter + 1).to(Uint8) + # dsm_will_write_signals = read_byte(dsm_pending_packed, dsm_pending_idx) == dsm_counter + + if dsm_will_write_signals: + # The original c_pipeline.producer_acquire() + # := PipelineTmaStore.producer_acquire() + # := TmaStoreFence.wait() + # := cute.arch.cp_async_bulk_wait_group(self.num_stages - 1, read=True) + cute.arch.cp_async_bulk_wait_group( + self.num_c_stage - 1, + # Change `read` from True to False to also wait writes + read=False, + ) + else: + c_pipeline.producer_acquire() + + else: + c_pipeline.producer_acquire() + cute.arch.barrier( barrier_id=self.epilog_sync_bar_id, number_of_threads=epilog_threads, ) + if cutlass.const_expr(tile_sched_params.dst_signals is not None): + lane_id = tidx % 32 + if warp_idx == self.epilog_warp_id[0] and lane_id == 0: + # while ( + # (dsm_pending_idx < num_experts) and + # (read_byte(dsm_pending_packed, dsm_pending_idx) == dsm_counter) + # ): + if dsm_will_write_signals: + # TODO unify w/ the other branch + # atomic_add_release_global( + # tile_sched_params.dst_signals.toint() + sizeof_i32 * dsm_pending_idx, + # value=1, + # ) + # dsm_pending_idx += 1 + write_signal(tile_sched_params, dsm_pending_group_idx_and_m_block_idx) + # # Async arrive accumulator buffer empty # @@ -1670,11 +1875,19 @@ def kernel( acc_pipeline.consumer_release(acc_consumer_state) acc_consumer_state.advance() + if cutlass.const_expr(tile_sched_params.dst_signals is not None): + dsm_pending_group_idx_and_m_block_idx = (work_tile.tile_idx[2], work_tile.tile_idx[0]) + # # Advance to next tile # tile_sched.advance_to_next_work() - work_tile = tile_sched.get_current_work() + # work_tile, dsm_pending_packed = tile_sched.get_current_work( + work_tile = tile_sched.get_current_work( + # dsm_pending_packed=dsm_pending_packed, + # dsm_counter=dsm_counter, + # num_c_stage=self.num_c_stage, + ) # # Dealloc the tensor memory buffer @@ -1697,7 +1910,31 @@ def kernel( # # Wait for C store complete # - c_pipeline.producer_tail() + if cutlass.const_expr(tile_sched_params.dst_signals is not None): + # The original c_pipeline.producer_tail() + # := PipelineTmaStore.producer_tail() + # := TmaStoreFence.tail() + # := cute.arch.cp_async_bulk_wait_group(0, read=True) + cute.arch.cp_async_bulk_wait_group( + 0, + # Change `read` from True to False to also wait writes + read=False, + ) + + lane_id = tidx % 32 + if warp_idx == self.epilog_warp_id[0] and lane_id == 0: + # while dsm_pending_idx < num_experts: + # # TODO unify w/ the other branch + # atomic_add_release_global( + # tile_sched_params.dst_signals.toint() + sizeof_i32 * dsm_pending_idx, + # value=1, + # ) + # dsm_pending_idx += 1 + if dsm_pending_group_idx_and_m_block_idx[1] != -1: + write_signal(tile_sched_params, dsm_pending_group_idx_and_m_block_idx) + + else: + c_pipeline.producer_tail() def mainloop_s2t_copy_and_partition( self, @@ -2009,6 +2246,9 @@ def _compute_stages( @staticmethod def _compute_grid( masked_m_tensor: cute.Tensor, + src_signals: Optional[cute.Pointer], + src_signals_expect_value: int, + dst_signals: Optional[cute.Pointer], c: cute.Tensor, cta_tile_shape_mnk: Tuple[int, int, int], cluster_shape_mn: Tuple[int, int], @@ -2034,7 +2274,7 @@ def _compute_grid( cluster_shape_mnl = (*cluster_shape_mn, 1) tile_sched_params = MaskedSchedulerParams( - masked_m_tensor, c, c_tiler, cluster_shape_mnl + masked_m_tensor, src_signals, src_signals_expect_value, dst_signals, c, c_tiler, cluster_shape_mnl ) grid = MaskedScheduler.get_grid_shape(tile_sched_params, max_active_clusters) @@ -2424,6 +2664,7 @@ def __init__( cluster_shape_mn: Tuple[int, int], sm_count: int, sm_version: str, + src_signal_expect_value: int, ): self._m = m self._n = n @@ -2439,6 +2680,7 @@ def __init__( self._sf_vec_size = sf_vec_size self._mma_tiler_mn = mma_tiler_mn self._cluster_shape_mn = cluster_shape_mn + self._src_signal_expect_value = src_signal_expect_value if not Sm100BlockScaledPersistentDenseGemmKernel.can_implement( ab_dtype, @@ -2478,6 +2720,8 @@ def __call__( sfb_ptr: cute.Pointer, c_ptr: cute.Pointer, masked_m_ptr: cute.Pointer, + src_signals_ptr: Optional[cute.Pointer], + dst_signals_ptr: Optional[cute.Pointer], alpha_ptr: cute.Pointer, current_stream: cuda.CUstream, ): @@ -2566,6 +2810,7 @@ def ceil_div(a, b): mma_tiler_mn=self._mma_tiler_mn, cluster_shape_mn=self._cluster_shape_mn, sm_version=self._sm_version, + src_signal_expect_value=self._src_signal_expect_value, )( a_tensor, b_tensor, @@ -2573,6 +2818,8 @@ def ceil_div(a, b): sfb_tensor, c_tensor, masked_m_tensor, + src_signals_ptr, + dst_signals_ptr, alpha_tensor, self._max_active_clusters, current_stream, @@ -2597,6 +2844,9 @@ def get_cute_dsl_compiled_masked_gemm_kernel( cluster_shape_mn: Tuple[int, int], sm_count: int, sm_version: str, + src_signal_expect_value: int, + enable_src_signals: bool, + enable_dst_signals: bool, ) -> Callable: def get_cute_pointers( input_tensors: Optional[List[torch.tensor]], @@ -2609,8 +2859,16 @@ def get_cute_pointers( sfb_data_ptr, c_data_ptr, masked_m_data_ptr, + src_signals_data_ptr, + dst_signals_data_ptr, alpha_data_ptr, - ) = [16 for _ in range(7)] + ) = [16 for _ in range(9)] + + if not enable_src_signals: + src_signals_data_ptr = None + if not enable_dst_signals: + dst_signals_data_ptr = None + else: ( a_tensor_gpu, @@ -2619,8 +2877,14 @@ def get_cute_pointers( sfb_tensor_gpu, c_tensor_gpu, masked_m_tensor_gpu, + src_signals_tensor_gpu, + dst_signals_tensor_gpu, alpha_tensor_gpu, ) = input_tensors + + assert enable_src_signals == (src_signals_tensor_gpu is not None) + assert enable_dst_signals == (dst_signals_tensor_gpu is not None) + ( a_data_ptr, b_data_ptr, @@ -2628,6 +2892,8 @@ def get_cute_pointers( sfb_data_ptr, c_data_ptr, masked_m_data_ptr, + src_signals_data_ptr, + dst_signals_data_ptr, alpha_data_ptr, ) = ( a_tensor_gpu.data_ptr(), @@ -2636,6 +2902,8 @@ def get_cute_pointers( sfb_tensor_gpu.data_ptr(), c_tensor_gpu.data_ptr(), masked_m_tensor_gpu.data_ptr(), + src_signals_tensor_gpu.data_ptr() if src_signals_tensor_gpu is not None else None, + dst_signals_tensor_gpu.data_ptr() if dst_signals_tensor_gpu is not None else None, alpha_tensor_gpu.data_ptr() if alpha_tensor_gpu is not None else None, ) @@ -2675,6 +2943,26 @@ def get_cute_pointers( cute.AddressSpace.gmem, assumed_align=16, ) + src_signals_ptr = ( + make_ptr( + cutlass.Uint32, + src_signals_data_ptr, + cute.AddressSpace.gmem, + assumed_align=16, + ) + if src_signals_data_ptr is not None + else None + ) + dst_signals_ptr = ( + make_ptr( + cutlass.Uint32, + dst_signals_data_ptr, + cute.AddressSpace.gmem, + assumed_align=16, + ) + if dst_signals_data_ptr is not None + else None + ) alpha_ptr = ( make_ptr( alpha_dtype, @@ -2686,7 +2974,7 @@ def get_cute_pointers( else None ) - return [a_ptr, b_ptr, sfa_ptr, sfb_ptr, c_ptr, masked_m_ptr, alpha_ptr] + return [a_ptr, b_ptr, sfa_ptr, sfb_ptr, c_ptr, masked_m_ptr, src_signals_ptr, dst_signals_ptr, alpha_ptr] kernel = cute.compile( MaskedBatchedMatmulCuteDSL( @@ -2706,6 +2994,7 @@ def get_cute_pointers( cluster_shape_mn=cluster_shape_mn, sm_count=sm_count, sm_version=sm_version, + src_signal_expect_value=src_signal_expect_value, ), *get_cute_pointers(None), cutlass_torch.current_stream(), @@ -2717,6 +3006,8 @@ def tensor_api( sfa_tensor_gpu: torch.Tensor, sfb_tensor_gpu: torch.Tensor, masked_m_tensor_gpu: torch.Tensor, + src_signals_tensor_gpu: torch.Tensor, + dst_signals_tensor_gpu: torch.Tensor, c_tensor_gpu: Optional[torch.Tensor] = None, alpha_tensor_gpu: Optional[torch.Tensor] = None, ): @@ -2741,6 +3032,8 @@ def tensor_api( sfb_tensor_gpu, c_tensor_gpu, masked_m_tensor_gpu, + src_signals_tensor_gpu, + dst_signals_tensor_gpu, alpha_tensor_gpu, ] ), @@ -2762,6 +3055,9 @@ def grouped_gemm_nt_masked( sf_dtype: str, c_dtype: str, sf_vec_size: int, + src_signals: Optional[torch.Tensor] = None, + src_signal_expect_value: int = 0, + dst_signals: Optional[torch.Tensor] = None, sm_count: Optional[int] = None, **kwargs, ): @@ -2825,6 +3121,11 @@ def grouped_gemm_nt_masked( if major == 11 and minor == 0: raise ValueError("SM110 is not supported for cute-dsl backend.") + if dst_signals is not None: + assert dst_signals.dtype == torch.int32 + assert dst_signals.shape == (l, ceil_div(m, mma_tiler_mn[0])), f"{dst_signals.shape=} {l=} {m=} {mma_tiler_mn=}" + assert dst_signals.is_contiguous() + return get_cute_dsl_compiled_masked_gemm_kernel( m=m, n=n, @@ -2842,6 +3143,9 @@ def grouped_gemm_nt_masked( cluster_shape_mn=cluster_shape_mn, sm_count=sm_count, sm_version=f"sm_{major}{minor}", + src_signal_expect_value=src_signal_expect_value, + enable_src_signals=src_signals is not None, + enable_dst_signals=dst_signals is not None, )( a_tensor_gpu=a_torch, b_tensor_gpu=b_torch, @@ -2849,5 +3153,7 @@ def grouped_gemm_nt_masked( sfb_tensor_gpu=sfb_torch, c_tensor_gpu=c_torch, masked_m_tensor_gpu=masked_m, + src_signals_tensor_gpu=src_signals, + dst_signals_tensor_gpu=dst_signals, alpha_tensor_gpu=alpha, ) diff --git a/tests/test_cute_dsl_blockscaled_gemm.py b/tests/test_cute_dsl_blockscaled_gemm.py index 8b5b69a46f..c2d07c96e0 100644 --- a/tests/test_cute_dsl_blockscaled_gemm.py +++ b/tests/test_cute_dsl_blockscaled_gemm.py @@ -21,6 +21,7 @@ get_cutlass_dtype, is_cute_dsl_available, ) +from flashinfer.utils import ceil_div @pytest.mark.skipif( @@ -59,6 +60,9 @@ @pytest.mark.parametrize("sm_count", [132, None]) @pytest.mark.parametrize("tolerance", [1e-01]) @pytest.mark.parametrize("iterations", [3]) +# TODO enable in tests +@pytest.mark.parametrize("enable_src_signals", [True]) +@pytest.mark.parametrize("enable_dst_signals", [False, True]) def test_blockscaled_gemm_python_interface( lm: Tuple[int, int], kn: Tuple[int, int], @@ -76,6 +80,8 @@ def test_blockscaled_gemm_python_interface( sm_count: int, tolerance: float, iterations: int, + enable_src_signals: int, + enable_dst_signals: int, ): torch.manual_seed(42) device = torch.device("cuda:0") @@ -107,6 +113,8 @@ def test_blockscaled_gemm_python_interface( pytest.skip( f"Unsupported testcase {ab_dtype}, {sf_dtype}, {sf_vec_size}, {c_dtype}, {mma_tiler_mn}, {cluster_shape_mn}, {m}, {n}, {k}, {l}, {a_major}, {b_major}, {c_major}" ) + if enable_dst_signals and sf_vec_size != 16: + pytest.skip(f"Unsupported testcase {enable_dst_signals=} {sf_vec_size=}") if not (a_major == "k" and b_major == "k" and c_major == "n"): # not supported since we try to align deepgemm for now @@ -175,6 +183,9 @@ def test_blockscaled_gemm_python_interface( masked_m_tensor = torch.randint(0, m, (l,), dtype=torch.int32, device=device) for _ in range(iterations): + # dst_signals = torch.zeros((l,), dtype=torch.uint32, device="cuda") if enable_dst_signals else None + dst_signals = torch.zeros((l, ceil_div(m, mma_tiler_mn[0])), dtype=torch.int32, device="cuda") if enable_dst_signals else None + # deepgemm-like python interface: fp4 packed, for DLFW integration grouped_gemm_nt_masked( (a_torch, sfa_torch), @@ -190,8 +201,17 @@ def test_blockscaled_gemm_python_interface( alpha=alpha_tensor, alpha_dtype=alpha_dtype, sm_count=sm_count, + src_signals=torch.ones((l,), dtype=torch.uint32, device="cuda") if enable_src_signals else None, + src_signal_expect_value=1 if enable_src_signals else 0, + dst_signals=dst_signals, ) + if enable_dst_signals: + expect_dst_signals = torch.zeros_like(dst_signals) + for expert_idx in range(l): + expect_dst_signals[expert_idx, :ceil_div(masked_m_tensor[expert_idx], mma_tiler_mn[0])] = ceil_div(n, mma_tiler_mn[1]) + assert torch.all(dst_signals == expect_dst_signals), f"{dst_signals=} {expect_dst_signals=} {masked_m_tensor=}" + # compute ref output if not fuse_alpha: alpha_tensor = torch.ones(l, dtype=torch.float32, device=device) @@ -244,8 +264,10 @@ def test_blockscaled_gemm_python_interface( if __name__ == "__main__": test_blockscaled_gemm_python_interface( - lm=(1, 1024), - kn=(7168, 4096), + # lm=(1, 1024), + # kn=(7168, 4096), + lm=(6, 1024), + kn=(2048, 7168), ab_dtype="float4_e2m1fn", sf_dtype="float8_e8m0fnu", sf_vec_size=16, @@ -260,4 +282,29 @@ def test_blockscaled_gemm_python_interface( tolerance=1e-01, iterations=3, sm_count=132, + enable_src_signals=False, + enable_dst_signals=True, ) + + # TODO can use this to reproduce illegal mem access + # test_blockscaled_gemm_python_interface( + # # TODO real value in masked_m is smaller than this m shape + # lm=(6, 36864), + # kn=(7168, 4096), + # ab_dtype="float4_e2m1fn", + # sf_dtype="float8_e8m0fnu", + # sf_vec_size=16, + # c_dtype="float16", + # a_major="k", + # b_major="k", + # c_major="n", + # fuse_alpha=False, + # alpha_dtype="float32", + # mma_tiler_mn=(128, 128), + # cluster_shape_mn=(2, 1), + # tolerance=1e-01, + # iterations=3, + # sm_count=132, + # enable_src_signals=False, + # enable_dst_signals=False, + # )