Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

int8 StableHLO export #8373

Open
Wheest opened this issue Nov 12, 2024 · 2 comments
Open

int8 StableHLO export #8373

Wheest opened this issue Nov 12, 2024 · 2 comments
Assignees

Comments

@Wheest
Copy link

Wheest commented Nov 12, 2024

🐛 Bug

I'm looking at generating a int8 quantised PyTorch model (both weights and activations at int8), and exporting to StableHLO via torch-xla's exported_program_to_stablehlo.

Right now I'm relatively ambivalent regarding how the model is quantised, as long as I end up with a valid graph with int8 weights and activations (with i32 accumulation types, presumably).

However, there are a few ways to quantise in PyTorch, with various caveats and issues. The furthest I've been able to get is below, in a reproducible script. However, it raises the error:

  File "/app/examples/generate_weenet_mlp_int8.py", line 70, in <module>
    stablehlo_program = exported_program_to_stablehlo(exported)
                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/user/.local/lib/python3.11/site-packages/torch_xla/stablehlo.py", line 618, in exported_program_to_stablehlo
    bundle = _exported_program_to_stablehlo_bundle(exported_model, options)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/user/.local/lib/python3.11/site-packages/torch_xla/stablehlo.py", line 370, in _exported_program_to_stablehlo_bundle
    raise RuntimeError(message)
RuntimeError: This model contains ops not capturable by Pytorch/XLA: aten::_fused_moving_avg_obs_fq_helper

To Reproduce

import os
import torch
import torch.nn as nn
from torch.ao.quantization import get_default_qat_qconfig, QConfigMapping
from torch.ao.quantization.quantize_fx import prepare_fx
from torch.utils.data import DataLoader, TensorDataset
from torch.export import export
from torch_xla.stablehlo import exported_program_to_stablehlo

# Ensure CPU-only execution for torch_xla and disable CUDA
os.environ["XLA_USE_BF16"] = "0"
os.environ["XLA_USE_CUDA"] = "0"
os.environ["CUDA_VISIBLE_DEVICES"] = ""


# Define a simple neural network model
class WeeNetMLP(nn.Module):
    def __init__(self):
        super(WeeNetMLP, self).__init__()
        self.flatten = nn.Flatten()
        self.dense1 = nn.Linear(4 * 6, 32)
        self.relu1 = nn.ReLU()
        self.dense2 = nn.Linear(32, 16)
        self.relu2 = nn.ReLU()
        self.dense3 = nn.Linear(16, 8)
        self.relu3 = nn.ReLU()

    def forward(self, x):
        x = self.flatten(x)
        x = self.dense1(x)
        x = self.relu1(x)
        x = self.dense2(x)
        x = self.relu2(x)
        x = self.dense3(x)
        x = self.relu3(x)
        return x


# Initialize the model and set to evaluation mode
model = WeeNetMLP().eval()

# Configure fake quantization using QAT configuration
qconfig_mapping = QConfigMapping().set_global(get_default_qat_qconfig("fbgemm"))

# Define example inputs and input shape
example_inputs = (torch.randn(1, 4, 6),)
input_shape = (1, 4, 6)

# Prepare the model for fake quantization
prepared_model = prepare_fx(model, qconfig_mapping, example_inputs=example_inputs)

# Generate random data for calibration
calibration_data = torch.randn(100, *input_shape)

# Create a dataset and data loader for calibration
calibration_dataset = TensorDataset(calibration_data)
calibration_data_loader = DataLoader(calibration_dataset, batch_size=10)

# Calibrate the model with the calibration data
for data in calibration_data_loader:
    prepared_model(data[0])  # Run the prepared model on each batch of calibration data

# After calibration
prepared_model.apply(torch.ao.quantization.disable_observer)

# Export the prepared model to StableHLO format
exported = export(prepared_model, example_inputs)
stablehlo_program = exported_program_to_stablehlo(exported)

Expected behavior

I would expect this to produce a StableHLO graph with int8 tensors in it.

If this can be achieved with a different quantisation method in PyTorch, that also works. The issue here seems to be around this aten op.

Environment

  • Reproducible on XLA backend CPU/TPU
  • torch_xla version: 2.4.0
@JackCaoG
Copy link
Collaborator

@lsy323 is out but he can take a look when he is back.

@miladm
Copy link
Collaborator

miladm commented Nov 18, 2024

More context, we are looking to expand torch_ao support in the coming future; appreciate you filing the bug and surfacing use cases and issues observed. @lsy323 to help drive this issue as mentioned earlier.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

4 participants