diff --git a/torchao/quantization/autoquant.py b/torchao/quantization/autoquant.py index d44867739..e4edd1938 100644 --- a/torchao/quantization/autoquant.py +++ b/torchao/quantization/autoquant.py @@ -410,9 +410,6 @@ def _quantized_linear_op(act_mat, w_qtensor, bias): @classmethod def _autoquant_test(cls, act_mat, *args): - # if act_mat has batchsize>2 don't use this kernel - if act_mat.reshape(-1, act_mat.shape[-1]).shape[0]>32: - return torch.inf return super()._autoquant_test(act_mat, *args) class AQInt8WeightOnlyQuantizedLinearWeight3(AQInt8WeightOnlyQuantizedLinearWeight, AQMixin):