Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unified AffineQuantizedTensor subclass #214

Merged
merged 2 commits into from
May 7, 2024
Merged

Conversation

jerryzh168
Copy link
Contributor

@jerryzh168 jerryzh168 commented May 4, 2024

Summary:
Creatd a AffineQuantizedTensor subclass that works for both weight and input (for dynamic quantization), for all granularities (levering the recently added choose_qparams_affine, quantize_affine and dequantize_affine ops)

only verified for 8da4w for executorch (q/dq representation) right now, we can make it work for other types of quantization (mostly the operator dispatching part) later

Test Plan:
python test/quantization/test_quant_api.py -k test_quantized_tensor_subclass_8da4w

Reviewers:

Subscribers:

Tasks:

Tags:

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label May 4, 2024
@jerryzh168 jerryzh168 changed the title Unified tensor subclass Unified QuantizedTensor subclass May 4, 2024
torchao/quantization/subclass.py Outdated Show resolved Hide resolved
torchao/quantization/subclass.py Outdated Show resolved Hide resolved
torchao/quantization/subclass.py Outdated Show resolved Hide resolved
torchao/quantization/subclass.py Show resolved Hide resolved
torchao/quantization/subclass.py Outdated Show resolved Hide resolved
torchao/quantization/subclass.py Outdated Show resolved Hide resolved
@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs):
# two scenarios where we currently fall back to vanilla mm:
# 1 - when tensor is on CPU: we are missing qmm for CPU, but we should have a CPU implementation
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I only see is_cuda version, did you mean to add cpu when we don't have qmm?

Copy link
Contributor Author

@jerryzh168 jerryzh168 May 6, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so right now the version is just for 8da4w + executorch. we actually haven't added cpu or cuda version. in the future:

  1. we'll add cpu/cuda version (int4mm etc.)
  2. we'll need to hide the 8da4w executorch version under things like layouts (we also have multiple impl for cpu kernel as Michael mentioned), so it will be something like
    cpu device + et laytout --> gives current 8da4w executorch representation
    cpu device + avx layout --> gives optimized kernel for 8da4w in avx cpu etc.
    cuda device + some layout --> gives cuda kernel

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh I will need to change is_cuda to is_cpu here actually

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cpu device + et laytout --> gives current 8da4w executorch representation
cpu device + avx layout --> gives optimized kernel for 8da4w in avx cpu etc

Barring some quant parameters, isn't the weight representation kernel independent i.e. weight elements will be int8 with values in [qmin, qmax] range?

The actual packed weight layout can be quite different and may not be able to fit in the tensor representation

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to have efficient executorch in eager (e.g. for cpu or cuda), we would like to produce packed weight directly in this quantized Tensor I think.

but int8 with qmin/qmax could be an intermediate "unpacked" format that can serve as a standard representation that we can use to translate between different packing format, e.g.

  1. first the tensor is packed for some cpu efficient kernel
  2. then user calls tensor.to("cuda") (or model.to("cuda")), we can first unpack the tensor to int8 + qmin/qmax representation, and then repack to the cuda format

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kimishpatel - I don't quite follow, why is using a layout for a particular bit packing structure not the right abstraction?

Sure, some bit packing layouts are currently very useful to certain accelerators, but perhaps some of them translate to others? Maybe there's some similarities between them? I think it's also still ongoing research which one is the best (or at least it seems like it is).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

**@kimishpatel - I don't quite follow, why is using a layout for a particular bit packing structure not the right abstraction?

Because, we will continue to introduce execution dependent packing layouts as part of this tensor subclass. e.g. 5 year older x86/ARM want to pack int4/int8 tensor with A, newer ones may want it in way B and future ones want in way C (not to say about different layouts that different classes of GPUs may want). So this class assuming the responsibility of encoding all the different packing layouts does not seem right to me.

