Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Schedulers] Add exponential sigmas / exponential noise schedule #9499

Merged
merged 4 commits into from
Sep 23, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions src/diffusers/schedulers/scheduling_euler_discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,8 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
use_karras_sigmas (`bool`, *optional*, defaults to `False`):
Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`,
the sigmas are determined according to a sequence of noise levels {σi}.
use_exponential_sigmas (`bool`, *optional*, defaults to `False`):
Whether to use exponential sigmas for step sizes in the noise schedule during the sampling process.
timestep_spacing (`str`, defaults to `"linspace"`):
The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
Expand Down Expand Up @@ -186,6 +188,7 @@ def __init__(
prediction_type: str = "epsilon",
interpolation_type: str = "linear",
use_karras_sigmas: Optional[bool] = False,
use_exponential_sigmas: Optional[bool] = False,
sigma_min: Optional[float] = None,
sigma_max: Optional[float] = None,
timestep_spacing: str = "linspace",
Expand Down Expand Up @@ -235,6 +238,7 @@ def __init__(

self.is_scale_input_called = False
self.use_karras_sigmas = use_karras_sigmas
self.use_exponential_sigmas = use_exponential_sigmas

self._step_index = None
self._begin_index = None
Expand Down Expand Up @@ -332,6 +336,12 @@ def set_timesteps(
raise ValueError("Can only pass one of `num_inference_steps` or `timesteps` or `sigmas`.")
if timesteps is not None and self.config.use_karras_sigmas:
raise ValueError("Cannot set `timesteps` with `config.use_karras_sigmas = True`.")
if timesteps is not None and self.config.use_exponential_sigmas:
raise ValueError("Cannot set `timesteps` with `config.use_exponential_sigmas = True`.")
hlky marked this conversation as resolved.
Show resolved Hide resolved
if self.config.use_exponential_sigmas and self.config.use_karras_sigmas:
raise ValueError(
"Cannot set both `config.use_exponential_sigmas = True` and config.use_karras_sigmas = True`"
)
if (
timesteps is not None
and self.config.timestep_type == "continuous"
Expand Down Expand Up @@ -396,6 +406,10 @@ def set_timesteps(
sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps)
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])

elif self.config.use_exponential_sigmas:
sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps)
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])

if self.config.final_sigmas_type == "sigma_min":
sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5
elif self.config.final_sigmas_type == "zero":
Expand Down Expand Up @@ -468,6 +482,28 @@ def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> to
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
return sigmas

# Copied from https://github.com/crowsonkb/k-diffusion/blob/686dbad0f39640ea25c8a8c6a6e56bb40eacefa2/k_diffusion/sampling.py#L26
def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps: int) -> torch.Tensor:
"""Constructs an exponential noise schedule."""

# Hack to make sure that other schedulers which copy this function don't break
# TODO: Add this logic to the other schedulers
if hasattr(self.config, "sigma_min"):
sigma_min = self.config.sigma_min
else:
sigma_min = None

if hasattr(self.config, "sigma_max"):
sigma_max = self.config.sigma_max
else:
sigma_max = None

sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()

sigmas = torch.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps).exp()
return sigmas

def index_for_timestep(self, timestep, schedule_timesteps=None):
if schedule_timesteps is None:
schedule_timesteps = self.timesteps
Expand Down
Loading