You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Pardon the naive question, trying to understand how to implement a basic tensor subclass.
The problem I'm encountering is that the tensor subclass loses its attributes after calling torch.save on a state dict containing the subclass likely due to the use of swap_tensors.
Running the above gives the following prints an error while loading the state dict for SimpleTensor with weights_only=True even after registering SimpleTensor as safe (torch.serialization.add_safe_globals([SimpleTensor])):
State load error: Weights only load failed. In PyTorch 2.6, we changed the default value of the `weights_only` argument in `torch.load` from `False` to `True`. Re-running `torch.load` with `weights_only` set to `False` will likely succeed, but it can result in arbitrary code execution. Do it only if you got the file from a trusted source.
Please file an issue with the following so that we can make `weights_only=True` compatible with your use case: WeightsUnpickler error: Unsupported operand 48
If I set weights_only=False, the loaded state dict loads the tensor as a SimpleTensor but gives the following error:
Restored tensor error: 'SimpleTensor' object has no attribute 'inner_tensor'
NF4Tensor, on the other hand, saves and loads just fine.
Are there particular ops that need to be implemented in order to serialize a subclass?
The issue I think is rising from the use of swap_tensors, which I've seen used in torchtunehere and mentioned as needed when loading subclasses with multiple wrapped tensors.
This is with torch 2.6.
Thanks!
The text was updated successfully, but these errors were encountered:
Pardon the naive question, trying to understand how to implement a basic tensor subclass.
The problem I'm encountering is that the tensor subclass loses its attributes after calling torch.save on a state dict containing the subclass likely due to the use of
swap_tensors
.Minimal repro:
Running the above gives the following prints an error while loading the state dict for
SimpleTensor
withweights_only=True
even after registeringSimpleTensor
as safe (torch.serialization.add_safe_globals([SimpleTensor])
):If I set
weights_only=False
, the loaded state dict loads the tensor as aSimpleTensor
but gives the following error:NF4Tensor
, on the other hand, saves and loads just fine.Are there particular ops that need to be implemented in order to serialize a subclass?
The issue I think is rising from the use of
swap_tensors
, which I've seen used intorchtune
here and mentioned as needed when loading subclasses with multiple wrapped tensors.This is with
torch 2.6
.Thanks!
The text was updated successfully, but these errors were encountered: