Skip to content

Commit 862a7a3

Browse files
committed
update tensor slicing for per-tensor/row/block scales
1 parent 664ae4d commit 862a7a3

File tree

1 file changed

+34
-27
lines changed

1 file changed

+34
-27
lines changed

torchao/float8/inference.py

Lines changed: 34 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -139,43 +139,50 @@ def _slice_scale_for_dimension(
139139
Slice the scale tensor appropriately based on the data tensor slicing.
140140
This function calculates how the scale should be sliced when the data tensor
141141
is sliced along a given dimension, taking into account the block structure.
142-
"""
143-
aten = torch.ops.aten
144142
145-
# Unsupported case for now, this would be 1 scale per data element
146-
if scale.shape == data_shape:
147-
return aten.slice.Tensor(scale, dim, start, end, step)
143+
Example:
144+
If data_shape is [256, 128] and scale shape is [1] (indicating per-tensor scaling),
145+
slicing along any dimension should return the same scale tensor.
148146
149-
# Reconstruct block sizes based on data shape and scale shape
150-
block_sizes = tuple(data_shape[i] // scale.shape[i] for i in range(len(data_shape)))
147+
If data_shape is [256, 128] and scale shape is [256] (indicating per-row scaling),
148+
and we slice data along dim=0 from 64 to 192, the corresponding scale
149+
"""
150+
aten = torch.ops.aten
151151

152-
if dim >= len(block_sizes):
153-
# Slicing beyond the dimensions we care about
152+
# Case 1: Per-tensor quantization (scalar scale)
153+
if scale.numel() <= 1:
154154
return scale
155155

156+
# Case 2: Per-row quantization (1D scale)
157+
# Scale is per-element along this dimension
158+
if scale.ndim == 1:
159+
if dim == 0:
160+
return aten.slice.Tensor(scale, 0, start, end, step)
161+
else:
162+
return scale
163+
164+
# Case 3: Per-block quantization (2D scale)
165+
block_sizes = tuple(
166+
data_shape[i] // scale.shape[i] for i in range(len(scale.shape))
167+
)
168+
156169
block_size_for_dim = block_sizes[dim]
157170

158-
if block_size_for_dim == 1:
159-
# Scale is per-element along this dimension
160-
# Slice away as normal
161-
return aten.slice.Tensor(scale, dim, start, end, step)
162-
else:
163-
# There is blocking in this dimension
164-
# Calculate which scale elements correspond to the sliced data
165-
scale_start = start // block_size_for_dim if start is not None else None
166-
scale_end = (
167-
(end + block_size_for_dim - 1) // block_size_for_dim
168-
if end is not None
169-
else None
171+
if step > 1:
172+
raise NotImplementedError(
173+
"Slicing with step > 1 is not implemented for scale tensors."
170174
)
171175

172-
# Error on Step > 1
173-
if step > 1:
174-
raise NotImplementedError(
175-
"Slicing with step > 1 is not implemented for scale tensors."
176-
)
176+
# There is blocking in this dimension
177+
# Calculate which scale elements correspond to the sliced data
178+
scale_start = start // block_size_for_dim if start is not None else None
179+
scale_end = (
180+
(end + block_size_for_dim - 1) // block_size_for_dim
181+
if end is not None
182+
else None
183+
)
177184

178-
return aten.slice.Tensor(scale, dim, scale_start, scale_end, 1)
185+
return aten.slice.Tensor(scale, dim, scale_start, scale_end, 1)
179186

180187

181188
def _is_rowwise_scaled(x: torch.Tensor) -> bool:

0 commit comments

Comments
 (0)