@@ -139,43 +139,50 @@ def _slice_scale_for_dimension(
139139 Slice the scale tensor appropriately based on the data tensor slicing.
140140 This function calculates how the scale should be sliced when the data tensor
141141 is sliced along a given dimension, taking into account the block structure.
142- """
143- aten = torch .ops .aten
144142
145- # Unsupported case for now, this would be 1 scale per data element
146- if scale . shape == data_shape :
147- return aten . slice . Tensor ( scale , dim , start , end , step )
143+ Example:
144+ If data_shape is [256, 128] and scale shape is [1] (indicating per-tensor scaling),
145+ slicing along any dimension should return the same scale tensor.
148146
149- # Reconstruct block sizes based on data shape and scale shape
150- block_sizes = tuple (data_shape [i ] // scale .shape [i ] for i in range (len (data_shape )))
147+ If data_shape is [256, 128] and scale shape is [256] (indicating per-row scaling),
148+ and we slice data along dim=0 from 64 to 192, the corresponding scale
149+ """
150+ aten = torch .ops .aten
151151
152- if dim >= len ( block_sizes ):
153- # Slicing beyond the dimensions we care about
152+ # Case 1: Per-tensor quantization (scalar scale)
153+ if scale . numel () <= 1 :
154154 return scale
155155
156+ # Case 2: Per-row quantization (1D scale)
157+ # Scale is per-element along this dimension
158+ if scale .ndim == 1 :
159+ if dim == 0 :
160+ return aten .slice .Tensor (scale , 0 , start , end , step )
161+ else :
162+ return scale
163+
164+ # Case 3: Per-block quantization (2D scale)
165+ block_sizes = tuple (
166+ data_shape [i ] // scale .shape [i ] for i in range (len (scale .shape ))
167+ )
168+
156169 block_size_for_dim = block_sizes [dim ]
157170
158- if block_size_for_dim == 1 :
159- # Scale is per-element along this dimension
160- # Slice away as normal
161- return aten .slice .Tensor (scale , dim , start , end , step )
162- else :
163- # There is blocking in this dimension
164- # Calculate which scale elements correspond to the sliced data
165- scale_start = start // block_size_for_dim if start is not None else None
166- scale_end = (
167- (end + block_size_for_dim - 1 ) // block_size_for_dim
168- if end is not None
169- else None
171+ if step > 1 :
172+ raise NotImplementedError (
173+ "Slicing with step > 1 is not implemented for scale tensors."
170174 )
171175
172- # Error on Step > 1
173- if step > 1 :
174- raise NotImplementedError (
175- "Slicing with step > 1 is not implemented for scale tensors."
176- )
176+ # There is blocking in this dimension
177+ # Calculate which scale elements correspond to the sliced data
178+ scale_start = start // block_size_for_dim if start is not None else None
179+ scale_end = (
180+ (end + block_size_for_dim - 1 ) // block_size_for_dim
181+ if end is not None
182+ else None
183+ )
177184
178- return aten .slice .Tensor (scale , dim , scale_start , scale_end , 1 )
185+ return aten .slice .Tensor (scale , dim , scale_start , scale_end , 1 )
179186
180187
181188def _is_rowwise_scaled (x : torch .Tensor ) -> bool :
0 commit comments