diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py index 49c07a504985..8b60c02929d6 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py @@ -128,6 +128,9 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): Whether to use the uniform-logSNR for step sizes proposed by Lu's DPM-Solver in the noise schedule during the sampling process. If `True`, the sigmas and time steps are determined according to a sequence of `lambda(t)`. + final_sigmas_type (`str`, defaults to `"zero"`): + The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final sigma + is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0. lambda_min_clipped (`float`, defaults to `-inf`): Clipping threshold for the minimum value of `lambda(t)` for numerical stability. This is critical for the cosine (`squaredcos_cap_v2`) noise schedule. @@ -165,11 +168,16 @@ def __init__( euler_at_final: bool = False, use_karras_sigmas: Optional[bool] = False, use_lu_lambdas: Optional[bool] = False, + final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min" lambda_min_clipped: float = -float("inf"), variance_type: Optional[str] = None, timestep_spacing: str = "linspace", steps_offset: int = 0, ): + if algorithm_type in ["dpmsolver", "sde-dpmsolver"]: + deprecation_message = f"algorithm_type {algorithm_type} is deprecated and will be removed in a future version. Choose from `dpmsolver++` or `sde-dpmsolver++` instead" + deprecate("algorithm_types dpmsolver and sde-dpmsolver", "1.0.0", deprecation_message) + if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) elif beta_schedule == "linear": @@ -207,6 +215,11 @@ def __init__( else: raise NotImplementedError(f"{solver_type} does is not implemented for {self.__class__}") + if algorithm_type not in ["dpmsolver++", "sde-dpmsolver++"] and final_sigmas_type == "zero": + raise ValueError( + f"`final_sigmas_type` {final_sigmas_type} is not supported for `algorithm_type` {algorithm_type}. Please choose `sigma_min` instead." + ) + # setable values self.num_inference_steps = None timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=np.float32)[::-1].copy() @@ -267,17 +280,24 @@ def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torc sigmas = np.flip(sigmas).copy() sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps) timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round() - sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32) elif self.config.use_lu_lambdas: lambdas = np.flip(log_sigmas.copy()) lambdas = self._convert_to_lu(in_lambdas=lambdas, num_inference_steps=num_inference_steps) sigmas = np.exp(lambdas) timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round() - sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32) else: sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) + + if self.config.final_sigmas_type == "sigma_min": sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5 - sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32) + elif self.config.final_sigmas_type == "zero": + sigma_last = 0 + else: + raise ValueError( + f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}" + ) + + sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32) self.sigmas = torch.from_numpy(sigmas) self.timesteps = torch.from_numpy(timesteps).to(device=device, dtype=torch.int64) @@ -831,7 +851,9 @@ def step( # Improve numerical stability for small number of steps lower_order_final = (self.step_index == len(self.timesteps) - 1) and ( - self.config.euler_at_final or (self.config.lower_order_final and len(self.timesteps) < 15) + self.config.euler_at_final + or (self.config.lower_order_final and len(self.timesteps) < 15) + or self.config.final_sigmas_type == "zero" ) lower_order_second = ( (self.step_index == len(self.timesteps) - 2) and self.config.lower_order_final and len(self.timesteps) < 15 diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py index 5d8f3fdf49cd..aea3c4c4b9f6 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py @@ -165,6 +165,10 @@ def __init__( timestep_spacing: str = "linspace", steps_offset: int = 0, ): + if algorithm_type in ["dpmsolver", "sde-dpmsolver"]: + deprecation_message = f"algorithm_type {algorithm_type} is deprecated and will be removed in a future version. Choose from `dpmsolver++` or `sde-dpmsolver++` instead" + deprecate("algorithm_types dpmsolver and sde-dpmsolver", "1.0.0", deprecation_message) + if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) elif beta_schedule == "linear": @@ -783,7 +787,6 @@ def _init_step_index(self, timestep): self._step_index = step_index - # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.step def step( self, model_output: torch.FloatTensor, diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py b/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py index dea033822e14..f4ee8b8bdc5d 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py @@ -108,7 +108,7 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin): The threshold value for dynamic thresholding. Valid only when `thresholding=True` and `algorithm_type="dpmsolver++"`. algorithm_type (`str`, defaults to `dpmsolver++`): - Algorithm type for the solver; can be `dpmsolver`, `dpmsolver++`, `sde-dpmsolver` or `sde-dpmsolver++`. The + Algorithm type for the solver; can be `dpmsolver` or `dpmsolver++`. The `dpmsolver` type implements the algorithms in the [DPMSolver](https://huggingface.co/papers/2206.00927) paper, and the `dpmsolver++` type implements the algorithms in the [DPMSolver++](https://huggingface.co/papers/2211.01095) paper. It is recommended to use `dpmsolver++` or @@ -122,6 +122,9 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin): use_karras_sigmas (`bool`, *optional*, defaults to `False`): Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`, the sigmas are determined according to a sequence of noise levels {σi}. + final_sigmas_type (`str`, *optional*, defaults to `"zero"`): + The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final sigma + is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0. lambda_min_clipped (`float`, defaults to `-inf`): Clipping threshold for the minimum value of `lambda(t)` for numerical stability. This is critical for the cosine (`squaredcos_cap_v2`) noise schedule. @@ -150,9 +153,14 @@ def __init__( solver_type: str = "midpoint", lower_order_final: bool = True, use_karras_sigmas: Optional[bool] = False, + final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min" lambda_min_clipped: float = -float("inf"), variance_type: Optional[str] = None, ): + if algorithm_type == "dpmsolver": + deprecation_message = "algorithm_type `dpmsolver` is deprecated and will be removed in a future version. Choose from `dpmsolver++` or `sde-dpmsolver++` instead" + deprecate("algorithm_types=dpmsolver", "1.0.0", deprecation_message) + if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) elif beta_schedule == "linear": @@ -189,6 +197,11 @@ def __init__( else: raise NotImplementedError(f"{solver_type} does is not implemented for {self.__class__}") + if algorithm_type != "dpmsolver++" and final_sigmas_type == "zero": + raise ValueError( + f"`final_sigmas_type` {final_sigmas_type} is not supported for `algorithm_type` {algorithm_type}. Please chooose `sigma_min` instead." + ) + # setable values self.num_inference_steps = None timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=np.float32)[::-1].copy() @@ -267,11 +280,18 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic sigmas = np.flip(sigmas).copy() sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps) timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round() - sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32) else: sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) + + if self.config.final_sigmas_type == "sigma_min": sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5 - sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32) + elif self.config.final_sigmas_type == "zero": + sigma_last = 0 + else: + raise ValueError( + f" `final_sigmas_type` must be one of `sigma_min` or `zero`, but got {self.config.final_sigmas_type}" + ) + sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32) self.sigmas = torch.from_numpy(sigmas).to(device=device) @@ -285,6 +305,12 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic ) self.register_to_config(lower_order_final=True) + if not self.config.lower_order_final and self.config.final_sigmas_type == "zero": + logger.warn( + " `last_sigmas_type='zero'` is not supported for `lower_order_final=False`. Changing scheduler {self.config} to have `lower_order_final` set to True." + ) + self.register_to_config(lower_order_final=True) + self.order_list = self.get_order_list(num_inference_steps) # add an index counter for schedulers that allow duplicated timesteps diff --git a/tests/schedulers/test_scheduler_dpm_multi.py b/tests/schedulers/test_scheduler_dpm_multi.py index 7fe71941b4e7..515cea5bc4ba 100644 --- a/tests/schedulers/test_scheduler_dpm_multi.py +++ b/tests/schedulers/test_scheduler_dpm_multi.py @@ -32,6 +32,7 @@ def get_scheduler_config(self, **kwargs): "euler_at_final": False, "lambda_min_clipped": -float("inf"), "variance_type": None, + "final_sigmas_type": "sigma_min", } config.update(**kwargs) diff --git a/tests/schedulers/test_scheduler_dpm_single.py b/tests/schedulers/test_scheduler_dpm_single.py index b1f14cb530e6..251a1506f39b 100644 --- a/tests/schedulers/test_scheduler_dpm_single.py +++ b/tests/schedulers/test_scheduler_dpm_single.py @@ -30,6 +30,7 @@ def get_scheduler_config(self, **kwargs): "solver_type": "midpoint", "lambda_min_clipped": -float("inf"), "variance_type": None, + "final_sigmas_type": "sigma_min", } config.update(**kwargs)