Skip to content

Commit c012133

Browse files
authored
pytorch/ao/torchao/experimental/ops/mps/test
Differential Revision: D67388057 Pull Request resolved: #1442
1 parent fe3f359 commit c012133

File tree

1 file changed

+4
-5
lines changed

1 file changed

+4
-5
lines changed

torchao/experimental/ops/mps/test/test_quantizer.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,18 +4,17 @@
44
# This source code is licensed under the license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
from typing import Optional
87
import copy
98
import itertools
109
import os
1110
import sys
11+
import unittest
12+
from typing import Optional
1213

1314
import torch
14-
import unittest
1515

1616
from parameterized import parameterized
17-
from torchao.experimental.quant_api import UIntxWeightOnlyLinearQuantizer
18-
from torchao.experimental.quant_api import _quantize
17+
from torchao.experimental.quant_api import _quantize, UIntxWeightOnlyLinearQuantizer
1918

2019
libname = "libtorchao_ops_mps_aten.dylib"
2120
libpath = os.path.abspath(
@@ -80,7 +79,7 @@ def test_export(self, nbit):
8079
activations = torch.randn(m, k0, dtype=torch.float32, device="mps")
8180

8281
quantized_model = self._quantize_model(model, torch.float32, nbit, group_size)
83-
exported = torch.export.export(quantized_model, (activations,))
82+
exported = torch.export.export(quantized_model, (activations,), strict=True)
8483

8584
for node in exported.graph.nodes:
8685
if node.op == "call_function":

0 commit comments

Comments
 (0)