diff --git a/examples/diffusers/functional.py b/examples/diffusers/functional.py index 350736f..f5880d8 100755 --- a/examples/diffusers/functional.py +++ b/examples/diffusers/functional.py @@ -13,7 +13,7 @@ import skrample.scheduling as scheduling from skrample.common import predict_flow from skrample.diffusers import SkrampleWrapperScheduler -from skrample.sampling import functional, structured +from skrample.sampling import functional, models, structured from skrample.sampling.interface import StructuredFunctionalAdapter model_id = "black-forest-labs/FLUX.1-dev" @@ -68,7 +68,8 @@ def sample_callback(x: torch.Tensor, n: int, t: float, s: float) -> None: block_state["latents"] = sampler.sample_model( sample=block_state["latents"], - model=sampler.model_with_predictor(call_model, wrapper.predictor), + model=call_model, + model_transform=models.FlowModel, steps=block_state["num_inference_steps"], callback=sample_callback, ) diff --git a/examples/functional.py b/examples/functional.py index 3bb7152..f4559b3 100755 --- a/examples/functional.py +++ b/examples/functional.py @@ -7,10 +7,9 @@ from tqdm import tqdm from transformers.models.clip import CLIPTextModel, CLIPTokenizer -import skrample.common import skrample.pytorch.noise as noise import skrample.scheduling as scheduling -from skrample.sampling import functional, structured +from skrample.sampling import functional, models, structured from skrample.sampling.interface import StructuredFunctionalAdapter with torch.inference_mode(): @@ -57,8 +56,7 @@ def call_model(x: torch.Tensor, t: float, s: float) -> torch.Tensor: t, torch.cat([text_embeds, torch.zeros_like(text_embeds)]), ).sample.chunk(2) - p = conditioned + (cfg - 1) * (conditioned - unconditioned) - return skrample.common.predict_epsilon(x, p, s, schedule.sigma_transform) + return conditioned + (cfg - 1) * (conditioned - unconditioned) if isinstance(sampler, functional.FunctionalHigher): steps = sampler.adjust_steps(steps) @@ -67,6 +65,7 @@ def call_model(x: torch.Tensor, t: float, s: float) -> torch.Tensor: bar = tqdm(total=steps) sample = sampler.generate_model( model=call_model, + model_transform=models.EpsilonModel, steps=steps, rng=lambda: rng.generate().to(dtype=dtype, device=device), callback=lambda x, n, t, s: bar.update(n + 1 - bar.n), diff --git a/examples/predictions.py b/examples/predictions.py new file mode 100755 index 0000000..c0e2f44 --- /dev/null +++ b/examples/predictions.py @@ -0,0 +1,126 @@ +#! /usr/bin/env python + +import json + +import huggingface_hub as hf +import torch +from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL +from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel +from PIL import Image +from safetensors.torch import load_file +from tqdm import tqdm +from transformers.models.clip import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer + +import skrample.pytorch.noise as noise +import skrample.scheduling as scheduling +from skrample.common import Predictor, predict_epsilon, predict_sample, predict_velocity +from skrample.sampling import functional, models, structured +from skrample.sampling.interface import StructuredFunctionalAdapter + +with torch.inference_mode(): + device: torch.device = torch.device("cuda") + dtype: torch.dtype = torch.float16 + steps: int = 15 + cfg: float = 8 + seed = torch.Generator("cpu").manual_seed(0) + prompts = ["dreamy analog photograph of a kitten in a stained glass church", "blurry, noisy, cropped"] + + schedule = scheduling.Scaled() + + sampler_snr = StructuredFunctionalAdapter(schedule, structured.DPM(order=1)) + sampler_df = functional.RKUltra(schedule, order=1) + + base = "stabilityai/stable-diffusion-xl-base-1.0" + + tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(base, subfolder="tokenizer") + tokenizer_2: CLIPTokenizer = CLIPTokenizer.from_pretrained(base, subfolder="tokenizer_2") + text_encoder: CLIPTextModelWithProjection = CLIPTextModelWithProjection.from_pretrained( + base, subfolder="text_encoder", device_map=device, torch_dtype=dtype + ) + text_encoder_2: CLIPTextModel = CLIPTextModel.from_pretrained( + base, subfolder="text_encoder_2", device_map=device, torch_dtype=dtype + ) + image_encoder: AutoencoderKL = AutoencoderKL.from_pretrained( # type: ignore + base, subfolder="vae", device_map=device, torch_dtype=torch.float32 + ) + + text_embeds: torch.Tensor = text_encoder( + tokenizer(prompts, padding="max_length", return_tensors="pt").input_ids.to(device=device), + output_hidden_states=True, + ).hidden_states[-2] + te2_out = text_encoder_2( + tokenizer_2(prompts, padding="max_length", return_tensors="pt").input_ids.to(device=device), + output_hidden_states=True, + ) + text_embeds = torch.cat([text_embeds, te2_out.hidden_states[-2]], dim=-1) + pooled_embeds: torch.Tensor = te2_out.pooler_output + + time_embeds = text_embeds.new([[4096, 4096, 0, 0, 4096, 4096]]).repeat(2, 1) + + configs: tuple[tuple[models.ModelTransform, Predictor, str, str], ...] = ( + (models.EpsilonModel, predict_epsilon, base, ""), + (models.VelocityModel, predict_velocity, "terminusresearch/terminus-xl-velocity-v2", ""), + (models.XModel, predict_sample, "ByteDance/SDXL-Lightning", "sdxl_lightning_1step_unet_x0.safetensors"), + ) + + for transform, predictor, url, weights in configs: + model_steps = 1 if transform is models.XModel else steps + model_cfg = 1 if transform is models.XModel else cfg + + if weights: + model: UNet2DConditionModel = UNet2DConditionModel.from_config( # type: ignore + json.load(open(hf.hf_hub_download(base, "config.json", subfolder="unet"))), + device_map=device, + torch_dtype=dtype, + ) + model.load_state_dict(load_file(hf.hf_hub_download(url, weights))) + model = model.to(device=device, dtype=dtype) # pyright: ignore [reportCallIssue] + else: + model: UNet2DConditionModel = UNet2DConditionModel.from_pretrained( # type: ignore + url, subfolder="unet", device_map=device, torch_dtype=dtype + ) + + def call_model(x: torch.Tensor, t: float, s: float) -> torch.Tensor: + conditioned, unconditioned = model( + x.expand([x.shape[0] * 2, *x.shape[1:]]), + t, + text_embeds, + added_cond_kwargs={"text_embeds": pooled_embeds, "time_ids": time_embeds}, + ).sample.chunk(2) + return conditioned + (model_cfg - 1) * (conditioned - unconditioned) + + rng = noise.Random.from_inputs((1, 4, 128, 128), seed.clone_state()) + bar = tqdm(total=model_steps) + sample = sampler_snr.generate_model( + model=call_model, + model_transform=transform, + steps=model_steps, + rng=lambda: rng.generate().to(dtype=dtype, device=device), + callback=lambda x, n, t, s: bar.update(n + 1 - bar.n), + ) + + image: torch.Tensor = image_encoder.decode( + sample.to(dtype=image_encoder.dtype) / image_encoder.config.scaling_factor # type: ignore + ).sample[0] # type: ignore + Image.fromarray( + ((image + 1) * (255 / 2)).clamp(0, 255).permute(1, 2, 0).to(device="cpu", dtype=torch.uint8).numpy() + ).save(f"{predictor.__name__}.png") + + rng = noise.Random.from_inputs((1, 4, 128, 128), seed.clone_state()) + bar = tqdm(total=sampler_df.adjust_steps(model_steps)) + sample = sampler_df.generate_model( + model=call_model, + model_transform=transform, + steps=sampler_df.adjust_steps(model_steps), + rng=lambda: rng.generate().to(dtype=dtype, device=device), + callback=lambda x, n, t, s: bar.update(n + 1 - bar.n), + ) + + image: torch.Tensor = image_encoder.decode( + sample.to(dtype=image_encoder.dtype) / image_encoder.config.scaling_factor # type: ignore + ).sample[0] # type: ignore + Image.fromarray( + ((image + 1) * (255 / 2)).clamp(0, 255).permute(1, 2, 0).to(device="cpu", dtype=torch.uint8).numpy() + ).save(f"{transform.__name__}.png") + + model = model.to(device="meta") # pyright: ignore [reportCallIssue] diff --git a/scripts/plot_skrample.py b/scripts/plot_skrample.py index e51c5a9..97c4b2b 100755 --- a/scripts/plot_skrample.py +++ b/scripts/plot_skrample.py @@ -14,7 +14,7 @@ import skrample.scheduling as scheduling from skrample.common import SigmaTransform, sigma_complement, sigma_polar, spowf -from skrample.sampling import functional, structured +from skrample.sampling import functional, models, structured from skrample.sampling.interface import StructuredFunctionalAdapter OKLAB_XYZ_M1 = np.array( @@ -58,9 +58,9 @@ def colors(hue_steps: int) -> Generator[list[float]]: yield oklch_to_srgb(np.array([lighness_actual, chroma_actual, hue], dtype=np.float64)) -TRANSFORMS: dict[str, tuple[float, SigmaTransform]] = { - "polar": (14.6, sigma_polar), - "complement": (1.0, sigma_complement), +TRANSFORMS: dict[str, tuple[float, SigmaTransform, models.ModelTransform]] = { + "polar": (1.0, sigma_polar, models.EpsilonModel), + "complement": (1.0, sigma_complement, models.FlowModel), } SAMPLERS: dict[str, structured.StructuredSampler | functional.FunctionalSampler] = { "euler": structured.Euler(), @@ -195,7 +195,8 @@ def callback(x: float, n: int, t: float, s: float) -> None: sampler.sample_model( sample=sample, - model=lambda x, t, s: x + math.sin(t / schedule.base_timesteps * args.curve) * (s + 1), + model=lambda x, t, s: x - math.sin(t / schedule.base_timesteps * args.curve), + model_transform=TRANSFORMS[args.transform][2], steps=adjusted, rng=random, callback=callback, diff --git a/skrample/common.py b/skrample/common.py index a75eca9..f8984c7 100644 --- a/skrample/common.py +++ b/skrample/common.py @@ -143,6 +143,16 @@ def merge_noise[T: Sample](sample: T, noise: T, sigma: float, sigma_transform: S return sample * sigma_v + noise * sigma_u # type: ignore +def divf(lhs: float, rhs: float) -> float: + "Float division with infinity" + if rhs != 0: + return lhs / rhs + elif lhs == 0: + raise ZeroDivisionError + else: + return math.copysign(math.inf, lhs) + + def safe_log(x: float) -> float: "Returns inf rather than throw an err" try: diff --git a/skrample/sampling/functional.py b/skrample/sampling/functional.py index 4abb882..41e1249 100644 --- a/skrample/sampling/functional.py +++ b/skrample/sampling/functional.py @@ -12,6 +12,7 @@ from skrample.common import RNG, DictOrProxy, FloatSchedule, Predictor, Sample, SigmaTransform from . import tableaux +from .models import ModelTransform type SampleCallback[T: Sample] = Callable[[T, int, float, float], Any] "Return is ignored" @@ -38,31 +39,17 @@ def fractional_step( return result -def to_derivative_polar[T: Sample](sample: T, prediction: T, sigma: float, transform: SigmaTransform) -> T: - sigma_u, sigma_v = transform(sigma) - return (sample - (sigma_v * prediction)) / sigma_u # pyright: ignore [reportReturnType] - - -def to_derivative_complement[T: Sample](sample: T, prediction: T, sigma: float, transform: SigmaTransform) -> T: - return (sample - prediction) / sigma # pyright: ignore [reportReturnType] - - -type DerivativeTransform[T: Sample] = Callable[[T, T, float, SigmaTransform], T] - - -def step_tableau_derive[T: Sample]( +def step_tableau[T: Sample]( tableau: tableaux.Tableau | tableaux.ExtendedTableau, sample: T, model: SampleableModel[T], + model_transform: ModelTransform, step: int, schedule: FloatSchedule, - transform: SigmaTransform, - derivative_io: tuple[DerivativeTransform[T], DerivativeTransform[T]], + sigma_transform: SigmaTransform, step_size: int = 1, epsilon: float = 1e-8, ) -> tuple[T, ...]: - to_d, from_d = derivative_io - nodes, weights = tableau[0], tableau[1:] derivatives: list[T] = [] @@ -74,92 +61,34 @@ def step_tableau_derive[T: Sample]( for frac_sc, icoeffs in zip(fractions, (t[1] for t in nodes), strict=True): sigma_i = frac_sc[1] if icoeffs: - X: T = common.euler( # pyright: ignore [reportAssignmentType] + X: T = model_transform.forward( # pyright: ignore [reportAssignmentType] sample, - from_d( - sample, - math.sumprod(derivatives, icoeffs) / math.fsum(icoeffs), # pyright: ignore [reportArgumentType] - S0, - transform, - ), + math.sumprod(derivatives, icoeffs) / math.fsum(icoeffs), # pyright: ignore [reportArgumentType] S0, sigma_i, - transform, + sigma_transform, ) else: X = sample # Do not call model on timestep = 0 or sigma = 0 if any(abs(v) < epsilon for v in frac_sc): - derivatives.append(to_d(sample, X, S0, transform)) + derivatives.append(model_transform.backward(sample, X, S0, S1, sigma_transform)) else: - derivatives.append(to_d(X, model(X, *frac_sc), sigma_i, transform)) + derivatives.append(model(X, *frac_sc)) return tuple( # pyright: ignore [reportReturnType] - common.euler( + model_transform.forward( sample, - from_d( - sample, - math.sumprod(derivatives, w), # pyright: ignore [reportArgumentType] - S0, - transform, - ), + math.sumprod(derivatives, w), # pyright: ignore [reportArgumentType] S0, S1, - transform, + sigma_transform, ) for w in weights ) -def step_tableau[T: Sample]( - tableau: tableaux.Tableau | tableaux.ExtendedTableau, - sample: T, - model: SampleableModel[T], - step: int, - schedule: FloatSchedule, - transform: SigmaTransform, - step_size: int = 1, - epsilon: float = 1e-8, -) -> tuple[T, ...]: - if transform is common.sigma_complement: - return step_tableau_derive( - tableau, - sample, - model, - step, - schedule, - transform, - (to_derivative_complement, common.predict_flow), - step_size, - epsilon, - ) - elif transform is common.sigma_polar: - return step_tableau_derive( - tableau, - sample, - model, - step, - schedule, - transform, - (to_derivative_polar, common.predict_epsilon), - step_size, - epsilon, - ) - - return step_tableau_derive( - tableau, - sample, - model, - step, - schedule, - transform, - ((lambda x, p, s, t: p), (lambda x, p, s, t: p)), - step_size, - epsilon, - ) - - @dataclasses.dataclass(frozen=True) class FunctionalSampler(ABC): schedule: scheduling.SkrampleSchedule @@ -183,6 +112,7 @@ def sample_model[T: Sample]( self, sample: T, model: SampleableModel[T], + model_transform: ModelTransform, steps: int, include: slice = slice(None), rng: RNG[T] | None = None, @@ -195,6 +125,7 @@ def sample_model[T: Sample]( def generate_model[T: Sample]( self, model: SampleableModel[T], + model_transform: ModelTransform, rng: RNG[T], steps: int, include: slice = slice(None), @@ -215,7 +146,7 @@ def generate_model[T: Sample]( ) / self.merge_noise(0.0, 1.0, steps, 0) # Rescale sample by initial sigma. Mostly just to handle quirks with Scaled - return self.sample_model(sample, model, steps, include, rng, callback) + return self.sample_model(sample, model, model_transform, steps, include, rng, callback) @dataclasses.dataclass(frozen=True) @@ -242,6 +173,7 @@ def step[T: Sample]( self, sample: T, model: SampleableModel[T], + model_transform: ModelTransform, step: int, schedule: FloatSchedule, rng: RNG[T] | None = None, @@ -251,6 +183,7 @@ def sample_model[T: Sample]( self, sample: T, model: SampleableModel[T], + model_transform: ModelTransform, steps: int, include: slice = slice(None), rng: RNG[T] | None = None, @@ -259,7 +192,7 @@ def sample_model[T: Sample]( schedule: FloatSchedule = self.schedule.schedule(steps) for n in list(range(steps))[include]: - sample = self.step(sample, model, n, schedule, rng) + sample = self.step(sample, model, model_transform, n, schedule, rng) if callback: callback(sample, n, *schedule[n] if n < len(schedule) else (0, 0)) @@ -291,7 +224,7 @@ class RKUltra(FunctionalHigher, FunctionalSinglestep): providers: DictOrProxy[int, tableaux.TableauProvider[tableaux.Tableau | tableaux.ExtendedTableau]] = ( MappingProxyType( { - 2: tableaux.RK2.Ralston, + 2: tableaux.RK2.Heun, 3: tableaux.RK3.Ralston, 4: tableaux.RK4.Ralston, 5: tableaux.RK5.Nystrom, @@ -327,11 +260,20 @@ def step[T: Sample]( self, sample: T, model: SampleableModel[T], + model_transform: ModelTransform, step: int, schedule: FloatSchedule, rng: RNG[T] | None = None, ) -> T: - return step_tableau(self.tableau(), sample, model, step, schedule, self.schedule.sigma_transform)[0] + return step_tableau( + self.tableau(), + sample, + model, + model_transform, + step, + schedule, + self.schedule.sigma_transform, + )[0] @dataclasses.dataclass(frozen=True) @@ -355,22 +297,26 @@ def step[T: Sample]( self, sample: T, model: SampleableModel[T], + model_transform: ModelTransform, step: int, schedule: FloatSchedule, rng: RNG[T] | None = None, ) -> T: - dt, scale = common.scaled_delta_step(step, schedule, self.schedule.sigma_transform) - k1 = model(sample, *schedule[step]) - result: T = sample * scale + k1 * dt # type: ignore + sigma_from = schedule[step][1] + sigma_to = schedule[step + 1][1] if step + 1 < len(schedule) else 0 + result: T = model_transform.forward(sample, k1, sigma_from, sigma_to, self.schedule.sigma_transform) + + eta_t = model_transform.to_eta(sigma_from, self.schedule.sigma_transform) + eta_s = model_transform.to_eta(sigma_to, self.schedule.sigma_transform) + dt = eta_s - eta_t # Multiplying by step size here is kind of an asspull, but so is this whole solver so... - if ( - step + 1 < len(schedule) - and self.evaluator(sample, result) / max(self.evaluator(0, result), 1e-16) > self.threshold * dt - ): - k2 = model(result, *schedule[step + 1]) - result: T = sample * scale + (k1 + k2) / 2 * dt # type: ignore + if step + 1 < len(schedule) and self.evaluator(sample, result) / max( + self.evaluator(0, result), 1e-16 + ) > self.threshold * abs(dt): + k2: T = (k1 + model(result, *schedule[step + 1])) / 2 # pyright: ignore [reportAssignmentType] + result: T = model_transform.forward(sample, k2, sigma_from, sigma_to, self.schedule.sigma_transform) return result @@ -388,7 +334,7 @@ class RKMoire(FunctionalAdaptive, FunctionalHigher): """Providers for a given order, starting from 2. Falls back to RKE2.Heun""" - threshold: float = 1e-3 + threshold: float = 5e-5 initial: float = 1 / 50 "Percent of schedule to take as an initial step." @@ -428,6 +374,7 @@ def sample_model[T: Sample]( self, sample: T, model: SampleableModel[T], + model_transform: ModelTransform, steps: int, include: slice = slice(None), rng: RNG[T] | None = None, @@ -455,11 +402,18 @@ def sample_model[T: Sample]( if step_next < len(schedule): sample_high, sample_low = step_tableau( - tab, sample, model, step, schedule, self.schedule.sigma_transform, step_size + tab, sample, model, model_transform, step, schedule, self.schedule.sigma_transform, step_size ) - delta = common.scaled_delta_step(step, schedule, self.schedule.sigma_transform, step_size)[0] - delta_next = common.scaled_delta_step(step_next, schedule, self.schedule.sigma_transform, step_size)[0] + eta_t = model_transform.to_eta(schedule[step][1], self.schedule.sigma_transform) + eta_s = model_transform.to_eta(schedule[step_next][1], self.schedule.sigma_transform) + delta = abs(eta_s - eta_t) + + eta_s1 = model_transform.to_eta( + schedule[step_next + step_size][1] if step_next + step_size < len(schedule) else 0, + self.schedule.sigma_transform, + ) + delta_next = abs(eta_s1 - eta_s) # Normalize against pure error error = self.evaluator(sample_low, sample_high) / max(self.evaluator(0, sample_high), epsilon) @@ -476,7 +430,7 @@ def sample_model[T: Sample]( else: # Save the extra euler call since the 2nd weight isn't used sample_high = step_tableau( - tab[:2], sample, model, step, schedule, self.schedule.sigma_transform, step_size + tab[:2], sample, model, model_transform, step, schedule, self.schedule.sigma_transform, step_size )[0] sample = sample_high diff --git a/skrample/sampling/interface.py b/skrample/sampling/interface.py index cfb39d3..d2f83ce 100644 --- a/skrample/sampling/interface.py +++ b/skrample/sampling/interface.py @@ -1,7 +1,8 @@ import dataclasses from skrample.common import RNG, FloatSchedule, Sample -from skrample.sampling import functional, structured + +from . import functional, models, structured @dataclasses.dataclass(frozen=True) @@ -17,6 +18,7 @@ def sample_model[T: Sample]( self, sample: T, model: functional.SampleableModel[T], + model_transform: models.ModelTransform, steps: int, include: slice = slice(None), rng: RNG[T] | None = None, @@ -28,7 +30,8 @@ def sample_model[T: Sample]( for n in list(range(len(schedule)))[include]: timestep, sigma = schedule[n] - prediction = model(self.sampler.scale_input(sample, sigma, self.schedule.sigma_transform), timestep, sigma) + output = model(self.sampler.scale_input(sample, sigma, self.schedule.sigma_transform), timestep, sigma) + prediction = model_transform.to_x(sample, output, sigma, self.schedule.sigma_transform) sksamples = self.sampler.sample( sample, diff --git a/skrample/sampling/models.py b/skrample/sampling/models.py new file mode 100644 index 0000000..87d8134 --- /dev/null +++ b/skrample/sampling/models.py @@ -0,0 +1,169 @@ +import math + +from skrample.common import Sample, SigmaTransform, divf, predict_epsilon + + +class DiffusionModel: + """Implements euler method forward and backward through novel method described in https://diffusionflow.github.io/ + Base data type is X̂, or sample prediction""" + + @classmethod + def to_x[T: Sample](cls, sample: T, output: T, sigma: float, sigma_transform: SigmaTransform) -> T: + "output -> X̂" + return output + + @classmethod + def from_x[T: Sample](cls, sample: T, x: T, sigma: float, sigma_transform: SigmaTransform) -> T: + "X̂ -> output" + return x + + @classmethod + def to_z[T: Sample](cls, sample: T, sigma: float, sigma_transform: SigmaTransform) -> T: + "zₜ -> z̃ₜ" + sigma_t, _alpha_t = sigma_transform(sigma) + z_t = sample / sigma_t + return z_t # pyright: ignore [reportReturnType] + + @classmethod + def from_z[T: Sample](cls, z: T, sigma: float, sigma_transform: SigmaTransform) -> T: + "z̃ₜ -> zₜ" + sigma_t, _alpha_t = sigma_transform(sigma) + z_t = z * sigma_t + return z_t # pyright: ignore [reportReturnType] + + @classmethod + def to_eta(cls, sigma: float, sigma_transform: SigmaTransform) -> float: + "σₜ -> ηₜ" + sigma_t, alpha_t = sigma_transform(sigma) + eta_t = divf(alpha_t, sigma_t) + return eta_t + + @classmethod + def to_h(cls, sigma_from: float, sigma_to: float, sigma_transform: SigmaTransform) -> float: + "Shorthand for σₜ, σₛ -> ηₛ-ηₜ" + return cls.to_eta(sigma_to, sigma_transform) - cls.to_eta(sigma_from, sigma_transform) + + @classmethod + def forward[T: Sample]( + cls, sample: T, output: T, sigma_from: float, sigma_to: float, sigma_transform: SigmaTransform + ) -> T: + """Perform the Euler method. + z̃ₛ = z̃ₜ + output · (ηₛ - ηₜ) + Equation (5) @ https://diffusionflow.github.io/""" + if math.isinf(h := cls.to_h(sigma_from, sigma_to, sigma_transform)): + return cls.to_x(sample, output, sigma_from, sigma_transform) + else: + z_t = cls.to_z(sample, sigma_from, sigma_transform) + z_s = z_t + output * h + return cls.from_z(z_s, sigma_to, sigma_transform) # pyright: ignore [reportReturnType] + + @classmethod + def backward[T: Sample]( + cls, sample: T, result: T, sigma_from: float, sigma_to: float, sigma_transform: SigmaTransform + ) -> T: + """Undo the Euler method. + output = (z̃ₛ - z̃ₜ) / (ηₛ - ηₜ) + Equation (5) @ https://diffusionflow.github.io/""" + if math.isinf(h := cls.to_h(sigma_from, sigma_to, sigma_transform)): + return cls.from_x(sample, result, sigma_from, sigma_transform) + else: + z_t = cls.to_z(sample, sigma_from, sigma_transform) + z_s = cls.to_z(result, sigma_to, sigma_transform) + return (z_s - z_t) / h # pyright: ignore [reportReturnType] + + +class XModel(DiffusionModel): + "Equivalent to DiffusionModel, for type checking" + + +class EpsilonModel(DiffusionModel): + "Typically used with the variance-preserving (VP) noise schedule" + + @classmethod + def to_x[T: Sample](cls, sample: T, output: T, sigma: float, sigma_transform: SigmaTransform) -> T: + return predict_epsilon(sample, output, sigma, sigma_transform) + + @classmethod + def from_x[T: Sample](cls, sample: T, x: T, sigma: float, sigma_transform: SigmaTransform) -> T: + sigma_t, alpha_t = sigma_transform(sigma) + output = (sample - alpha_t * x) / sigma_t + return output # pyright: ignore [reportReturnType] + + @classmethod + def to_z[T: Sample](cls, sample: T, sigma: float, sigma_transform: SigmaTransform) -> T: + _sigma_t, alpha_t = sigma_transform(sigma) + z_t = sample / alpha_t + return z_t # pyright: ignore [reportReturnType] + + @classmethod + def from_z[T: Sample](cls, z: T, sigma: float, sigma_transform: SigmaTransform) -> T: + _sigma_t, alpha_t = sigma_transform(sigma) + z_t = z * alpha_t + return z_t # pyright: ignore [reportReturnType] + + @classmethod + def to_eta(cls, sigma: float, sigma_transform: SigmaTransform) -> float: + sigma_t, alpha_t = sigma_transform(sigma) + eta_t = divf(sigma_t, alpha_t) + return eta_t + + +class FlowModel(DiffusionModel): + "Typically used with the linear noise schedule" + + @classmethod + def to_x[T: Sample](cls, sample: T, output: T, sigma: float, sigma_transform: SigmaTransform) -> T: + sigma_t, alpha_t = sigma_transform(sigma) + return (sample - sigma_t * output) / (alpha_t + sigma_t) # pyright: ignore [reportReturnType] + + @classmethod + def from_x[T: Sample](cls, sample: T, x: T, sigma: float, sigma_transform: SigmaTransform) -> T: + sigma_t, alpha_t = sigma_transform(sigma) + return (sample - (alpha_t + sigma_t) * x) / sigma_t # pyright: ignore [reportReturnType] + + @classmethod + def to_z[T: Sample](cls, sample: T, sigma: float, sigma_transform: SigmaTransform) -> T: + sigma_t, alpha_t = sigma_transform(sigma) + return sample / (alpha_t + sigma_t) # pyright: ignore [reportReturnType] + + @classmethod + def from_z[T: Sample](cls, z: T, sigma: float, sigma_transform: SigmaTransform) -> T: + sigma_t, alpha_t = sigma_transform(sigma) + return z * (alpha_t + sigma_t) # pyright: ignore [reportReturnType] + + @classmethod + def to_eta(cls, sigma: float, sigma_transform: SigmaTransform) -> float: + sigma_t, alpha_t = sigma_transform(sigma) + return sigma_t / (alpha_t + sigma_t) + + +class VelocityModel(DiffusionModel): + """Typically used with the variance-preserving (VP) noise schedule. + Currently just converts output to X during forward and backward.""" + + @classmethod + def to_x[T: Sample](cls, sample: T, output: T, sigma: float, sigma_transform: SigmaTransform) -> T: + sigma_t, alpha_t = sigma_transform(sigma) + return alpha_t * sample - sigma_t * output # pyright: ignore [reportReturnType] + + @classmethod + def from_x[T: Sample](cls, sample: T, x: T, sigma: float, sigma_transform: SigmaTransform) -> T: + sigma_t, alpha_t = sigma_transform(sigma) + return (alpha_t * sample - x) / sigma_t # pyright: ignore [reportReturnType] + + @classmethod + def forward[T: Sample]( + cls, sample: T, output: T, sigma_from: float, sigma_to: float, sigma_transform: SigmaTransform + ) -> T: + output = cls.to_x(sample, output, sigma_from, sigma_transform) + return DiffusionModel.forward(sample, output, sigma_from, sigma_to, sigma_transform) + + @classmethod + def backward[T: Sample]( + cls, sample: T, result: T, sigma_from: float, sigma_to: float, sigma_transform: SigmaTransform + ) -> T: + output: T = DiffusionModel.backward(sample, result, sigma_from, sigma_to, sigma_transform) + return cls.from_x(sample, output, sigma_from, sigma_transform) + + +type ModelTransform = DiffusionModel | type[DiffusionModel] diff --git a/tests/diffusers_samplers.py b/tests/diffusers_samplers.py index 4d322a3..4abc577 100644 --- a/tests/diffusers_samplers.py +++ b/tests/diffusers_samplers.py @@ -17,6 +17,7 @@ from skrample.common import predict_flow as FLOW from skrample.common import predict_velocity as VELOCITY from skrample.sampling.functional import RKUltra +from skrample.sampling.models import EpsilonModel, FlowModel from skrample.sampling.structured import DPM, Euler, SKSamples, StructuredSampler, UniPC from skrample.sampling.tableaux import RK2 from skrample.scheduling import SkrampleSchedule @@ -198,7 +199,10 @@ def test_heun_scaled() -> None: margin = 1e-8 sigma_transform = sigma_polar - for predictor in [(EPSILON, "epsilon"), (VELOCITY, "v_prediction")]: + for predictor in [ + (EpsilonModel, "epsilon"), + # (EpsilonModel, "v_prediction"), # They do Heun on epsilon-hat not v-hat which we don't support yet. + ]: for steps in 2, 3, 30, 31, 200, 201: df: HeunDiscreteScheduler = HeunDiscreteScheduler.from_config(SCALED_CONFIG, prediction_type=predictor[1]) # type: ignore @@ -225,18 +229,18 @@ def test_heun_scaled() -> None: )[0] sk_sample = sk.generate_model( - sk.model_with_predictor(lambda x, t, s: fake_model(x), predictor[0]), + lambda x, t, s: fake_model(x), + predictor[0], lambda: torch.randn(sk_sample.shape, generator=seed, dtype=sk_sample.dtype), steps, initial=sk_sample, ) - compare_tensors(df_sample, sk_sample, message=f"{steps}", margin=margin) + compare_tensors(df_sample, sk_sample, message=f"{steps} {predictor[1]}", margin=margin) def test_heun_flow() -> None: margin = 1e-8 - predictor: Predictor = FLOW sigma_transform = sigma_complement for steps in 2, 3, 30, 31, 200, 201: df: FlowMatchHeunDiscreteScheduler = FlowMatchHeunDiscreteScheduler.from_config(FLOW_CONFIG) # type: ignore @@ -259,7 +263,8 @@ def test_heun_flow() -> None: df_sample: torch.Tensor = df.step(fake_model(df_sample), sample=df_sample, timestep=t)[0] # type: ignore sk_sample = sk.generate_model( - sk.model_with_predictor(lambda x, t, s: fake_model(x), predictor), + lambda x, t, s: fake_model(x), + FlowModel, lambda: torch.randn(sk_sample.shape, generator=seed, dtype=sk_sample.dtype), steps, initial=sk_sample, diff --git a/tests/miscellaneous.py b/tests/miscellaneous.py index 6b02360..2d7cf60 100644 --- a/tests/miscellaneous.py +++ b/tests/miscellaneous.py @@ -1,15 +1,29 @@ +import itertools import math import random from dataclasses import replace import numpy as np +import pytest import torch from testing_common import compare_tensors -from skrample.common import MergeStrategy, bashforth, sigma_complement, sigmoid, softmax, spowf +from skrample.common import ( + MergeStrategy, + SigmaTransform, + bashforth, + euler, + predict_flow, + sigma_complement, + sigma_polar, + sigmoid, + softmax, + spowf, +) from skrample.diffusers import SkrampleWrapperScheduler from skrample.sampling import tableaux from skrample.sampling.interface import StructuredFunctionalAdapter +from skrample.sampling.models import EpsilonModel, FlowModel, ModelTransform, VelocityModel, XModel from skrample.sampling.structured import ( DPM, SPC, @@ -51,6 +65,32 @@ def test_sigmas_to_timesteps() -> None: compare_tensors(torch.tensor(timesteps), torch.tensor(timesteps_inv), margin=0) # shocked this rounds good +@pytest.mark.parametrize( + ("model_transform", "sigma_transform"), + itertools.product([EpsilonModel, FlowModel, VelocityModel, XModel], [sigma_complement, sigma_polar]), +) +def test_model_transforms(model_transform: ModelTransform, sigma_transform: SigmaTransform) -> None: + sample = 0.8 + output = 0.3 + sigma = 0.2 + + x = model_transform.to_x(sample, output, sigma, sigma_transform) + o = model_transform.from_x(sample, x, sigma, sigma_transform) + assert abs(output - o) < 1e-12 + + z = model_transform.to_z(sample, sigma, sigma_transform) + s = model_transform.from_z(z, sigma, sigma_transform) + assert abs(sample - s) < 1e-12 + + sigma_next = 0.05 + for sigma_next in 0.05, 0: # extra 0 to validate X̂ + snr = euler( + sample, model_transform.to_x(sample, output, sigma, sigma_transform), sigma, sigma_next, sigma_transform + ) + df = model_transform.forward(sample, output, sigma, sigma_next, sigma_transform) + assert abs(snr - df) < 1e-12 + + def test_sampler_generics() -> None: eps = 1e-12 for sampler in [ @@ -192,7 +232,7 @@ def fake_model(x: float, _: float, s: float) -> float: noise = [random.random() for _ in range(steps)] rng = iter(noise) - sample_f = adapter.sample_model(sample, fake_model, steps, rng=lambda: next(rng)) + sample_f = adapter.sample_model(sample, fake_model, FlowModel, steps, rng=lambda: next(rng)) rng = iter(noise) float_schedule = schedule.schedule(steps) @@ -201,7 +241,7 @@ def fake_model(x: float, _: float, s: float) -> float: for n, (t, s) in enumerate(float_schedule): results = sampler.sample( sample_s, - fake_model(sample_s, t, s), + predict_flow(sample_s, fake_model(sample_s, t, s), s, schedule.sigma_transform), n, float_schedule, schedule.sigma_transform,