|
20 | 20 | # |
21 | 21 | # This layout is handled internally in the tensor subclass. |
22 | 22 | import torch |
23 | | -from torchao.quantization.subclass import QuantizedLinearWeightBase |
| 23 | +from torch.utils._python_dispatch import return_and_correct_aliasing |
| 24 | +from typing_extensions import deprecated |
| 25 | + |
| 26 | + |
| 27 | +aten = torch.ops.aten |
24 | 28 |
|
25 | 29 |
|
26 | 30 | def down_size(size): |
@@ -129,6 +133,173 @@ def to_float( |
129 | 133 | return a * scale.unsqueeze(1) |
130 | 134 |
|
131 | 135 |
|
| 136 | +@deprecated("QuantizedLinearWeightBase is deleted from torchao. DO NOT USE!") |
| 137 | +class QuantizedLinearWeightBase(torch.Tensor): |
| 138 | + """ |
| 139 | + *** LEGACY TORCHAO TENSOR SUBCLASS *** |
| 140 | +
|
| 141 | + Note: this subclass no longer exists in torchao. No one should be importing or extending this |
| 142 | + subclass anymore. It is added back here just for internal executorch BC. DO NOT USE! |
| 143 | +
|
| 144 | + Base quantized tensor subclass for quantized linear weights. When the from_float method is used, |
| 145 | + to create an instance of any QuantizedLinearWeightBase, we assume the input |
| 146 | + weight is oriented the way it is in a normal linear op, i.e. out-channels x in-channels. |
| 147 | +
|
| 148 | + The shape and dtype of the tensor subclass represent how the tensor subclass looks externally, |
| 149 | + regardless of the internal representation's type or orientation. |
| 150 | + """ |
| 151 | + |
| 152 | + @staticmethod |
| 153 | + def __new__(cls, int_data, transposed, shape, *args, **kwargs): |
| 154 | + kwargs["device"] = int_data.device |
| 155 | + kwargs["layout"] = ( |
| 156 | + kwargs.get("layout") if kwargs.get("layout", False) else int_data.layout |
| 157 | + ) |
| 158 | + assert "dtype" in kwargs |
| 159 | + assert not kwargs.get("requires_grad", False) |
| 160 | + kwargs["requires_grad"] = False |
| 161 | + return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] |
| 162 | + |
| 163 | + def __init__(self, int_data, transposed, *args, **kwargs): |
| 164 | + self.int_data = int_data |
| 165 | + |
| 166 | + self.transposed = transposed |
| 167 | + |
| 168 | + @staticmethod |
| 169 | + def _quantized_op(act_mat, w_qtensor, bias): |
| 170 | + pass |
| 171 | + |
| 172 | + def __repr__(self): |
| 173 | + return ( |
| 174 | + f"{self.__class__.__name__}(data={self.dequantize()}, shape={self.shape}, " |
| 175 | + f"device={self.device}, dtype={self.dtype}, requires_grad={self.requires_grad})" |
| 176 | + ) |
| 177 | + |
| 178 | + def dequantize(self): |
| 179 | + pass |
| 180 | + |
| 181 | + def int_repr(self): |
| 182 | + pass |
| 183 | + |
| 184 | + def q_params(self): |
| 185 | + pass |
| 186 | + |
| 187 | + def half(self): |
| 188 | + return self.to(torch.float16) |
| 189 | + |
| 190 | + def _get_to_kwargs(self, *args, **kwargs): |
| 191 | + device, dtype, _, memory_format = torch._C._nn._parse_to(*args, **kwargs) |
| 192 | + device = self.device if device is None else device |
| 193 | + dtype = self.dtype if dtype is None else dtype |
| 194 | + memory_format = ( |
| 195 | + memory_format if memory_format is not None else torch.preserve_format |
| 196 | + ) |
| 197 | + kwargs = { |
| 198 | + "device": device, |
| 199 | + "dtype": dtype, |
| 200 | + "memory_format": memory_format, |
| 201 | + } |
| 202 | + return kwargs |
| 203 | + |
| 204 | + def _apply_fn_to_data(self, fn): |
| 205 | + pass |
| 206 | + |
| 207 | + def _change_shape(self): |
| 208 | + pass |
| 209 | + |
| 210 | + def __tensor_flatten__(self): |
| 211 | + pass |
| 212 | + |
| 213 | + @classmethod |
| 214 | + def __tensor_unflatten__( |
| 215 | + cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride |
| 216 | + ): |
| 217 | + pass |
| 218 | + |
| 219 | + @classmethod |
| 220 | + def from_float(cls, input_float): |
| 221 | + pass |
| 222 | + |
| 223 | + # __torch_function__ = torch._C._disabled_torch_function_impl |
| 224 | + |
| 225 | + @classmethod |
| 226 | + def __torch_function__(cls, func, types, args=(), kwargs=None): |
| 227 | + kwargs = {} if kwargs is None else kwargs |
| 228 | + |
| 229 | + if func is torch.nn.functional.linear: |
| 230 | + mat1, w_qtensor, bias = ( |
| 231 | + args[0], |
| 232 | + args[1], |
| 233 | + args[2] if len(args) > 2 else None, |
| 234 | + ) |
| 235 | + assert not w_qtensor.transposed |
| 236 | + return cls._quantized_op(mat1, w_qtensor, bias) |
| 237 | + |
| 238 | + try: |
| 239 | + with torch._C.DisableTorchFunctionSubclass(): |
| 240 | + return func(*args, **kwargs) |
| 241 | + except Exception: |
| 242 | + print(f"ERR: subclass doesn't implement {func}") |
| 243 | + |
| 244 | + @classmethod |
| 245 | + def __torch_dispatch__(cls, func, types, args, kwargs): |
| 246 | + # two scenarios where we currently fall back to vanilla mm: |
| 247 | + # 1 - when tensor is on CPU: we are missing qmm for CPU, but we should have a CPU implementation |
| 248 | + # for consistency and to allow people to test |
| 249 | + # 2 - we're given non-floats - quantizing long to int8 is crazy |
| 250 | + if ( |
| 251 | + func in [aten.mm.default, aten.addmm.default] |
| 252 | + and args[0].is_floating_point() |
| 253 | + and args[0].is_cuda |
| 254 | + ): |
| 255 | + if func == aten.addmm.default: |
| 256 | + assert args[1].shape[-1] == args[2].shape[0], ( |
| 257 | + f"need mat1 shape: {args[1].shape} final" |
| 258 | + f"dim to match mat2 shape: {args[2].shape} first dim " |
| 259 | + ) |
| 260 | + mat1, w_qtensor, bias = ( |
| 261 | + args[1], |
| 262 | + args[2], |
| 263 | + args[0], |
| 264 | + ) |
| 265 | + else: |
| 266 | + assert args[0].shape[-1] == args[1].shape[0], ( |
| 267 | + f"need mat1 shape: {args[0].shape} final dim" |
| 268 | + f"to match mat2 shape: {args[1].shape} first dim" |
| 269 | + ) |
| 270 | + mat1, w_qtensor, bias = ( |
| 271 | + args[0], |
| 272 | + args[1], |
| 273 | + None if len(args) == 2 else args[2], |
| 274 | + ) |
| 275 | + # call the quantized op for the specific type |
| 276 | + # of quantized tensor subclass |
| 277 | + return cls._quantized_op(mat1, w_qtensor, bias) |
| 278 | + |
| 279 | + if func is aten.detach.default: |
| 280 | + return return_and_correct_aliasing( |
| 281 | + func, args, kwargs, args[0]._apply_fn_to_data(torch.detach) |
| 282 | + ) |
| 283 | + |
| 284 | + if func is aten.clone.default: |
| 285 | + return return_and_correct_aliasing( |
| 286 | + func, args, kwargs, args[0]._apply_fn_to_data(torch.clone) |
| 287 | + ) |
| 288 | + |
| 289 | + if func is aten.t.default: |
| 290 | + args[0].transposed = not args[0].transposed |
| 291 | + new = args[0]._change_shape(args[0].shape[::-1]) |
| 292 | + return return_and_correct_aliasing(func, args, kwargs, new) |
| 293 | + |
| 294 | + if func is aten._to_copy.default: |
| 295 | + return return_and_correct_aliasing( |
| 296 | + func, |
| 297 | + args, |
| 298 | + kwargs, |
| 299 | + args[0].to(*args[1:], **kwargs)._apply_fn_to_data(torch.clone), |
| 300 | + ) |
| 301 | + |
| 302 | + |
132 | 303 | class GGMLInt4LinearWeight(QuantizedLinearWeightBase): |
133 | 304 | """ |
134 | 305 | A Tensor subclass that when applied to a weight used in a linear op/module, |
|
0 commit comments