Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

unwrap_tensor_subclass and nested tensor subclasses issue #515

Closed
vayuda opened this issue Jul 17, 2024 · 18 comments
Closed

unwrap_tensor_subclass and nested tensor subclasses issue #515

vayuda opened this issue Jul 17, 2024 · 18 comments
Assignees

Comments

@vayuda
Copy link
Collaborator

vayuda commented Jul 17, 2024

I'm noticing strange behavior when trying to create a tensor_subclass which holds another tensor_sub class.
Here is a minified repro: (add this to the bottom of torchao/dtypes/affine_quantized_tensor.py

@dataclass(frozen=True)
class TestLayout(LayoutType):
    scales: torch.Tensor
    zeros: torch.Tensor
    
    def post_process(self, input: torch.Tensor) -> torch.Tensor:
        return PlainAQTLayout.from_plain(input, self.scales, self.zeros, PlainLayoutType())
    
@register_layout_cls(TestLayout)
class TestAQTLayout(PlainAQTLayout):
    def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        return self.int_data.get_plain()[0], self.scale, self.zero_point
    
    @classmethod
    def from_plain(
        cls,
        int_data: torch.Tensor,
        scale: torch.Tensor,
        zero_point: torch.Tensor,
        layout_type: LayoutType,
    ):
        assert isinstance(layout_type, TestLayout)
        return cls(int_data, scale, zero_point, layout_type)
    
if __name__ == "__main__":
    from torchao.quantization.quant_api import quantize_, _get_linear_subclass_inserter 
    from torchao.utils import unwrap_tensor_subclass
    def test_quant():
        def apply_test_quant(weight):
            layout_type = TestLayout(torch.tensor([1.0]), torch.tensor([0.0]) )
            mapping_type = MappingType.ASYMMETRIC
            block_size = (1, 2)
            quant_min = 0
            quant_max = 8
            eps = torch.finfo(torch.float32).eps
            zero_point_dtype = torch.int32
            zero_point_domain = ZeroPointDomain.INT
            
            return to_affine_quantized(
                weight, mapping_type, block_size, torch.uint8, quant_min = quant_min,
                quant_max = quant_max, eps = eps, 
                zero_point_dtype=zero_point_dtype,
                zero_point_domain=zero_point_domain,
                layout_type=layout_type,
            )
    
        return _get_linear_subclass_inserter (apply_test_quant)
    class LinearModel(torch.nn.Module):
        def __init__(self):
            super(LinearModel, self).__init__()
            self.linear = torch.nn.Linear(32, 10)        
        def forward(self, x):
            return self.linear(x)
    test_input = torch.randn(32)
    m =LinearModel()
    m.forward(test_input)
    quantize_(m, test_quant())
    m.forward(test_input)
    m = unwrap_tensor_subclass(m)
    m = torch.compile(m, fullgraph=True)
    m.forward(test_input)

When running this code I get the following error when calling the model after compiling:

File "/home/swan/pytorch/ao/min_repro.py", line 199, in dequantize
    int_data, scale, zero_point = self.layout_tensor.get_plain()
                                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/swan/pytorch/ao/min_repro.py", line 854, in get_plain
    return self.int_data.get_plain()[0], self.scale, self.zero_point
           ^^^^^^^^^^^^^^^^^^^^^^^
torch._dynamo.exc.TorchRuntimeError: Failed running call_method forward(*(ParametrizedLinear(
  in_features=32, out_features=10, bias=True
  (parametrizations): ModuleDict(
    (weight): ParametrizationList(
      (0): UnwrapTensorSubclass()
    )
  )
), FakeTensor(..., size=(32,))), **{}):
'FakeTensor' object has no attribute 'get_plain'

There might be an issue with how unwrap_tensor_subclass handles cases where there are nested tensor_subclasses, but Im not sure why this doesn't work, but AffineQuantizedTensor is able to hold an AQTLayout tensor and work just fine.
You can check out https://github.com/vayuda/ao/blob/intx/torchao/dtypes/affine_quantized_tensor.py#L593 for what Im trying to do.

@drisspg
Copy link
Contributor

drisspg commented Jul 17, 2024

cc @bdhirsh

@bdhirsh
Copy link
Contributor

bdhirsh commented Jul 18, 2024

Hmm I pulled latest main of torchao and patched in that change, but when I run it I get a different error (before we hit compile):

(/home/hirsheybar/local/b/pytorch-env) [[email protected] ~/local/b/pytorch/ao (main)]$ TORCH_LOGS="graph_code,aot" python torchao/dtypes/affine_quantized_tensor.py
Traceback (most recent call last):
  File "/home/hirsheybar/local/b/pytorch/ao/torchao/dtypes/affine_quantized_tensor.py", line 903, in <module>
    quantize_(m, test_quant())
  File "/home/hirsheybar/local/b/pytorch/ao/torchao/quantization/quant_api.py", line 318, in quantize_
    _replace_with_custom_fn_if_matches_filter(
  File "/home/hirsheybar/local/b/pytorch/ao/torchao/quantization/quant_api.py", line 170, in _replace_with_custom_fn_if_matches_filter
    new_child = _replace_with_custom_fn_if_matches_filter(
  File "/home/hirsheybar/local/b/pytorch/ao/torchao/quantization/quant_api.py", line 166, in _replace_with_custom_fn_if_matches_filter
    model = replacement_fn(model)
  File "/home/hirsheybar/local/b/pytorch/ao/torchao/dtypes/affine_quantized_tensor.py", line 885, in apply_test_quant
    return to_affine_quantized(
  File "/home/hirsheybar/local/b/pytorch/ao/torchao/dtypes/affine_quantized_tensor.py", line 230, in from_float
    original_shape = input_float.shape
  File "/home/hirsheybar/local/b/pytorch/torch/nn/modules/module.py", line 1895, in __getattr__
    raise AttributeError(
AttributeError: 'Linear' object has no attribute 'shape'
(/home/hirsheybar/local/b/pytorch-env) [[email protected] ~/local/b/pytorch/ao (main)]$ TORCH_LOGS="graph_code,aot" python torchao/dtypes/affine_quantized_tensor.py
Traceback (most recent call last):
  File "/home/hirsheybar/local/b/pytorch/ao/torchao/dtypes/affine_quantized_tensor.py", line 903, in <module>
    quantize_(m, test_quant())
  File "/home/hirsheybar/local/b/pytorch/ao/torchao/quantization/quant_api.py", line 318, in quantize_
    _replace_with_custom_fn_if_matches_filter(
  File "/home/hirsheybar/local/b/pytorch/ao/torchao/quantization/quant_api.py", line 170, in _replace_with_custom_fn_if_matches_filter
    new_child = _replace_with_custom_fn_if_matches_filter(
  File "/home/hirsheybar/local/b/pytorch/ao/torchao/quantization/quant_api.py", line 166, in _replace_with_custom_fn_if_matches_filter
    model = replacement_fn(model)
  File "/home/hirsheybar/local/b/pytorch/ao/torchao/dtypes/affine_quantized_tensor.py", line 885, in apply_test_quant
    return to_affine_quantized(
  File "/home/hirsheybar/local/b/pytorch/ao/torchao/dtypes/affine_quantized_tensor.py", line 230, in from_float
    original_shape = input_float.shape
  File "/home/hirsheybar/local/b/pytorch/torch/nn/modules/module.py", line 1895, in __getattr__
    raise AttributeError(
AttributeError: 'Linear' object has no attribute 'shape'

@bdhirsh
Copy link
Contributor

bdhirsh commented Jul 18, 2024

But two things to call out here:

(1) @vayuda you actually do not need to use m = unwrap_tensor_subclasses(m) at all here if you are using torch.compile. You only need it if you want to use export / AOTInductor (for now - we want to make this not required in a few months, let me know if you are trying to get AOTI to work or if you only care about compile).

The story there is that unwrap_tensor_subclass() is using the parametrizations to take all tensor subclass parameters on the model (the affine quantized subclass), and desugar them into their inner plain tensors. This is needed for export/AOTI - since when using AOTInductor, you are in a no-python environment, so you cannot have python tensor subclass parameters in your state dict / model.

But for regular torch.compile usage, it's perfectly fine to have model parameters/buffers that are python tensor subclasses

(2) that parametrization code doesn't appear to play very well with dynamo today. Fixing that is tracked here: pytorch/pytorch#129682

@jerryzh168
Copy link
Contributor

Hmm I pulled latest main of torchao and patched in that change, but when I run it I get a different error (before we hit compile):

(/home/hirsheybar/local/b/pytorch-env) [[email protected] ~/local/b/pytorch/ao (main)]$ TORCH_LOGS="graph_code,aot" python torchao/dtypes/affine_quantized_tensor.py
Traceback (most recent call last):
  File "/home/hirsheybar/local/b/pytorch/ao/torchao/dtypes/affine_quantized_tensor.py", line 903, in <module>
    quantize_(m, test_quant())
  File "/home/hirsheybar/local/b/pytorch/ao/torchao/quantization/quant_api.py", line 318, in quantize_
    _replace_with_custom_fn_if_matches_filter(
  File "/home/hirsheybar/local/b/pytorch/ao/torchao/quantization/quant_api.py", line 170, in _replace_with_custom_fn_if_matches_filter
    new_child = _replace_with_custom_fn_if_matches_filter(
  File "/home/hirsheybar/local/b/pytorch/ao/torchao/quantization/quant_api.py", line 166, in _replace_with_custom_fn_if_matches_filter
    model = replacement_fn(model)
  File "/home/hirsheybar/local/b/pytorch/ao/torchao/dtypes/affine_quantized_tensor.py", line 885, in apply_test_quant
    return to_affine_quantized(
  File "/home/hirsheybar/local/b/pytorch/ao/torchao/dtypes/affine_quantized_tensor.py", line 230, in from_float
    original_shape = input_float.shape
  File "/home/hirsheybar/local/b/pytorch/torch/nn/modules/module.py", line 1895, in __getattr__
    raise AttributeError(
AttributeError: 'Linear' object has no attribute 'shape'
(/home/hirsheybar/local/b/pytorch-env) [[email protected] ~/local/b/pytorch/ao (main)]$ TORCH_LOGS="graph_code,aot" python torchao/dtypes/affine_quantized_tensor.py
Traceback (most recent call last):
  File "/home/hirsheybar/local/b/pytorch/ao/torchao/dtypes/affine_quantized_tensor.py", line 903, in <module>
    quantize_(m, test_quant())
  File "/home/hirsheybar/local/b/pytorch/ao/torchao/quantization/quant_api.py", line 318, in quantize_
    _replace_with_custom_fn_if_matches_filter(
  File "/home/hirsheybar/local/b/pytorch/ao/torchao/quantization/quant_api.py", line 170, in _replace_with_custom_fn_if_matches_filter
    new_child = _replace_with_custom_fn_if_matches_filter(
  File "/home/hirsheybar/local/b/pytorch/ao/torchao/quantization/quant_api.py", line 166, in _replace_with_custom_fn_if_matches_filter
    model = replacement_fn(model)
  File "/home/hirsheybar/local/b/pytorch/ao/torchao/dtypes/affine_quantized_tensor.py", line 885, in apply_test_quant
    return to_affine_quantized(
  File "/home/hirsheybar/local/b/pytorch/ao/torchao/dtypes/affine_quantized_tensor.py", line 230, in from_float
    original_shape = input_float.shape
  File "/home/hirsheybar/local/b/pytorch/torch/nn/modules/module.py", line 1895, in __getattr__
    raise AttributeError(
AttributeError: 'Linear' object has no attribute 'shape'

oh this is because we have a bit of change in quantize_ API, you'll need to add a call to _get_linear_subclass_inserter for apply_test_quant like the following:

return _get_linear_subclass_inserter(apply_int8_dynamic_activation_int8_weight_quant)

@vayuda
Copy link
Collaborator Author

vayuda commented Jul 18, 2024

oh this is because we have a bit of change in quantize_ API, you'll need to add a call to _get_linear_subclass_inserter for apply_test_quant like the following:

I updated the repro to reflect this change

You only need it if you want to use export / AOTInductor

So when i comment out unwrap_tensor_subclass
the script runs with torch.compile(backend="eager") but fails with the default inductor backend. By not unwrapping the subclass, there are some strange things happening that I have yet to figure out the source of (somehow dequantize is receiving a float32 input dtype instead of a uint8 maybe because it is using a fake tensor or something.)

Im not sure how an eager backend works, so dynamo has extracted the graph but doesn't do anything with it? I would like to use the full power of torch compile and fuse kernels together, but Im not interested in autograd. I dont need AOTI support for now.

I guess I can retest after pytorch/pytorch#129682 is fixed.

@bdhirsh
Copy link
Contributor

bdhirsh commented Jul 19, 2024

Interesting, I can repro too when I comment out unwrap_tensor_subclass. It also repros for me with torch.compile(m, backend="aot_eager") but not torch.compile(m, backend="eager").

@IvanKobzarev is going to take a look

@jerryzh168
Copy link
Contributor

jerryzh168 commented Jul 29, 2024

@vayuda for the following code

    def post_process(self, input: torch.Tensor) -> torch.Tensor:
        return PlainAQTLayout.from_plain(input, self.scales, self.zeros, PlainLayoutType())

is this trying to simulate that the int_data field will be another tensor subclass like UInt4Tensor?

@vayuda
Copy link
Collaborator Author

vayuda commented Jul 30, 2024

@vayuda for the following code

    def post_process(self, input: torch.Tensor) -> torch.Tensor:
        return PlainAQTLayout.from_plain(input, self.scales, self.zeros, PlainLayoutType())

is this trying to simulate that the int_data field will be another tensor subclass like UInt4Tensor?

Yes exactly. Here is the original version:

@dataclass(frozen=True)
class IntxLayoutType(LayoutType):
    bit_size: int
    pack_dim: int = -1
    
    def post_process(self, input: torch.Tensor) -> torch.Tensor:
        from torchao.prototype.intx import to_intx
        return to_intx(input, self.bit_size, self.pack_dim)

@jerryzh168
Copy link
Contributor

@vayuda can you check if this fix works? pytorch/pytorch#132096 you'll need to install pytorch from source (with the PR applied)

@IvanKobzarev
Copy link

IvanKobzarev commented Jul 30, 2024

@vayuda can you check if this fix works? pytorch/pytorch#132096 you'll need to install pytorch from source (with the PR applied)

@vayuda As a details - I tested compilation of original repro without ao unwrap_tensor_subclasses, for me after pytorch/pytorch#132096 compilation works fine, aot_autograd handles all nested subclass logic.

@vayuda
Copy link
Collaborator Author

vayuda commented Jul 30, 2024

@vayuda As a details - I tested compilation of original repro without ao unwrap_tensor_subclasses, for me after pytorch/pytorch#132096 compilation works fine, aot_autograd handles all nested subclass logic.

I built pytorch from your pr branch. I get a different error from last time now, but it is probably unrelated to this issue. Strangely, this repro errors, but my full code doesn't, so I'm still satisfied with the fix haha.

Error:

Traceback (most recent call last):
  File "/home/swan/pytorch/test/ao/torchao/dtypes/affine_quantized_tensor.py", line 896, in <module>
    quantize_(m, test_quant())
  File "/home/swan/pytorch/test/env/lib/python3.10/site-packages/torchao-0.3.1-py3.10.egg/torchao/quantization/quant_api.py", line 313, in quantize_
    _replace_with_custom_fn_if_matches_filter(
  File "/home/swan/pytorch/test/env/lib/python3.10/site-packages/torchao-0.3.1-py3.10.egg/torchao/quantization/quant_api.py", line 170, in _replace_with_custom_fn_if_matches_filter
    new_child = _replace_with_custom_fn_if_matches_filter(
  File "/home/swan/pytorch/test/env/lib/python3.10/site-packages/torchao-0.3.1-py3.10.egg/torchao/quantization/quant_api.py", line 166, in _replace_with_custom_fn_if_matches_filter
    model = replacement_fn(model)
  File "/home/swan/pytorch/test/env/lib/python3.10/site-packages/torchao-0.3.1-py3.10.egg/torchao/quantization/quant_api.py", line 257, in insert_subclass
    lin.weight = torch.nn.Parameter(constructor(lin.weight), requires_grad=False)
  File "/home/swan/pytorch/test/env/lib/python3.10/site-packages/torchao-0.3.1-py3.10.egg/torchao/quantization/quant_api.py", line 257, in insert_subclass
    lin.weight = torch.nn.Parameter(constructor(lin.weight), requires_grad=False)
AttributeError: 'Parameter' object has no attribute 'weight'

IvanKobzarev added a commit to pytorch/pytorch that referenced this issue Jul 31, 2024
…ed tensors ordering"


get_plain_tensors() should result in DFS of leaves.
The error was that plain tensors (leaves) on the same level were returned before subclasses plained tensors even if subclasses are before in "flatten" list.

Original issue from AO: pytorch/ao#515

Test:TBD, need to make asymetric subclass with dense tensors and subclasses

[ghstack-poisoned]
IvanKobzarev added a commit to pytorch/pytorch that referenced this issue Jul 31, 2024
…ing"


get_plain_tensors() should result in DFS of leaves.
The error was that plain tensors (leaves) on the same level were returned before subclasses plained tensors even if subclasses are before in "flatten" list.

Original issue from AO: pytorch/ao#515

Test:TBD, need to make asymetric subclass with dense tensors and subclasses

[ghstack-poisoned]
pytorchmergebot pushed a commit to pytorch/pytorch that referenced this issue Jul 31, 2024
get_plain_tensors() should result in DFS of leaves.
The error was that plain tensors (leaves) on the same level were returned before subclasses plained tensors even if subclasses are before in "flatten" list.

Original issue from AO: pytorch/ao#515

Test:TBD, need to make asymetric subclass with dense tensors and subclasses
Pull Request resolved: #132096
Approved by: https://github.com/bdhirsh
@raziel
Copy link
Contributor

raziel commented Aug 6, 2024

Hi, this seems a blocker of being able to work with torch subtypes in torch.export

A simple example that will result in:

aot_export is not currently supported with traceable tensor subclass.

import torch
import numpy as np
from torchao.dtypes.uint4 import UInt4Tensor


class ExampleModel(torch.nn.Module):
    def __init__(self):
        """Init"""
        super().__init__()
        x_uint8 = torch.randint(0, 16, (4, 8)).to(torch.uint8)
        self.x = UInt4Tensor.from_unpacked(x_uint8)

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        """Invoke."""
        transposed = self.x.view(8, 4).to(torch.uint8)
        return torch.add(input, transposed)


def main():
    input_tensor = torch.randn(8, 4)
    model = ExampleModel()
    with torch.no_grad():
        exported_program = torch.export.export(
            model.eval(), args=(), kwargs={"input":input_tensor},
        )
    print("===exported_program====")
    print(exported_program)

if __name__ == "__main__":
    main()

Let us know if there's something we can try (the above issue is in 2.4 release).

Also, let us know if we should log this into a separate issue, or this is the right place to track it down.

Cheers.

@jerryzh168
Copy link
Contributor

Hi, this seems a blocker of being able to work with torch subtypes in torch.export

Hi @raziel ! for torch.export (and AOTI) you'd still need to use unwrap_tensor_subclass (but for torch.compile with torch nightly version we don't need to use unwrap_tensor_subclass anymore) see the end of https://github.com/pytorch/ao/tree/main/torchao/quantization#quantization-flow-example section for docs.
Also this workaround will not longer be needed after pytorch/pytorch#129682 is fixed.

Looks like you are trying to use uint4 quantization, please take a look at our current doc for integer affine quantization here: https://github.com/pytorch/ao/tree/main/torchao/quantization#affine-quantization and here is a test for torch.export that uses torchao API and produces some quantization ops that can be lowered to other backends (uint4 is still represented with uint8 + qmin/qmax restriction):

class TestExport(unittest.TestCase):
@parameterized.expand(
list(itertools.product(TENSOR_SUBCLASS_APIS, COMMON_DEVICES, COMMON_DTYPES)),
)
@run_supported_device_dtype
def test_export(self, api, test_device, test_dtype):
if not TORCH_VERSION_AFTER_2_4:
self.skipTest("aoti compatibility requires 2.4+.")
logger.info(f"TestExport: {api}, {test_device}, {test_dtype}")
if test_dtype != torch.bfloat16:
self.skipTest(f"{api} in {test_dtype} is not support for aoti compilation yet")
m, k, n = 32, 64, 32
class test_model(nn.Module):
def __init__(self):
super().__init__()
self.lin1 = nn.Linear(k, n)
self.relu = nn.ReLU()
self.lin2 = nn.Linear(n, n)
def forward(self, x):
x = self.lin1(x)
x = self.relu(x)
x = self.lin2(x)
return x
x = torch.randn(m, k, dtype=test_dtype, device=test_device)
# get float reference
model = test_model().to(dtype=test_dtype, device=test_device).eval()
ref_f = model(x)
api(model)
# running model
ref = model(x)
# make sure it compiles
example_inputs = (x,)
from torch._export import capture_pre_autograd_graph
# TODO: export changes numerics right now, this is because of functionalization according to Zhengxu
# we can re-enable this after non-functional IR is enabled in export
# model = torch.export.export(model, example_inputs).module()
model = capture_pre_autograd_graph(model, example_inputs)
after_export = model(x)
self.assertTrue(torch.equal(after_export, ref))
if api is _int8da_int8w_api:
targets = [n.target for n in model.graph.nodes]
self.assertTrue(torch.ops.quant.choose_qparams_affine.default in targets)
self.assertTrue(torch.ops.quant.quantize_affine.default in targets)

Our UInt4 tensor subclass is an initial prototype example and it's not ready for use yet, we are actively working on sub-byte dtypes starting from #468

@raziel
Copy link
Contributor

raziel commented Aug 6, 2024

Hi @jerryzh168 !

Yeah I saw pytorch/pytorch#129682 . I originally posted there but since this seemed a more specific issue to our example I moved the post here.

Is there a timeline for pytorch/pytorch#129682 @bdhirsh ? Based on the flag it seems you're targeting not needing the unwrap_tensor_subclass until 2.6?

Looks like you are trying to use uint4 quantization
Right now we’re trying 4bit quant, and trying the existing torch.ao types, but in general we would like to get general support for Tensor subclasses since there’s other custom types we may need.

Regarding the use of unwrap_tensor_subclass, I believe we tried it but had some issues. @xiaoqiqi177 can provide more details.

Thanks!

@jerryzh168
Copy link
Contributor

Oh I see, makes sense, I'm not sure about the timeline yet, last I heard is a few weeks, cc @tugsbayasgalan can you give a timeline for pytorch/pytorch#129682

@xiaoqiqi177
Copy link

xiaoqiqi177 commented Aug 6, 2024

@jerryzh168 Regarding the use of unwrap_tensor_subclass, here are several problems I encountered for my use case:

  1. It seems that only Linear module is supported. However, I have some non-linear modules to be supported as well. is it safe to lift that constraint in unwrap_tensor_subclass?
  2. Just to make sure, unwrap_tensor_subclass for uint4 quant does not help save the pytorch model size right?
    e.g., without quantization, your test model in bf16 from
    def test_export(self, api, test_device, test_dtype):
    has weight p_lin1_weight: "bf16[32, 64]", p_lin2_weight: "bf16[32, 32]" in the exported torch program. However, with quantization, it has weight p_lin1_parametrizations_weight_original0: "i32[4, 8, 32, 4]", p_lin1_parametrizations_weight_original1: "bf16[32, 32, 2]", p_lin2_parametrizations_weight_original0: "i32[4, 8, 32, 4]", p_lin2_parametrizations_weight_original1: "bf16[32, 32, 2]" with group_size=32. Why do we have int32 in the new weight?
    Also, the exported program has quite a lot of long padding pad_1: "bf16[32, 1024]" = torch.ops.aten.pad.default(_to_copy_2, [0, 992]); _to_copy_2 = None. Is it expected?

(Let me know if I need to open a new thread for the details)

Overall, it would be nice to have a more general support for Tensor subclasses from torch.export, including uint4Tensor like the one in ao without unwrap_tensor_subclass such that the it could be exported to weights represented in torch.uint4.
With unwrap_tensor_subclass as a workaround, could we extend the support to non-linear, are we able to save the model size with 4bit quantization?

@jerryzh168
Copy link
Contributor

  • It seems that only Linear module is supported. However, I have some non-linear modules to be supported as well. is it safe to lift that constraint in unwrap_tensor_subclass?

yes, we can definitely support more than linear, please feel free to open an PR. just need to extend this:

def unwrap_tensor_subclass(model, filter_fn=None):

Just to make sure, unwrap_tensor_subclass for uint4 quant does not help save the pytorch model size right

so there is some preliminary work from intel to pack the int4 weights by default (pytorch/pytorch#129940, we haven't integrate this properly in torchao yet), but we also have intx work going on that packs the lower bits: #468 that's probably going to land soon, I think this is what we can start with. for export I'm not sure what kind of format you need, if it's better to have single pack/unpack ops, instead of seeing all the bit shifting operations in the IR, we could use

def _register_custom_op(lib):
to preserve the op during export.

Why do we have int32 in the new weight?
int32 weight is updated to int8 after pytorch/pytorch#129940 I think, it's available in pytorch nightly

Also, the exported program has quite a lot of long padding pad_1: "bf16[32, 1024]" = torch.ops.aten.pad.default(_to_copy_2, [0, 992]); _to_copy_2 = None. Is it expected?

not exactly sure where this comes from, we'd need to check the source code first.

(Let me know if I need to open a new thread for the details)

yeah I feel a new thread might make sense, the issue in this thread has been resolved actually.

@jerryzh168
Copy link
Contributor

yanbing-j pushed a commit to yanbing-j/ao that referenced this issue Dec 9, 2024
* make --device fast the default

* Update iOS.md (pytorch#517)

* Update iOS.md

* Update iOS.md

* Pip to pip3 (pytorch#504)

* remove macos-12 test

* pip to pip3

* break aoti CI jobs separately (pytorch#500)

* init

* fixes

* more fixes

* fixes

* fix

* fix

* bug fix

* add objcopy update

* suppress int8

* undefined variable

---------

Co-authored-by: Michael Gschwind <[email protected]>

* Support llama3 in chat in run.cpp  (pytorch#486)

* refactor chat runner in preparation for llama3

* add sketch for llama3 prompt template and move to returning tokens

* fix tiktoken

* fixes to chat

* add default llama_ver

* Add tests for quantize json, add cuda device specification and precision to cuda.json (pytorch#519)

* remove code for no KV Cache path (pytorch#527)

* Update ADVANCED-USERS.md (pytorch#529)

Update Advanced Users description to reflect changes in the repo since the description was initially created.

* runner-aoti on cuda (pytorch#531)

* runner-aoti on cuda

* transfer results back to CPU

* transfer results back to CPU

* runner-aoti on cuda

* Update runner_build.md (pytorch#530)

Update description of runner and build process in runner_build.md

* clean up runner code a little (pytorch#532)

* clean up runner code a little

* update

* update

* pull out generate loop in chat

* updates

* edit docs

* typo

* move int8 linear class and function into qops.py (pytorch#534)

* add dtype tests for runner-aoti + runner-et (pytorch#539)

* add dtype tests for runner-aoti + runner-et

* typo

* Quantized embedding (pytorch#536)

* move int8 linear class and function into qops.py

* move Quantized Embedding to qops.py

* Move Linear int4 to qops (pytorch#537)

* move int8 linear class and function into qops.py

* move Quantized Embedding to qops.py

* move int4 linear to qops

* Revert "add dtype tests for runner-aoti + runner-et (pytorch#539)" (pytorch#548)

This reverts commit a7a24577a65be67ac9ae4dc05452f35d9c49e5d1.

* fix generate for llama3 (pytorch#538)

* fix generate for llama3

* switch more things to C

* remove C++ header

* add delegation visualization instructions (pytorch#551)

* Add dtype runner aoti (pytorch#552)

* add dtype tests for runner-aoti + runner-et

* typo

* add dtype test runner-aoti

* test sdpa with fp16 (pytorch#553)

* test sdpa with fp16

* kv cache fp32

* typo

* update (pytorch#560)

* Only support newest versions of lm-eval (pytorch#556)

Summary:
remove support for lm-eval 0.3 to reduce the options we have

Test Plan:
CI

Reviewers:

Subscribers:

Tasks:

Tags:

* split cpu eval CI by dtype (pytorch#554)

* split cpu eval CI by dtype

* fix

* differentiate names with checks

* keep one name the same as old

* fix

* Removing duplicate HF issue message from README (pytorch#559)

Co-authored-by: Michael Gschwind <[email protected]>

* doc updates (pytorch#567)

* Add VM-safe MPS check

---------

Co-authored-by: Anthony Shoumikhin <[email protected]>
Co-authored-by: metascroy <[email protected]>
Co-authored-by: Nikita Shulga <[email protected]>
Co-authored-by: lucylq <[email protected]>
Co-authored-by: Jerry Zhang <[email protected]>
Co-authored-by: Jack-Khuu <[email protected]>
yanbing-j pushed a commit to yanbing-j/ao that referenced this issue Dec 9, 2024
* code beautification

* code beautification, move functions together

* make --device fast the default (pytorch#515)

* make --device fast the default

* Update iOS.md (pytorch#517)

* Update iOS.md

* Update iOS.md

* Pip to pip3 (pytorch#504)

* remove macos-12 test

* pip to pip3

* break aoti CI jobs separately (pytorch#500)

* init

* fixes

* more fixes

* fixes

* fix

* fix

* bug fix

* add objcopy update

* suppress int8

* undefined variable

---------

Co-authored-by: Michael Gschwind <[email protected]>

* Support llama3 in chat in run.cpp  (pytorch#486)

* refactor chat runner in preparation for llama3

* add sketch for llama3 prompt template and move to returning tokens

* fix tiktoken

* fixes to chat

* add default llama_ver

* Add tests for quantize json, add cuda device specification and precision to cuda.json (pytorch#519)

* remove code for no KV Cache path (pytorch#527)

* Update ADVANCED-USERS.md (pytorch#529)

Update Advanced Users description to reflect changes in the repo since the description was initially created.

* runner-aoti on cuda (pytorch#531)

* runner-aoti on cuda

* transfer results back to CPU

* transfer results back to CPU

* runner-aoti on cuda

* Update runner_build.md (pytorch#530)

Update description of runner and build process in runner_build.md

* clean up runner code a little (pytorch#532)

* clean up runner code a little

* update

* update

* pull out generate loop in chat

* updates

* edit docs

* typo

* move int8 linear class and function into qops.py (pytorch#534)

* add dtype tests for runner-aoti + runner-et (pytorch#539)

* add dtype tests for runner-aoti + runner-et

* typo

* Quantized embedding (pytorch#536)

* move int8 linear class and function into qops.py

* move Quantized Embedding to qops.py

* Move Linear int4 to qops (pytorch#537)

* move int8 linear class and function into qops.py

* move Quantized Embedding to qops.py

* move int4 linear to qops

* Revert "add dtype tests for runner-aoti + runner-et (pytorch#539)" (pytorch#548)

This reverts commit a7a24577a65be67ac9ae4dc05452f35d9c49e5d1.

* fix generate for llama3 (pytorch#538)

* fix generate for llama3

* switch more things to C

* remove C++ header

* add delegation visualization instructions (pytorch#551)

* Add dtype runner aoti (pytorch#552)

* add dtype tests for runner-aoti + runner-et

* typo

* add dtype test runner-aoti

* test sdpa with fp16 (pytorch#553)

* test sdpa with fp16

* kv cache fp32

* typo

* update (pytorch#560)

* Only support newest versions of lm-eval (pytorch#556)

Summary:
remove support for lm-eval 0.3 to reduce the options we have

Test Plan:
CI

Reviewers:

Subscribers:

Tasks:

Tags:

* split cpu eval CI by dtype (pytorch#554)

* split cpu eval CI by dtype

* fix

* differentiate names with checks

* keep one name the same as old

* fix

* Removing duplicate HF issue message from README (pytorch#559)

Co-authored-by: Michael Gschwind <[email protected]>

* doc updates (pytorch#567)

* Add VM-safe MPS check

---------

Co-authored-by: Anthony Shoumikhin <[email protected]>
Co-authored-by: metascroy <[email protected]>
Co-authored-by: Nikita Shulga <[email protected]>
Co-authored-by: lucylq <[email protected]>
Co-authored-by: Jerry Zhang <[email protected]>
Co-authored-by: Jack-Khuu <[email protected]>

* add unpacking support (pytorch#525)

* add unpacking support

* fix typos and linter

* perform parallel prefill when possible (pytorch#568)

* perform parallel prefill when possible

* typo

* disable hack

* remove print

* remove debug messages which prevent export

* fixes

* stream results in generate.py (pytorch#571)

* remove logging interfering with export

---------

Co-authored-by: Anthony Shoumikhin <[email protected]>
Co-authored-by: metascroy <[email protected]>
Co-authored-by: Nikita Shulga <[email protected]>
Co-authored-by: lucylq <[email protected]>
Co-authored-by: Jerry Zhang <[email protected]>
Co-authored-by: Jack-Khuu <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

7 participants