Skip to content

Commit e16a15f

Browse files
committed
Refine documentation
Signed-off-by: Barry Kang <[email protected]>
1 parent 76ca12d commit e16a15f

File tree

1 file changed

+11
-0
lines changed

1 file changed

+11
-0
lines changed

tensorrt_llm/quantization/utils/fp8_utils.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -555,7 +555,18 @@ def _transpose_kernel(input_ptr, output_ptr, M, N, stride_in_m, stride_in_n,
555555

556556

557557
def masked_transpose(input: torch.Tensor, n_available: int) -> torch.Tensor:
558+
"""
559+
Perform a masked transpose operation on a 2D tensor.
560+
561+
Args:
562+
input: Input tensor of shape (M, N)
563+
n_available: Number of columns to transpose (must be <= N)
564+
565+
Returns:
566+
Transposed tensor of shape (n_available, M)
567+
"""
558568
M, N = input.shape
569+
assert n_available <= N, "n_available must be less than or equal to N"
559570
BLOCK_SIZE = 32
560571
output = torch.empty((n_available, M),
561572
dtype=input.dtype,

0 commit comments

Comments
 (0)