I can be convinced otherwise, but it feel that it would be cleaner if there was a canonical representation that this class assumes and then it can be either extended, for custom layouts/packing, or introduce custom ops that do custom packing.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kimishpatel - yes, but that's ok though, no? I mean, wouldn't you want to pick the packing layout that's best for the available CPU? Also, we give these layouts names. You can still allocate a GPU specific layout on your raspberry PI. It just might not run (at least not efficiently) locally.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The bitpacking layout seems independent from the allocator used to allocate the memory. You can use malloc to allocate a bitpacking layout that's optimized for CPU or you can use cuda malloc to allocate a bitpacking layout that's optimized for the Skylake CPU. But maybe this is an overly naive view?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Personally I feel it is trying to do too much because the layout you want is maybe tied to kernel you wrote. Then say the output of your op is also AQT which can also have a special layout. Which means every op that you register for AQT must handle all the packed layouts. Like in another comment you had an example of gelu op being registered for nf4. Now if nf4 tensor had multiple packings/layouts then the registered op will have to interpret each of them. (Maybe you need a canonical representation where to/from tranforms are possible)

So the question is do you allow for arbitrary packing in AQT via some attribute + to/from tranformations for canonical representation OR
decouple it in a way that any special consideration like packing is handled elsewhere.

And note that this has to play well with export and other systems as well, else we are just solving eager mode problem.

Copy link
Member

@msaroufim msaroufim left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

did not review but just unblocking

EDIT: Will review in at most 2h

input_target_dtype = torch.int8
input_quant_func = lambda x: QuantizedTensor.from_float(x, input_mapping_type, get_per_token_block_size(x), input_target_dtype)

m = M().eval()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we rename this class to ToyLinearModel? Also I noticed that class has to.(torch.float) calls I believe the intent is float32 in which case could we be more explicit?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure.

in which case could we be more explicit?

can you clarify what do you mean by being more explicit? do you mean to put float32 in name?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you mean to put float32 in name?

Yes just do a to(torch.float32) instead

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you mean do this for m instance? instead of doing this inside ToyLinearModel

from torchao.quantization.quant_primitives import MappingType
import copy

# weight settings
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you expect this list to be fairly constant over time? Should we consider some dataclass like object for the config?

Also things like quant_min and quant_max should be derivable assuming the target_dtype is int8

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we could have some helper functions to give us everything here in the future I think

for quant_min/quant_max, right now this is using 4 bit, we don't have torch.int4 right now, but we could add that

test/quantization/test_quant_api.py Outdated Show resolved Hide resolved
torchao/quantization/subclass.py Show resolved Hide resolved
kwargs.get("layout") if kwargs.get("layout", False) else int_data.layout
)
if dtype is None:
dtype = scale.dtype
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do this? Wouldn't it be better to just always expect a dtype instead of implicitly deriving it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is copied from other code, yeah maybe just make this required is easier, I can change it

Copy link
Contributor Author

@jerryzh168 jerryzh168 May 7, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as I'm trying to change this, I found that this is used in kwargs, maybe it's better to keep it this way: https://github.com/pytorch/ao/blob/main/torchao/quantization/subclass.py#L90

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, what is the dtype meant to convey? If this is a the dtype of the quantized tensor, than shouldnt int_data tensor already capture that?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is dtype for the "external representation" of the quantized tensor, e.g. if it's quantized from bfloat16, then the dtype will be bfloat16

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do you need that? You already have scale. YOu can look that up

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

scale dtype may not be guaranteed to match the original fp dtype right? but this external dtype stuff may need more discussion I think, but that's kind of copied from nf4 tensor (which is a floating point tensor) I think, maybe we want something else for integer quantized tensor

torchao/quantization/subclass.py Outdated Show resolved Hide resolved
Summary:
Creatd a `QuantizedTensor` subclass that works for both weight and input (for dynamic quantization), for all granularities (levering the recently added choose_qparams_affine, quantize_affine
and dequantize_affine ops)

only verified for 8da4w right now, we can make it work for other types of quantization (mostly the operator dispatching part) later

Test Plan:
python test/quantization/test_quant_api.py -k test_quantized_tensor_subclass_8da4w

Reviewers:

Subscribers:

Tasks:

Tags:
@jerryzh168 jerryzh168 changed the title Unified QuantizedTensor subclass Unified AffineQuantizedTensor subclass May 7, 2024
@jerryzh168 jerryzh168 merged commit f0bdc8f into pytorch:main May 7, 2024
9 of 12 checks passed
@@ -134,14 +139,21 @@ def __torch_function__(cls, func, types, args=(), kwargs=None):

@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs):
# Note: we only added cpu path here for 8da4w, this is for executorch, in the future
# 1. we'll add cpu/cuda version (int4mm etc.)
# 2. we'll need to hide the 8da4w executorch version under things like layouts (we also have multiple impl for cpu kernel as Michael mentioned), so it will be something like
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This definitely worries me. What is et layout really? And I dont feel we have put enough thoughts on whether hiding it like this, to dispatch to weight only quantization vs. weight + dyn act quantization, is a good idea

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh, et layout just means we don't use packed weights + dispatch to int4mm kernels (as what we are doing for cpu and cuda device), since executorch does not have their own device, so we are reusing cpu as the device for executorch path, and thus we need some extra information (layout or something else) to generate the q/dq format the executorch is consuming.

Copy link
Contributor Author

@jerryzh168 jerryzh168 May 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we are dispatching on weight only quant v.s. weight + dyn act quant, not sure what you mean by hiding actually, I'm putting up another PR soon, hope it will make this clearer

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What I mean by hiding is this: you are trying to dispatch to different implementation based on device type arg and because for et there is none, we are introducing the notion of layout which to me doesnt make sense.

In fact even device type based dispatch can be used only when the implementation of the op is really device specific and now hw specific. Say for example we will support weight only quantization on ARM architectures. Then can I say that all I need to do is to change cpu impl of int4packed_mm, by putting if __arm__ kind of ifdefs, and it would work? I dont think that is true because the notion of packed weights can be architecture specific and there is no canonical representation of that, unlike say torch.Tensor.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you want to do subclass based approach, I would expect each different flavor to inherit perhaps from AffineQuantizedTensor class, e.g.

class X86WeightOnlyQuantizedLinearWeightTensor(AffineQuantizedTensor):
    def __init__(...):
         do packing here
    def __torch_dispatch__(....):
          dispatch to x86 impl

Base affine quantized tensor subclass. When the from_float method is used,
to create an instance of any AffineQuantizedTensor

The shape and dtype of the tensor subclass represent how the tensor subclass looks externally,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does externally mean here? Do you mean to say in export?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What happens when you call dtype and shape on an instance of AffineQuantizedTensor. I agree the wording is confusing.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, this means what it looks like for autograd engine mostly I think

