Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How-swaping model weights of torch.complied or traced model #7608

Closed
BitPhinix opened this issue Jul 2, 2024 · 11 comments
Closed

How-swaping model weights of torch.complied or traced model #7608

BitPhinix opened this issue Jul 2, 2024 · 11 comments

Comments

@BitPhinix
Copy link
Contributor

BitPhinix commented Jul 2, 2024

I've been trying to get module weight hot-swapping working for almost two days now. All ways I found were blocked by some bug in torch XLA or pytorch itself:

Here is what I tried:

torch.func.functional_call

Mutating the model params directly:

Replacing torch.compile with torch.jit.trace:

torch.jit.script:

  • Not an option since the module I'm working with is quite big and not compatible at all with the baseline requirements for torch.jit.script to work + I'm not sure it would actually end up being close in perf to torch.compile

I'm out of ideas. It's annoying because this model weight swapping should be possible without significant hassle; it works out of the box with Jax / Flax, as the model weights are just parameters. I highly appreciate the work on torch xla; this just ended up being quite frustrating as each possible path ended up blocked by some issues or another. Any pointers to what else I could try? I need to compile/trace the model since we use it for inference and otherwise it's too slow.

@JackCaoG
Copy link
Collaborator

JackCaoG commented Jul 2, 2024

Sorry for the bad experience, let's see how I can help.

We never really supported torch.jit.trace which is why this path you hit a lot of the blockers. torch.compile issue seem to be related to nn moudles is currently allowed in graph and dynamo will not trace into it, AFAIK pytorch team is trying to change that(dynamo will try to inline into nn modules).

I am trying to understand your use case, is this inference or training? I felt like parameter swapping should just works if you use the default lazy tensor, which is what we recommend for the training for now. If you are using inference then torch.compile is probably the way to go and I can see what I can do there.

@BitPhinix
Copy link
Contributor Author

Thanks for the quick reply! We have a lora for every user which we apply during inference requests. We have realtime inference so are trying to keep request times as low as possible, that's why we are merging the lora with the model weights (seems to be about 25% faster compared to keeping them in separate layers).

@JackCaoG
Copy link
Collaborator

JackCaoG commented Jul 2, 2024

ok so you are trying to combine the lora weight and the model weight and then swap the model parameters. I felt like the fastest path is to fix Mutating the model params directly.

Are you using openxla_eval backend or openxla? You can try both. If you have a repo of the issue I can see if I can debug it. If the issue is in pytorch/xla I can fix it, I am just a bit worried that dynamo will realized that python object ID of the model parameters changed after swapping and will recompile(dynamo is very paranoid haha).

@BitPhinix
Copy link
Contributor Author

I tried both openxla_eval and openxla as the backend, ran into #7138 for the openxla one, and gave up. Trying again with the nightly build right now 🙏

@BitPhinix
Copy link
Contributor Author

Running into:

[rank0]: NotImplementedError: Could not run 'aten::empty.memory_format' with arguments from the 'XLA' backend. This could be because the operator doesn't exist for this backend, or was omitted during the selective/custom build process (if using custom build). If you are a Facebook employee using PyTorch on mobile, please visit https://fburl.com/ptmfixes for possible resolutions. 'aten::empty.memory_format' is only available for these backends: [CPU, MPS, Meta, QuantizedCPU, QuantizedMeta, MkldnnCPU, SparseCPU, SparseMeta, SparseCsrCPU, SparseCsrMeta, BackendSelect, Python, FuncTorchDynamicLayerBackMode, Functionalize, Named, Conjugate, Negative, ZeroTensor, ADInplaceOrView, AutogradOther, AutogradCPU, AutogradCUDA, AutogradHIP, AutogradXLA, AutogradMPS, AutogradIPU, AutogradXPU, AutogradHPU, AutogradVE, AutogradLazy, AutogradMTIA, AutogradPrivateUse1, AutogradPrivateUse2, AutogradPrivateUse3, AutogradMeta, AutogradNestedTensor, Tracer, AutocastCPU, AutocastXPU, AutocastXLA, AutocastCUDA, FuncTorchBatched, BatchedNestedTensor, FuncTorchVmapMode, Batched, VmapMode, FuncTorchGradWrapper, PythonTLSSnapshot, FuncTorchDynamicLayerFrontMode, PreDispatch, PythonDispatcher].

When using nightly build (both torch 2.4 and torch 2.5) with 2.4.0+gite7a8700

@JackCaoG
Copy link
Collaborator

JackCaoG commented Jul 2, 2024

XLA only support continuous memory format, do you know what memory format the code is trying to use and whether that's expected?

@BitPhinix
Copy link
Contributor Author

BitPhinix commented Jul 2, 2024

components = {
    "unet": UNet2DConditionModel.from_pretrained(
        checkpoint, subfolder="unet", use_safetensors=True, torch_dtype=torch.bfloat16
    ),
}

components['unet'].to(device) # device is xla, it fails here

Just loading a unet from diffusers, not sure what this uses internally but this works on the non-dev build without issues 🤔

@JackCaoG
Copy link
Collaborator

JackCaoG commented Jul 2, 2024

so I am able to run

import torch
import torch_xla
from diffusers import UNet2DConditionModel

checkpoint = "runwayml/stable-diffusion-v1-5"
device = torch_xla.device()

components = {
    "unet":
        UNet2DConditionModel.from_pretrained(
            checkpoint,
            subfolder="unet",
            use_safetensors=True,
            torch_dtype=torch.bfloat16),
}

components['unet'].to(device)  # device is xla, it fails here

with nightly

pip3 install https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-nightly+20240701-cp310-cp310-linux_x86_64.whl
pip3 install https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-nightly+20240701-cp310-cp310-linux_x86_64.whl
pip install torch_xla[tpu] -f https://storage.googleapis.com/libtpu-releases/index.html

not sure which checkpoint you are using.

@JackCaoG
Copy link
Collaborator

JackCaoG commented Jul 2, 2024

@BitPhinix btw I think the failure you run into was something new, we started to seeing it on today's CI. If you use the 07/01 nightly it will likely works. We will see how to fix this new regression.

@BitPhinix
Copy link
Contributor Author

Yeah tried with your versions, and it seems to load the model fine, let's hope everything else goes well as well 🙏

@BitPhinix
Copy link
Contributor Author

BitPhinix commented Jul 4, 2024

Seems to work fine, thanks again for your help @JackCaoG 🙏

TLDR for anyone who stumbles upon this:

Use openxla as the backend, not openxla_eval. Use the nightly builds mentioned above.

Modify the params in place of model._orig_mod on the compiled model. Be sure to modify the params and not directly the data inside the params

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants