Skip to content

Commit 73940fe

Browse files
andrewor14facebook-github-bot
authored andcommitted
Remove internal executorch dependency on torchao.quantization.subclass (pytorch#15223)
Summary: **Summary:** This is a really old quantization API that we recently removed in torchao (D84842047). No one should be calling it anymore. For BC, let's just copy the base class into executorch for now. We should delete this in the future. **Test Plan:** CI bypass-github-export-checks Reviewed By: vkuzo Differential Revision: D84921134
1 parent 5076749 commit 73940fe

File tree

3 files changed

+175
-2
lines changed

3 files changed

+175
-2
lines changed

examples/models/llama/experimental/__init__.py

Whitespace-only changes.

examples/models/llama/experimental/load_gguf_q4_0.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,8 @@
2626
from executorch.extension.gguf_util.load_gguf import GGUFWeights, load_file
2727
from gguf import ReaderTensor
2828
from gguf.constants import GGMLQuantizationType
29-
from torchao.quantization.subclass import QuantizedLinearWeightBase
29+
30+
from .subclass import QuantizedLinearWeightBase
3031

3132
FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s"
3233
logging.basicConfig(level=logging.INFO, format=FORMAT)

examples/models/llama/experimental/subclass.py

Lines changed: 173 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,12 @@
2020
#
2121
# This layout is handled internally in the tensor subclass.
2222
import torch
23-
from torchao.quantization.subclass import QuantizedLinearWeightBase
23+
24+
from executorch.exir._warnings import deprecated
25+
from torch.utils._python_dispatch import return_and_correct_aliasing
26+
27+
28+
aten = torch.ops.aten
2429

2530

2631
def down_size(size):
@@ -129,6 +134,173 @@ def to_float(
129134
return a * scale.unsqueeze(1)
130135

131136

137+
@deprecated("QuantizedLinearWeightBase is deleted from torchao. DO NOT USE!")
138+
class QuantizedLinearWeightBase(torch.Tensor):
139+
"""
140+
*** LEGACY TORCHAO TENSOR SUBCLASS ***
141+
142+
Note: this subclass no longer exists in torchao. No one should be importing or extending this
143+
subclass anymore. It is added back here just for internal executorch BC. DO NOT USE!
144+
145+
Base quantized tensor subclass for quantized linear weights. When the from_float method is used,
146+
to create an instance of any QuantizedLinearWeightBase, we assume the input
147+
weight is oriented the way it is in a normal linear op, i.e. out-channels x in-channels.
148+
149+
The shape and dtype of the tensor subclass represent how the tensor subclass looks externally,
150+
regardless of the internal representation's type or orientation.
151+
"""
152+
153+
@staticmethod
154+
def __new__(cls, int_data, transposed, shape, *args, **kwargs):
155+
kwargs["device"] = int_data.device
156+
kwargs["layout"] = (
157+
kwargs.get("layout") if kwargs.get("layout", False) else int_data.layout
158+
)
159+
assert "dtype" in kwargs
160+
assert not kwargs.get("requires_grad", False)
161+
kwargs["requires_grad"] = False
162+
return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined]
163+
164+
def __init__(self, int_data, transposed, *args, **kwargs):
165+
self.int_data = int_data
166+
167+
self.transposed = transposed
168+
169+
@staticmethod
170+
def _quantized_op(act_mat, w_qtensor, bias):
171+
pass
172+
173+
def __repr__(self):
174+
return (
175+
f"{self.__class__.__name__}(data={self.dequantize()}, shape={self.shape}, "
176+
f"device={self.device}, dtype={self.dtype}, requires_grad={self.requires_grad})"
177+
)
178+
179+
def dequantize(self):
180+
pass
181+
182+
def int_repr(self):
183+
pass
184+
185+
def q_params(self):
186+
pass
187+
188+
def half(self):
189+
return self.to(torch.float16)
190+
191+
def _get_to_kwargs(self, *args, **kwargs):
192+
device, dtype, _, memory_format = torch._C._nn._parse_to(*args, **kwargs)
193+
device = self.device if device is None else device
194+
dtype = self.dtype if dtype is None else dtype
195+
memory_format = (
196+
memory_format if memory_format is not None else torch.preserve_format
197+
)
198+
kwargs = {
199+
"device": device,
200+
"dtype": dtype,
201+
"memory_format": memory_format,
202+
}
203+
return kwargs
204+
205+
def _apply_fn_to_data(self, fn):
206+
pass
207+
208+
def _change_shape(self):
209+
pass
210+
211+
def __tensor_flatten__(self):
212+
pass
213+
214+
@classmethod
215+
def __tensor_unflatten__(
216+
cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride
217+
):
218+
pass
219+
220+
@classmethod
221+
def from_float(cls, input_float):
222+
pass
223+
224+
# __torch_function__ = torch._C._disabled_torch_function_impl
225+
226+
@classmethod
227+
def __torch_function__(cls, func, types, args=(), kwargs=None):
228+
kwargs = {} if kwargs is None else kwargs
229+
230+
if func is torch.nn.functional.linear:
231+
mat1, w_qtensor, bias = (
232+
args[0],
233+
args[1],
234+
args[2] if len(args) > 2 else None,
235+
)
236+
assert not w_qtensor.transposed
237+
return cls._quantized_op(mat1, w_qtensor, bias)
238+
239+
try:
240+
with torch._C.DisableTorchFunctionSubclass():
241+
return func(*args, **kwargs)
242+
except Exception:
243+
print(f"ERR: subclass doesn't implement {func}")
244+
245+
@classmethod
246+
def __torch_dispatch__(cls, func, types, args, kwargs):
247+
# two scenarios where we currently fall back to vanilla mm:
248+
# 1 - when tensor is on CPU: we are missing qmm for CPU, but we should have a CPU implementation
249+
# for consistency and to allow people to test
250+
# 2 - we're given non-floats - quantizing long to int8 is crazy
251+
if (
252+
func in [aten.mm.default, aten.addmm.default]
253+
and args[0].is_floating_point()
254+
and args[0].is_cuda
255+
):
256+
if func == aten.addmm.default:
257+
assert args[1].shape[-1] == args[2].shape[0], (
258+
f"need mat1 shape: {args[1].shape} final"
259+
f"dim to match mat2 shape: {args[2].shape} first dim "
260+
)
261+
mat1, w_qtensor, bias = (
262+
args[1],
263+
args[2],
264+
args[0],
265+
)
266+
else:
267+
assert args[0].shape[-1] == args[1].shape[0], (
268+
f"need mat1 shape: {args[0].shape} final dim"
269+
f"to match mat2 shape: {args[1].shape} first dim"
270+
)
271+
mat1, w_qtensor, bias = (
272+
args[0],
273+
args[1],
274+
None if len(args) == 2 else args[2],
275+
)
276+
# call the quantized op for the specific type
277+
# of quantized tensor subclass
278+
return cls._quantized_op(mat1, w_qtensor, bias)
279+
280+
if func is aten.detach.default:
281+
return return_and_correct_aliasing(
282+
func, args, kwargs, args[0]._apply_fn_to_data(torch.detach)
283+
)
284+
285+
if func is aten.clone.default:
286+
return return_and_correct_aliasing(
287+
func, args, kwargs, args[0]._apply_fn_to_data(torch.clone)
288+
)
289+
290+
if func is aten.t.default:
291+
args[0].transposed = not args[0].transposed
292+
new = args[0]._change_shape(args[0].shape[::-1])
293+
return return_and_correct_aliasing(func, args, kwargs, new)
294+
295+
if func is aten._to_copy.default:
296+
return return_and_correct_aliasing(
297+
func,
298+
args,
299+
kwargs,
300+
args[0].to(*args[1:], **kwargs)._apply_fn_to_data(torch.clone),
301+
)
302+
303+
132304
class GGMLInt4LinearWeight(QuantizedLinearWeightBase):
133305
"""
134306
A Tensor subclass that when applied to a weight used in a linear op/module,

0 commit comments

Comments
 (0)