-
Notifications
You must be signed in to change notification settings - Fork 443
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
How-swaping model weights of torch.complied or traced model #7608
Comments
Sorry for the bad experience, let's see how I can help. We never really supported I am trying to understand your use case, is this inference or training? I felt like parameter swapping should just works if you use the default lazy tensor, which is what we recommend for the training for now. If you are using inference then |
Thanks for the quick reply! We have a lora for every user which we apply during inference requests. We have realtime inference so are trying to keep request times as low as possible, that's why we are merging the lora with the model weights (seems to be about 25% faster compared to keeping them in separate layers). |
ok so you are trying to combine the lora weight and the model weight and then swap the model parameters. I felt like the fastest path is to fix Are you using |
I tried both |
Running into:
When using nightly build (both torch 2.4 and torch 2.5) with |
XLA only support continuous memory format, do you know what memory format the code is trying to use and whether that's expected? |
Just loading a unet from diffusers, not sure what this uses internally but this works on the non-dev build without issues 🤔 |
so I am able to run
with nightly
not sure which checkpoint you are using. |
@BitPhinix btw I think the failure you run into was something new, we started to seeing it on today's CI. If you use the 07/01 nightly it will likely works. We will see how to fix this new regression. |
Yeah tried with your versions, and it seems to load the model fine, let's hope everything else goes well as well 🙏 |
Seems to work fine, thanks again for your help @JackCaoG 🙏 TLDR for anyone who stumbles upon this: Use Modify the params in place of |
I've been trying to get module weight hot-swapping working for almost two days now. All ways I found were blocked by some bug in torch XLA or pytorch itself:
Here is what I tried:
torch.func.functional_call
torch.func.functional_call
, blocked by [dynamo][inline inbuilt nn modules]torch.func.functional_call
doesn't work with compiled models pytorch#97909torch.func.functional_call
inside it. The same issue as 1, didn't do anythingMutating the model params directly:
Trying to access XLA data while an async operation is in flight: UNKNOWN_SCALAR[]
Replacing
torch.compile
withtorch.jit.trace
:torch.jit.script
:torch.jit.script
to work + I'm not sure it would actually end up being close in perf totorch.compile
I'm out of ideas. It's annoying because this model weight swapping should be possible without significant hassle; it works out of the box with Jax / Flax, as the model weights are just parameters. I highly appreciate the work on torch xla; this just ended up being quite frustrating as each possible path ended up blocked by some issues or another. Any pointers to what else I could try? I need to compile/trace the model since we use it for inference and otherwise it's too slow.
The text was updated successfully, but these errors were encountered: