diff --git a/examples/diffusers/functional.py b/examples/diffusers/functional.py new file mode 100755 index 0000000..5820d81 --- /dev/null +++ b/examples/diffusers/functional.py @@ -0,0 +1,102 @@ +#! /usr/bin/env python + +from typing import ClassVar + +import torch +from diffusers.modular_pipelines.components_manager import ComponentsManager +from diffusers.modular_pipelines.flux.denoise import FluxDenoiseStep, FluxLoopDenoiser +from diffusers.modular_pipelines.flux.modular_blocks import TEXT2IMAGE_BLOCKS +from diffusers.modular_pipelines.flux.modular_pipeline import FluxModularPipeline +from diffusers.modular_pipelines.modular_pipeline import ModularPipelineBlocks, PipelineState, SequentialPipelineBlocks +from tqdm import tqdm + +import skrample.scheduling as scheduling +from skrample.diffusers import SkrampleWrapperScheduler +from skrample.sampling import functional, models, structured +from skrample.sampling.interface import StructuredFunctionalAdapter + +model_id = "black-forest-labs/FLUX.1-dev" + +blocks = SequentialPipelineBlocks.from_blocks_dict(TEXT2IMAGE_BLOCKS) + +schedule = scheduling.FlowShift(scheduling.Linear(), shift=2) +wrapper = SkrampleWrapperScheduler( + sampler=structured.Euler(), schedule=schedule, model=models.FlowModel(), allow_dynamic=False +) + +# Equivalent to structured example +sampler = StructuredFunctionalAdapter(schedule, structured.DPM(order=2, add_noise=True)) +# Native functional example +sampler = functional.RKUltra(schedule, 4) +# Dynamic model calls +sampler = functional.FastHeun(schedule) +# Dynamic step sizes +sampler = functional.RKMoire(schedule) + + +class FunctionalDenoise(FluxDenoiseStep): + # Exclude the after_denoise block + block_classes: ClassVar[list[type[ModularPipelineBlocks]]] = [FluxLoopDenoiser] + block_names: ClassVar[list[str]] = ["denoiser"] + + @torch.no_grad() + def __call__(self, components: FluxModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + if isinstance(sampler, functional.FunctionalHigher): + block_state["num_inference_steps"] = sampler.adjust_steps(block_state["num_inference_steps"]) + progress = tqdm(total=block_state["num_inference_steps"]) + + i = 0 + + def call_model(sample: torch.Tensor, timestep: float, sigma: float) -> torch.Tensor: + nonlocal i, components, block_state, progress + block_state["latents"] = sample + components, block_state = self.loop_step( + components, + block_state, # type: ignore + i=i, + t=sample.new_tensor([timestep] * len(sample)), + ) + return block_state["noise_pred"] # type: ignore + + def sample_callback(x: torch.Tensor, n: int, t: float, s: float) -> None: + nonlocal i + progress.update(n + 1 - progress.n) + i = n + 1 + + block_state["latents"] = sampler.sample_model( + sample=block_state["latents"], + model=call_model, + model_transform=models.FlowModel(), + steps=block_state["num_inference_steps"], + callback=sample_callback, + ) + + self.set_block_state(state, block_state) # type: ignore + return components, state # type: ignore + + +blocks.sub_blocks["denoise"] = FunctionalDenoise() + +cm = ComponentsManager() +cm.enable_auto_cpu_offload() +pipe = blocks.init_pipeline(components_manager=cm) +pipe.load_components(["text_encoder"], repo=model_id, subfolder="text_encoder", torch_dtype=torch.bfloat16) +pipe.load_components(["tokenizer"], repo=model_id, subfolder="tokenizer") +pipe.load_components(["text_encoder_2"], repo=model_id, subfolder="text_encoder_2", torch_dtype=torch.bfloat16) +pipe.load_components(["tokenizer_2"], repo=model_id, subfolder="tokenizer_2") +pipe.load_components(["transformer"], repo=model_id, subfolder="transformer", torch_dtype=torch.bfloat16) +pipe.load_components(["vae"], repo=model_id, subfolder="vae", torch_dtype=torch.bfloat16) + +pipe.register_components(scheduler=wrapper) + + +pipe( # type: ignore + prompt="sharp, high dynamic range photograph of a kitten on a beach of rainbow pebbles", + generator=torch.Generator("cpu").manual_seed(42), + width=1024, + height=1024, + num_inference_steps=25, + guidance_scale=2.5, +).get("images")[0].save("diffusers_functional.png") diff --git a/examples/diffusers/wrapper.py b/examples/diffusers/wrapper.py index 9565691..1d064e7 100755 --- a/examples/diffusers/wrapper.py +++ b/examples/diffusers/wrapper.py @@ -4,10 +4,10 @@ from diffusers.pipelines.flux.pipeline_flux import FluxPipeline import skrample.pytorch.noise as sknoise -import skrample.sampling as sampling +import skrample.sampling.structured as sampling import skrample.scheduling as scheduling -from skrample.common import predict_flow from skrample.diffusers import SkrampleWrapperScheduler +from skrample.sampling.models import FlowModel pipe: FluxPipeline = FluxPipeline.from_pretrained( # type: ignore "black-forest-labs/FLUX.1-dev", @@ -17,7 +17,7 @@ pipe.scheduler = scheduler = SkrampleWrapperScheduler( sampler=sampling.DPM(order=2, add_noise=True), schedule=scheduling.FlowShift(scheduling.Linear(), shift=2.0), - predictor=predict_flow, + model=FlowModel(), noise_type=sknoise.Brownian, allow_dynamic=False, ) diff --git a/examples/diffusers/wrapper_from.py b/examples/diffusers/wrapper_from.py index e009b20..f33ad55 100755 --- a/examples/diffusers/wrapper_from.py +++ b/examples/diffusers/wrapper_from.py @@ -4,7 +4,7 @@ from diffusers.pipelines.flux.pipeline_flux import FluxPipeline import skrample.pytorch.noise as sknoise -import skrample.sampling as sampling +import skrample.sampling.structured as sampling from skrample.diffusers import SkrampleWrapperScheduler pipe: FluxPipeline = FluxPipeline.from_pretrained( # type: ignore diff --git a/examples/functional.py b/examples/functional.py new file mode 100755 index 0000000..63cc269 --- /dev/null +++ b/examples/functional.py @@ -0,0 +1,77 @@ +#! /usr/bin/env python + +import torch +from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL +from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel +from PIL import Image +from tqdm import tqdm +from transformers.models.clip import CLIPTextModel, CLIPTokenizer + +import skrample.pytorch.noise as noise +import skrample.scheduling as scheduling +from skrample.sampling import functional, models, structured +from skrample.sampling.interface import StructuredFunctionalAdapter + +with torch.inference_mode(): + device: torch.device = torch.device("cuda") + dtype: torch.dtype = torch.float16 + url: str = "Lykon/dreamshaper-8" + seed = torch.Generator("cpu").manual_seed(0) + steps: int = 25 + cfg: float = 3 + + schedule = scheduling.Karras(scheduling.Scaled()) + + # Equivalent to structured example + sampler = StructuredFunctionalAdapter(schedule, structured.DPM(order=2, add_noise=True)) + # Native functional example + sampler = functional.RKUltra(schedule, 4) + # Dynamic model calls + sampler = functional.FastHeun(schedule) + # Dynamic step sizes + sampler = functional.RKMoire(schedule) + + tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(url, subfolder="tokenizer") + text_encoder: CLIPTextModel = CLIPTextModel.from_pretrained( + url, subfolder="text_encoder", device_map=device, torch_dtype=dtype + ) + model: UNet2DConditionModel = UNet2DConditionModel.from_pretrained( # type: ignore + url, subfolder="unet", device_map=device, torch_dtype=dtype + ) + image_encoder: AutoencoderKL = AutoencoderKL.from_pretrained( # type: ignore + url, subfolder="vae", device_map=device, torch_dtype=dtype + ) + + text_embeds: torch.Tensor = text_encoder( + tokenizer( + "bright colorful fantasy art of a kitten in a field of rainbow flowers", + padding="max_length", + return_tensors="pt", + ).input_ids.to(device=device) + ).last_hidden_state + + def call_model(x: torch.Tensor, t: float, s: float) -> torch.Tensor: + conditioned, unconditioned = model( + x.expand([x.shape[0] * 2, *x.shape[1:]]), + t, + torch.cat([text_embeds, torch.zeros_like(text_embeds)]), + ).sample.chunk(2) + return conditioned + (cfg - 1) * (conditioned - unconditioned) + + if isinstance(sampler, functional.FunctionalHigher): + steps = sampler.adjust_steps(steps) + + rng = noise.Random.from_inputs((1, 4, 80, 80), seed) + bar = tqdm(total=steps) + sample = sampler.generate_model( + model=call_model, + model_transform=models.NoiseModel(), + steps=steps, + rng=lambda: rng.generate().to(dtype=dtype, device=device), + callback=lambda x, n, t, s: bar.update(n + 1 - bar.n), + ) + + image: torch.Tensor = image_encoder.decode(sample / image_encoder.config.scaling_factor).sample[0] # type: ignore + Image.fromarray( + ((image + 1) * (255 / 2)).clamp(0, 255).permute(1, 2, 0).to(device="cpu", dtype=torch.uint8).numpy() + ).save("functional.png") diff --git a/examples/raw.py b/examples/structured.py similarity index 82% rename from examples/raw.py rename to examples/structured.py index 2bc9580..dffb601 100755 --- a/examples/raw.py +++ b/examples/structured.py @@ -7,9 +7,9 @@ from tqdm import tqdm from transformers.models.clip import CLIPTextModel, CLIPTokenizer -import skrample.common -import skrample.sampling as sampling +import skrample.sampling.structured as structured import skrample.scheduling as scheduling +from skrample.sampling import models with torch.inference_mode(): device: torch.device = torch.device("cuda") @@ -20,8 +20,8 @@ cfg: float = 3 schedule: scheduling.SkrampleSchedule = scheduling.Karras(scheduling.Scaled()) - sampler: sampling.SkrampleSampler = sampling.DPM(order=2, add_noise=True) - predictor: skrample.common.Predictor = skrample.common.predict_epsilon + sampler: structured.StructuredSampler = structured.DPM(order=2, add_noise=True) + transform: models.DiffusionModel = models.NoiseModel() tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(url, subfolder="tokenizer") text_encoder: CLIPTextModel = CLIPTextModel.from_pretrained( @@ -43,9 +43,10 @@ ).last_hidden_state sample: torch.Tensor = torch.randn([1, 4, 80, 80], generator=seed).to(dtype=dtype, device=device) - previous: list[sampling.SKSamples[torch.Tensor]] = [] + previous: list[structured.SKSamples[torch.Tensor]] = [] + float_schedule = schedule.schedule(steps) - for n, (timestep, sigma) in enumerate(tqdm(schedule.schedule(steps))): + for n, (timestep, sigma) in enumerate(tqdm(float_schedule)): conditioned, unconditioned = model( sample.expand([sample.shape[0] * 2, *sample.shape[1:]]), timestep, @@ -53,13 +54,13 @@ ).sample.chunk(2) model_output: torch.Tensor = conditioned + (cfg - 1) * (conditioned - unconditioned) - prediction = predictor(sample, model_output, sigma, schedule.sigma_transform) + prediction = transform.to_x(sample, model_output, sigma, schedule.sigma_transform) sampler_output = sampler.sample( sample=sample, prediction=prediction, step=n, - sigma_schedule=schedule.sigmas(steps), + schedule=float_schedule, sigma_transform=schedule.sigma_transform, noise=torch.randn(sample.shape, generator=seed).to(dtype=sample.dtype, device=sample.device), previous=tuple(previous), @@ -71,4 +72,4 @@ image: torch.Tensor = image_encoder.decode(sample / image_encoder.config.scaling_factor).sample[0] # type: ignore Image.fromarray( ((image + 1) * (255 / 2)).clamp(0, 255).permute(1, 2, 0).to(device="cpu", dtype=torch.uint8).numpy() - ).save("raw.png") + ).save("structured.png") diff --git a/pyproject.toml b/pyproject.toml index 2a32793..e82ebc7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,7 +29,7 @@ scripts = ["skrample[all]", "matplotlib>=3.10.1"] test = [ "skrample[scripts]", "accelerate>=1.3", - "diffusers>=0.32", + "diffusers>=0.35", "protobuf>=5.29", "pyright>=1.1.400", "pytest-xdist>=3.6.1", diff --git a/scripts/overhead.py b/scripts/overhead.py index 3bff208..a513892 100755 --- a/scripts/overhead.py +++ b/scripts/overhead.py @@ -5,12 +5,12 @@ import torch from skrample.diffusers import SkrampleWrapperScheduler -from skrample.sampling import DPM +from skrample.sampling.structured import Euler from skrample.scheduling import Beta, FlowShift, SigmoidCDF def bench_wrapper() -> int: - wrapper = SkrampleWrapperScheduler(DPM(), Beta(FlowShift(SigmoidCDF()))) + wrapper = SkrampleWrapperScheduler(Euler(), Beta(FlowShift(SigmoidCDF()))) wrapper.set_timesteps(1000) clock = perf_counter_ns() diff --git a/scripts/plot_skrample.py b/scripts/plot_skrample.py index a432121..39c572d 100755 --- a/scripts/plot_skrample.py +++ b/scripts/plot_skrample.py @@ -1,7 +1,7 @@ #! /usr/bin/env python import math -from argparse import ArgumentParser +from argparse import ArgumentParser, BooleanOptionalAction from collections.abc import Generator from dataclasses import replace from pathlib import Path @@ -12,9 +12,10 @@ import numpy as np from numpy.typing import NDArray -import skrample.sampling as sampling import skrample.scheduling as scheduling from skrample.common import SigmaTransform, sigma_complement, sigma_polar, spowf +from skrample.sampling import functional, models, structured +from skrample.sampling.interface import StructuredFunctionalAdapter OKLAB_XYZ_M1 = np.array( [ @@ -57,21 +58,24 @@ def colors(hue_steps: int) -> Generator[list[float]]: yield oklch_to_srgb(np.array([lighness_actual, chroma_actual, hue], dtype=np.float64)) -TRANSFORMS: dict[str, SigmaTransform] = { - "polar": sigma_polar, - "complement": sigma_complement, +TRANSFORMS: dict[str, tuple[float, SigmaTransform, models.DiffusionModel]] = { + "polar": (1.0, sigma_polar, models.NoiseModel()), + "complement": (1.0, sigma_complement, models.FlowModel()), } -SAMPLERS: dict[str, sampling.SkrampleSampler] = { - "euler": sampling.Euler(), - "adams": sampling.Adams(), - "dpm": sampling.DPM(), - "unip": sampling.UniP(), - "unipc": sampling.UniPC(), - "spc": sampling.SPC(), +SAMPLERS: dict[str, structured.StructuredSampler | functional.FunctionalSampler] = { + "euler": structured.Euler(), + "adams": structured.Adams(), + "dpm": structured.DPM(), + "unip": structured.UniP(), + "unipc": structured.UniPC(), + "spc": structured.SPC(), + "rku": functional.RKUltra(scheduling.Linear()), + "rkm": functional.RKMoire(scheduling.Linear()), + "fheun": functional.FastHeun(scheduling.Linear()), } for k, v in list(SAMPLERS.items()): - if isinstance(v, sampling.HighOrderSampler): - for o in range(1, v.max_order() + 1): + if isinstance(v, structured.StructuredMultistep | functional.FunctionalHigher): + for o in range(v.min_order(), min(v.max_order() + 1, 9)): if o != v.order: SAMPLERS[k + str(o)] = replace(v, order=o) @@ -88,7 +92,6 @@ def colors(hue_steps: int) -> Generator[list[float]]: "exponential": (scheduling.Exponential, {}), "karras": (scheduling.Karras, {}), "flow": (scheduling.FlowShift, {}), - "flow_mu": (scheduling.FlowShift, {"mu": 1}), "hyper": (scheduling.Hyper, {}), "vyper": (scheduling.Hyper, {"scale": -2}), "hype": (scheduling.Hyper, {"tail": False}), @@ -100,12 +103,13 @@ def colors(hue_steps: int) -> Generator[list[float]]: # Common parser = ArgumentParser() parser.add_argument("file", type=Path) -parser.add_argument("--steps", "-s", type=int, default=20) +parser.add_argument("--steps", "-s", type=int, default=25) subparsers = parser.add_subparsers(dest="command") # Samplers parser_sampler = subparsers.add_parser("samplers") -parser_sampler.add_argument("--curve", "-k", type=int, default=10) +parser_sampler.add_argument("--adjust", type=bool, default=True, action=BooleanOptionalAction) +parser_sampler.add_argument("--curve", "-k", type=int, default=30) parser_sampler.add_argument("--transform", "-t", type=str, choices=list(TRANSFORMS.keys()), default="polar") parser_sampler.add_argument( "--sampler", @@ -154,40 +158,61 @@ def colors(hue_steps: int) -> Generator[list[float]]: plt.ylabel("Sample") plt.title("Skrample Samplers") - schedule = scheduling.Linear(base_timesteps=10_000) + schedule = scheduling.Hyper( + scheduling.Linear( + sigma_start=TRANSFORMS[args.transform][0], + base_timesteps=10_000, + custom_transform=TRANSFORMS[args.transform][1], + ), + -2, + False, + ) + + def sample_model( + sampler: structured.StructuredSampler | functional.FunctionalSampler, steps: int + ) -> tuple[list[float], list[float]]: + if isinstance(sampler, structured.StructuredSampler): + sampler = StructuredFunctionalAdapter(schedule, sampler) + else: + sampler = replace(sampler, schedule=schedule) - def sample_model(sampler: sampling.SkrampleSampler, schedule: NDArray[np.float64]) -> list[float]: - previous: list[sampling.SKSamples] = [] sample = 1.0 sampled_values = [sample] - for step, sigma in enumerate(schedule): - result = sampler.sample( - sample=sample, - prediction=math.sin(sigma * args.curve), - step=step, - sigma_schedule=schedule, - sigma_transform=TRANSFORMS[args.transform], - previous=tuple(previous), - noise=random(), - ) - previous.append(result) - sample = result.final - sampled_values.append(sample) - return sampled_values - - plt.plot( - [*schedule.sigmas(schedule.base_timesteps), 0], - sample_model(sampling.Euler(), schedule.sigmas(schedule.base_timesteps)), - label="Reference", - color=next(COLORS), - ) + timesteps = [0.0] + + def callback(x: float, n: int, t: float, s: float) -> None: + nonlocal sampled_values, timesteps + sampled_values.append(x) + timesteps.insert(-1, t / schedule.base_timesteps) + + if isinstance(sampler, functional.RKMoire) and args.adjust: + adjusted = schedule.base_timesteps + elif isinstance(sampler, functional.FunctionalHigher) and args.adjust: + adjusted = sampler.adjust_steps(steps) + else: + adjusted = steps + + sampler.sample_model( + sample=sample, + model=lambda x, t, s: x - math.sin(t / schedule.base_timesteps * args.curve), + model_transform=TRANSFORMS[args.transform][2], + steps=adjusted, + rng=random, + callback=callback, + ) + + return timesteps, sampled_values + + plt.plot(*sample_model(structured.Euler(), schedule.base_timesteps), label="Reference", color=next(COLORS)) for sampler in [SAMPLERS[s] for s in args.sampler]: - sigmas = schedule.sigmas(args.steps) label = type(sampler).__name__ - if isinstance(sampler, sampling.HighOrderSampler) and sampler.order != type(sampler).order: + if ( + isinstance(sampler, structured.StructuredMultistep | functional.FunctionalHigher) + and sampler.order != type(sampler).order + ): label += " " + str(sampler.order) - plt.plot([*sigmas, 0], sample_model(sampler, sigmas), label=label, color=next(COLORS), linestyle="--") + plt.plot(*sample_model(sampler, args.steps), label=label, color=next(COLORS), linestyle="--") elif args.command == "schedules": plt.xlabel("Step") @@ -205,12 +230,12 @@ def sample_model(sampler: sampling.SkrampleSampler, schedule: NDArray[np.float64 for mod_label, (mod_type, mod_props) in [ # type: ignore # Destructure m for m in [(mod1, MODIFIERS[mod1]), (mod2, MODIFIERS[mod2])] if m[1] ]: - composed = mod_type(schedule, **mod_props) + composed = mod_type(composed, **mod_props) label += "_" + mod_label label = " ".join([s.capitalize() for s in label.split("_")]) - data = np.concatenate([composed.schedule(args.steps), [[0, 0]]], dtype=np.float64) + data = np.concatenate([composed.schedule_np(args.steps), [[0, 0]]], dtype=np.float64) timesteps = data[:, 0] / composed.base_timesteps sigmas = data[:, 1] / data[:, 1].max() diff --git a/scripts/spc.py b/scripts/spc.py index 3e4837b..991a9fe 100755 --- a/scripts/spc.py +++ b/scripts/spc.py @@ -9,9 +9,9 @@ import numpy as np from numpy.typing import NDArray -import skrample.sampling as sampling +import skrample.sampling.structured as sampling import skrample.scheduling as scheduling -from skrample.common import SigmaTransform, sigma_complement, sigma_polar +from skrample.common import FloatSchedule, SigmaTransform, sigma_complement, sigma_polar parser = ArgumentParser() parser.add_argument("out", type=FileType("w")) @@ -33,17 +33,17 @@ class Row: def sample_model( - sampler: sampling.SkrampleSampler, schedule: NDArray[np.float64], curve: int, transform: SigmaTransform + sampler: sampling.StructuredSampler, schedule: FloatSchedule, curve: int, transform: SigmaTransform ) -> NDArray: previous: list[sampling.SKSamples] = [] sample = 1.0 sampled_values = [sample] - for step, sigma in enumerate(schedule): + for step, (timestep, sigma) in enumerate(schedule): result = sampler.sample( sample=sample, prediction=math.sin(sigma * curve), step=step, - sigma_schedule=schedule, + schedule=schedule, sigma_transform=transform, previous=tuple(previous), noise=random(), @@ -54,9 +54,9 @@ def sample_model( return np.array(sampled_values) -samplers: set[sampling.SkrampleSampler] = {sampling.Euler(), sampling.Adams(order=2), sampling.DPM(order=2)} +samplers: set[sampling.StructuredSampler] = {sampling.Euler(), sampling.Adams(order=2), sampling.DPM(order=2)} for v in samplers.copy(): - if isinstance(v, sampling.HighOrderSampler): + if isinstance(v, sampling.StructuredMultistep): for o in range(2, v.max_order() + 1): samplers.add(replace(v, order=o)) @@ -65,17 +65,17 @@ def sample_model( table: list[Row] = [] for t in [sigma_polar, sigma_complement]: for k in args.curves: - reference = sample_model(sampling.Euler(), schedule.sigmas(schedule.base_timesteps), k, t) + reference = sample_model(sampling.Euler(), schedule.schedule(schedule.base_timesteps), k, t) for h in args.steps: reference_aliased = np.interp(np.linspace(0, 1, h + 1), np.linspace(0, 1, len(reference)), reference) for pe in samplers: for ce in samplers: spc = sampling.SPC(predictor=pe, corrector=ce) - sampled = sample_model(spc, schedule.sigmas(h), k, t) + sampled = sample_model(spc, schedule.schedule(h), k, t) table.append( Row( - type(pe).__name__ + (str(pe.order) if isinstance(pe, sampling.HighOrderSampler) else ""), - type(ce).__name__ + (str(ce.order) if isinstance(ce, sampling.HighOrderSampler) else ""), + type(pe).__name__ + (str(pe.order) if isinstance(pe, sampling.StructuredMultistep) else ""), + type(ce).__name__ + (str(ce.order) if isinstance(ce, sampling.StructuredMultistep) else ""), t.__name__, k, h, diff --git a/skrample/common.py b/skrample/common.py index 4f9adaa..137ec25 100644 --- a/skrample/common.py +++ b/skrample/common.py @@ -1,8 +1,9 @@ import enum import math -from collections.abc import Callable +from collections.abc import Callable, Sequence from functools import lru_cache from itertools import repeat +from types import MappingProxyType from typing import TYPE_CHECKING import numpy as np @@ -20,8 +21,15 @@ type SigmaTransform = Callable[[float], tuple[float, float]] "Transforms a single noise sigma into a pair" -type Predictor[S: Sample] = Callable[[S, S, float, SigmaTransform], S] -"sample, output, sigma, sigma_transform" + +type DictOrProxy[T, U] = MappingProxyType[T, U] | dict[T, U] # Mapping does not implement __or__ +"Simple union type for a possibly immutable dictionary" + +type FloatSchedule = Sequence[tuple[float, float]] +"Sequence of timestep, sigma" + +type RNG[T: Sample] = Callable[[], T] +"Distribution should match model, typically normal" @enum.unique @@ -66,36 +74,67 @@ def sigma_polar(sigma: float) -> tuple[float, float]: return math.sin(theta), math.cos(theta) -def predict_epsilon[T: Sample](sample: T, output: T, sigma: float, sigma_transform: SigmaTransform) -> T: - "If a model does not specify, this is usually what it needs." +def get_sigma_uv(step: int, schedule: FloatSchedule, sigma_transform: SigmaTransform) -> tuple[float, float]: + """Gets sigma u/v with bounds check. + If step >= len(schedule), the sigma is assumed to be zero.""" + return sigma_transform(schedule[step][1] if step < len(schedule) else 0) + + +def scaled_delta(sigma: float, sigma_next: float, sigma_transform: SigmaTransform) -> tuple[float, float]: + "Returns delta (h) and scale factor to perform the euler method." sigma_u, sigma_v = sigma_transform(sigma) - return (sample - sigma_u * output) / sigma_v # type: ignore + sigma_u_next, sigma_v_next = sigma_transform(sigma_next) + + scale = sigma_u_next / sigma_u + delta = sigma_v_next - sigma_v * scale # aka `h` or `dt` + return delta, scale -def predict_sample[T: Sample](sample: T, output: T, sigma: float, sigma_transform: SigmaTransform) -> T: - "No prediction. Only for single step afaik." - return output +def euler[T: Sample](sample: T, prediction: T, sigma: float, sigma_next: float, sigma_transform: SigmaTransform) -> T: + "Perform the euler method using scaled_delta" + # Returns delta, scale so prediction is first + return math.sumprod((prediction, sample), scaled_delta(sigma, sigma_next, sigma_transform)) # type: ignore -def predict_velocity[T: Sample](sample: T, output: T, sigma: float, sigma_transform: SigmaTransform) -> T: - "Rare, models will usually explicitly say they require velocity/vpred/zero terminal SNR" +def scaled_delta_step( + step: int, schedule: FloatSchedule, sigma_transform: SigmaTransform, step_size: int = 1 +) -> tuple[float, float]: + """Returns delta (h) and scale factor to perform the euler method. + If step + step_size > len(schedule), assumes the next timestep and sigma are zero""" + step_next = step + step_size + return scaled_delta(schedule[step][1], schedule[step_next][1] if step_next < len(schedule) else 0, sigma_transform) + + +def euler_step[T: Sample]( + sample: T, prediction: T, step: int, schedule: FloatSchedule, sigma_transform: SigmaTransform, step_size: int = 1 +) -> T: + "Perform the euler method using scaled_delta_step" + return math.sumprod((prediction, sample), scaled_delta_step(step, schedule, sigma_transform, step_size)) # type: ignore + + +def merge_noise[T: Sample](sample: T, noise: T, sigma: float, sigma_transform: SigmaTransform) -> T: sigma_u, sigma_v = sigma_transform(sigma) - return sigma_v * sample - sigma_u * output # type: ignore + return sample * sigma_v + noise * sigma_u # type: ignore -def predict_flow[T: Sample](sample: T, output: T, sigma: float, sigma_transform: SigmaTransform) -> T: - "Flow matching models use this, notably FLUX.1 and SD3" - # TODO(beinsezii): this might need to be u * output. Don't trust diffusers - # Our tests will fail if we do so, leaving here for now. - return sample - sigma * output # type: ignore +def divf(lhs: float, rhs: float) -> float: + "Float division with infinity" + if rhs != 0: + return lhs / rhs + elif lhs == 0: + raise ZeroDivisionError + else: + return math.copysign(math.inf, lhs) -def safe_log(x: float) -> float: - "Returns inf rather than throw an err" - try: +def ln(x: float) -> float: + "Natural logarithm with infinity" + if x > 0: return math.log(x) - except ValueError: - return math.inf + elif x < 0: + raise ValueError + else: + return -math.inf def normalize(regular_array: NDArray[np.float64], start: float, end: float = 0) -> NDArray[np.float64]: @@ -128,6 +167,14 @@ def spowf[T: Sample](x: T, f: float) -> T: return abs(x) ** f * (-1 * (x < 0) | 1) # type: ignore +def mean(x: Sample) -> float: + "For an array this returns mean().item(). For a float this returns x" + if isinstance(x, float | int): + return x + else: + return x.mean().item() + + @lru_cache def bashforth(order: int) -> tuple[float, ...]: # tuple return so lru isnt mutable "Bashforth coefficients for a given order" diff --git a/skrample/diffusers.py b/skrample/diffusers.py index e7f7664..9b58908 100644 --- a/skrample/diffusers.py +++ b/skrample/diffusers.py @@ -9,8 +9,9 @@ from numpy.typing import NDArray from torch import Tensor -from skrample import sampling, scheduling -from skrample.common import MergeStrategy, Predictor, predict_epsilon, predict_flow, predict_sample, predict_velocity +import skrample.sampling.structured as sampling +from skrample import scheduling +from skrample.common import FloatSchedule, MergeStrategy from skrample.pytorch.noise import ( BatchTensorNoise, Random, @@ -18,14 +19,15 @@ TensorNoiseProps, schedule_to_ramp, ) -from skrample.sampling import SkrampleSampler, SKSamples +from skrample.sampling.models import DataModel, DiffusionModel, FlowModel, NoiseModel, VelocityModel +from skrample.sampling.structured import SKSamples, StructuredSampler from skrample.scheduling import ScheduleCommon, ScheduleModifier, SkrampleSchedule if TYPE_CHECKING: from diffusers.configuration_utils import ConfigMixin -DIFFUSERS_CLASS_MAP: dict[str, tuple[type[SkrampleSampler], dict[str, Any]]] = { +DIFFUSERS_CLASS_MAP: dict[str, tuple[type[StructuredSampler], dict[str, Any]]] = { "DDIMScheduler": (sampling.Euler, {}), "DDPMScheduler": (sampling.DPM, {"add_noise": True, "order": 1}), "DPMSolverMultistepScheduler": (sampling.DPM, {}), @@ -65,10 +67,10 @@ ("algorithm_type", "sde-dpmsolver"): ("add_noise", True), ("algorithm_type", "sde-dpmsolver++"): ("add_noise", True), # Complex types - ("prediction_type", "epsilon"): ("skrample_predictor", predict_epsilon), - ("prediction_type", "flow"): ("skrample_predictor", predict_flow), - ("prediction_type", "sample"): ("skrample_predictor", predict_sample), - ("prediction_type", "v_prediction"): ("skrample_predictor", predict_velocity), + ("prediction_type", "epsilon"): ("skrample_predictor", NoiseModel()), + ("prediction_type", "flow"): ("skrample_predictor", FlowModel()), + ("prediction_type", "sample"): ("skrample_predictor", DataModel()), + ("prediction_type", "v_prediction"): ("skrample_predictor", VelocityModel()), ("use_beta_sigmas", True): ("skrample_modifier", scheduling.Beta), ("use_exponential_sigmas", True): ("skrample_modifier", scheduling.Exponential), ("use_karras_sigmas", True): ("skrample_modifier", scheduling.Karras), @@ -93,17 +95,17 @@ class ParsedDiffusersConfig: "Values read from a combination of the diffusers config and provided types" - sampler: type[SkrampleSampler] + sampler: type[StructuredSampler] sampler_props: dict[str, Any] schedule: type[SkrampleSchedule] schedule_props: dict[str, Any] schedule_modifiers: list[tuple[type[ScheduleModifier], dict[str, Any]]] - predictor: Predictor + model: DiffusionModel def parse_diffusers_config( config: "dict[str, Any] | ConfigMixin", - sampler: type[SkrampleSampler] | None = None, + sampler: type[StructuredSampler] | None = None, schedule: type[SkrampleSchedule] | None = None, ) -> ParsedDiffusersConfig: """Reads a diffusers scheduler or config as a set of skrample classes and properties. @@ -120,11 +122,11 @@ def parse_diffusers_config( } if "skrample_predictor" in remapped: - predictor: Predictor = remapped.pop("skrample_predictor") + model: DiffusionModel = remapped.pop("skrample_predictor") elif "shift" in remapped: # should only be flow - predictor = predict_flow + model = FlowModel() else: - predictor = predict_epsilon + model = NoiseModel() if not sampler: sampler, sampler_props = DIFFUSERS_CLASS_MAP.get(diffusers_class, (sampling.DPM, {})) @@ -132,7 +134,7 @@ def parse_diffusers_config( sampler_props = {} if not schedule: - if predictor is predict_flow: + if isinstance(model, FlowModel): schedule = scheduling.Linear elif remapped.get("rescale_betas_zero_snr", False): schedule = scheduling.ZSNR @@ -140,16 +142,16 @@ def parse_diffusers_config( schedule = scheduling.Scaled # Adjust sigma_start to match scaled beta for sd1/xl - if "sigma_start" not in remapped and predictor is not predict_flow and issubclass(schedule, scheduling.Linear): + if "sigma_start" not in remapped and not isinstance(model, FlowModel) and issubclass(schedule, scheduling.Linear): scaled_keys = [f.name for f in dataclasses.fields(scheduling.Scaled)] # non-uniform misses a whole timestep scaled = scheduling.Scaled(**{k: v for k, v in remapped.items() if k in scaled_keys} | {"uniform": True}) - sigma_start: float = scaled.sigmas(1).item() + sigma_start: float = scaled.sigmas(1)[0] remapped["sigma_start"] = math.sqrt(sigma_start) schedule_modifiers: list[tuple[type[ScheduleModifier], dict[str, Any]]] = [] - if predictor is predict_flow: + if isinstance(model, FlowModel): flow_keys = [f.name for f in dataclasses.fields(scheduling.FlowShift)] schedule_modifiers.append((scheduling.FlowShift, {k: v for k, v in remapped.items() if k in flow_keys})) @@ -168,7 +170,7 @@ def parse_diffusers_config( schedule=schedule, schedule_props={k: v for k, v in remapped.items() if k in schedule_keys}, schedule_modifiers=schedule_modifiers, - predictor=predictor, + model=model, ) @@ -179,10 +181,14 @@ def attr_dict[T: Any](**kwargs: T) -> OrderedDict[str, T]: return od -def as_diffusers_config(sampler: SkrampleSampler, schedule: SkrampleSchedule, predictor: Predictor) -> dict[str, Any]: +def as_diffusers_config( + sampler: StructuredSampler, + schedule: SkrampleSchedule, + model: DiffusionModel, +) -> dict[str, Any]: "Converts skrample classes back into a diffusers-readable config. Not comprehensive" skrample_config = dataclasses.asdict(sampler) - skrample_config["skrample_predictor"] = predictor + skrample_config["skrample_predictor"] = model if isinstance(schedule, ScheduleModifier): for modifier in schedule.all: @@ -208,9 +214,9 @@ class SkrampleWrapperScheduler[T: TensorNoiseProps | None]: Best effort approach. Most of the items presented in .config are fake, and many function inputs are ignored. A general rule of thumb is it will always prioritize the skrample properties over the incoming properties.""" - sampler: SkrampleSampler + sampler: StructuredSampler schedule: SkrampleSchedule - predictor: Predictor[Tensor] = predict_epsilon + model: DiffusionModel = NoiseModel() # noqa: RUF009 # is immutable noise_type: type[TensorNoiseCommon[T]] = Random # type: ignore # Unsure why? noise_props: T | None = None compute_scale: torch.dtype | None = torch.float32 @@ -233,10 +239,10 @@ def __post_init__(self) -> None: def from_diffusers_config[N: TensorNoiseProps | None]( # pyright fails if you use the outer generic cls, config: "dict[str, Any] | ConfigMixin", - sampler: type[SkrampleSampler] | None = None, + sampler: type[StructuredSampler] | None = None, schedule: type[SkrampleSchedule] | None = None, schedule_modifiers: list[tuple[type[ScheduleModifier], dict[str, Any]]] = [], - predictor: Predictor[Tensor] | None = None, + model: DiffusionModel | None = None, noise_type: type[TensorNoiseCommon[N]] = Random, compute_scale: torch.dtype | None = torch.float32, sampler_props: dict[str, Any] = {}, @@ -262,7 +268,7 @@ def from_diffusers_config[N: TensorNoiseProps | None]( # pyright fails if you u return cls( built_sampler, built_schedule, - predictor or parsed.predictor, + model or parsed.model, noise_type=noise_type, # type: ignore # think these are weird because of the defaults? noise_props=noise_props, # type: ignore compute_scale=compute_scale, @@ -271,9 +277,13 @@ def from_diffusers_config[N: TensorNoiseProps | None]( # pyright fails if you u ) @property - def schedule_np(self) -> NDArray[np.float64]: + def schedule_float(self) -> FloatSchedule: return scheduling.schedule_lru(self.schedule, self._steps) + @property + def schedule_np(self) -> NDArray[np.float64]: + return scheduling.np_schedule_lru(self.schedule, self._steps) + @property def schedule_pt(self) -> Tensor: return torch.from_numpy(self.schedule_np).to(self._device) @@ -290,7 +300,7 @@ def sigmas(self) -> Tensor: @property def init_noise_sigma(self) -> float: - return self.sampler.scale_input(1, self.schedule_np[0, 1].item(), sigma_transform=self.schedule.sigma_transform) + return self.sampler.scale_input(1, self.schedule_float[0][1], sigma_transform=self.schedule.sigma_transform) @property def order(self) -> int: @@ -299,7 +309,7 @@ def order(self) -> int: @property def config(self) -> OrderedDict[str, Any]: # Diffusers expects the frozen shift value - return attr_dict(**(self.fake_config | as_diffusers_config(self.sampler, self._schedule, self.predictor))) + return attr_dict(**(self.fake_config | as_diffusers_config(self.sampler, self._schedule, self.model))) def time_shift(self, mu: float, sigma: float, t: Tensor) -> Tensor: return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) @@ -403,7 +413,7 @@ def step( noise = None sample_cast = sample.to(dtype=self.compute_scale) - prediction = self.predictor( + prediction = self.model.to_x( sample_cast, model_output.to(dtype=self.compute_scale), schedule[step, 1].item(), @@ -412,7 +422,7 @@ def step( sampled = self.sampler.sample( sample=sample_cast, prediction=prediction, - sigma_schedule=schedule[:, 1], + schedule=self.schedule_float, step=step, noise=noise, previous=tuple(self._previous), diff --git a/skrample/sampling/__init__.py b/skrample/sampling/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/skrample/sampling/functional.py b/skrample/sampling/functional.py new file mode 100644 index 0000000..7687a2f --- /dev/null +++ b/skrample/sampling/functional.py @@ -0,0 +1,451 @@ +import dataclasses +import math +from abc import ABC, abstractmethod +from collections.abc import Callable +from types import MappingProxyType +from typing import Any + +import numpy as np + +from skrample import common, scheduling +from skrample.common import RNG, DictOrProxy, FloatSchedule, Sample, SigmaTransform + +from . import models, tableaux + +type SampleCallback[T: Sample] = Callable[[T, int, float, float], Any] +"Return is ignored" +type SampleableModel[T: Sample] = Callable[[T, float, float], T] +"sample, timestep, sigma" + + +def fractional_step( + schedule: FloatSchedule, + current: int, + idx: tuple[float, ...], +) -> tuple[tuple[float, float], ...]: + schedule_np = np.array([*schedule, (0, 0)]) + idx_np = np.array(idx) / len(schedule) + current / len(schedule) + scale = np.linspace(0, 1, len(schedule_np)) + + # TODO (beinszeii): better 2d interpolate + result = tuple( + zip( + (np.interp(idx_np, scale, schedule_np[:, 0])).tolist(), + (np.interp(idx_np, scale, schedule_np[:, 1])).tolist(), + ) + ) + return result + + +def step_tableau[T: Sample]( + tableau: tableaux.Tableau | tableaux.ExtendedTableau, + sample: T, + model: SampleableModel[T], + model_transform: models.DiffusionModel, + step: int, + schedule: FloatSchedule, + sigma_transform: SigmaTransform, + derivative_transform: models.DiffusionModel | None = None, + step_size: int = 1, + epsilon: float = 1e-8, +) -> tuple[T, ...]: + nodes, weights = tableau[0], tableau[1:] + + if derivative_transform: + model = models.ModelConvert(model_transform, derivative_transform).wrap_model_call(model, sigma_transform) + model_transform = derivative_transform + + derivatives: list[T] = [] + S0 = schedule[step][1] + S1 = schedule[step + step_size][1] if step + step_size < len(schedule) else 0 + + fractions = fractional_step(schedule, step, tuple(f[0] * step_size for f in nodes)) + + for frac_sc, icoeffs in zip(fractions, (t[1] for t in nodes), strict=True): + sigma_i = frac_sc[1] + if icoeffs: + X: T = model_transform.forward( # pyright: ignore [reportAssignmentType] + sample, + math.sumprod(derivatives, icoeffs) / math.fsum(icoeffs), # pyright: ignore [reportArgumentType] + S0, + sigma_i, + sigma_transform, + ) + else: + X = sample + + # Do not call model on timestep = 0 or sigma = 0 + if any(abs(v) < epsilon for v in frac_sc): + derivatives.append(model_transform.backward(sample, X, S0, S1, sigma_transform)) + else: + derivatives.append(model(X, *frac_sc)) + + return tuple( # pyright: ignore [reportReturnType] + model_transform.forward( + sample, + math.sumprod(derivatives, w), # pyright: ignore [reportArgumentType] + S0, + S1, + sigma_transform, + ) + for w in weights + ) + + +@dataclasses.dataclass(frozen=True) +class FunctionalSampler(ABC): + schedule: scheduling.SkrampleSchedule + + def merge_noise[T: Sample](self, sample: T, noise: T, steps: int, start: int) -> T: + sigmas = self.schedule.sigmas(steps) + sigma = sigmas[start] if start < len(sigmas) else 0 + return common.merge_noise(sample, noise, sigma, self.schedule.sigma_transform) + + @abstractmethod + def sample_model[T: Sample]( + self, + sample: T, + model: SampleableModel[T], + model_transform: models.DiffusionModel, + steps: int, + include: slice = slice(None), + rng: RNG[T] | None = None, + callback: SampleCallback | None = None, + ) -> T: ... + + """Runs the noisy sample through the model for a given range `include` of total steps. + Calls callback every step with sampled value.""" + + def generate_model[T: Sample]( + self, + model: SampleableModel[T], + model_transform: models.DiffusionModel, + rng: RNG[T], + steps: int, + include: slice = slice(None), + initial: T | None = None, + callback: SampleCallback | None = None, + ) -> T: + """Equivalent to `sample_model` except the noise is handled automatically + rather than being pre-added to the initial value""" + + if initial is None and include.start is None: # Short circuit for common case + sample: T = rng() + else: + sample: T = self.merge_noise( + 0 if initial is None else initial, # type: ignore + rng(), + steps, + include.start or 0, + ) / self.merge_noise(0.0, 1.0, steps, 0) + # Rescale sample by initial sigma. Mostly just to handle quirks with Scaled + + return self.sample_model(sample, model, model_transform, steps, include, rng, callback) + + +@dataclasses.dataclass(frozen=True) +class FunctionalHigher(FunctionalSampler): + order: int = 1 + + @staticmethod + def min_order() -> int: + return 1 + + @staticmethod + @abstractmethod + def max_order() -> int: ... + + def adjust_steps(self, steps: int) -> int: + "Adjust the steps to approximate an equal amount of model calls" + return round(steps / self.order) + + +@dataclasses.dataclass(frozen=True) +class FunctionalDerivative(FunctionalHigher): + derivative_transform: models.DiffusionModel | None = models.DataModel() # noqa: RUF009 # is immutable + "Transform model to this space when computing higher order samples." + + +@dataclasses.dataclass(frozen=True) +class FunctionalSinglestep(FunctionalSampler): + @abstractmethod + def step[T: Sample]( + self, + sample: T, + model: SampleableModel[T], + model_transform: models.DiffusionModel, + step: int, + schedule: FloatSchedule, + rng: RNG[T] | None = None, + ) -> T: ... + + def sample_model[T: Sample]( + self, + sample: T, + model: SampleableModel[T], + model_transform: models.DiffusionModel, + steps: int, + include: slice = slice(None), + rng: RNG[T] | None = None, + callback: SampleCallback | None = None, + ) -> T: + schedule: FloatSchedule = self.schedule.schedule(steps) + + for n in list(range(steps))[include]: + sample = self.step(sample, model, model_transform, n, schedule, rng) + + if callback: + callback(sample, n, *schedule[n] if n < len(schedule) else (0, 0)) + + return sample + + +@dataclasses.dataclass(frozen=True) +class FunctionalAdaptive(FunctionalSampler): + type Evaluator[T: Sample] = Callable[[T, T], float] + + @staticmethod + def mse[T: Sample](a: T, b: T) -> float: + error: T = abs(a - b) ** 2 # type: ignore + return common.mean(error) + + evaluator: Evaluator = mse + "Function used to measure error of two samples" + threshold: float = 1e-2 + "Target error threshold for a given evaluation" + + +@dataclasses.dataclass(frozen=True) +class RKUltra(FunctionalDerivative, FunctionalSinglestep): + "Implements almost every single method from https://en.wikipedia.org/wiki/List_of_Runge–Kutta_methods" # noqa: RUF002 + + order: int = 2 + + providers: DictOrProxy[int, tableaux.TableauProvider[tableaux.Tableau | tableaux.ExtendedTableau]] = ( + MappingProxyType( + { + 2: tableaux.RK2.Heun, + 3: tableaux.RK3.Ralston, + 4: tableaux.RK4.Ralston, + 5: tableaux.RKE5.CashKarp, + } + ) + ) + """Providers for a given order, starting from 2. + Order 1 is always the Euler method.""" + + @staticmethod + def max_order() -> int: + return 99 + + def tableau(self, order: int | None = None) -> tableaux.Tableau: + if order is None: + order = self.order + + if order >= 2 and (morder := max(o for o in self.providers.keys() if o <= order)): + return self.providers[morder].tableau()[:2] + else: # Euler / RK1 + return tableaux.RK1 + + def adjust_steps(self, steps: int) -> int: + stages = self.tableau()[0] + calls = len(stages) + + # Add back the skipped calls on penultimate T + adjusted = steps / calls + sum(abs(1 - f[0]) < 1e-8 for f in stages) / calls + + return max(round(adjusted), 1) + + def step[T: Sample]( + self, + sample: T, + model: SampleableModel[T], + model_transform: models.DiffusionModel, + step: int, + schedule: FloatSchedule, + rng: RNG[T] | None = None, + ) -> T: + return step_tableau( + self.tableau(), + sample, + model, + model_transform, + step, + schedule, + self.schedule.sigma_transform, + derivative_transform=self.derivative_transform, + )[0] + + +@dataclasses.dataclass(frozen=True) +class FastHeun(FunctionalAdaptive, FunctionalSinglestep, FunctionalHigher): + order: int = 2 + + threshold: float = 5e-2 + + @staticmethod + def min_order() -> int: + return 2 + + @staticmethod + def max_order() -> int: + return 2 + + def adjust_steps(self, steps: int) -> int: + return round(steps * 0.75 + 0.25) + + def step[T: Sample]( + self, + sample: T, + model: SampleableModel[T], + model_transform: models.DiffusionModel, + step: int, + schedule: FloatSchedule, + rng: RNG[T] | None = None, + ) -> T: + k1 = model(sample, *schedule[step]) + sigma_from = schedule[step][1] + sigma_to = schedule[step + 1][1] if step + 1 < len(schedule) else 0 + result: T = model_transform.forward(sample, k1, sigma_from, sigma_to, self.schedule.sigma_transform) + + # Multiplying by step size here is kind of an asspull, but so is this whole solver so... + if step + 1 < len(schedule) and self.evaluator(sample, result) / max( + self.evaluator(0, result), 1e-16 + ) > self.threshold * abs(sigma_from - sigma_to): + k2: T = (k1 + model(result, *schedule[step + 1])) / 2 # pyright: ignore [reportAssignmentType] + result: T = model_transform.forward(sample, k2, sigma_from, sigma_to, self.schedule.sigma_transform) + + return result + + +@dataclasses.dataclass(frozen=True) +class RKMoire(FunctionalAdaptive, FunctionalDerivative): + order: int = 2 + + providers: DictOrProxy[int, tableaux.TableauProvider[tableaux.ExtendedTableau]] = MappingProxyType( + { + 2: tableaux.RKE2.Heun, + 3: tableaux.RKE3.BogackiShampine, + 5: tableaux.RKE5.Fehlberg, + } + ) + """Providers for a given order, starting from 2. + Falls back to RKE2.Heun""" + + threshold: float = 1e-4 + + initial: float = 1 / 50 + "Percent of schedule to take as an initial step." + maximum: float = 1 / 4 + "Percent of schedule to take as a maximum step." + adaption: float = 0.3 + "How fast to adjust step size in relation to error" + discard: float = float("inf") + "If the final adjustment down is more than this, the entire previous step is discarded." + + rescale_init: bool = True + "Scale initial by a tableau's model evals." + rescale_max: bool = False + "Scale maximum by a tableau's model evals." + + @staticmethod + def min_order() -> int: + return 2 + + @staticmethod + def max_order() -> int: + return 99 + + def adjust_steps(self, steps: int) -> int: + return steps + + def tableau(self, order: int | None = None) -> tableaux.ExtendedTableau: + if order is None: + order = self.order + + if order >= 2 and (morder := max(o for o in self.providers.keys() if o <= order)): + return self.providers[morder].tableau() + else: + return tableaux.RKE2.Heun.tableau() + + def sample_model[T: Sample]( + self, + sample: T, + model: SampleableModel[T], + model_transform: models.DiffusionModel, + steps: int, + include: slice = slice(None), + rng: RNG[T] | None = None, + callback: SampleCallback | None = None, + ) -> T: + tab = self.tableau() + + initial = self.initial + maximum = self.maximum + if self.rescale_init: + initial *= len(tab[0]) / 2 # Heun is base so / 2 + if self.rescale_max: + maximum *= len(tab[0]) / 2 # Heun is base so / 2 + + step_size: int = max(round(steps * initial), 1) + epsilon: float = 1e-16 # lgtm + + schedule: FloatSchedule = self.schedule.schedule(steps) + + indices: list[int] = list(range(steps))[include] + step: int = indices[0] + + while step <= indices[-1]: + step_next = min(step + step_size, indices[-1] + 1) + + if step_next < len(schedule): + sample_high, sample_low = step_tableau( + tab, + sample, + model, + model_transform, + step, + schedule, + self.schedule.sigma_transform, + self.derivative_transform, + step_size, + ) + + sigma0 = schedule[step][1] + sigma1 = schedule[step_next][1] + sigma2 = schedule[step_next + step_size][1] if step_next + step_size < len(schedule) else 0 + # Offset adjustment by dt2 / dt to account for non-linearity + # Basically if we want a 50% larger step but the next dt will already be 25% larger, + # we should only set a 20% larger step ie 1.5 / 1.25 + slope = abs(sigma0 - sigma1) / abs(sigma1 - sigma2) + + # Normalize against pure error + error = self.evaluator(sample_low, sample_high) / max(self.evaluator(0, sample_high), epsilon) + adjustment: float = (self.threshold / max(error, epsilon)) ** self.adaption / slope + step_size = max(round(min(step_size * adjustment, steps * maximum)), 1) + + # Only discard if it will actually decrease step size + if step_next - step > step_size and 1 / max(adjustment, epsilon) > self.discard: + continue + + else: # Save the extra euler call since the 2nd weight isn't used + sample_high = step_tableau( + tab[:2], + sample, + model, + model_transform, + step, + schedule, + self.schedule.sigma_transform, + self.derivative_transform, + step_size, + )[0] + + sample = sample_high + + if callback: + callback(sample, step_next - 1, *schedule[step] if step < len(schedule) else (0, 0)) + + step = step_next + + return sample diff --git a/skrample/sampling/interface.py b/skrample/sampling/interface.py new file mode 100644 index 0000000..f91683c --- /dev/null +++ b/skrample/sampling/interface.py @@ -0,0 +1,55 @@ +import dataclasses + +from skrample.common import RNG, FloatSchedule, Sample + +from . import functional, models, structured + + +@dataclasses.dataclass(frozen=True) +class StructuredFunctionalAdapter(functional.FunctionalSampler): + sampler: structured.StructuredSampler + + def merge_noise[T: Sample](self, sample: T, noise: T, steps: int, start: int) -> T: + sigmas = self.schedule.sigmas(steps) + sigma = sigmas[start] if start < len(sigmas) else 0 + return self.sampler.merge_noise(sample, noise, sigma, self.schedule.sigma_transform) + + def sample_model[T: Sample]( + self, + sample: T, + model: functional.SampleableModel[T], + model_transform: models.DiffusionModel, + steps: int, + include: slice = slice(None), + rng: RNG[T] | None = None, + callback: functional.SampleCallback | None = None, + ) -> T: + previous: list[structured.SKSamples[T]] = [] + schedule: FloatSchedule = self.schedule.schedule(steps) + + for n in list(range(len(schedule)))[include]: + timestep, sigma = schedule[n] + + output = model(self.sampler.scale_input(sample, sigma, self.schedule.sigma_transform), timestep, sigma) + prediction = model_transform.to_x(sample, output, sigma, self.schedule.sigma_transform) + + sksamples = self.sampler.sample( + sample, + prediction, + n, + schedule, + self.schedule.sigma_transform, + noise=rng() if rng and self.sampler.require_noise else None, + previous=tuple(previous), + ) + + if self.sampler.require_previous > 0: + previous.append(sksamples) + previous = previous[max(len(previous) - self.sampler.require_previous, 0) :] + + sample = sksamples.final + + if callback: + callback(sample, n, *schedule[n] if n < len(schedule) else (0, 0)) + + return sample diff --git a/skrample/sampling/models.py b/skrample/sampling/models.py new file mode 100644 index 0000000..8e116d1 --- /dev/null +++ b/skrample/sampling/models.py @@ -0,0 +1,226 @@ +import abc +import dataclasses +import math +from collections.abc import Callable +from functools import wraps + +from skrample.common import Sample, SigmaTransform + + +@dataclasses.dataclass(frozen=True) +class DiffusionModel(abc.ABC): + """Common framework for diffusion model sampling.""" + + @abc.abstractmethod + def to_x[T: Sample](self, sample: T, output: T, sigma: float, sigma_transform: SigmaTransform) -> T: + "output -> X̂" + + @abc.abstractmethod + def from_x[T: Sample](self, sample: T, x: T, sigma: float, sigma_transform: SigmaTransform) -> T: + "X̂ -> output" + + @abc.abstractmethod + def gamma(self, sigma_from: float, sigma_to: float, sigma_transform: SigmaTransform) -> float: + "σₜ, σₛ -> Γ" + + @abc.abstractmethod + def delta(self, sigma_from: float, sigma_to: float, sigma_transform: SigmaTransform) -> float: + "σₜ, σₛ -> Δ" + + def forward[T: Sample]( + self, sample: T, output: T, sigma_from: float, sigma_to: float, sigma_transform: SigmaTransform + ) -> T: + "sample * Γ + output * Δ" + gamma = self.gamma(sigma_from, sigma_to, sigma_transform) + delta = self.delta(sigma_from, sigma_to, sigma_transform) + return math.sumprod((sample, output), (gamma, delta)) # pyright: ignore [reportReturnType, reportArgumentType] + + def backward[T: Sample]( + self, sample: T, result: T, sigma_from: float, sigma_to: float, sigma_transform: SigmaTransform + ) -> T: + "(output - sample * Γ) / Δ" + gamma = self.gamma(sigma_from, sigma_to, sigma_transform) + delta = self.delta(sigma_from, sigma_to, sigma_transform) + return (result - sample * gamma) / delta # pyright: ignore [reportReturnType] + + +@dataclasses.dataclass(frozen=True) +class DataModel(DiffusionModel): + """X-Prediction + Predicts the clean image. + Usually for single step models.""" + + def to_x[T: Sample](self, sample: T, output: T, sigma: float, sigma_transform: SigmaTransform) -> T: + "output -> X̂" + return output + + def from_x[T: Sample](self, sample: T, x: T, sigma: float, sigma_transform: SigmaTransform) -> T: + "X̂ -> output" + return x + + def gamma(self, sigma_from: float, sigma_to: float, sigma_transform: SigmaTransform) -> float: + "σₜ, σₛ -> Γ" + sigma_t, _alpha_t = sigma_transform(sigma_from) + sigma_s, _alpha_s = sigma_transform(sigma_to) + return sigma_s / sigma_t + + def delta(self, sigma_from: float, sigma_to: float, sigma_transform: SigmaTransform) -> float: + "σₜ, σₛ -> Δ" + sigma_t, alpha_t = sigma_transform(sigma_from) + sigma_s, alpha_s = sigma_transform(sigma_to) + return alpha_s - (alpha_t * sigma_s) / sigma_t + + +@dataclasses.dataclass(frozen=True) +class NoiseModel(DiffusionModel): + """Ε-Prediction + Predicts the added noise. + If a model does not specify, this is usually what it needs.""" # noqa: RUF002 + + def to_x[T: Sample](self, sample: T, output: T, sigma: float, sigma_transform: SigmaTransform) -> T: + sigma_t, alpha_t = sigma_transform(sigma) + return (sample - sigma_t * output) / alpha_t # pyright: ignore [reportReturnType] + + def from_x[T: Sample](self, sample: T, x: T, sigma: float, sigma_transform: SigmaTransform) -> T: + sigma_t, alpha_t = sigma_transform(sigma) + return (sample - alpha_t * x) / sigma_t # pyright: ignore [reportReturnType] + + def gamma(self, sigma_from: float, sigma_to: float, sigma_transform: SigmaTransform) -> float: + _sigma_t, alpha_t = sigma_transform(sigma_from) + _sigma_s, alpha_s = sigma_transform(sigma_to) + return alpha_s / alpha_t + + def delta(self, sigma_from: float, sigma_to: float, sigma_transform: SigmaTransform) -> float: + sigma_t, alpha_t = sigma_transform(sigma_from) + sigma_s, alpha_s = sigma_transform(sigma_to) + return sigma_s - (alpha_s * sigma_t) / alpha_t + + +@dataclasses.dataclass(frozen=True) +class FlowModel(DiffusionModel): + """U-Prediction. + Flow matching models use this, notably FLUX.1 and SD3""" + + def to_x[T: Sample](self, sample: T, output: T, sigma: float, sigma_transform: SigmaTransform) -> T: + sigma_t, alpha_t = sigma_transform(sigma) + return (sample - sigma_t * output) / (alpha_t + sigma_t) # pyright: ignore [reportReturnType] + + def from_x[T: Sample](self, sample: T, x: T, sigma: float, sigma_transform: SigmaTransform) -> T: + sigma_t, alpha_t = sigma_transform(sigma) + return (sample - (alpha_t + sigma_t) * x) / sigma_t # pyright: ignore [reportReturnType] + + def gamma(self, sigma_from: float, sigma_to: float, sigma_transform: SigmaTransform) -> float: + sigma_t, alpha_t = sigma_transform(sigma_from) + sigma_s, alpha_s = sigma_transform(sigma_to) + return (sigma_s + alpha_s) / (sigma_t + alpha_t) + + def delta(self, sigma_from: float, sigma_to: float, sigma_transform: SigmaTransform) -> float: + sigma_t, alpha_t = sigma_transform(sigma_from) + sigma_s, alpha_s = sigma_transform(sigma_to) + return (alpha_t * sigma_s - alpha_s * sigma_t) / (alpha_t + sigma_t) + + +@dataclasses.dataclass(frozen=True) +class VelocityModel(DiffusionModel): + """V-Prediction. + Rare, models will usually explicitly say they require velocity/vpred/zero terminal SNR""" + + def to_x[T: Sample](self, sample: T, output: T, sigma: float, sigma_transform: SigmaTransform) -> T: + sigma_t, alpha_t = sigma_transform(sigma) + return alpha_t * sample - sigma_t * output # pyright: ignore [reportReturnType] + + def from_x[T: Sample](self, sample: T, x: T, sigma: float, sigma_transform: SigmaTransform) -> T: + sigma_t, alpha_t = sigma_transform(sigma) + return (alpha_t * sample - x) / sigma_t # pyright: ignore [reportReturnType] + + def gamma(self, sigma_from: float, sigma_to: float, sigma_transform: SigmaTransform) -> float: + sigma_t, alpha_t = sigma_transform(sigma_from) + sigma_s, alpha_s = sigma_transform(sigma_to) + return (sigma_s / sigma_t) * (1 - alpha_t * alpha_t) + alpha_s * alpha_t + + def delta(self, sigma_from: float, sigma_to: float, sigma_transform: SigmaTransform) -> float: + sigma_t, alpha_t = sigma_transform(sigma_from) + sigma_s, alpha_s = sigma_transform(sigma_to) + return alpha_t * sigma_s - alpha_s * sigma_t + + +@dataclasses.dataclass(frozen=True) +class FakeModel(DiffusionModel): + "Marker for transforms that are only used for alternative sampling of other models." + + +@dataclasses.dataclass(frozen=True) +class ScaleX(FakeModel): + "X / Sample prediction with sampling bias" + + bias: float = 3 + """Bias for sample prediction. + Higher values create a stronger image.""" + + def x_scale(self, sigma_t: float, alpha_t: float) -> float: + # Remap -∞ → 0 → ∞ » 0 → 1 → log(∞) + if self.bias < 0: + # -∞ → 0⁻ » 0⁺ → 1⁻ + factor = 1 / math.log(math.e - self.bias) + else: + # 0 → ∞ » 1 → log(∞) + factor = math.log(math.e + self.bias) + + # Rescale sigma_t to average bias scale on VP and NV schedules + sigma_mean = sigma_t / (sigma_t + alpha_t) + return factor**sigma_mean + + def to_x[T: Sample](self, sample: T, output: T, sigma: float, sigma_transform: SigmaTransform) -> T: + sigma_t, alpha_t = sigma_transform(sigma) + return output * self.x_scale(sigma_t, alpha_t) # pyright: ignore [reportReturnType] + + def from_x[T: Sample](self, sample: T, x: T, sigma: float, sigma_transform: SigmaTransform) -> T: + sigma_t, alpha_t = sigma_transform(sigma) + return x / self.x_scale(sigma_t, alpha_t) # pyright: ignore [reportReturnType] + + def gamma(self, sigma_from: float, sigma_to: float, sigma_transform: SigmaTransform) -> float: + sigma_t, _alpha_t = sigma_transform(sigma_from) + sigma_s, _alpha_s = sigma_transform(sigma_to) + return sigma_s / sigma_t + + def delta(self, sigma_from: float, sigma_to: float, sigma_transform: SigmaTransform) -> float: + sigma_t, alpha_t = sigma_transform(sigma_from) + sigma_s, alpha_s = sigma_transform(sigma_to) + return (alpha_s - alpha_t * sigma_s / sigma_t) * self.x_scale(sigma_t, alpha_t) + + +@dataclasses.dataclass(frozen=True) +class ModelConvert: + transform_from: DiffusionModel + transform_to: DiffusionModel + + def output_to[T: Sample](self, sample: T, output_from: T, sigma: float, sigma_transform: SigmaTransform) -> T: + if self.transform_to is self.transform_from: + return output_from + else: + return self.transform_to.from_x( + sample, + self.transform_from.to_x(sample, output_from, sigma, sigma_transform), + sigma, + sigma_transform, + ) + + def output_from[T: Sample](self, sample: T, output_to: T, sigma: float, sigma_transform: SigmaTransform) -> T: + if self.transform_from is self.transform_to: + return output_to + else: + return self.transform_from.from_x( + sample, + self.transform_to.to_x(sample, output_to, sigma, sigma_transform), + sigma, + sigma_transform, + ) + + def wrap_model_call[T: Sample]( + self, model: Callable[[T, float, float], T], sigma_transform: SigmaTransform + ) -> Callable[[T, float, float], T]: + @wraps(model) + def converted(sample: T, timestep: float, sigma: float) -> T: + return self.output_to(sample, model(sample, timestep, sigma), sigma, sigma_transform) + + return converted diff --git a/skrample/sampling.py b/skrample/sampling/structured.py similarity index 78% rename from skrample/sampling.py rename to skrample/sampling/structured.py index 76948de..6fc60db 100644 --- a/skrample/sampling.py +++ b/skrample/sampling/structured.py @@ -3,9 +3,9 @@ from dataclasses import dataclass, replace import numpy as np -from numpy.typing import NDArray -from skrample.common import Sample, SigmaTransform, bashforth, safe_log, softmax, spowf +from skrample import common +from skrample.common import FloatSchedule, Sample, SigmaTransform, divf, ln, merge_noise, softmax, spowf @dataclass(frozen=True) @@ -27,7 +27,7 @@ class SKSamples[T: Sample]: @dataclass(frozen=True) -class SkrampleSampler(ABC): +class StructuredSampler(ABC): """Generic sampler structure with basic configurables and a stateless design. Abstract class not to be used directly. @@ -44,18 +44,13 @@ def require_previous(self) -> int: "How many prior samples the sampler needs in `previous: list[T]`" return 0 - @staticmethod - def get_sigma(step: int, sigma_schedule: NDArray) -> float: - "Just returns zero if step > len" - return sigma_schedule[step].item() if step < len(sigma_schedule) else 0 - @abstractmethod def sample[T: Sample]( self, sample: T, prediction: T, step: int, - sigma_schedule: NDArray, + schedule: FloatSchedule, sigma_transform: SigmaTransform, noise: T | None = None, previous: tuple[SKSamples[T], ...] = (), @@ -75,15 +70,14 @@ def scale_input[T: Sample](self, sample: T, sigma: float, sigma_transform: Sigma return sample def merge_noise[T: Sample](self, sample: T, noise: T, sigma: float, sigma_transform: SigmaTransform) -> T: - sigma_u, sigma_v = sigma_transform(sigma) - return sample * sigma_v + noise * sigma_u # type: ignore + return merge_noise(sample, noise, sigma, sigma_transform) def __call__[T: Sample]( self, sample: T, prediction: T, step: int, - sigma_schedule: NDArray, + schedule: FloatSchedule, sigma_transform: SigmaTransform, noise: T | None = None, previous: tuple[SKSamples[T], ...] = (), @@ -92,7 +86,7 @@ def __call__[T: Sample]( sample=sample, prediction=prediction, step=step, - sigma_schedule=sigma_schedule, + schedule=schedule, sigma_transform=sigma_transform, noise=noise, previous=previous, @@ -100,7 +94,7 @@ def __call__[T: Sample]( @dataclass(frozen=True) -class HighOrderSampler(SkrampleSampler): +class StructuredMultistep(StructuredSampler): """Samplers inheriting this trait support order > 1, and will require `prevous` be managed and passed to function accordingly.""" @@ -119,7 +113,7 @@ def max_order() -> int: def require_previous(self) -> int: return max(min(self.order, self.max_order()), self.min_order()) - 1 - def effective_order(self, step: int, schedule: NDArray, previous: tuple[SKSamples, ...]) -> int: + def effective_order(self, step: int, schedule: FloatSchedule, previous: tuple[SKSamples, ...]) -> int: "The order used in calculation given a step, schedule length, and previous sample count" return max( 1, # not min_order because previous may be < min. samplers should check effective >= min @@ -134,7 +128,7 @@ def effective_order(self, step: int, schedule: NDArray, previous: tuple[SKSample @dataclass(frozen=True) -class StochasticSampler(SkrampleSampler): +class StructuredStochastic(StructuredSampler): add_noise: bool = False "Flag for whether or not to add the given noise" @@ -144,7 +138,7 @@ def require_noise(self) -> bool: @dataclass(frozen=True) -class Euler(SkrampleSampler): +class Euler(StructuredSampler): """Basic sampler, the "safe" choice.""" def sample[T: Sample]( @@ -152,30 +146,20 @@ def sample[T: Sample]( sample: T, prediction: T, step: int, - sigma_schedule: NDArray, + schedule: FloatSchedule, sigma_transform: SigmaTransform, noise: T | None = None, previous: tuple[SKSamples[T], ...] = (), ) -> SKSamples[T]: - sigma = self.get_sigma(step, sigma_schedule) - sigma_next = self.get_sigma(step + 1, sigma_schedule) - - sigma_u, sigma_v = sigma_transform(sigma) - sigma_u_next, sigma_v_next = sigma_transform(sigma_next) - - scale = sigma_u_next / sigma_u - delta = sigma_v_next - sigma_v * scale # aka `h` or `dt` - final = sample * scale + prediction * delta - - return SKSamples( # type: ignore - final=final, + return SKSamples( + final=common.euler_step(sample, prediction, step, schedule, sigma_transform), prediction=prediction, sample=sample, ) @dataclass(frozen=True) -class DPM(HighOrderSampler, StochasticSampler): +class DPM(StructuredMultistep, StructuredStochastic): """Good sampler, supports basically everything. Recommended default. https://arxiv.org/abs/2211.01095 @@ -191,19 +175,16 @@ def sample[T: Sample]( sample: T, prediction: T, step: int, - sigma_schedule: NDArray, + schedule: FloatSchedule, sigma_transform: SigmaTransform, noise: T | None = None, previous: tuple[SKSamples[T], ...] = (), ) -> SKSamples[T]: - sigma = self.get_sigma(step, sigma_schedule) - sigma_next = self.get_sigma(step + 1, sigma_schedule) - - sigma_u, sigma_v = sigma_transform(sigma) - sigma_u_next, sigma_v_next = sigma_transform(sigma_next) + sigma_u, sigma_v = common.get_sigma_uv(step, schedule, sigma_transform) + sigma_u_next, sigma_v_next = common.get_sigma_uv(step + 1, schedule, sigma_transform) - lambda_ = safe_log(sigma_v) - safe_log(sigma_u) - lambda_next = safe_log(sigma_v_next) - safe_log(sigma_u_next) + lambda_ = ln(divf(sigma_v, sigma_u)) + lambda_next = ln(divf(sigma_v_next, sigma_u_next)) h = abs(lambda_next - lambda_) if noise is not None and self.add_noise: @@ -222,13 +203,10 @@ def sample[T: Sample]( # 1st order final -= (sigma_v_next * exp2) * prediction - effective_order = self.effective_order(step, sigma_schedule, previous) + if (effective_order := self.effective_order(step, schedule, previous)) >= 2: + sigma_u_prev, sigma_v_prev = common.get_sigma_uv(step - 1, schedule, sigma_transform) - if effective_order >= 2: - sigma_prev = self.get_sigma(step - 1, sigma_schedule) - sigma_u_prev, sigma_v_prev = sigma_transform(sigma_prev) - - lambda_prev = safe_log(sigma_v_prev) - safe_log(sigma_u_prev) + lambda_prev = ln(divf(sigma_v_prev, sigma_u_prev)) h_prev = lambda_ - lambda_prev r = h_prev / h # math people and their var names... @@ -237,9 +215,8 @@ def sample[T: Sample]( D1_0 = (1.0 / r) * (prediction - prediction_prev) if effective_order >= 3: - sigma_prev2 = self.get_sigma(step - 2, sigma_schedule) - sigma_u_prev2, sigma_v_prev2 = sigma_transform(sigma_prev2) - lambda_prev2 = safe_log(sigma_v_prev2) - safe_log(sigma_u_prev2) + sigma_u_prev2, sigma_v_prev2 = common.get_sigma_uv(step - 2, schedule, sigma_transform) + lambda_prev2 = ln(divf(sigma_v_prev2, sigma_u_prev2)) h_prev2 = lambda_prev - lambda_prev2 r_prev2 = h_prev2 / h @@ -263,7 +240,7 @@ def sample[T: Sample]( @dataclass(frozen=True) -class Adams(HighOrderSampler, Euler): +class Adams(StructuredMultistep, Euler): "Higher order extension to Euler using the Adams-Bashforth coefficients on the model prediction" order: int = 2 @@ -277,27 +254,27 @@ def sample[T: Sample]( sample: T, prediction: T, step: int, - sigma_schedule: NDArray, + schedule: FloatSchedule, sigma_transform: SigmaTransform, noise: T | None = None, previous: tuple[SKSamples[T], ...] = (), ) -> SKSamples[T]: - effective_order = self.effective_order(step, sigma_schedule, previous) + effective_order = self.effective_order(step, schedule, previous) predictions = [prediction, *reversed([p.prediction for p in previous[-effective_order + 1 :]])] weighted_prediction: T = math.sumprod( predictions[:effective_order], # type: ignore - bashforth(effective_order), + common.bashforth(effective_order), ) return replace( - super().sample(sample, weighted_prediction, step, sigma_schedule, sigma_transform, noise, previous), + super().sample(sample, weighted_prediction, step, schedule, sigma_transform, noise, previous), prediction=prediction, ) @dataclass(frozen=True) -class UniP(HighOrderSampler): +class UniP(StructuredMultistep): "Just the solver from UniPC without any correction stages." fast_solve: bool = False @@ -314,24 +291,19 @@ def unisolve[T: Sample]( sample: T, prediction: T, step: int, - sigma_schedule: NDArray, + schedule: FloatSchedule, sigma_transform: SigmaTransform, noise: T | None = None, previous: tuple[SKSamples[T], ...] = (), prediction_next: Sample | None = None, ) -> T: "Passing `prediction_next` is equivalent to UniC, otherwise behaves as UniP" - sigma = self.get_sigma(step, sigma_schedule) - - sigma = self.get_sigma(step, sigma_schedule) - sigma_u, sigma_v = sigma_transform(sigma) - lambda_ = safe_log(sigma_v) - safe_log(sigma_u) - effective_order = self.effective_order(step, sigma_schedule, previous) + sigma_u, sigma_v = common.get_sigma_uv(step, schedule, sigma_transform) + sigma_u_next, sigma_v_next = common.get_sigma_uv(step + 1, schedule, sigma_transform) - sigma_next = self.get_sigma(step + 1, sigma_schedule) - sigma_u_next, sigma_v_next = sigma_transform(sigma_next) - lambda_next = safe_log(sigma_v_next) - safe_log(sigma_u_next) + lambda_ = ln(divf(sigma_v, sigma_u)) + lambda_next = ln(divf(sigma_v_next, sigma_u_next)) h = abs(lambda_next - lambda_) # hh = -h if self.predict_x0 else h @@ -345,11 +317,12 @@ def unisolve[T: Sample]( rks: list[float] = [] D1s: list[Sample] = [] + effective_order = self.effective_order(step, schedule, previous) for n in range(1, effective_order): step_prev_N = step - n prediction_prev_N = previous[-n].prediction - sigma_u_prev_N, sigma_v_prev_N = sigma_transform(self.get_sigma(step_prev_N, sigma_schedule)) - lambda_pO = safe_log(sigma_v_prev_N) - safe_log(sigma_u_prev_N) + sigma_u_prev_N, sigma_v_prev_N = common.get_sigma_uv(step_prev_N, schedule, sigma_transform) + lambda_pO = ln(divf(sigma_v_prev_N, sigma_u_prev_N)) rk = (lambda_pO - lambda_) / h if math.isfinite(rk): # for subnormal rks.append(rk) @@ -401,13 +374,13 @@ def sample[T: Sample]( sample: T, prediction: T, step: int, - sigma_schedule: NDArray, + schedule: FloatSchedule, sigma_transform: SigmaTransform, noise: T | None = None, previous: tuple[SKSamples[T], ...] = (), ) -> SKSamples[T]: return SKSamples( # type: ignore - final=self.unisolve(sample, prediction, step, sigma_schedule, sigma_transform, noise, previous), + final=self.unisolve(sample, prediction, step, schedule, sigma_transform, noise, previous), prediction=prediction, sample=sample, ) @@ -419,7 +392,7 @@ class UniPC(UniP): The additional correction essentially adds +1 order on top of what is set. https://arxiv.org/abs/2302.04867""" - solver: SkrampleSampler | None = None + solver: StructuredSampler | None = None "If not set, defaults to `UniSolver(order=self.order)`" @staticmethod @@ -442,7 +415,7 @@ def sample[T: Sample]( sample: T, prediction: T, step: int, - sigma_schedule: NDArray, + schedule: FloatSchedule, sigma_transform: SigmaTransform, noise: T | None = None, previous: tuple[SKSamples[T], ...] = (), @@ -452,7 +425,7 @@ def sample[T: Sample]( previous[-1].sample, previous[-1].prediction, step - 1, - sigma_schedule, + schedule, sigma_transform, noise, previous[:-1], @@ -463,7 +436,7 @@ def sample[T: Sample]( sample, prediction, step, - sigma_schedule, + schedule, sigma_transform, noise, previous, @@ -471,13 +444,13 @@ def sample[T: Sample]( @dataclass(frozen=True) -class SPC(SkrampleSampler): +class SPC(StructuredSampler): """Simple predictor-corrector. Uses basic blended correction against the previous sample.""" - predictor: SkrampleSampler = Euler() + predictor: StructuredSampler = Euler() "Sampler for the current step" - corrector: SkrampleSampler = Adams(order=4) + corrector: StructuredSampler = Adams(order=4) "Sampler to correct the previous step" bias: float = 0 @@ -502,7 +475,7 @@ def sample[T: Sample]( sample: T, prediction: T, step: int, - sigma_schedule: NDArray, + schedule: FloatSchedule, sigma_transform: SigmaTransform, noise: T | None = None, previous: tuple[SKSamples[T], ...] = (), @@ -518,14 +491,14 @@ def sample[T: Sample]( prior.sample, prior.prediction, step - 1, - sigma_schedule, + schedule, sigma_transform, prior.noise, offset_previous[:-1], ).final if self.adaptive: - p, c = sigma_transform(self.get_sigma(step, sigma_schedule)) + p, c = common.get_sigma_uv(step, schedule, sigma_transform) else: p, c = 0, 0 @@ -543,6 +516,6 @@ def sample[T: Sample]( sample = sample * p + corrected * c # type: ignore return replace( - self.predictor.sample(sample, prediction, step, sigma_schedule, sigma_transform, noise, previous), + self.predictor.sample(sample, prediction, step, schedule, sigma_transform, noise, previous), noise=noise, # the corrector may or may not need noise so we always store ) diff --git a/skrample/sampling/tableaux.py b/skrample/sampling/tableaux.py new file mode 100644 index 0000000..a2ea3ca --- /dev/null +++ b/skrample/sampling/tableaux.py @@ -0,0 +1,300 @@ +import abc +import dataclasses +import enum +import math +from typing import Protocol + +type TabNode = tuple[float, tuple[float, ...]] +type TabWeight = tuple[float, ...] + +type Tableau = tuple[ + tuple[TabNode, ...], + TabWeight, +] +type ExtendedTableau = tuple[ + tuple[TabNode, ...], + TabWeight, + TabWeight, +] + + +def validate_tableau(tab: Tableau | ExtendedTableau, tolerance: float = 1e-15) -> None | IndexError | ValueError: + for index, node in enumerate(tab[0]): + if index != (node_len := len(node[1])): + return IndexError(f"{index=}, {node_len=}, {node=}") + if tolerance < (node_err := abs(node[0] - math.fsum(node[1]))): + return ValueError(f"{tolerance=}, {node_err=}, {node=}") + + for weight in tab[1:]: + if (node_count := len(tab[0])) != (weight_len := len(weight)): + return IndexError(f"{node_count=}, {weight_len=}, {weight=}") + if tolerance < (weight_err := abs(1 - math.fsum(weight))): + return ValueError(f"{tolerance=}, {weight_err=}, {weight=}") + + +def rk2_tableau(c1: float) -> Tableau: + "Create a generic 2nd order Tableau from a given coefficient." + return ( + ( + (0.0, ()), + (c1, (c1,)), + ), + (1 - 1 / (2 * c1), 1 / (2 * c1)), + ) + + +def rk3_tableau(c1: float, c2: float) -> Tableau: + "Create a generic 3rd order Tableau from given coefficients." + return ( + ( + (0.0, ()), + (c1, (c1,)), + (c2, (c2 / c1 * ((c2 - 3 * c1 * (1 - c1)) / (3 * c1 - 2)), -c2 / c1 * ((c2 - c1) / (3 * c1 - 2)))), + ), + ( + 1 - (3 * c1 + 3 * c2 - 2) / (6 * c1 * c2), + (3 * c2 - 2) / (6 * c1 * (c2 - c1)), + (2 - 3 * c1) / (6 * c2 * (c2 - c1)), + ), + ) + + +def rk4_tableau(c1: float, c2: float) -> Tableau: + """Create a generic 4th order Tableau from 3 coefficients. + 1/2, 1/2 (Classic) is a special case and cannot be computed using this function. + https://pages.hmc.edu/ruye/MachineLearning/lectures/ch5/node10.html""" + + ### Automatically transcribed from website using QwenVL 235B Thinking + + D = 6 * c1 * c2 - 4 * (c1 + c2) + 3 + + # Compute b coefficients + b2 = (2 * c2 - 1) / (12 * c1 * (c2 - c1) * (1 - c1)) + b3 = (2 * c1 - 1) / (12 * c2 * (c1 - c2) * (1 - c2)) + b4 = D / (12 * (1 - c1) * (1 - c2)) + b1 = 1 - b2 - b3 - b4 + + # Compute a31 and a32 + a32 = c2 * (c1 - c2) / (2 * c1 * (2 * c1 - 1)) + a31 = c2 - a32 + + # Compute a41, a42, a43 + num_a42 = (4 * c2**2 - 5 * c2 - c1 + 2) * (1 - c1) + denom_a42 = 2 * c1 * (c1 - c2) * D + a42 = num_a42 / denom_a42 + + num_a43 = (2 * c1 - 1) * (1 - c1) * (1 - c2) + denom_a43 = c2 * (c1 - c2) * D + a43 = num_a43 / denom_a43 + + a41 = 1 - a42 - a43 + + stages = ( + (0.0, ()), + (c1, (c1,)), # a21 = c1 + (c2, (a31, a32)), + (1.0, (a41, a42, a43)), + ) + + b_vector = (b1, b2, b3, b4) + + return (stages, b_vector) + + +class TableauProvider[T: Tableau | ExtendedTableau](Protocol): + @abc.abstractmethod + def tableau(self) -> T: + raise NotImplementedError + + +RK1: Tableau = ( + ((0, ()),), + (1,), +) +"Euler method" + + +@dataclasses.dataclass(frozen=True) +class CustomTableau[T: Tableau | ExtendedTableau](TableauProvider[T]): + custom: T + + def tableau(self) -> T: + return self.custom + + +@dataclasses.dataclass(frozen=True) +class RK2Custom(TableauProvider): + c1: float = 1.0 + + def tableau(self) -> Tableau: + return rk2_tableau(self.c1) + + +@dataclasses.dataclass(frozen=True) +class RK3Custom(TableauProvider): + c1: float = 1 / 2 + c2: float = 1.0 + + def tableau(self) -> Tableau: + return rk3_tableau(self.c1, self.c2) + + +@dataclasses.dataclass(frozen=True) +class RK4Custom(TableauProvider): + c1: float = 1 / 3 + c2: float = 2 / 3 + + def tableau(self) -> Tableau: + return rk4_tableau(self.c1, self.c2) + + +@enum.unique +class RK2(enum.Enum): + "2nd order, 2 calls" + + Heun = rk2_tableau(1) + Mid = rk2_tableau(1 / 2) + Ralston = rk2_tableau(2 / 3) + + def tableau(self) -> Tableau: + return self.value + + +@enum.unique +class RK3(enum.Enum): + "3rd order, 3 calls" + + Kutta = rk3_tableau(1 / 2, 1) + Heun = rk3_tableau(1 / 3, 2 / 3) + Ralston = rk3_tableau(1 / 2, 3 / 4) + Wray = rk3_tableau(8 / 15, 2 / 3) + SSPRK3 = rk3_tableau(1, 1 / 2) + + def tableau(self) -> Tableau: + return self.value + + +@enum.unique +class RK4(enum.Enum): + "4th order, 4 calls" + + Classic = ( + ( + (0, ()), + (1 / 2, (1 / 2,)), + (1 / 2, (0, 1 / 2)), + (1, (0, 0, 1)), + ), + (1 / 6, 1 / 3, 1 / 3, 1 / 6), + ) + Eighth = rk4_tableau(1 / 3, 2 / 3) + Ralston = rk4_tableau(2 / 5, (14 - 3 * math.sqrt(5)) / 16) + + def tableau(self) -> Tableau: + return self.value + + +@enum.unique +class RKZ(enum.Enum): + """Tableaux provided by this method do not have clean generic forms, and require more calls than their order. + Since these are rare, they are all categorized into one enum""" + + Nystrom5 = ( + ( + (0, ()), + (1 / 3, (1 / 3,)), + (2 / 5, (4 / 25, 6 / 25)), + (1, (1 / 4, -3, 15 / 4)), + (2 / 3, (2 / 27, 10 / 9, -50 / 81, 8 / 81)), + (4 / 5, (2 / 25, 12 / 25, 2 / 15, 8 / 75, 0)), + ), + (23 / 192, 0, 125 / 192, 0, -27 / 64, 125 / 192), + ) + + def tableau(self) -> Tableau: + return self.value + + +@enum.unique +class RKE2(enum.Enum): + Heun = ( + ( + (0, ()), + (1, (1,)), + ), + (1 / 2, 1 / 2), + (1, 0), + ) + Fehlberg = ( + ( + (0, ()), + (1 / 2, (1 / 2,)), + (1, (1 / 256, 255 / 256)), + ), + (1 / 512, 255 / 256, 1 / 512), + (1 / 256, 255 / 256, 0), + ) + + def tableau(self) -> ExtendedTableau: + return self.value + + +@enum.unique +class RKE3(enum.Enum): + BogackiShampine = ( + ( + (0, ()), + (1 / 2, (1 / 2,)), + (3 / 4, (0, 3 / 4)), + (1, (2 / 9, 1 / 3, 4 / 9)), + ), + (2 / 9, 1 / 3, 4 / 9, 0), + (7 / 24, 1 / 4, 1 / 3, 1 / 8), + ) + + def tableau(self) -> ExtendedTableau: + return self.value + + +@enum.unique +class RKE5(enum.Enum): + Fehlberg = ( + ( + (0, ()), + (1 / 4, (1 / 4,)), + (3 / 8, (3 / 32, 9 / 32)), + (12 / 13, (1932 / 2197, -7200 / 2197, 7296 / 2197)), + (1, (439 / 216, -8, 3680 / 513, -845 / 4104)), + (1 / 2, (-8 / 27, 2, -3544 / 2565, 1859 / 4104, -11 / 40)), + ), + (16 / 135, 0, 6656 / 12825, 28561 / 56430, -9 / 50, 2 / 55), + (25 / 216, 0, 1408 / 2565, 2197 / 4104, -1 / 5, 0), + ) + CashKarp = ( + ( + (0, ()), + (1 / 5, (1 / 5,)), + (3 / 10, (3 / 40, 9 / 40)), + (3 / 5, (3 / 10, -9 / 10, 6 / 5)), + (1, (-11 / 54, 5 / 2, -70 / 27, 35 / 27)), + (7 / 8, (1631 / 55296, 175 / 512, 575 / 13824, 44275 / 110592, 253 / 4096)), + ), + (37 / 378, 0, 250 / 621, 125 / 594, 0, 512 / 1771), + (2825 / 27648, 0, 18575 / 48384, 13525 / 55296, 277 / 14336, 1 / 4), + ) + DormandPrince = ( + ( + (0, ()), + (1 / 5, (1 / 5,)), + (3 / 10, (3 / 40, 9 / 40)), + (4 / 5, (44 / 45, -56 / 15, 32 / 9)), + (8 / 9, (19372 / 6561, -25360 / 2187, 64448 / 6561, -212 / 729)), + (1, (9017 / 3168, -355 / 33, 46732 / 5247, 49 / 176, -5103 / 18656)), + (1, (35 / 384, 0, 500 / 1113, 125 / 192, -2187 / 6784, 11 / 84)), + ), + (35 / 384, 0, 500 / 1113, 125 / 192, -2187 / 6784, 11 / 84, 0), + (5179 / 57600, 0, 7571 / 16695, 393 / 640, -92097 / 339200, 187 / 2100, 1 / 40), + ) + + def tableau(self) -> ExtendedTableau: + return self.value diff --git a/skrample/scheduling.py b/skrample/scheduling.py index affc9fa..515f5aa 100644 --- a/skrample/scheduling.py +++ b/skrample/scheduling.py @@ -1,19 +1,27 @@ import math from abc import ABC, abstractmethod +from collections.abc import Sequence from dataclasses import dataclass, replace from functools import lru_cache import numpy as np from numpy.typing import NDArray -from skrample.common import SigmaTransform, normalize, regularize, sigma_complement, sigma_polar, sigmoid +from skrample.common import FloatSchedule, SigmaTransform, normalize, regularize, sigma_complement, sigma_polar, sigmoid @lru_cache -def schedule_lru(schedule: "SkrampleSchedule", steps: int) -> NDArray[np.float64]: +def np_schedule_lru(schedule: "SkrampleSchedule", steps: int) -> NDArray[np.float64]: """Globally cached function for SkrampleSchedule.schedule(steps). Prefer moving SkrampleScheudle.schedule() outside of any loops if possible.""" - return schedule.schedule(steps) + return schedule.schedule_np(steps) + + +@lru_cache +def schedule_lru(schedule: "SkrampleSchedule", steps: int) -> FloatSchedule: + """Globally cached function for SkrampleSchedule.schedule(steps). + Prefer moving SkrampleScheudle.schedule() outside of any loops if possible.""" + return tuple(map(tuple, np_schedule_lru(schedule, steps).tolist())) @dataclass(frozen=True) @@ -26,20 +34,33 @@ def sigma_transform(self) -> SigmaTransform: "SigmaTransform required for a given noise schedule" @abstractmethod - def schedule(self, steps: int) -> NDArray[np.float64]: + def schedule_np(self, steps: int) -> NDArray[np.float64]: """Return the full noise schedule, timesteps stacked on top of sigmas. Excludes the trailing zero""" - def timesteps(self, steps: int) -> NDArray[np.float64]: + def timesteps_np(self, steps: int) -> NDArray[np.float64]: "Just the timesteps component as a 1-d array" - return self.schedule(steps)[:, 0] + return self.schedule_np(steps)[:, 0] - def sigmas(self, steps: int) -> NDArray[np.float64]: + def sigmas_np(self, steps: int) -> NDArray[np.float64]: "Just the sigmas component as a 1-d array" - return self.schedule(steps)[:, 1] + return self.schedule_np(steps)[:, 1] + + def schedule(self, steps: int) -> FloatSchedule: + """Return the full noise schedule, [(timestep, sigma), ...) + Excludes the trailing zero""" + return tuple(map(tuple, self.schedule_np(steps).tolist())) + + def timesteps(self, steps: int) -> Sequence[float]: + "Just the timesteps component" + return self.timesteps_np(steps).tolist() + + def sigmas(self, steps: int) -> Sequence[float]: + "Just the sigmas component" + return self.sigmas_np(steps).tolist() def __call__(self, steps: int) -> NDArray[np.float64]: - return self.schedule(steps) + return self.schedule_np(steps) @dataclass(frozen=True) @@ -101,10 +122,10 @@ def sigmas_to_timesteps(self, sigmas: NDArray[np.float64]) -> NDArray[np.float64 t = (1 - w) * low_idx + w * high_idx return t - def timesteps(self, steps: int) -> NDArray[np.float64]: + def timesteps_np(self, steps: int) -> NDArray[np.float64]: # # https://arxiv.org/abs/2305.08891 Table 2 if self.uniform: - return np.linspace(self.base_timesteps - 1, 0, steps + 1, dtype=np.float64).round()[:-1] + return np.linspace(self.base_timesteps - 1, 0, steps, endpoint=False, dtype=np.float64).round() else: # They use a truncated ratio for ...reasons? return np.flip(np.arange(0, steps, dtype=np.float64) * (self.base_timesteps // steps)).round() @@ -126,9 +147,9 @@ def alphas_cumprod(self, betas: NDArray[np.float64]) -> NDArray[np.float64]: def scaled_sigmas(self, alphas_cumprod: NDArray[np.float64]) -> NDArray[np.float64]: return ((1 - alphas_cumprod) / alphas_cumprod) ** 0.5 - def schedule(self, steps: int) -> NDArray[np.float64]: + def schedule_np(self, steps: int) -> NDArray[np.float64]: sigmas = self.scaled_sigmas(self.alphas_cumprod(self.betas())) - timesteps = self.timesteps(steps) + timesteps = self.timesteps_np(steps) sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) return np.stack([timesteps, sigmas], axis=1) @@ -188,11 +209,11 @@ def sigma_transform(self) -> SigmaTransform: def sigmas_to_timesteps(self, sigmas: NDArray[np.float64]) -> NDArray[np.float64]: return normalize(sigmas, self.sigma_start) * self.base_timesteps - def sigmas(self, steps: int) -> NDArray[np.float64]: + def sigmas_np(self, steps: int) -> NDArray[np.float64]: return np.linspace(self.sigma_start, 0, steps, endpoint=False, dtype=np.float64) - def schedule(self, steps: int) -> NDArray[np.float64]: - sigmas = self.sigmas(steps) + def schedule_np(self, steps: int) -> NDArray[np.float64]: + sigmas = self.sigmas_np(steps) timesteps = self.sigmas_to_timesteps(sigmas) return np.stack([timesteps, sigmas], axis=1) @@ -207,7 +228,7 @@ class SigmoidCDF(Linear): cdf_scale: float = 3 "Multiply the inverse CDF output before the sigmoid function is applied" - def sigmas(self, steps: int) -> NDArray[np.float64]: + def sigmas_np(self, steps: int) -> NDArray[np.float64]: from scipy.stats import norm step_peak = 1 / (steps * math.pi / 2) @@ -303,8 +324,8 @@ def find_split[T: "ScheduleModifier"]( class NoMod(ScheduleModifier): "Does nothing. For generic programming against ScheduleModifier" - def schedule(self, steps: int) -> NDArray[np.float64]: - return self.base.schedule(steps) + def schedule_np(self, steps: int) -> NDArray[np.float64]: + return self.base.schedule_np(steps) @dataclass(frozen=True) @@ -312,8 +333,8 @@ class FlowShift(ScheduleModifier): shift: float = 3.0 """Amount to shift noise schedule by.""" - def schedule(self, steps: int) -> NDArray[np.float64]: - sigmas = self.base.sigmas(steps) + def schedule_np(self, steps: int) -> NDArray[np.float64]: + sigmas = self.base.sigmas_np(steps) start = sigmas.max().item() sigmas = self.shift / (self.shift + (start / sigmas - 1)) * start @@ -330,8 +351,8 @@ class Karras(ScheduleModifier): rho: float = 7.0 "Ramp power" - def schedule(self, steps: int) -> NDArray[np.float64]: - sigmas = self.base.sigmas(steps) + def schedule_np(self, steps: int) -> NDArray[np.float64]: + sigmas = self.base.sigmas_np(steps) sigma_min = sigmas[-1].item() sigma_max = sigmas[0].item() @@ -351,8 +372,8 @@ class Exponential(ScheduleModifier): rho: float = 1.0 "Ramp power" - def schedule(self, steps: int) -> NDArray[np.float64]: - sigmas = self.base.sigmas(steps) + def schedule_np(self, steps: int) -> NDArray[np.float64]: + sigmas = self.base.sigmas_np(steps) sigma_min = sigmas[-1].item() sigma_max = sigmas[0].item() @@ -371,10 +392,10 @@ class Beta(ScheduleModifier): alpha: float = 0.6 beta: float = 0.6 - def schedule(self, steps: int) -> NDArray[np.float64]: + def schedule_np(self, steps: int) -> NDArray[np.float64]: import scipy - sigmas = self.base.sigmas(steps) + sigmas = self.base.sigmas_np(steps) sigma_min = sigmas[-1].item() sigma_max = sigmas[0].item() @@ -397,11 +418,11 @@ class Hyper(ScheduleModifier): tail: bool = True "Include the trailing end to make an S curve" - def schedule(self, steps: int) -> NDArray[np.float64]: + def schedule_np(self, steps: int) -> NDArray[np.float64]: if abs(self.scale) <= 1e-8: - return self.base.schedule(steps) + return self.base.schedule_np(steps) - sigmas = self.base.sigmas(steps) + sigmas = self.base.sigmas_np(steps) start = sigmas[0].item() sigmas = normalize(sigmas, start) # Base -> 1..0 diff --git a/tests/diffusers_map.py b/tests/diffusers_map.py index 2d3a620..f4eb279 100644 --- a/tests/diffusers_map.py +++ b/tests/diffusers_map.py @@ -10,13 +10,15 @@ from diffusers.schedulers.scheduling_unipc_multistep import UniPCMultistepScheduler from testing_common import FLOW_CONFIG, SCALED_CONFIG -from skrample.common import predict_epsilon as EPSILON -from skrample.common import predict_flow as FLOW -from skrample.common import predict_velocity as VELOCITY from skrample.diffusers import SkrampleWrapperScheduler -from skrample.sampling import DPM, Adams, Euler, UniPC +from skrample.sampling.models import FlowModel, NoiseModel, VelocityModel +from skrample.sampling.structured import DPM, Adams, Euler, UniPC from skrample.scheduling import Beta, Exponential, FlowShift, Karras, Linear, Scaled +EPSILON = NoiseModel() +FLOW = FlowModel() +VELOCITY = VelocityModel() + def check_wrapper(wrapper: SkrampleWrapperScheduler, scheduler: ConfigMixin, params: list[str] = []) -> None: a, b = wrapper, SkrampleWrapperScheduler.from_diffusers_config(scheduler) diff --git a/tests/diffusers_samplers.py b/tests/diffusers_samplers.py index 30efb43..0621522 100644 --- a/tests/diffusers_samplers.py +++ b/tests/diffusers_samplers.py @@ -1,23 +1,51 @@ +import dataclasses +import itertools from inspect import signature +import numpy as np +import pytest import torch from diffusers.schedulers.scheduling_dpmsolver_multistep import DPMSolverMultistepScheduler from diffusers.schedulers.scheduling_euler_ancestral_discrete import EulerAncestralDiscreteScheduler from diffusers.schedulers.scheduling_euler_discrete import EulerDiscreteScheduler from diffusers.schedulers.scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler +from diffusers.schedulers.scheduling_flow_match_heun_discrete import FlowMatchHeunDiscreteScheduler +from diffusers.schedulers.scheduling_heun_discrete import HeunDiscreteScheduler from diffusers.schedulers.scheduling_unipc_multistep import UniPCMultistepScheduler from testing_common import FLOW_CONFIG, SCALED_CONFIG, compare_tensors -from skrample.common import Predictor, sigma_complement, sigma_polar -from skrample.common import predict_epsilon as EPSILON -from skrample.common import predict_flow as FLOW -from skrample.common import predict_velocity as VELOCITY -from skrample.sampling import DPM, Euler, SkrampleSampler, SKSamples, UniPC +from skrample.common import FloatSchedule, SigmaTransform, sigma_complement, sigma_polar +from skrample.sampling.functional import RKUltra +from skrample.sampling.models import DiffusionModel, FlowModel, NoiseModel, VelocityModel +from skrample.sampling.structured import DPM, Euler, SKSamples, StructuredSampler, UniPC +from skrample.sampling.tableaux import RK2 +from skrample.scheduling import SkrampleSchedule DiffusersScheduler = ( - EulerDiscreteScheduler | DPMSolverMultistepScheduler | FlowMatchEulerDiscreteScheduler | UniPCMultistepScheduler + EulerDiscreteScheduler + | DPMSolverMultistepScheduler + | FlowMatchEulerDiscreteScheduler + | FlowMatchHeunDiscreteScheduler + | UniPCMultistepScheduler ) +EPSILON = NoiseModel() +FLOW = FlowModel() +VELOCITY = VelocityModel() + + +@dataclasses.dataclass(frozen=True) +class FixedSchedule(SkrampleSchedule): + fixed_schedule: FloatSchedule + transform: SigmaTransform + + def schedule_np(self, steps: int) -> np.typing.NDArray[np.float64]: + return np.array(self.fixed_schedule, dtype=np.float64) + + @property + def sigma_transform(self) -> SigmaTransform: + return self.transform + def fake_model(t: torch.Tensor) -> torch.Tensor: t @= torch.randn(t.shape, generator=torch.Generator(t.device).manual_seed(-1), dtype=t.dtype) @@ -25,9 +53,9 @@ def fake_model(t: torch.Tensor) -> torch.Tensor: def dual_sample( - a: SkrampleSampler, + a: StructuredSampler, b: DiffusersScheduler, - predictor: Predictor, + model_transform: DiffusionModel, steps: range, mu: float | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: @@ -48,10 +76,10 @@ def dual_sample( if isinstance(b, FlowMatchEulerDiscreteScheduler): b_sample = b.scale_noise(sample=b_sample, timestep=timestep.unsqueeze(0), noise=initial_noise) - sigma_transform = sigma_complement else: b_sample = b.add_noise(original_samples=b_sample, noise=initial_noise, timesteps=timestep.unsqueeze(0)) - sigma_transform = sigma_polar + + sigma_transform = sigma_complement if isinstance(model_transform, FlowModel) else sigma_polar a_sample = a.merge_noise(a_sample, initial_noise, sigma.item(), sigma_transform) @@ -62,10 +90,10 @@ def dual_sample( timestep, sigma = schedule[step] - a_output = predictor( + a_output = model_transform.to_x( a_sample, fake_model(a.scale_input(a_sample, sigma.item(), sigma_transform)), sigma.item(), sigma_transform ) - sampled = a.sample(a_sample, a_output, step, schedule[:, 1].numpy(), sigma_transform, noise, prior_steps) + sampled = a.sample(a_sample, a_output, step, schedule.numpy().tolist(), sigma_transform, noise, prior_steps) a_sample = sampled.final prior_steps.append(sampled) @@ -83,16 +111,16 @@ def dual_sample( def compare_samplers( - a: SkrampleSampler, + a: StructuredSampler, b: DiffusersScheduler, - p: Predictor = EPSILON, + t: DiffusionModel = EPSILON, mu: float | None = None, margin: float = 1e-8, message: str = "", ) -> None: for step_range in [range(0, 2), range(0, 11), range(0, 201), range(3, 6), range(2, 23), range(31, 200)]: compare_tensors( - *dual_sample(a, b, p, step_range, mu), + *dual_sample(a, b, t, step_range, mu), message=str(step_range) + (" | " + message if message else ""), margin=margin, ) @@ -107,7 +135,7 @@ def test_euler() -> None: prediction_type=predictor[1], ), predictor[0], - message=predictor[0].__name__, + message=type(predictor[0]).__name__, ) @@ -120,7 +148,7 @@ def test_euler_ancestral() -> None: prediction_type=predictor[1], ), predictor[0], - message=predictor[0].__name__, + message=type(predictor[0]).__name__, ) @@ -145,9 +173,10 @@ def test_dpm() -> None: final_sigmas_type="zero", solver_order=order, prediction_type=predictor[1], + use_flow_sigmas=predictor[0] == FLOW, ), predictor[0], - message=f"{predictor[0].__name__} o{order} s{stochastic}", + message=f"{type(predictor[0]).__name__} o{order} s{stochastic}", ) @@ -164,7 +193,90 @@ def test_unipc() -> None: final_sigmas_type="zero", solver_order=order, prediction_type=predictor[1], + use_flow_sigmas=predictor[0] == FLOW, ), predictor[0], - message=f"{predictor[0].__name__} o{order}", + message=f"{type(predictor[0]).__name__} o{order}", ) + + +@pytest.mark.parametrize( + ("model_transform", "derivative_transform", "sigma_transform", "diffusers_scheduler", "steps"), + ( + (mt, dt, st, ds, s) + for (mt, dt, st, ds), s in itertools.product( + ( + ( + NoiseModel(), + NoiseModel(), + sigma_polar, + HeunDiscreteScheduler.from_config(SCALED_CONFIG, prediction_type="epsilon"), + ), + ( + VelocityModel(), + NoiseModel(), + sigma_polar, + HeunDiscreteScheduler.from_config(SCALED_CONFIG, prediction_type="v_prediction"), + ), + ( + FlowModel(), + FlowModel(), + sigma_complement, + FlowMatchHeunDiscreteScheduler.from_config(FLOW_CONFIG), + ), + ), + (2, 3, 30, 31, 200, 201), + ) + ), +) +def test_heun( + model_transform: DiffusionModel, + derivative_transform: DiffusionModel, + sigma_transform: SigmaTransform, + diffusers_scheduler: HeunDiscreteScheduler | FlowMatchHeunDiscreteScheduler, + steps: int, +) -> None: + diffusers_scheduler.set_timesteps(steps) + + fixed: list[tuple[float, float]] = [] + for t in zip(diffusers_scheduler.timesteps.tolist(), diffusers_scheduler.sigmas.tolist()): + if t not in fixed: + fixed.append(t) + + skrample_sampler = RKUltra( + FixedSchedule(fixed, sigma_transform), + order=2, + providers=RKUltra.providers | {2: RK2.Heun}, + derivative_transform=derivative_transform, + ) + + sk_sample = torch.zeros([1, 4, 128, 128], dtype=torch.float32) + seed = torch.manual_seed(0) + + df_noise = torch.randn(sk_sample.shape, generator=seed.clone_state(), dtype=sk_sample.dtype) + # df_sample = df.add_noise(sk_sample.clone(), df_noise, df.timesteps[0:1]) + + df_sample = df_noise.clone() + if isinstance(diffusers_scheduler, HeunDiscreteScheduler): + df_sample *= diffusers_scheduler.init_noise_sigma + + for t in diffusers_scheduler.timesteps: + model_input: torch.Tensor = df_sample + if isinstance(diffusers_scheduler, HeunDiscreteScheduler): + model_input = diffusers_scheduler.scale_model_input(df_sample, timestep=t) + + df_sample: torch.Tensor = diffusers_scheduler.step( + fake_model(model_input), # pyright: ignore [reportArgumentType] + sample=df_sample, # pyright: ignore [reportArgumentType] + timestep=t, # pyright: ignore [reportArgumentType] + )[0] + + sk_sample = skrample_sampler.generate_model( + lambda x, t, s: fake_model(x), + model_transform, + lambda: torch.randn(sk_sample.shape, generator=seed, dtype=sk_sample.dtype), + steps, + initial=sk_sample, + ) + + compare_tensors(df_sample, sk_sample, margin=1e-8) diff --git a/tests/diffusers_schedules.py b/tests/diffusers_schedules.py index a3220c2..41c0a23 100644 --- a/tests/diffusers_schedules.py +++ b/tests/diffusers_schedules.py @@ -24,13 +24,13 @@ def compare_schedules( b.set_timesteps(num_inference_steps=steps) compare_tensors( - torch.from_numpy(a.timesteps(steps)), + torch.from_numpy(a.timesteps_np(steps)), b.timesteps, f"TIMESTEPS @ {steps}", margin=ts_margin, ) compare_tensors( - torch.from_numpy(a.sigmas(steps)), + torch.from_numpy(a.sigmas_np(steps)), b.sigmas[:-1], f"SIGMAS @ {steps}", margin=sig_margin, diff --git a/tests/miscellaneous.py b/tests/miscellaneous.py index dc6d4ac..4ad3254 100644 --- a/tests/miscellaneous.py +++ b/tests/miscellaneous.py @@ -1,27 +1,51 @@ +import itertools import math import random +from collections.abc import Sequence from dataclasses import replace import numpy as np +import pytest import torch from testing_common import compare_tensors -from skrample.common import MergeStrategy, bashforth, sigma_complement, sigmoid, softmax, spowf +from skrample.common import ( + MergeStrategy, + SigmaTransform, + bashforth, + euler, + sigma_complement, + sigma_polar, + sigmoid, + softmax, + spowf, +) from skrample.diffusers import SkrampleWrapperScheduler -from skrample.sampling import ( +from skrample.sampling import tableaux +from skrample.sampling.interface import StructuredFunctionalAdapter +from skrample.sampling.models import ( + DataModel, + DiffusionModel, + FlowModel, + ModelConvert, + NoiseModel, + ScaleX, + VelocityModel, +) +from skrample.sampling.structured import ( DPM, SPC, Adams, Euler, - HighOrderSampler, - SkrampleSampler, SKSamples, - StochasticSampler, + StructuredMultistep, + StructuredSampler, + StructuredStochastic, UniPC, ) -from skrample.scheduling import Beta, FlowShift, Karras, Linear, Scaled, SigmoidCDF +from skrample.scheduling import Beta, FlowShift, Karras, Linear, Scaled, ScheduleCommon, ScheduleModifier, SigmoidCDF -ALL_SAMPLERS = [ +ALL_STRUCTURED: Sequence[type[StructuredSampler]] = [ Adams, DPM, Euler, @@ -29,44 +53,122 @@ UniPC, ] -ALL_SCHEDULES = [ +ALL_SCHEDULES: Sequence[type[ScheduleCommon]] = [ Linear, Scaled, SigmoidCDF, ] -ALL_MODIFIERS = [ +ALL_MODIFIERS: Sequence[type[ScheduleModifier]] = [ Beta, FlowShift, Karras, ] +ALL_MODELS: Sequence[type[DiffusionModel]] = [ + DataModel, + NoiseModel, + FlowModel, + VelocityModel, +] + +ALL_FAKE_MODELS: Sequence[type[DiffusionModel]] = [ + ScaleX, +] + +ALL_TRANSFROMS: Sequence[SigmaTransform] = [ + sigma_complement, + sigma_polar, +] + def test_sigmas_to_timesteps() -> None: for schedule in [*(cls() for cls in ALL_SCHEDULES), Scaled(beta_scale=1)]: # base schedules - timesteps = schedule.timesteps(123) - timesteps_inv = schedule.sigmas_to_timesteps(schedule.sigmas(123)) + timesteps = schedule.timesteps_np(123) + timesteps_inv = schedule.sigmas_to_timesteps(schedule.sigmas_np(123)) compare_tensors(torch.tensor(timesteps), torch.tensor(timesteps_inv), margin=0) # shocked this rounds good +@pytest.mark.parametrize( + ("model_type", "sigma_transform"), + itertools.product(ALL_MODELS, ALL_TRANSFROMS), +) +def test_model_transforms(model_type: type[DiffusionModel], sigma_transform: SigmaTransform) -> None: + model_transform = model_type() + sample = 0.8 + output = 0.3 + sigma = 0.2 + + x = model_transform.to_x(sample, output, sigma, sigma_transform) + o = model_transform.from_x(sample, x, sigma, sigma_transform) + assert abs(output - o) < 1e-12 + + sigma_next = 0.05 + for sigma_next in 0.05, 0: # extra 0 to validate X̂ + snr = euler( + sample, model_transform.to_x(sample, output, sigma, sigma_transform), sigma, sigma_next, sigma_transform + ) + df = model_transform.forward(sample, output, sigma, sigma_next, sigma_transform) + assert abs(snr - df) < 1e-12 + + ob = model_transform.backward(sample, df, sigma, sigma_next, sigma_transform) + assert abs(o - ob) < 1e-12 + + +@pytest.mark.parametrize( + ("model_from", "model_to", "sigma_transform", "sigma_to"), + itertools.product(ALL_MODELS, ALL_MODELS + ALL_FAKE_MODELS, ALL_TRANSFROMS, (0.05, 0.0)), +) +def test_model_convert( + model_from: type[DiffusionModel], + model_to: type[DiffusionModel], + sigma_transform: SigmaTransform, + sigma_to: float, +) -> None: + convert = ModelConvert(model_from(), model_to()) + sample = 0.8 + output = 0.3 + sigma_from = 0.2 + + def model(x: float, t: float, s: float) -> float: + return output + + x_from = convert.transform_from.forward( + sample, + model(sample, sigma_from, sigma_from), + sigma_from, + sigma_to, + sigma_transform, + ) + x_to = convert.transform_to.forward( + sample, + convert.wrap_model_call(model, sigma_transform)(sample, sigma_from, sigma_from), + sigma_from, + sigma_to, + sigma_transform, + ) + + assert abs(x_from - x_to) < 1e-12 + + def test_sampler_generics() -> None: eps = 1e-12 for sampler in [ - *(cls() for cls in ALL_SAMPLERS), - *(cls(order=cls.max_order()) for cls in ALL_SAMPLERS if issubclass(cls, HighOrderSampler)), + *(cls() for cls in ALL_STRUCTURED), + *(cls(order=cls.max_order()) for cls in ALL_STRUCTURED if issubclass(cls, StructuredMultistep)), ]: for schedule in Scaled(), FlowShift(Linear()): i, o = random.random(), random.random() prev = [SKSamples(random.random(), random.random(), random.random()) for _ in range(9)] - scalar = sampler.sample(i, o, 4, schedule.sigmas(10), schedule.sigma_transform, previous=prev).final + scalar = sampler.sample(i, o, 4, schedule.schedule(10), schedule.sigma_transform, previous=prev).final # Enforce FP64 as that should be equivalent to python scalar ndarr = sampler.sample( np.array([i], dtype=np.float64), np.array([o], dtype=np.float64), 4, - schedule.sigmas(10), + schedule.schedule(10), schedule.sigma_transform, previous=prev, # type: ignore ).final.item() @@ -75,7 +177,7 @@ def test_sampler_generics() -> None: torch.tensor([i], dtype=torch.float64), torch.tensor([o], dtype=torch.float64), 4, - schedule.sigmas(10), + schedule.schedule(10), schedule.sigma_transform, previous=prev, # type: ignore ).final.item() @@ -94,9 +196,9 @@ def test_mu_set() -> None: def test_require_previous() -> None: - samplers: list[SkrampleSampler] = [] - for cls in ALL_SAMPLERS: - if issubclass(cls, HighOrderSampler): + samplers: list[StructuredSampler] = [] + for cls in ALL_STRUCTURED: + if issubclass(cls, StructuredMultistep): samplers.extend([cls(order=o + 1) for o in range(cls.min_order(), cls.max_order())]) else: samplers.append(cls()) @@ -115,7 +217,7 @@ def test_require_previous() -> None: sample, prediction, 31, - Linear().sigmas(100), + Linear().schedule(100), sigma_complement, None, previous, @@ -124,7 +226,7 @@ def test_require_previous() -> None: sample, prediction, 31, - Linear().sigmas(100), + Linear().schedule(100), sigma_complement, None, previous[len(previous) - sampler.require_previous :], @@ -134,9 +236,9 @@ def test_require_previous() -> None: def test_require_noise() -> None: - samplers: list[SkrampleSampler] = [] - for cls in ALL_SAMPLERS: - if issubclass(cls, StochasticSampler): + samplers: list[StructuredSampler] = [] + for cls in ALL_STRUCTURED: + if issubclass(cls, StructuredStochastic): samplers.extend([cls(add_noise=n) for n in (False, True)]) else: samplers.append(cls()) @@ -156,7 +258,7 @@ def test_require_noise() -> None: sample, prediction, 31, - Linear().sigmas(100), + Linear().schedule(100), sigma_complement, noise, previous, @@ -165,7 +267,7 @@ def test_require_noise() -> None: sample, prediction, 31, - Linear().sigmas(100), + Linear().schedule(100), sigma_complement, noise if sampler.require_noise else None, previous, @@ -177,6 +279,42 @@ def test_require_noise() -> None: assert a == b, (sampler, sampler.require_noise) +def test_functional_adapter() -> None: + def fake_model(x: float, _: float, s: float) -> float: + return x + math.sin(x) * s + + samplers: list[StructuredSampler] = [DPM(n, o) for o in range(1, 4) for n in [False, True]] + for schedule in Linear(), Scaled(): + for sampler in samplers: + for steps in [1, 3, 4, 9, 512, 999]: + sample = 1.5 + adapter = StructuredFunctionalAdapter(schedule, sampler) + noise = [random.random() for _ in range(steps)] + + rng = iter(noise) + model_transform = FlowModel() + sample_f = adapter.sample_model(sample, fake_model, model_transform, steps, rng=lambda: next(rng)) + + rng = iter(noise) + float_schedule = schedule.schedule(steps) + sample_s = sample + previous: list[SKSamples[float]] = [] + for n, (t, s) in enumerate(float_schedule): + results = sampler.sample( + sample_s, + model_transform.to_x(sample_s, fake_model(sample_s, t, s), s, schedule.sigma_transform), + n, + float_schedule, + schedule.sigma_transform, + next(rng), + tuple(previous), + ) + previous.append(results) + sample_s = results.final + + assert sample_s == sample_f, (sample_s, sample_f, sampler, schedule, steps) + + def test_bashforth() -> None: for n, coeffs in enumerate( np.array(c) for c in ((1,), (3 / 2, -1 / 2), (23 / 12, -4 / 3, 5 / 12), (55 / 24, -59 / 24, 37 / 24, -3 / 8)) @@ -184,6 +322,62 @@ def test_bashforth() -> None: assert np.allclose(coeffs, np.array(bashforth(n + 1)), atol=1e-12, rtol=1e-12) +def test_tableau_providers() -> None: + for provider in [ + tableaux.RK2, + tableaux.RK3, + tableaux.RK4, + tableaux.RKZ, + tableaux.RKE2, + tableaux.RKE3, + tableaux.RKE5, + ]: + for variant in provider: + if error := tableaux.validate_tableau(variant.tableau()): + raise error + + +def flat_tableau(t: tuple[float | tuple[float | tuple[float | tuple[float, ...], ...], ...], ...]) -> tuple[float, ...]: + return tuple(z for y in (flat_tableau(x) if isinstance(x, tuple) else (x,) for x in t) for z in y) + + +def tableau_distance(a: tableaux.Tableau, b: tableaux.Tableau) -> float: + return abs(np.subtract(flat_tableau(a), flat_tableau(b))).max().item() + + +def test_rk2_tableau() -> None: + assert ( + tableau_distance( + ( # Ralston + ( + (0.0, ()), + (2 / 3, (2 / 3,)), + ), + (1 / 4, 3 / 4), + ), + tableaux.rk2_tableau(2 / 3), + ) + < 1e-20 + ) + + +def test_rk3_tableau() -> None: + assert ( + tableau_distance( + ( # Wray + ( + (0.0, ()), + (8 / 15, (8 / 15,)), + (2 / 3, (1 / 4, 5 / 12)), + ), + (1 / 4, 0.0, 3 / 4), + ), + tableaux.rk3_tableau(8 / 15, 2 / 3), + ) + < 1e-15 + ) + + def test_sigmoid() -> None: items = spowf(torch.linspace(-2, 2, 9, dtype=torch.float64), 2) a = torch.sigmoid(items)