[NF4] Various bugs in how NF4 handles .to()
to move to a different device
#1310
Labels
bug
Something isn't working
.to()
to move to a different device
#1310
Reproduction
Summary:
NF4Tensor.cuda()
will dequantize -> this is unwantedNF4Tensor.to(device="cuda")
will raiseIndexError
, sinceargs[1]
does not existNF4Tensor.cpu()
does not move.nf4
attribute -> cannot dequantizenn.Module.to(device)
NF4Tensor.to(torch.float32)
will dequantize is the culprit that causes these troubles + it is not consistent with AQT behavor. If.to(dtype)
does not dequantize (only change appearance dtype), we only need to implementaten._to_copy
instead ofTensor.cpu
,Tensor.to
and myriad of others. Though I understand this design is to make NF4 feels more like a true dtype.NF4Tensor.dequantize()
as the method to dequantize the tensor (also consistent with plain Tensor behavior, though plainTensor.dequantize()
will always return FP32), instead of the current situation (NF4Tensor.dequantize()
is a static method for lookup table, whileNF4Tensor.get_original_weight()
does dequant)The text was updated successfully, but these errors were encountered: