Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Lora] Speed up lora loading #4994

Merged
merged 13 commits into from
Sep 12, 2023
Merged

[Lora] Speed up lora loading #4994

merged 13 commits into from
Sep 12, 2023

Conversation

patrickvonplaten
Copy link
Contributor

@patrickvonplaten patrickvonplaten commented Sep 12, 2023

This PR refactors LoRA loading a bit (removing some boilerplate code) and speeds-up the loading process by being a bit smarter with device and dtype placement and adding low_cpu_mem_usage support.

The following should be sped up by at least a factor of 2.

from diffusers import DiffusionPipeline
from safetensors.torch import load_file
import torch
from pathlib import Path
from huggingface_hub import HfApi, hf_hub_download
import os
import hf_image_uploader as hiu
import time

api = HfApi()


pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16)
pipe.to("cuda")

# pipe.load_lora_weights("stabilityai/stable-diffusion-xl-base-1.0", weight_name="sd_xl_offset_example-lora_1.0.safetensors", low_cpu_mem_usage=True)
# file = hf_hub_download("TheLastBen/Papercut_SDXL", filename="papercut.safetensors")
file = hf_hub_download("hf-internal-testing/sdxl-0.9-daiton-lora", filename="daiton-xl-lora-test.safetensors")
state_dict = load_file(file)
state_dict = {k: v.to(device="cuda", dtype=torch.float16) for k,v in state_dict.items() if torch.is_tensor(v)}

start_time = time.time()
pipe.load_lora_weights(state_dict, low_cpu_mem_usage=True)
print(time.time() - start_time)

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Sep 12, 2023

The documentation is not available anymore as the PR was closed or merged.

@patrickvonplaten patrickvonplaten changed the title speed up lora loading [Lora] Speed up lora loading Sep 12, 2023
@patrickvonplaten patrickvonplaten merged commit 37cb819 into main Sep 12, 2023
@patrickvonplaten patrickvonplaten deleted the speed_up_loading branch September 12, 2023 15:51
@patrickvonplaten
Copy link
Contributor Author

Testing scripts here:

1.) Pure loading time:

from diffusers import DiffusionPipeline
from safetensors.torch import load_file
import torch
import time


pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16)
pipe.to("cuda")

# pipe.load_lora_weights("stabilityai/stable-diffusion-xl-base-1.0", weight_name="sd_xl_offset_example-lora_1.0.safetensors", low_cpu_mem_usage=True)
# file = hf_hub_download("TheLastBen/Papercut_SDXL", filename="papercut.safetensors")
file = hf_hub_download("hf-internal-testing/sdxl-0.9-daiton-lora", filename="daiton-xl-lora-test.safetensors")
state_dict = load_file(file)
state_dict = {k: v.to(device="cuda", dtype=torch.float16) for k,v in state_dict.items() if torch.is_tensor(v)}

start_time = time.time()
pipe.load_lora_weights(state_dict, low_cpu_mem_usage=True)
print(time.time() - start_time)

2.) LoRA fusing/unfusing:

from diffusers import DiffusionPipeline
from safetensors.torch import load_file
import torch
import time


pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16)
pipe.to("cuda")

# pipe.load_lora_weights("stabilityai/stable-diffusion-xl-base-1.0", weight_name="sd_xl_offset_example-lora_1.0.safetensors", low_cpu_mem_usage=True)
# file = hf_hub_download("TheLastBen/Papercut_SDXL", filename="papercut.safetensors")
file = hf_hub_download("hf-internal-testing/sdxl-0.9-daiton-lora", filename="daiton-xl-lora-test.safetensors")
state_dict = load_file(file)
state_dict = {k: v.to(device="cuda", dtype=torch.float16) for k,v in state_dict.items() if torch.is_tensor(v)}
pipe.load_lora_weights(state_dict, low_cpu_mem_usage=True)

start_time = time.time()
pipe.fuse_lora()
print(time.time() - start_time)
start_time = time.time()
pipe.unfuse_lora()
print(time.time() - start_time)

yoonseokjin pushed a commit to yoonseokjin/diffusers that referenced this pull request Dec 25, 2023
* speed up lora loading

* Apply suggestions from code review

* up

* up

* Fix more

* Correct more

* Apply suggestions from code review

* up

* Fix more

* Fix more -

* up

* up
AmericanPresidentJimmyCarter pushed a commit to AmericanPresidentJimmyCarter/diffusers that referenced this pull request Apr 26, 2024
* speed up lora loading

* Apply suggestions from code review

* up

* up

* Fix more

* Correct more

* Apply suggestions from code review

* up

* Fix more

* Fix more -

* up

* up
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants