Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,10 @@ def compute_num_stages(
# that is not fully captured by the simple stage_size model above.
if is_persistent and (lhs_dtype == FP32 or rhs_dtype == FP32):
smem_capacity -= 32 * 1024
if is_persistent and not has_native_mxfp and epilogue_reduction_n > 1:
# Hopper fused reductions materialize an additional reduced-N output
# tile in smem.
smem_capacity -= int(block_m * acc_block_n * out_itemsize)
smem_capacity = max(smem_capacity, 0)
max_stages = 5 if rhs_dtype == FP4 else 4 # maybe 5 everywhere; just haven't tested
num_stages = min(smem_capacity // int(stage_size), max_stages)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,13 +104,11 @@ def swizzle_data(self, data: torch.Tensor) -> torch.Tensor:
Target layout: [E, num_tiles_k, num_tiles_n, tile_n, tile_k_packed]
This matches the baseline TMA block shape [block_n, packed_block_k] after swapping.
"""
if data.ndim == 2:
data = data.unsqueeze(0)
if data.ndim != 3:
raise ValueError(f"Expected 2D or 3D canonical data, got {data.ndim}D")

data = self._canonical_to_physical(data)
E, K_packed, N = data.shape
leading_shape = data.shape[:-2]
E = math.prod(leading_shape)
K_packed, N = data.shape[-2:]
data = data.reshape(E, K_packed, N)
tile_k_packed, tile_n, padded_K_packed, padded_N, num_tiles_k, num_tiles_n = \
self._compute_params(E, K_packed, N)

Expand Down Expand Up @@ -139,6 +137,7 @@ def unswizzle_data(self, data: torch.Tensor) -> torch.Tensor:
Input layout: [E, num_tiles_k, num_tiles_n, tile_n, tile_k_packed]
"""
E = data.shape[0]
leading_shape = self.shape[:-2]
# Recover original shape from self.shape (the logical shape passed to convert_layout)
orig_K_packed = self.shape[-2] // 2 if self.is_fp4 else self.shape[-2]
orig_N = self.shape[-1]
Expand All @@ -159,4 +158,6 @@ def unswizzle_data(self, data: torch.Tensor) -> torch.Tensor:
# Trim padding back to original shape
data = data[:, :orig_K_packed, :orig_N].contiguous()
data = self._physical_to_canonical(data)
return data if len(self.shape) == 3 else data.squeeze(0)
if not leading_shape:
return data.squeeze(0)
return data.reshape(*leading_shape, data.shape[-2], data.shape[-1])
Loading