Skip to content

Commit 31ab988

Browse files
andrewor14facebook-github-bot
authored andcommitted
Remove internal executorch dependency on torchao.quantization.subclass (pytorch#15223)
Summary: **Summary:** This is a really old quantization API that we recently removed in torchao (D84842047). No one should be calling it anymore. For BC, let's just copy the base class into executorch for now. We should delete this in the future. **Test Plan:** CI bypass-github-export-checks Reviewed By: vkuzo Differential Revision: D84921134
1 parent 5076749 commit 31ab988

File tree

3 files changed

+174
-2
lines changed

3 files changed

+174
-2
lines changed

examples/models/llama/experimental/__init__.py

Whitespace-only changes.

examples/models/llama/experimental/load_gguf_q4_0.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,8 @@
2626
from executorch.extension.gguf_util.load_gguf import GGUFWeights, load_file
2727
from gguf import ReaderTensor
2828
from gguf.constants import GGMLQuantizationType
29-
from torchao.quantization.subclass import QuantizedLinearWeightBase
29+
30+
from .subclass import QuantizedLinearWeightBase
3031

3132
FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s"
3233
logging.basicConfig(level=logging.INFO, format=FORMAT)

examples/models/llama/experimental/subclass.py

Lines changed: 172 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,11 @@
2020
#
2121
# This layout is handled internally in the tensor subclass.
2222
import torch
23-
from torchao.quantization.subclass import QuantizedLinearWeightBase
23+
from torch.utils._python_dispatch import return_and_correct_aliasing
24+
from typing_extensions import deprecated
25+
26+
27+
aten = torch.ops.aten
2428

2529

2630
def down_size(size):
@@ -129,6 +133,173 @@ def to_float(
129133
return a * scale.unsqueeze(1)
130134

131135

136+
@deprecated("QuantizedLinearWeightBase is deleted from torchao. DO NOT USE!")
137+
class QuantizedLinearWeightBase(torch.Tensor):
138+
"""
139+
*** LEGACY TORCHAO TENSOR SUBCLASS ***
140+
141+
Note: this subclass no longer exists in torchao. No one should be importing or extending this
142+
subclass anymore. It is added back here just for internal executorch BC. DO NOT USE!
143+
144+
Base quantized tensor subclass for quantized linear weights. When the from_float method is used,
145+
to create an instance of any QuantizedLinearWeightBase, we assume the input
146+
weight is oriented the way it is in a normal linear op, i.e. out-channels x in-channels.
147+
148+
The shape and dtype of the tensor subclass represent how the tensor subclass looks externally,
149+
regardless of the internal representation's type or orientation.
150+
"""
151+
152+
@staticmethod
153+
def __new__(cls, int_data, transposed, shape, *args, **kwargs):
154+
kwargs["device"] = int_data.device
155+
kwargs["layout"] = (
156+
kwargs.get("layout") if kwargs.get("layout", False) else int_data.layout
157+
)
158+
assert "dtype" in kwargs
159+
assert not kwargs.get("requires_grad", False)
160+
kwargs["requires_grad"] = False
161+
return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined]
162+
163+
def __init__(self, int_data, transposed, *args, **kwargs):
164+
self.int_data = int_data
165+
166+
self.transposed = transposed
167+
168+
@staticmethod
169+
def _quantized_op(act_mat, w_qtensor, bias):
170+
pass
171+
172+
def __repr__(self):
173+
return (
174+
f"{self.__class__.__name__}(data={self.dequantize()}, shape={self.shape}, "
175+
f"device={self.device}, dtype={self.dtype}, requires_grad={self.requires_grad})"
176+
)
177+
178+
def dequantize(self):
179+
pass
180+
181+
def int_repr(self):
182+
pass
183+
184+
def q_params(self):
185+
pass
186+
187+
def half(self):
188+
return self.to(torch.float16)
189+
190+
def _get_to_kwargs(self, *args, **kwargs):
191+
device, dtype, _, memory_format = torch._C._nn._parse_to(*args, **kwargs)
192+
device = self.device if device is None else device
193+
dtype = self.dtype if dtype is None else dtype
194+
memory_format = (
195+
memory_format if memory_format is not None else torch.preserve_format
196+
)
197+
kwargs = {
198+
"device": device,
199+
"dtype": dtype,
200+
"memory_format": memory_format,
201+
}
202+
return kwargs
203+
204+
def _apply_fn_to_data(self, fn):
205+
pass
206+
207+
def _change_shape(self):
208+
pass
209+
210+
def __tensor_flatten__(self):
211+
pass
212+
213+
@classmethod
214+
def __tensor_unflatten__(
215+
cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride
216+
):
217+
pass
218+
219+
@classmethod
220+
def from_float(cls, input_float):
221+
pass
222+
223+
# __torch_function__ = torch._C._disabled_torch_function_impl
224+
225+
@classmethod
226+
def __torch_function__(cls, func, types, args=(), kwargs=None):
227+
kwargs = {} if kwargs is None else kwargs
228+
229+
if func is torch.nn.functional.linear:
230+
mat1, w_qtensor, bias = (
231+
args[0],
232+
args[1],
233+
args[2] if len(args) > 2 else None,
234+
)
235+
assert not w_qtensor.transposed
236+
return cls._quantized_op(mat1, w_qtensor, bias)
237+
238+
try:
239+
with torch._C.DisableTorchFunctionSubclass():
240+
return func(*args, **kwargs)
241+
except Exception:
242+
print(f"ERR: subclass doesn't implement {func}")
243+
244+
@classmethod
245+
def __torch_dispatch__(cls, func, types, args, kwargs):
246+
# two scenarios where we currently fall back to vanilla mm:
247+
# 1 - when tensor is on CPU: we are missing qmm for CPU, but we should have a CPU implementation
248+
# for consistency and to allow people to test
249+
# 2 - we're given non-floats - quantizing long to int8 is crazy
250+
if (
251+
func in [aten.mm.default, aten.addmm.default]
252+
and args[0].is_floating_point()
253+
and args[0].is_cuda
254+
):
255+
if func == aten.addmm.default:
256+
assert args[1].shape[-1] == args[2].shape[0], (
257+
f"need mat1 shape: {args[1].shape} final"
258+
f"dim to match mat2 shape: {args[2].shape} first dim "
259+
)
260+
mat1, w_qtensor, bias = (
261+
args[1],
262+
args[2],
263+
args[0],
264+
)
265+
else:
266+
assert args[0].shape[-1] == args[1].shape[0], (
267+
f"need mat1 shape: {args[0].shape} final dim"
268+
f"to match mat2 shape: {args[1].shape} first dim"
269+
)
270+
mat1, w_qtensor, bias = (
271+
args[0],
272+
args[1],
273+
None if len(args) == 2 else args[2],
274+
)
275+
# call the quantized op for the specific type
276+
# of quantized tensor subclass
277+
return cls._quantized_op(mat1, w_qtensor, bias)
278+
279+
if func is aten.detach.default:
280+
return return_and_correct_aliasing(
281+
func, args, kwargs, args[0]._apply_fn_to_data(torch.detach)
282+
)
283+
284+
if func is aten.clone.default:
285+
return return_and_correct_aliasing(
286+
func, args, kwargs, args[0]._apply_fn_to_data(torch.clone)
287+
)
288+
289+
if func is aten.t.default:
290+
args[0].transposed = not args[0].transposed
291+
new = args[0]._change_shape(args[0].shape[::-1])
292+
return return_and_correct_aliasing(func, args, kwargs, new)
293+
294+
if func is aten._to_copy.default:
295+
return return_and_correct_aliasing(
296+
func,
297+
args,
298+
kwargs,
299+
args[0].to(*args[1:], **kwargs)._apply_fn_to_data(torch.clone),
300+
)
301+
302+
132303
class GGMLInt4LinearWeight(QuantizedLinearWeightBase):
133304
"""
134305
A Tensor subclass that when applied to a weight used in a linear op/module,

0 commit comments

Comments
 (0)