Skip to content

Commit 4abe4b8

Browse files
[int4-quant] Execute weights shuffling on CPU until MPS memory issue is resolved (pytorch#552)
1 parent 4563492 commit 4abe4b8

File tree

1 file changed

+6
-0
lines changed

1 file changed

+6
-0
lines changed

torchao/quantization/utils.py

+6
Original file line numberDiff line numberDiff line change
@@ -351,7 +351,13 @@ def groupwise_affine_quantize_tensor_from_qparams(
351351

352352
int_data = quantize_affine(w, block_size, scales, zeros, output_dtype, quant_min, quant_max, zero_point_domain = ZeroPointDomain.FLOAT)
353353
if TORCH_VERSION_AFTER_2_5:
354+
int_data_device_type = int_data.device.type
355+
# Move to cpu, until issue with MPS memory management of temporary tensors is resolved
356+
if int_data_device_type == 'mps':
357+
int_data = int_data.cpu()
354358
int_data = (int_data[::, ::2] << 4 | int_data[::, 1::2]).to(torch.uint8)
359+
if int_data_device_type == 'mps':
360+
int_data = int_data.to(device='mps')
355361
return int_data
356362

357363
def groupwise_affine_dequantize_tensor_from_qparams(

0 commit comments

Comments
 (0)