Skip to content

Commit 69b389a

Browse files
committed
Patch for duplicate timesteps in DPM Solver Single Step, until huggingface/diffusers#4231 gets resolved
1 parent 83d03d1 commit 69b389a

File tree

3 files changed

+29
-1
lines changed

3 files changed

+29
-1
lines changed

sdkit/models/model_loader/stable_diffusion/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,7 @@ def load_diffusers_model(context: Context, model_path, config_file_path, convert
128128
from sdkit.utils import gc, has_amd_gpu
129129

130130
from diffusers.pipelines.stable_diffusion.convert_from_ckpt import download_from_original_stable_diffusion_ckpt
131+
from . import diffusers_bugfixes
131132

132133
log.info("loading on diffusers")
133134

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
from diffusers.schedulers import DPMSolverSinglestepScheduler
2+
import numpy as np
3+
import torch
4+
5+
# temporary patch, while waiting for PR: https://github.com/huggingface/diffusers/pull/4231
6+
7+
old_set_timesteps = DPMSolverSinglestepScheduler.set_timesteps
8+
9+
10+
def set_timesteps_remove_duplicates(self, num_inference_steps: int, device=None):
11+
old_set_timesteps(self, num_inference_steps, device)
12+
13+
timesteps = self.timesteps.cpu().detach().numpy().astype(np.int64)
14+
15+
# when num_inference_steps == num_train_timesteps, we can end up with
16+
# duplicates in timesteps.
17+
_, unique_indices = np.unique(timesteps, return_index=True)
18+
timesteps = timesteps[np.sort(unique_indices)]
19+
20+
self.timesteps = torch.from_numpy(timesteps).to(device)
21+
22+
self.num_inference_steps = len(timesteps)
23+
24+
self.order_list = self.get_order_list(self.num_inference_steps)
25+
26+
27+
DPMSolverSinglestepScheduler.set_timesteps = set_timesteps_remove_duplicates

tests/samplers/test_samplers_txt2img.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def sampler_test():
3939
width=512,
4040
height=512,
4141
sampler_name=sampler_name,
42-
num_inference_steps=25,
42+
num_inference_steps=50,
4343
)[0]
4444

4545
expected_image = Image.open(f"{EXPECTED_DIR}/1.4-txt-{sampler_name}-42-512x512-50-cuda.png")

0 commit comments

Comments
 (0)