Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

An error was encountered setting torch._dynamo.decorators.mark_unbacked #1790

Open
songh11 opened this issue Feb 27, 2025 · 2 comments
Open
Labels
quantize question Further information is requested triaged

Comments

@songh11
Copy link

songh11 commented Feb 27, 2025

Hello, I want batch set up to be dynamic and I use torch._dynamo.mark_dynamic to set it. But I found that recompile is triggered when batch is 1 and 2. Then I used torch._dynamo.decorators.mark_unbacked but it quantizes incorrectly. Can you look at this problem?

My environment:
torch: 2.5.0
torchao: 0.8.0

This is the minimum repetition code

import torch


from torchao.quantization.quant_api import (
    quantize_,
    int8_dynamic_activation_int8_weight
)
torch._logging.set_logs(recompiles=True, recompiles_verbose = True)

class MyModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = torch.nn.Linear(128, 256)

    def forward(self, x):
        return self.linear(x)

model = MyModel().cuda().eval()
model = torch.compile(model, fullgraph=True)

# quant
quantize_(model, int8_dynamic_activation_int8_weight())

example_input = torch.randn(2, 64, 128).cuda()
torch._dynamo.decorators.mark_unbacked(example_input, 0)
torch._dynamo.mark_dynamic(example_input, 0)
model(example_input)

x1 = torch.randn(1, 64, 128).cuda()
x2 = torch.randn(2, 64, 128).cuda()

print("input shape: ", x1.shape)
model(x1)
print("input shape: ", x2.shape)
model(x2)

This is the error log

W0227 10:58:38.277000 1279033 torch/fx/experimental/symbolic_shapes.py:5124] [0/0] failed during evaluate_expr(Ne(u0, 1), hint=None, size_oblivious=False, forcing_spec=False E0227 10:58:38.277000 1279033 torch/fx/experimental/recording.py:298] [0/0] failed while running evaluate_expr(*(Ne(u0, 1), None), **{'fx_node': False}) Traceback (most recent call last): File "/root/picasso/songh/my_venv/py310/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 2132, in run_node return node.target(*args, **kwargs) File "/root/picasso/songh/my_venv/py310/lib/python3.10/site-packages/torchao/utils.py", line 433, in _dispatch__torch_function__ return cls._ATEN_OP_OR_TORCH_FN_TABLE[func](func, types, args, kwargs) File "/root/picasso/songh/my_venv/py310/lib/python3.10/site-packages/torchao/utils.py", line 412, in wrapper return func(f, types, args, kwargs) File "/root/picasso/songh/my_venv/py310/lib/python3.10/site-packages/torchao/quantization/linear_activation_quantized_tensor.py", line 126, in _ return weight_tensor._quantized_linear_op(input_tensor, weight_tensor, bias) File "/root/picasso/songh/my_venv/py310/lib/python3.10/site-packages/torchao/quantization/linear_activation_quantized_tensor.py", line 83, in _quantized_linear_op quantized_tensor = input_quant_func(input_tensor, **quant_kwargs) File "/root/picasso/songh/my_venv/py310/lib/python3.10/site-packages/torchao/quantization/quant_api.py", line 800, in _int8_symm_per_token_reduced_range_quant return to_affine_quantized_intx( File "/root/picasso/songh/my_venv/py310/lib/python3.10/site-packages/torchao/dtypes/affine_quantized_tensor.py", line 250, in from_hp_to_intx scale, zero_point = choose_qparams_affine( File "/root/picasso/songh/my_venv/py310/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context return func(*args, **kwargs) File "/root/picasso/songh/my_venv/py310/lib/python3.10/site-packages/torchao/quantization/quant_primitives.py", line 738, in choose_qparams_affine return _choose_qparams_affine( File "/root/picasso/songh/my_venv/py310/lib/python3.10/site-packages/torch/_ops.py", line 1116, in __call__ return self._op(*args, **(kwargs or {})) File "/root/picasso/songh/my_venv/py310/lib/python3.10/site-packages/torchao/quantization/quant_primitives.py", line 840, in _choose_qparams_affine shape_for_reduction, reduction_dims = _get_reduction_params( File "/root/picasso/songh/my_venv/py310/lib/python3.10/site-packages/torchao/quantization/quant_primitives.py", line 229, in _get_reduction_params if block_size[i] != input_size[i] and block_size[i] > 1: File "/root/picasso/songh/my_venv/py310/lib/python3.10/site-packages/torch/__init__.py", line 680, in __bool__ return self.node.bool_() File "/root/picasso/songh/my_venv/py310/lib/python3.10/site-packages/torch/fx/experimental/sym_node.py", line 511, in bool_ return self.guard_bool("", 0) File "/root/picasso/songh/my_venv/py310/lib/python3.10/site-packages/torch/fx/experimental/sym_node.py", line 449, in guard_bool r = self.shape_env.evaluate_expr(self.expr, self.hint, fx_node=self.fx_node) File "/root/picasso/songh/my_venv/py310/lib/python3.10/site-packages/torch/fx/experimental/recording.py", line 262, in wrapper return retlog(fn(*args, **kwargs)) File "/root/picasso/songh/my_venv/py310/lib/python3.10/site-packages/torch/fx/experimental/symbolic_shapes.py", line 5122, in evaluate_expr return self._evaluate_expr(orig_expr, hint, fx_node, size_oblivious, forcing_spec=forcing_spec) File "/root/picasso/songh/my_venv/py310/lib/python3.10/site-packages/torch/fx/experimental/symbolic_shapes.py", line 5238, in _evaluate_expr raise self._make_data_dependent_error( torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode: Could not guard on data-dependent expression Ne(u0, 1) (unhinted: Ne(u0, 1)). (Size-like symbols: u0)

ATTENTION: guard_size_oblivious would fix the error, evaluating expression to True.
Maybe you need to add guard_size_oblivious to framework code, see doc below for more guidance.

Potential framework code culprit (scroll up for full backtrace):
File "/root/picasso/songh/my_venv/py310/lib/python3.10/site-packages/torchao/quantization/quant_primitives.py", line 229, in _get_reduction_params
if block_size[i] != input_size[i] and block_size[i] > 1:

For more information, run with TORCH_LOGS="dynamic"
For extended logs when we create symbols, also add TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="u0"
If you suspect the guard was triggered from C++, add TORCHDYNAMO_EXTENDED_DEBUG_CPP=1
For more debugging help, see https://docs.google.com/document/d/1HSuTTVvYH1pTew89Rtpeu84Ht3nQEFTYhAX3Ypa_xJs/edit?usp=sharing

User Stack (most recent call last):
(snipped, see stack below for prefix)
File "/root/picasso/songh/workspace/QUANT/test_quant.py", line 16, in forward
return self.linear(x)
File "/root/picasso/songh/my_venv/py310/lib/python3.10/site-packages/torch/nn/modules/linear.py", line 125, in forward
return F.linear(input, self.weight, self.bias)

For C++ stack trace, run with TORCHDYNAMO_EXTENDED_DEBUG_CPP=1

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
File "/root/picasso/songh/my_venv/py310/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 2017, in get_fake_value
ret_val = wrap_fake_exception(
File "/root/picasso/songh/my_venv/py310/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 1574, in wrap_fake_exception
return fn()
File "/root/picasso/songh/my_venv/py310/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 2018, in
lambda: run_node(tx.output, node, args, kwargs, nnmodule)
File "/root/picasso/songh/my_venv/py310/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 2150, in run_node
raise RuntimeError(make_error_message(e)).with_traceback(
File "/root/picasso/songh/my_venv/py310/lib/python3.10/site-packages/torch/dynamo/utils.py", line 2132, in run_node
return node.target(*args, **kwargs)
File "/root/picasso/songh/my_venv/py310/lib/python3.10/site-packages/torchao/utils.py", line 433, in dispatch__torch_function

return cls._ATEN_OP_OR_TORCH_FN_TABLE[func](func, types, args, kwargs)
File "/root/picasso/songh/my_venv/py310/lib/python3.10/site-packages/torchao/utils.py", line 412, in wrapper
return func(f, types, args, kwargs)
File "/root/picasso/songh/my_venv/py310/lib/python3.10/site-packages/torchao/quantization/linear_activation_quantized_tensor.py", line 126, in _
return weight_tensor._quantized_linear_op(input_tensor, weight_tensor, bias)
File "/root/picasso/songh/my_venv/py310/lib/python3.10/site-packages/torchao/quantization/linear_activation_quantized_tensor.py", line 83, in _quantized_linear_op
quantized_tensor = input_quant_func(input_tensor, **quant_kwargs)
File "/root/picasso/songh/my_venv/py310/lib/python3.10/site-packages/torchao/quantization/quant_api.py", line 800, in _int8_symm_per_token_reduced_range_quant
return to_affine_quantized_intx(
File "/root/picasso/songh/my_venv/py310/lib/python3.10/site-packages/torchao/dtypes/affine_quantized_tensor.py", line 250, in from_hp_to_intx
scale, zero_point = choose_qparams_affine(
File "/root/picasso/songh/my_venv/py310/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
return func(*args, **kwargs)
File "/root/picasso/songh/my_venv/py310/lib/python3.10/site-packages/torchao/quantization/quant_primitives.py", line 738, in choose_qparams_affine
return _choose_qparams_affine(
File "/root/picasso/songh/my_venv/py310/lib/python3.10/site-packages/torch/_ops.py", line 1116, in call
return self._op(*args, **(kwargs or {}))
File "/root/picasso/songh/my_venv/py310/lib/python3.10/site-packages/torchao/quantization/quant_primitives.py", line 840, in _choose_qparams_affine
shape_for_reduction, reduction_dims = get_reduction_params(
File "/root/picasso/songh/my_venv/py310/lib/python3.10/site-packages/torchao/quantization/quant_primitives.py", line 229, in get_reduction_params
if block_size[i] != input_size[i] and block_size[i] > 1:
File "/root/picasso/songh/my_venv/py310/lib/python3.10/site-packages/torch/init.py", line 680, in bool
return self.node.bool
()
File "/root/picasso/songh/my_venv/py310/lib/python3.10/site-packages/torch/fx/experimental/sym_node.py", line 511, in bool

return self.guard_bool("", 0)
File "/root/picasso/songh/my_venv/py310/lib/python3.10/site-packages/torch/fx/experimental/sym_node.py", line 449, in guard_bool
r = self.shape_env.evaluate_expr(self.expr, self.hint, fx_node=self.fx_node)
File "/root/picasso/songh/my_venv/py310/lib/python3.10/site-packages/torch/fx/experimental/recording.py", line 262, in wrapper
return retlog(fn(args, **kwargs))
File "/root/picasso/songh/my_venv/py310/lib/python3.10/site-packages/torch/fx/experimental/symbolic_shapes.py", line 5122, in evaluate_expr
return self._evaluate_expr(orig_expr, hint, fx_node, size_oblivious, forcing_spec=forcing_spec)
File "/root/picasso/songh/my_venv/py310/lib/python3.10/site-packages/torch/fx/experimental/symbolic_shapes.py", line 5238, in _evaluate_expr
raise self._make_data_dependent_error(
RuntimeError: Failed running call_function (
(FakeTensor(..., device='cuda:0', size=(u0, 64, 128)), LinearActivationQuantizedTensor(AffineQuantizedTensor(tensor_impl=PlainAQTTensorImpl(data=FakeTensor(..., device='cuda:0', size=(256, 128), dtype=torch.int8)... , scale=FakeTensor(..., device='cuda:0', size=(256,))... , zero_point=FakeTensor(..., device='cuda:0', size=(256,), dtype=torch.int64)... , _layout=PlainLayout()), block_size=(1, 128), shape=torch.Size([256, 128]), device=cuda:0, dtype=torch.float32, requires_grad=False), <function _int8_symm_per_token_reduced_range_quant at 0x7fa631feac20>, quant_kwargs={})), Parameter(FakeTensor(..., device='cuda:0', size=(256,), requires_grad=True))), **{}):
Could not guard on data-dependent expression Ne(u0, 1) (unhinted: Ne(u0, 1)). (Size-like symbols: u0)

ATTENTION: guard_size_oblivious would fix the error, evaluating expression to True.
Maybe you need to add guard_size_oblivious to framework code, see doc below for more guidance.

Potential framework code culprit (scroll up for full backtrace):
File "/root/picasso/songh/my_venv/py310/lib/python3.10/site-packages/torchao/quantization/quant_primitives.py", line 229, in _get_reduction_params
if block_size[i] != input_size[i] and block_size[i] > 1:

For more information, run with TORCH_LOGS="dynamic"
For extended logs when we create symbols, also add TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="u0"
If you suspect the guard was triggered from C++, add TORCHDYNAMO_EXTENDED_DEBUG_CPP=1
For more debugging help, see https://docs.google.com/document/d/1HSuTTVvYH1pTew89Rtpeu84Ht3nQEFTYhAX3Ypa_xJs/edit?usp=sharing

User Stack (most recent call last):
(snipped, see stack below for prefix)
File "/root/picasso/songh/workspace/QUANT/test_quant.py", line 16, in forward
return self.linear(x)
File "/root/picasso/songh/my_venv/py310/lib/python3.10/site-packages/torch/nn/modules/linear.py", line 125, in forward
return F.linear(input, self.weight, self.bias)

For C++ stack trace, run with TORCHDYNAMO_EXTENDED_DEBUG_CPP=1

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
File "/root/picasso/songh/workspace/QUANT/test_quant.py", line 27, in
model(example_input)
File "/root/picasso/songh/my_venv/py310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/root/picasso/songh/my_venv/py310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
return forward_call(*args, **kwargs)
File "/root/picasso/songh/my_venv/py310/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 465, in _fn
return fn(*args, **kwargs)
File "/root/picasso/songh/my_venv/py310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/root/picasso/songh/my_venv/py310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
return forward_call(*args, **kwargs)
File "/root/picasso/songh/my_venv/py310/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 1269, in call
return self._torchdynamo_orig_callable(
File "/root/picasso/songh/my_venv/py310/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 526, in call
return _compile(
File "/root/picasso/songh/my_venv/py310/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 924, in _compile
guarded_code = compile_inner(code, one_graph, hooks, transform)
File "/root/picasso/songh/my_venv/py310/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 666, in compile_inner
return _compile_inner(code, one_graph, hooks, transform)
File "/root/picasso/songh/my_venv/py310/lib/python3.10/site-packages/torch/_utils_internal.py", line 87, in wrapper_function
return function(*args, **kwargs)
File "/root/picasso/songh/my_venv/py310/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 699, in _compile_inner
out_code = transform_code_object(code, transform)
File "/root/picasso/songh/my_venv/py310/lib/python3.10/site-packages/torch/_dynamo/bytecode_transformation.py", line 1322, in transform_code_object
transformations(instructions, code_options)
File "/root/picasso/songh/my_venv/py310/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 219, in _fn
return fn(*args, **kwargs)
File "/root/picasso/songh/my_venv/py310/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 634, in transform
tracer.run()
File "/root/picasso/songh/my_venv/py310/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2796, in run
super().run()
File "/root/picasso/songh/my_venv/py310/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 983, in run
while self.step():
File "/root/picasso/songh/my_venv/py310/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 895, in step
self.dispatch_table[inst.opcode](self, inst)
File "/root/picasso/songh/my_venv/py310/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 582, in wrapper
return inner_fn(self, inst)
File "/root/picasso/songh/my_venv/py310/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1602, in CALL_FUNCTION
self.call_function(fn, args, {})
File "/root/picasso/songh/my_venv/py310/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 830, in call_function
self.push(fn.call_function(self, args, kwargs)) # type: ignore[arg-type]
File "/root/picasso/songh/my_venv/py310/lib/python3.10/site-packages/torch/_dynamo/variables/lazy.py", line 156, in realize_and_forward
return getattr(self.realize(), name)(*args, **kwargs)
File "/root/picasso/songh/my_venv/py310/lib/python3.10/site-packages/torch/_dynamo/variables/nn_module.py", line 899, in call_function
return variables.UserFunctionVariable(fn, source=source).call_function(
File "/root/picasso/songh/my_venv/py310/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 324, in call_function
return super().call_function(tx, args, kwargs)
File "/root/picasso/songh/my_venv/py310/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 111, in call_function
return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
File "/root/picasso/songh/my_venv/py310/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 836, in inline_user_function_return
return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
File "/root/picasso/songh/my_venv/py310/lib/python3.10/site-packages/torch/dynamo/symbolic_convert.py", line 3011, in inline_call
return cls.inline_call
(parent, func, args, kwargs)
File "/root/picasso/songh/my_venv/py310/lib/python3.10/site-packages/torch/dynamo/symbolic_convert.py", line 3139, in inline_call
tracer.run()
File "/root/picasso/songh/my_venv/py310/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 983, in run
while self.step():
File "/root/picasso/songh/my_venv/py310/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 895, in step
self.dispatch_table[inst.opcode](self, inst)
File "/root/picasso/songh/my_venv/py310/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 582, in wrapper
return inner_fn(self, inst)
File "/root/picasso/songh/my_venv/py310/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1602, in CALL_FUNCTION
self.call_function(fn, args, {})
File "/root/picasso/songh/my_venv/py310/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 830, in call_function
self.push(fn.call_function(self, args, kwargs)) # type: ignore[arg-type]
File "/root/picasso/songh/my_venv/py310/lib/python3.10/site-packages/torch/_dynamo/variables/torch.py", line 897, in call_function
tensor_variable = wrap_fx_proxy(
File "/root/picasso/songh/my_venv/py310/lib/python3.10/site-packages/torch/_dynamo/variables/builder.py", line 2037, in wrap_fx_proxy
return wrap_fx_proxy_cls(target_cls=TensorVariable, **kwargs)
File "/root/picasso/songh/my_venv/py310/lib/python3.10/site-packages/torch/_dynamo/variables/builder.py", line 2124, in wrap_fx_proxy_cls
example_value = get_fake_value(proxy.node, tx, allow_non_graph_fake=True)
File "/root/picasso/songh/my_venv/py310/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 2072, in get_fake_value
raise UserError( # noqa: B904
torch._dynamo.exc.UserError: Could not guard on data-dependent expression Ne(u0, 1) (unhinted: Ne(u0, 1)). (Size-like symbols: u0)

ATTENTION: guard_size_oblivious would fix the error, evaluating expression to True.
Maybe you need to add guard_size_oblivious to framework code, see doc below for more guidance.

Potential framework code culprit (scroll up for full backtrace):
File "/root/picasso/songh/my_venv/py310/lib/python3.10/site-packages/torchao/quantization/quant_primitives.py", line 229, in _get_reduction_params
if block_size[i] != input_size[i] and block_size[i] > 1:

For more information, run with TORCH_LOGS="dynamic"
For extended logs when we create symbols, also add TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="u0"
If you suspect the guard was triggered from C++, add TORCHDYNAMO_EXTENDED_DEBUG_CPP=1
For more debugging help, see https://docs.google.com/document/d/1HSuTTVvYH1pTew89Rtpeu84Ht3nQEFTYhAX3Ypa_xJs/edit?usp=sharing

User Stack (most recent call last):
(snipped, see stack below for prefix)
File "/root/picasso/songh/workspace/QUANT/test_quant.py", line 16, in forward
return self.linear(x)
File "/root/picasso/songh/my_venv/py310/lib/python3.10/site-packages/torch/nn/modules/linear.py", line 125, in forward
return F.linear(input, self.weight, self.bias)

For C++ stack trace, run with TORCHDYNAMO_EXTENDED_DEBUG_CPP=1
For more information about this error, see: https://pytorch.org/docs/main/generated/exportdb/index.html#constrain-as-size-example

from user code:
File "/root/picasso/songh/workspace/QUANT/test_quant.py", line 16, in forward
return self.linear(x)
File "/root/picasso/songh/my_venv/py310/lib/python3.10/site-packages/torch/nn/modules/linear.py", line 125, in forward
return F.linear(input, self.weight, self.bias)

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information

You can suppress this exception and fall back to eager by setting:
import torch._dynamo
torch._dynamo.config.suppress_errors = True

@jerryzh168
Copy link
Contributor

Thanks for trying out this with dynamic shapes, I think we haven't really tested this.

First is that we do torch.compile after applying torchao quantization, instead of before.
Secondly I saw mark_dynamic is what's mentioned in https://pytorch.org/docs/stable/torch.compiler_dynamic_shapes.html so we can start with that I think.

The following code works for me, can you try:

import torch

from torchao.quantization.quant_api import (
    quantize_,
    int8_dynamic_activation_int8_weight
)
torch._logging.set_logs(recompiles=True, recompiles_verbose = True)

class MyModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = torch.nn.Linear(128, 256)

    def forward(self, x):
        return self.linear(x)

model = MyModel().cuda().eval()
# model = torch.compile(model, fullgraph=True)

# quant
quantize_(model, int8_dynamic_activation_int8_weight())
model = torch.compile(model, fullgraph=True)

example_input = torch.randn(2, 64, 128).cuda()
# torch._dynamo.decorators.mark_unbacked(example_input, 0)
torch._dynamo.mark_dynamic(example_input, 0)
model(example_input)

x1 = torch.randn(1, 64, 128).cuda()
x2 = torch.randn(2, 64, 128).cuda()

print("input shape: ", x1.shape)
model(x1)
print("input shape: ", x2.shape)
model(x2)

to see if this is what you are looking for?

@jerryzh168 jerryzh168 added question Further information is requested quantize triaged labels Feb 28, 2025
@songh11
Copy link
Author

songh11 commented Mar 3, 2025

Thanks for trying out this with dynamic shapes, I think we haven't really tested this.

First is that we do torch.compile after applying torchao quantization, instead of before. Secondly I saw mark_dynamic is what's mentioned in https://pytorch.org/docs/stable/torch.compiler_dynamic_shapes.html so we can start with that I think.

The following code works for me, can you try:

import torch

from torchao.quantization.quant_api import (
    quantize_,
    int8_dynamic_activation_int8_weight
)
torch._logging.set_logs(recompiles=True, recompiles_verbose = True)

class MyModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = torch.nn.Linear(128, 256)

    def forward(self, x):
        return self.linear(x)

model = MyModel().cuda().eval()
# model = torch.compile(model, fullgraph=True)

# quant
quantize_(model, int8_dynamic_activation_int8_weight())
model = torch.compile(model, fullgraph=True)

example_input = torch.randn(2, 64, 128).cuda()
# torch._dynamo.decorators.mark_unbacked(example_input, 0)
torch._dynamo.mark_dynamic(example_input, 0)
model(example_input)

x1 = torch.randn(1, 64, 128).cuda()
x2 = torch.randn(2, 64, 128).cuda()

print("input shape: ", x1.shape)
model(x1)
print("input shape: ", x2.shape)
model(x2)

to see if this is what you are looking for?

Thanks for your reply❤️, I had run this code, but there are still the same errors. May I ask what is your torch and torchao version?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
quantize question Further information is requested triaged
Projects
None yet
Development

No branches or pull requests

2 participants