**kwargs
):
kwargs["device"] = int_data.device
kwargs["layout"] = (
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the purpose of this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is copy pasted from other code, not exactly sure if we need this or not

input_quant_func: Optional[Callable] = None,
dtype=None,
*args,
**kwargs
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we not using args/kwargs?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh this is used in __new__ I think

def __torch_function__(cls, func, types, args=(), kwargs=None):
kwargs = {} if kwargs is None else kwargs

if func is torch.nn.functional.linear:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could it be aten.linear as well?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in __torch_function__ we are only catching torch functions, so we won't see aten.linear I think

HDCharles added a commit that referenced this pull request May 8, 2024
* Composing autoquant with compile

Summary:

this PR rewrites how torchao.autoquant works so that it works with
torch.compile. Previously you had to do:

torchao.autoquant(model, input)
mod=torch.compile(model)
mod(input)

now you can do
torchao.autoquant(torch.compile(model))
model(input)

The new method works with/without compile. Also this is BC so the old
path also works.

We use a forward_prehook to intercept the model call before
torch.compile tracing occurs at which point we do the autoquantization
and clean up all remaining hooks before passing things off to the
normal torch.compile tracing functionality.

note: in the case of multiple inputs, you can also do:

model.forward_log_only(input) to run the model forward with autoquant
shape logging and prevent the torch.compile tracing/autoquant
quantization from occuring.

Test Plan: python test/integration/test_integration.py -k "autoquant"

Reviewers:

Subscribers:

Tasks:

Tags:

* Fused DoRA kernels (#216)

* add dora kernels

* allowing error_on_unseen in autoquant func

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

* Unified AffineQuantizedTensor subclass (#214)

Summary:
Creatd a `AffineQuantizedTensor` subclass that works for both weight and input (for dynamic quantization), for all granularities (levering the recently added choose_qparams_affine, quantize_affine
and dequantize_affine ops)

only verified for 8da4w right now, we can make it work for other types of quantization (mostly the operator dispatching part) later

Test Plan:
python test/quantization/test_quant_api.py -k test_quantized_tensor_subclass_8da4w

Reviewers:

Subscribers:

Tasks:

Tags:

Co-authored-by: Mark Saroufim <[email protected]>

* add expecttest to requirements.txt (#225)

* add expecttest to requirements.txt

* update

* Install dev-requirements.txt in doc build (#224)

Install dev-requirements.txt

---------

Co-authored-by: Mark Saroufim <[email protected]>

* Fix an error in subclass impl (#226)

Summary:
Accidently changed the device check code for old subclass instead of the new one, forgot to fix before landing

Test Plan:
CI

Reviewers:

Subscribers:

Tasks:

Tags:

* update readme.md

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

* trying to fix the error in CI on cleanup hooks

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

* correct docs

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

* Some follow up fixes for quant primitives (#220)

Summary:
att

Test Plan:
python test/quantization/test_quant_primitives.py -k test_raises

Reviewers:

Subscribers:

Tasks:

Tags:

* Composing autoquant with compile

Summary:

this PR rewrites how torchao.autoquant works so that it works with
torch.compile. Previously you had to do:

torchao.autoquant(model, input)
mod=torch.compile(model)
mod(input)

now you can do
torchao.autoquant(torch.compile(model))
model(input)

The new method works with/without compile. Also this is BC so the old
path also works.

We use a forward_prehook to intercept the model call before
torch.compile tracing occurs at which point we do the autoquantization
and clean up all remaining hooks before passing things off to the
normal torch.compile tracing functionality.

note: in the case of multiple inputs, you can also do:

model.forward_log_only(input) to run the model forward with autoquant
shape logging and prevent the torch.compile tracing/autoquant
quantization from occuring.

Test Plan: python test/integration/test_integration.py -k "autoquant"

Reviewers:

Subscribers:

Tasks:

Tags:

* allowing error_on_unseen in autoquant func

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

* update readme.md

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

* trying to fix the error in CI on cleanup hooks

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

* correct docs

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

---------

Co-authored-by: jeromeku <[email protected]>
Co-authored-by: Jerry Zhang <[email protected]>
Co-authored-by: Mark Saroufim <[email protected]>
Co-authored-by: Svetlana Karslioglu <[email protected]>
m = ToyLinearModel().eval()
m_copy = copy.deepcopy(m)
example_inputs = m.example_inputs()
m.linear1.weight = torch.nn.Parameter(AffineQuantizedTensor.from_float(m.linear1.weight, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, input_quant_func=input_quant_func), requires_grad=False)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of AffineQuantizedTensor.from_float could we have a factory function similar to to_nf4 in

def to_nf4(tensor, block_size: int = 64, scaler_block_size: int = 256):
?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is not the final UI, we will need to integrate this with things like https://github.com/pytorch/ao/blob/main/torchao/quantization/quant_api.py#L129 as well, I'm not sure how to_nf4 would fit in there, but I think we could discuss further on how to expose this to end users

m = ToyLinearModel().eval()
m_copy = copy.deepcopy(m)
example_inputs = m.example_inputs()
m.linear1.weight = torch.nn.Parameter(AffineQuantizedTensor.from_float(m.linear1.weight, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, input_quant_func=input_quant_func), requires_grad=False)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is input_quant_func?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

input_quant_func is for quantizing input (in dynamic quantization)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I dont think this is a good idea. What if you have an op that takes in two tensor plus quantized weight. Now you will need to input_quant_funcs? We are expanding the execution semantics of the op, in this case linear, for which this class's __torch_dispatch__ is invoked. If you wanted to do this, I would have used another tensor subclass AffineQuantizedDynamicLinear whose semantic is more clear in terms of how it will override linear specifically. But I dont know if we can truly generalize this.

At high level, it is still not clear how we will use tensor subclass to represent quantized compute

@@ -136,7 +136,7 @@ def _get_reduction_params(block_size, input_size):

def quantize_affine(
input: torch.Tensor,
block_size: List[int],
block_size: Tuple[int, ...],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why this change?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think a Tuple is better for the block_size argument, since it's immutable

eps (Optional[float]: minimum scale
scale_dtype (torch.dtype): dtype for scales
zero_point_dtype (torch.dtype): dtype for zero_points
eps (Optional[float]): minimum scale, if not provided, default to eps of input.dtype
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where can the user find eps of input.dtype?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just running torch.finfo(dtype).eps I think, I didn't find a table showing this

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm, it's a bit of a nit, but you could just add that snippet like

minimum scale, if not provided, default to torch.finfo(input.dtype).eps

)

@classmethod
def from_float(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's make this a factory function instead of static class method

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure how the factory function is going to connect with the subclass tensor swap API

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm, it's really just a nit to make it align with to_nf4 for NF4Tensor. I don't think it'll change behavior much.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel nf4 should align with this, since I'm not sure how to plug in with the current module modification API: https://github.com/pytorch/ao/blob/main/torchao/quantization/quant_api.py#L124 if we have to_nf4, to_int4, to_int8 etc. instead of a single unified "from_float"

args[1],
args[2] if len(args) > 2 else None,
)
if weight_qtensor.input_quant_func is not None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't this something you'd call explicitly within the higher order wrapper or so? Or inject with an FX pass etc.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for the very short question, only had little time to take a look

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we discussed this one in a different place I think

input_quant_func: Optional[Callable] = None,
dtype=None,
*args,
**kwargs
Copy link
Contributor

@cpuhrsch cpuhrsch May 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These *args and **kwargs are impossible to document. Let's please remove them.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, will add a TODO

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unless it's not possible :(

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it should be possible

if (func is aten.detach.default or
func is aten.clone.default or
func is aten._to_copy.default):
return return_and_correct_aliasing(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are the kind of things that make me wonder if AffineQuantizedTensor via tensor subclass is the right approach.

Are we trying to unify all kind of quant techniques behind this class? If so, I am not sure if thats a good idea

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, I think we should still have a legokit of ops for quantization (e.g. (mixed) integer matmul, affine quant/dequant primitives, etc.), but at least historically we've used dtypes for lower precision data types (e.g. bfloat16). AQT is a way to provide a unified interface for dtypes based on affine quantization. People don't have to use it of course, but it seems convenient for like trying a new op (for example see

@torchao.dtypes.nf4tensor.implements([torch.ops.aten.gelu.default])
) or just printing the content of a quantized tensor.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. This is similar to what @jerryzh168 was suggesting about allowing users to override dispatch.

I think my main sticking point was this: the dynamic quantization seems like an aberration to the relatively ok patterns of "registering" in op for the tensor subclass as you have shown in the gelu example. Because in dynamic quant, at least for linear, we are saying I want to quantized input activation, but if i am doing weight only quant then i dont want to quantize input activation. And we are trying to encode this information in AQT via layout and other things.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kimishpatel - Yes, I agree. We should separate "injection" from the dtype of affine integer quantization.

Dynamic quantization is the process of recalculating the range and zero point of an affine integer quantized tensor on every call. It's static if you just use precalculated ranges and zero points. But morally the dtype is still an affine integer quantized tensor. It's just a particular way of representing real numbers in memory.

"injection" can be done in multiple ways. We could use tensor subclasses and reparameterize to e.g. replace F.linear with a couple ops. Or you can swap modules. Some users might even just modify the model itself and change the dtype of layers (like model.half()). Then we could also use FX passes and such to update the IR and inject this dtype. But that's separate.

Copy link
Contributor

@kimishpatel kimishpatel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SOrry for having reviewed this late, but it is not clear what problems AffineQuantizedTensor class will solve

dbyoung18 pushed a commit to dbyoung18/ao that referenced this pull request Jul 31, 2024
Summary:
Creatd a `AffineQuantizedTensor` subclass that works for both weight and input (for dynamic quantization), for all granularities (levering the recently added choose_qparams_affine, quantize_affine
and dequantize_affine ops)

only verified for 8da4w right now, we can make it work for other types of quantization (mostly the operator dispatching part) later

Test Plan:
python test/quantization/test_quant_api.py -k test_quantized_tensor_subclass_8da4w

Reviewers:

Subscribers:

Tasks:

Tags:

Co-authored-by: Mark Saroufim <[email protected]>
dbyoung18 pushed a commit to dbyoung18/ao that referenced this pull request Jul 31, 2024
* Composing autoquant with compile

Summary:

this PR rewrites how torchao.autoquant works so that it works with
torch.compile. Previously you had to do:

torchao.autoquant(model, input)
mod=torch.compile(model)
mod(input)

now you can do
torchao.autoquant(torch.compile(model))
model(input)

The new method works with/without compile. Also this is BC so the old
path also works.

We use a forward_prehook to intercept the model call before
torch.compile tracing occurs at which point we do the autoquantization
and clean up all remaining hooks before passing things off to the
normal torch.compile tracing functionality.

note: in the case of multiple inputs, you can also do:

model.forward_log_only(input) to run the model forward with autoquant
shape logging and prevent the torch.compile tracing/autoquant
quantization from occuring.

Test Plan: python test/integration/test_integration.py -k "autoquant"

Reviewers:

Subscribers:

Tasks:

Tags:

* Fused DoRA kernels (pytorch#216)

* add dora kernels

* allowing error_on_unseen in autoquant func

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

* Unified AffineQuantizedTensor subclass (pytorch#214)

Summary:
Creatd a `AffineQuantizedTensor` subclass that works for both weight and input (for dynamic quantization), for all granularities (levering the recently added choose_qparams_affine, quantize_affine
and dequantize_affine ops)

only verified for 8da4w right now, we can make it work for other types of quantization (mostly the operator dispatching part) later

Test Plan:
python test/quantization/test_quant_api.py -k test_quantized_tensor_subclass_8da4w

Reviewers:

Subscribers:

Tasks:

Tags:

Co-authored-by: Mark Saroufim <[email protected]>

* add expecttest to requirements.txt (pytorch#225)

* add expecttest to requirements.txt

* update

* Install dev-requirements.txt in doc build (pytorch#224)

Install dev-requirements.txt

---------

Co-authored-by: Mark Saroufim <[email protected]>

* Fix an error in subclass impl (pytorch#226)

Summary:
Accidently changed the device check code for old subclass instead of the new one, forgot to fix before landing

Test Plan:
CI

Reviewers:

Subscribers:

Tasks:

Tags:

* update readme.md

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

* trying to fix the error in CI on cleanup hooks

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

* correct docs

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

* Some follow up fixes for quant primitives (pytorch#220)

Summary:
att

Test Plan:
python test/quantization/test_quant_primitives.py -k test_raises

Reviewers:

Subscribers:

Tasks:

Tags:

* Composing autoquant with compile

Summary:

this PR rewrites how torchao.autoquant works so that it works with
torch.compile. Previously you had to do:

torchao.autoquant(model, input)
mod=torch.compile(model)
mod(input)

now you can do
torchao.autoquant(torch.compile(model))
model(input)

The new method works with/without compile. Also this is BC so the old
path also works.

We use a forward_prehook to intercept the model call before
torch.compile tracing occurs at which point we do the autoquantization
and clean up all remaining hooks before passing things off to the
normal torch.compile tracing functionality.

note: in the case of multiple inputs, you can also do:

model.forward_log_only(input) to run the model forward with autoquant
shape logging and prevent the torch.compile tracing/autoquant
quantization from occuring.

Test Plan: python test/integration/test_integration.py -k "autoquant"

Reviewers:

Subscribers:

Tasks:

Tags:

* allowing error_on_unseen in autoquant func

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

* update readme.md

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

* trying to fix the error in CI on cleanup hooks

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

* correct docs

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

---------

Co-authored-by: jeromeku <[email protected]>
Co-authored-by: Jerry Zhang <[email protected]>
Co-authored-by: Mark Saroufim <[email protected]>
Co-authored-by: Svetlana Karslioglu <[email protected]>
yanbing-j pushed a commit to yanbing-j/ao that referenced this pull request Dec 9, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants