You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I'm looking at generating a int8 quantised PyTorch model (both weights and activations at int8), and exporting to StableHLO via torch-xla's exported_program_to_stablehlo.
Right now I'm relatively ambivalent regarding how the model is quantised, as long as I end up with a valid graph with int8 weights and activations (with i32 accumulation types, presumably).
However, there are a few ways to quantise in PyTorch, with various caveats and issues. The furthest I've been able to get is below, in a reproducible script. However, it raises the error:
File "/app/examples/generate_weenet_mlp_int8.py", line 70, in <module>
stablehlo_program = exported_program_to_stablehlo(exported)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/user/.local/lib/python3.11/site-packages/torch_xla/stablehlo.py", line 618, in exported_program_to_stablehlo
bundle = _exported_program_to_stablehlo_bundle(exported_model, options)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/user/.local/lib/python3.11/site-packages/torch_xla/stablehlo.py", line 370, in _exported_program_to_stablehlo_bundle
raise RuntimeError(message)
RuntimeError: This model contains ops not capturable by Pytorch/XLA: aten::_fused_moving_avg_obs_fq_helper
To Reproduce
importosimporttorchimporttorch.nnasnnfromtorch.ao.quantizationimportget_default_qat_qconfig, QConfigMappingfromtorch.ao.quantization.quantize_fximportprepare_fxfromtorch.utils.dataimportDataLoader, TensorDatasetfromtorch.exportimportexportfromtorch_xla.stablehloimportexported_program_to_stablehlo# Ensure CPU-only execution for torch_xla and disable CUDAos.environ["XLA_USE_BF16"] ="0"os.environ["XLA_USE_CUDA"] ="0"os.environ["CUDA_VISIBLE_DEVICES"] =""# Define a simple neural network modelclassWeeNetMLP(nn.Module):
def__init__(self):
super(WeeNetMLP, self).__init__()
self.flatten=nn.Flatten()
self.dense1=nn.Linear(4*6, 32)
self.relu1=nn.ReLU()
self.dense2=nn.Linear(32, 16)
self.relu2=nn.ReLU()
self.dense3=nn.Linear(16, 8)
self.relu3=nn.ReLU()
defforward(self, x):
x=self.flatten(x)
x=self.dense1(x)
x=self.relu1(x)
x=self.dense2(x)
x=self.relu2(x)
x=self.dense3(x)
x=self.relu3(x)
returnx# Initialize the model and set to evaluation modemodel=WeeNetMLP().eval()
# Configure fake quantization using QAT configurationqconfig_mapping=QConfigMapping().set_global(get_default_qat_qconfig("fbgemm"))
# Define example inputs and input shapeexample_inputs= (torch.randn(1, 4, 6),)
input_shape= (1, 4, 6)
# Prepare the model for fake quantizationprepared_model=prepare_fx(model, qconfig_mapping, example_inputs=example_inputs)
# Generate random data for calibrationcalibration_data=torch.randn(100, *input_shape)
# Create a dataset and data loader for calibrationcalibration_dataset=TensorDataset(calibration_data)
calibration_data_loader=DataLoader(calibration_dataset, batch_size=10)
# Calibrate the model with the calibration datafordataincalibration_data_loader:
prepared_model(data[0]) # Run the prepared model on each batch of calibration data# After calibrationprepared_model.apply(torch.ao.quantization.disable_observer)
# Export the prepared model to StableHLO formatexported=export(prepared_model, example_inputs)
stablehlo_program=exported_program_to_stablehlo(exported)
Expected behavior
I would expect this to produce a StableHLO graph with int8 tensors in it.
If this can be achieved with a different quantisation method in PyTorch, that also works. The issue here seems to be around this aten op.
Environment
Reproducible on XLA backend CPU/TPU
torch_xla version: 2.4.0
The text was updated successfully, but these errors were encountered:
More context, we are looking to expand torch_ao support in the coming future; appreciate you filing the bug and surfacing use cases and issues observed. @lsy323 to help drive this issue as mentioned earlier.
🐛 Bug
I'm looking at generating a int8 quantised PyTorch model (both weights and activations at int8), and exporting to StableHLO via
torch-xla
'sexported_program_to_stablehlo
.Right now I'm relatively ambivalent regarding how the model is quantised, as long as I end up with a valid graph with int8 weights and activations (with i32 accumulation types, presumably).
However, there are a few ways to quantise in PyTorch, with various caveats and issues. The furthest I've been able to get is below, in a reproducible script. However, it raises the error:
To Reproduce
Expected behavior
I would expect this to produce a StableHLO graph with int8 tensors in it.
If this can be achieved with a different quantisation method in PyTorch, that also works. The issue here seems to be around this
aten
op.Environment
The text was updated successfully, but these errors were encountered: