From 7a9f78ffa43adfbea84fd9e5d8dcee43498d531a Mon Sep 17 00:00:00 2001 From: Beinsezii Date: Sun, 21 Sep 2025 01:51:29 -0700 Subject: [PATCH 01/59] First draft for FunctionalSampler Done - Rename SkrampleSampler -> StructuredSampler in structured module - Add FunctionalSampler and functional module - Add Heun sampler in functional module - Add StructuredFunctionalAdapter in interface module - Add UT for StructuredFunctionalAdapter - Add example for StructuredFunctionalAdapter and Heun - Move some types and functions around to dedupe code. Not Done - Plots for functional samplers - UT for Heun if diffusers behaves - Higher order Heun / Runge Kutta - Ideally RKX instead of just RK3 or RK4 - Functional integrations? - In *theory* it's possible to use FunctionalSinglestep through SkrampleWrapperScheduler by maintaining directly using `step` and doing a shitload of sketchy state management. Image2Image will 100% not work OOTB unless the pipeline actually respects `scheduler.order` which not all of them do I don't think. Alternatives involve monkey patches and sadness so I'd rather not. --- examples/diffusers/wrapper.py | 2 +- examples/diffusers/wrapper_from.py | 2 +- examples/functional.py | 76 ++++++++++++ examples/{raw.py => structured.py} | 8 +- scripts/overhead.py | 4 +- scripts/plot_skrample.py | 10 +- scripts/spc.py | 12 +- skrample/common.py | 9 ++ skrample/diffusers.py | 17 +-- skrample/sampling/__init__.py | 0 skrample/sampling/functional.py | 117 ++++++++++++++++++ skrample/sampling/interface.py | 50 ++++++++ .../{sampling.py => sampling/structured.py} | 36 +++--- tests/diffusers_map.py | 2 +- tests/diffusers_samplers.py | 6 +- tests/miscellaneous.py | 54 ++++++-- 16 files changed, 343 insertions(+), 62 deletions(-) create mode 100755 examples/functional.py rename examples/{raw.py => structured.py} (92%) create mode 100644 skrample/sampling/__init__.py create mode 100644 skrample/sampling/functional.py create mode 100644 skrample/sampling/interface.py rename skrample/{sampling.py => sampling/structured.py} (95%) diff --git a/examples/diffusers/wrapper.py b/examples/diffusers/wrapper.py index 9565691..6af5184 100755 --- a/examples/diffusers/wrapper.py +++ b/examples/diffusers/wrapper.py @@ -4,7 +4,7 @@ from diffusers.pipelines.flux.pipeline_flux import FluxPipeline import skrample.pytorch.noise as sknoise -import skrample.sampling as sampling +import skrample.sampling.structured as sampling import skrample.scheduling as scheduling from skrample.common import predict_flow from skrample.diffusers import SkrampleWrapperScheduler diff --git a/examples/diffusers/wrapper_from.py b/examples/diffusers/wrapper_from.py index e009b20..f33ad55 100755 --- a/examples/diffusers/wrapper_from.py +++ b/examples/diffusers/wrapper_from.py @@ -4,7 +4,7 @@ from diffusers.pipelines.flux.pipeline_flux import FluxPipeline import skrample.pytorch.noise as sknoise -import skrample.sampling as sampling +import skrample.sampling.structured as sampling from skrample.diffusers import SkrampleWrapperScheduler pipe: FluxPipeline = FluxPipeline.from_pretrained( # type: ignore diff --git a/examples/functional.py b/examples/functional.py new file mode 100755 index 0000000..7912b7d --- /dev/null +++ b/examples/functional.py @@ -0,0 +1,76 @@ +#! /usr/bin/env python + +import torch +from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL +from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel +from PIL import Image +from tqdm import tqdm +from transformers.models.clip import CLIPTextModel, CLIPTokenizer + +import skrample.common +import skrample.pytorch.noise as noise +import skrample.scheduling as scheduling +from skrample.sampling import functional, structured +from skrample.sampling.interface import StructuredFunctionalAdapter + +with torch.inference_mode(): + device: torch.device = torch.device("cuda") + dtype: torch.dtype = torch.float16 + url: str = "Lykon/dreamshaper-8" + seed = torch.Generator("cpu").manual_seed(0) + steps: int = 25 + cfg: float = 3 + + schedule = scheduling.Karras(scheduling.Scaled()) + + # Equivalent to structured example + sampler = StructuredFunctionalAdapter(schedule, structured.DPM(order=2, add_noise=True)) + # Native functional example + sampler = functional.Heun(schedule) + + tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(url, subfolder="tokenizer") + text_encoder: CLIPTextModel = CLIPTextModel.from_pretrained( + url, subfolder="text_encoder", device_map=device, torch_dtype=dtype + ) + model: UNet2DConditionModel = UNet2DConditionModel.from_pretrained( # type: ignore + url, subfolder="unet", device_map=device, torch_dtype=dtype + ) + image_encoder: AutoencoderKL = AutoencoderKL.from_pretrained( # type: ignore + url, subfolder="vae", device_map=device, torch_dtype=dtype + ) + + text_embeds: torch.Tensor = text_encoder( + tokenizer( + "bright colorful fantasy art of a kitten in a field of rainbow flowers", + padding="max_length", + return_tensors="pt", + ).input_ids.to(device=device) + ).last_hidden_state + + def call_model(x: torch.Tensor, t: float, s: float) -> torch.Tensor: + conditioned, unconditioned = model( + x.expand([x.shape[0] * 2, *x.shape[1:]]), + t, + torch.cat([text_embeds, torch.zeros_like(text_embeds)]), + ).sample.chunk(2) + p = conditioned + (cfg - 1) * (conditioned - unconditioned) + return skrample.common.predict_epsilon(x, p, s, schedule.sigma_transform) + + if isinstance(sampler, functional.FunctionalSinglestep): + steps = sampler.adjust_steps(steps) + + sample = torch.randn([1, 4, 80, 80], generator=seed).to(dtype=dtype, device=device) + rng = noise.Random.from_inputs(sample.shape, seed) + bar = tqdm(total=steps) + sample = sampler.sample_model( + sample, + call_model, + steps, + rng=lambda: rng.generate().to(dtype=dtype, device=device), + callback=lambda _: bar.update(), + ) + + image: torch.Tensor = image_encoder.decode(sample / image_encoder.config.scaling_factor).sample[0] # type: ignore + Image.fromarray( + ((image + 1) * (255 / 2)).clamp(0, 255).permute(1, 2, 0).to(device="cpu", dtype=torch.uint8).numpy() + ).save("functional.png") diff --git a/examples/raw.py b/examples/structured.py similarity index 92% rename from examples/raw.py rename to examples/structured.py index 2bc9580..7b84538 100755 --- a/examples/raw.py +++ b/examples/structured.py @@ -8,7 +8,7 @@ from transformers.models.clip import CLIPTextModel, CLIPTokenizer import skrample.common -import skrample.sampling as sampling +import skrample.sampling.structured as structured import skrample.scheduling as scheduling with torch.inference_mode(): @@ -20,7 +20,7 @@ cfg: float = 3 schedule: scheduling.SkrampleSchedule = scheduling.Karras(scheduling.Scaled()) - sampler: sampling.SkrampleSampler = sampling.DPM(order=2, add_noise=True) + sampler: structured.StructuredSampler = structured.DPM(order=2, add_noise=True) predictor: skrample.common.Predictor = skrample.common.predict_epsilon tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(url, subfolder="tokenizer") @@ -43,7 +43,7 @@ ).last_hidden_state sample: torch.Tensor = torch.randn([1, 4, 80, 80], generator=seed).to(dtype=dtype, device=device) - previous: list[sampling.SKSamples[torch.Tensor]] = [] + previous: list[structured.SKSamples[torch.Tensor]] = [] for n, (timestep, sigma) in enumerate(tqdm(schedule.schedule(steps))): conditioned, unconditioned = model( @@ -71,4 +71,4 @@ image: torch.Tensor = image_encoder.decode(sample / image_encoder.config.scaling_factor).sample[0] # type: ignore Image.fromarray( ((image + 1) * (255 / 2)).clamp(0, 255).permute(1, 2, 0).to(device="cpu", dtype=torch.uint8).numpy() - ).save("raw.png") + ).save("structured.png") diff --git a/scripts/overhead.py b/scripts/overhead.py index 3bff208..a513892 100755 --- a/scripts/overhead.py +++ b/scripts/overhead.py @@ -5,12 +5,12 @@ import torch from skrample.diffusers import SkrampleWrapperScheduler -from skrample.sampling import DPM +from skrample.sampling.structured import Euler from skrample.scheduling import Beta, FlowShift, SigmoidCDF def bench_wrapper() -> int: - wrapper = SkrampleWrapperScheduler(DPM(), Beta(FlowShift(SigmoidCDF()))) + wrapper = SkrampleWrapperScheduler(Euler(), Beta(FlowShift(SigmoidCDF()))) wrapper.set_timesteps(1000) clock = perf_counter_ns() diff --git a/scripts/plot_skrample.py b/scripts/plot_skrample.py index a432121..64914b9 100755 --- a/scripts/plot_skrample.py +++ b/scripts/plot_skrample.py @@ -12,7 +12,7 @@ import numpy as np from numpy.typing import NDArray -import skrample.sampling as sampling +import skrample.sampling.structured as sampling import skrample.scheduling as scheduling from skrample.common import SigmaTransform, sigma_complement, sigma_polar, spowf @@ -61,7 +61,7 @@ def colors(hue_steps: int) -> Generator[list[float]]: "polar": sigma_polar, "complement": sigma_complement, } -SAMPLERS: dict[str, sampling.SkrampleSampler] = { +SAMPLERS: dict[str, sampling.StructuredSampler] = { "euler": sampling.Euler(), "adams": sampling.Adams(), "dpm": sampling.DPM(), @@ -70,7 +70,7 @@ def colors(hue_steps: int) -> Generator[list[float]]: "spc": sampling.SPC(), } for k, v in list(SAMPLERS.items()): - if isinstance(v, sampling.HighOrderSampler): + if isinstance(v, sampling.StructuredMultistep): for o in range(1, v.max_order() + 1): if o != v.order: SAMPLERS[k + str(o)] = replace(v, order=o) @@ -156,7 +156,7 @@ def colors(hue_steps: int) -> Generator[list[float]]: schedule = scheduling.Linear(base_timesteps=10_000) - def sample_model(sampler: sampling.SkrampleSampler, schedule: NDArray[np.float64]) -> list[float]: + def sample_model(sampler: sampling.StructuredSampler, schedule: NDArray[np.float64]) -> list[float]: previous: list[sampling.SKSamples] = [] sample = 1.0 sampled_values = [sample] @@ -185,7 +185,7 @@ def sample_model(sampler: sampling.SkrampleSampler, schedule: NDArray[np.float64 for sampler in [SAMPLERS[s] for s in args.sampler]: sigmas = schedule.sigmas(args.steps) label = type(sampler).__name__ - if isinstance(sampler, sampling.HighOrderSampler) and sampler.order != type(sampler).order: + if isinstance(sampler, sampling.StructuredMultistep) and sampler.order != type(sampler).order: label += " " + str(sampler.order) plt.plot([*sigmas, 0], sample_model(sampler, sigmas), label=label, color=next(COLORS), linestyle="--") diff --git a/scripts/spc.py b/scripts/spc.py index 3e4837b..948d390 100755 --- a/scripts/spc.py +++ b/scripts/spc.py @@ -9,7 +9,7 @@ import numpy as np from numpy.typing import NDArray -import skrample.sampling as sampling +import skrample.sampling.structured as sampling import skrample.scheduling as scheduling from skrample.common import SigmaTransform, sigma_complement, sigma_polar @@ -33,7 +33,7 @@ class Row: def sample_model( - sampler: sampling.SkrampleSampler, schedule: NDArray[np.float64], curve: int, transform: SigmaTransform + sampler: sampling.StructuredSampler, schedule: NDArray[np.float64], curve: int, transform: SigmaTransform ) -> NDArray: previous: list[sampling.SKSamples] = [] sample = 1.0 @@ -54,9 +54,9 @@ def sample_model( return np.array(sampled_values) -samplers: set[sampling.SkrampleSampler] = {sampling.Euler(), sampling.Adams(order=2), sampling.DPM(order=2)} +samplers: set[sampling.StructuredSampler] = {sampling.Euler(), sampling.Adams(order=2), sampling.DPM(order=2)} for v in samplers.copy(): - if isinstance(v, sampling.HighOrderSampler): + if isinstance(v, sampling.StructuredMultistep): for o in range(2, v.max_order() + 1): samplers.add(replace(v, order=o)) @@ -74,8 +74,8 @@ def sample_model( sampled = sample_model(spc, schedule.sigmas(h), k, t) table.append( Row( - type(pe).__name__ + (str(pe.order) if isinstance(pe, sampling.HighOrderSampler) else ""), - type(ce).__name__ + (str(ce.order) if isinstance(ce, sampling.HighOrderSampler) else ""), + type(pe).__name__ + (str(pe.order) if isinstance(pe, sampling.StructuredMultistep) else ""), + type(ce).__name__ + (str(ce.order) if isinstance(ce, sampling.StructuredMultistep) else ""), t.__name__, k, h, diff --git a/skrample/common.py b/skrample/common.py index 4f9adaa..4e81343 100644 --- a/skrample/common.py +++ b/skrample/common.py @@ -90,6 +90,15 @@ def predict_flow[T: Sample](sample: T, output: T, sigma: float, sigma_transform: return sample - sigma * output # type: ignore +def euler[T: Sample](sample: T, prediction: T, sigma: float, sigma_next: float, sigma_transform: SigmaTransform) -> T: + sigma_u, sigma_v = sigma_transform(sigma) + sigma_u_next, sigma_v_next = sigma_transform(sigma_next) + + scale = sigma_u_next / sigma_u + delta = sigma_v_next - sigma_v * scale # aka `h` or `dt` + return sample * scale + prediction * delta # type: ignore + + def safe_log(x: float) -> float: "Returns inf rather than throw an err" try: diff --git a/skrample/diffusers.py b/skrample/diffusers.py index e7f7664..fad58b1 100644 --- a/skrample/diffusers.py +++ b/skrample/diffusers.py @@ -9,7 +9,8 @@ from numpy.typing import NDArray from torch import Tensor -from skrample import sampling, scheduling +import skrample.sampling.structured as sampling +from skrample import scheduling from skrample.common import MergeStrategy, Predictor, predict_epsilon, predict_flow, predict_sample, predict_velocity from skrample.pytorch.noise import ( BatchTensorNoise, @@ -18,14 +19,14 @@ TensorNoiseProps, schedule_to_ramp, ) -from skrample.sampling import SkrampleSampler, SKSamples +from skrample.sampling.structured import SKSamples, StructuredSampler from skrample.scheduling import ScheduleCommon, ScheduleModifier, SkrampleSchedule if TYPE_CHECKING: from diffusers.configuration_utils import ConfigMixin -DIFFUSERS_CLASS_MAP: dict[str, tuple[type[SkrampleSampler], dict[str, Any]]] = { +DIFFUSERS_CLASS_MAP: dict[str, tuple[type[StructuredSampler], dict[str, Any]]] = { "DDIMScheduler": (sampling.Euler, {}), "DDPMScheduler": (sampling.DPM, {"add_noise": True, "order": 1}), "DPMSolverMultistepScheduler": (sampling.DPM, {}), @@ -93,7 +94,7 @@ class ParsedDiffusersConfig: "Values read from a combination of the diffusers config and provided types" - sampler: type[SkrampleSampler] + sampler: type[StructuredSampler] sampler_props: dict[str, Any] schedule: type[SkrampleSchedule] schedule_props: dict[str, Any] @@ -103,7 +104,7 @@ class ParsedDiffusersConfig: def parse_diffusers_config( config: "dict[str, Any] | ConfigMixin", - sampler: type[SkrampleSampler] | None = None, + sampler: type[StructuredSampler] | None = None, schedule: type[SkrampleSchedule] | None = None, ) -> ParsedDiffusersConfig: """Reads a diffusers scheduler or config as a set of skrample classes and properties. @@ -179,7 +180,7 @@ def attr_dict[T: Any](**kwargs: T) -> OrderedDict[str, T]: return od -def as_diffusers_config(sampler: SkrampleSampler, schedule: SkrampleSchedule, predictor: Predictor) -> dict[str, Any]: +def as_diffusers_config(sampler: StructuredSampler, schedule: SkrampleSchedule, predictor: Predictor) -> dict[str, Any]: "Converts skrample classes back into a diffusers-readable config. Not comprehensive" skrample_config = dataclasses.asdict(sampler) skrample_config["skrample_predictor"] = predictor @@ -208,7 +209,7 @@ class SkrampleWrapperScheduler[T: TensorNoiseProps | None]: Best effort approach. Most of the items presented in .config are fake, and many function inputs are ignored. A general rule of thumb is it will always prioritize the skrample properties over the incoming properties.""" - sampler: SkrampleSampler + sampler: StructuredSampler schedule: SkrampleSchedule predictor: Predictor[Tensor] = predict_epsilon noise_type: type[TensorNoiseCommon[T]] = Random # type: ignore # Unsure why? @@ -233,7 +234,7 @@ def __post_init__(self) -> None: def from_diffusers_config[N: TensorNoiseProps | None]( # pyright fails if you use the outer generic cls, config: "dict[str, Any] | ConfigMixin", - sampler: type[SkrampleSampler] | None = None, + sampler: type[StructuredSampler] | None = None, schedule: type[SkrampleSchedule] | None = None, schedule_modifiers: list[tuple[type[ScheduleModifier], dict[str, Any]]] = [], predictor: Predictor[Tensor] | None = None, diff --git a/skrample/sampling/__init__.py b/skrample/sampling/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/skrample/sampling/functional.py b/skrample/sampling/functional.py new file mode 100644 index 0000000..9320c56 --- /dev/null +++ b/skrample/sampling/functional.py @@ -0,0 +1,117 @@ +import dataclasses +import math +from abc import ABC, abstractmethod +from collections.abc import Callable +from typing import Any + +from skrample import common, scheduling + + +@dataclasses.dataclass(frozen=True) +class FunctionalSampler(ABC): + type SampleCallback[T: common.Sample] = Callable[[T], Any] + "Return is ignored" + type SampleableModel[T: common.Sample] = Callable[[T, float, float], T] + "sample, timestep, sigma" + type RNG[T: common.Sample] = Callable[[], T] + "Distribution should match model, typically normal" + + schedule: scheduling.SkrampleSchedule + + @abstractmethod + def sample_model[T: common.Sample]( + self, + sample: T, + model: SampleableModel[T], + steps: int, + include: slice = slice(None), + rng: RNG[T] | None = None, + callback: SampleCallback | None = None, + ) -> T: ... + + +@dataclasses.dataclass(frozen=True) +class FunctionalSinglestep(FunctionalSampler): + order: int = 1 + + @staticmethod + def min_order() -> int: + return 1 + + @staticmethod + @abstractmethod + def max_order() -> int: ... + + def adjust_steps(self, steps: int) -> int: + "Adjust the steps to approximate an equal amount of model calls" + return round(steps / self.order) + + @abstractmethod + def step[T: common.Sample]( + self, + sample: T, + model: FunctionalSampler.SampleableModel[T], + step: int, + schedule: list[tuple[float, float]], + rng: FunctionalSampler.RNG[T] | None = None, + ) -> T: ... + + def sample_model[T: common.Sample]( + self, + sample: T, + model: FunctionalSampler.SampleableModel[T], + steps: int, + include: slice = slice(None), + rng: FunctionalSampler.RNG[T] | None = None, + callback: FunctionalSampler.SampleCallback | None = None, + ) -> T: + schedule: list[tuple[float, float]] = self.schedule.schedule(steps).tolist() + + for n in list(range(len(schedule)))[include]: + sample = self.step(sample, model, n, schedule, rng) + + if callback: + callback(sample) + + return sample + + +@dataclasses.dataclass(frozen=True) +class Heun(FunctionalSinglestep): + order: int = 2 + + @staticmethod + def max_order() -> int: + return 2 + + def adjust_steps(self, steps: int) -> int: + return math.ceil(steps / self.order) # since we skip a call on final step + + def step[T: common.Sample]( + self, + sample: T, + model: FunctionalSampler.SampleableModel[T], + step: int, + schedule: list[tuple[float, float]], + rng: FunctionalSampler.RNG[T] | None = None, + ) -> T: + prediction = model(sample, *schedule[step]) + sample_next = common.euler( + sample, + prediction, + schedule[step][1], + schedule[step + 1][1] if step + 1 < len(schedule) else 0, + self.schedule.sigma_transform, + ) + + if step + 1 < len(schedule) and self.order > 1: + prediction_next = model(sample_next, *schedule[step + 1]) + return common.euler( + sample, + (prediction + prediction_next) / 2, # type: ignore + schedule[step][1], + schedule[step + 1][1], + self.schedule.sigma_transform, + ) + else: + return sample_next diff --git a/skrample/sampling/interface.py b/skrample/sampling/interface.py new file mode 100644 index 0000000..792208b --- /dev/null +++ b/skrample/sampling/interface.py @@ -0,0 +1,50 @@ +import dataclasses + +from skrample import common +from skrample.sampling import functional, structured + + +@dataclasses.dataclass(frozen=True) +class StructuredFunctionalAdapter(functional.FunctionalSampler): + sampler: structured.StructuredSampler + + def sample_model[T: common.Sample]( + self, + sample: T, + model: functional.FunctionalSampler.SampleableModel[T], + steps: int, + include: slice = slice(None), + rng: functional.FunctionalSampler.RNG[T] | None = None, + callback: functional.FunctionalSampler.SampleCallback | None = None, + ) -> T: + previous: list[structured.SKSamples[T]] = [] + schedule_np = self.schedule.schedule(steps) + schedule: list[tuple[float, float]] = schedule_np.tolist() + sigmas = schedule_np[:, 1] + del schedule_np + + for n in list(range(len(schedule)))[include]: + timestep, sigma = schedule[n] + + prediction = model(self.sampler.scale_input(sample, sigma, self.schedule.sigma_transform), timestep, sigma) + + sksamples = self.sampler.sample( + sample, + prediction, + n, + sigmas, + self.schedule.sigma_transform, + noise=rng() if rng else None, + previous=tuple(previous), + ) + + if self.sampler.require_previous > 0: + previous.append(sksamples) + previous = previous[max(len(previous) - self.sampler.require_previous, 0) :] + + sample = sksamples.final + + if callback: + callback(sample) + + return sample diff --git a/skrample/sampling.py b/skrample/sampling/structured.py similarity index 95% rename from skrample/sampling.py rename to skrample/sampling/structured.py index 76948de..bf98901 100644 --- a/skrample/sampling.py +++ b/skrample/sampling/structured.py @@ -5,7 +5,7 @@ import numpy as np from numpy.typing import NDArray -from skrample.common import Sample, SigmaTransform, bashforth, safe_log, softmax, spowf +from skrample.common import Sample, SigmaTransform, bashforth, euler, safe_log, softmax, spowf @dataclass(frozen=True) @@ -27,7 +27,7 @@ class SKSamples[T: Sample]: @dataclass(frozen=True) -class SkrampleSampler(ABC): +class StructuredSampler(ABC): """Generic sampler structure with basic configurables and a stateless design. Abstract class not to be used directly. @@ -100,7 +100,7 @@ def __call__[T: Sample]( @dataclass(frozen=True) -class HighOrderSampler(SkrampleSampler): +class StructuredMultistep(StructuredSampler): """Samplers inheriting this trait support order > 1, and will require `prevous` be managed and passed to function accordingly.""" @@ -134,7 +134,7 @@ def effective_order(self, step: int, schedule: NDArray, previous: tuple[SKSample @dataclass(frozen=True) -class StochasticSampler(SkrampleSampler): +class StructuredStochastic(StructuredSampler): add_noise: bool = False "Flag for whether or not to add the given noise" @@ -144,7 +144,7 @@ def require_noise(self) -> bool: @dataclass(frozen=True) -class Euler(SkrampleSampler): +class Euler(StructuredSampler): """Basic sampler, the "safe" choice.""" def sample[T: Sample]( @@ -159,23 +159,15 @@ def sample[T: Sample]( ) -> SKSamples[T]: sigma = self.get_sigma(step, sigma_schedule) sigma_next = self.get_sigma(step + 1, sigma_schedule) - - sigma_u, sigma_v = sigma_transform(sigma) - sigma_u_next, sigma_v_next = sigma_transform(sigma_next) - - scale = sigma_u_next / sigma_u - delta = sigma_v_next - sigma_v * scale # aka `h` or `dt` - final = sample * scale + prediction * delta - - return SKSamples( # type: ignore - final=final, + return SKSamples( + final=euler(sample, prediction, sigma, sigma_next, sigma_transform), prediction=prediction, sample=sample, ) @dataclass(frozen=True) -class DPM(HighOrderSampler, StochasticSampler): +class DPM(StructuredMultistep, StructuredStochastic): """Good sampler, supports basically everything. Recommended default. https://arxiv.org/abs/2211.01095 @@ -263,7 +255,7 @@ def sample[T: Sample]( @dataclass(frozen=True) -class Adams(HighOrderSampler, Euler): +class Adams(StructuredMultistep, Euler): "Higher order extension to Euler using the Adams-Bashforth coefficients on the model prediction" order: int = 2 @@ -297,7 +289,7 @@ def sample[T: Sample]( @dataclass(frozen=True) -class UniP(HighOrderSampler): +class UniP(StructuredMultistep): "Just the solver from UniPC without any correction stages." fast_solve: bool = False @@ -419,7 +411,7 @@ class UniPC(UniP): The additional correction essentially adds +1 order on top of what is set. https://arxiv.org/abs/2302.04867""" - solver: SkrampleSampler | None = None + solver: StructuredSampler | None = None "If not set, defaults to `UniSolver(order=self.order)`" @staticmethod @@ -471,13 +463,13 @@ def sample[T: Sample]( @dataclass(frozen=True) -class SPC(SkrampleSampler): +class SPC(StructuredSampler): """Simple predictor-corrector. Uses basic blended correction against the previous sample.""" - predictor: SkrampleSampler = Euler() + predictor: StructuredSampler = Euler() "Sampler for the current step" - corrector: SkrampleSampler = Adams(order=4) + corrector: StructuredSampler = Adams(order=4) "Sampler to correct the previous step" bias: float = 0 diff --git a/tests/diffusers_map.py b/tests/diffusers_map.py index 2d3a620..07ca2b8 100644 --- a/tests/diffusers_map.py +++ b/tests/diffusers_map.py @@ -14,7 +14,7 @@ from skrample.common import predict_flow as FLOW from skrample.common import predict_velocity as VELOCITY from skrample.diffusers import SkrampleWrapperScheduler -from skrample.sampling import DPM, Adams, Euler, UniPC +from skrample.sampling.structured import DPM, Adams, Euler, UniPC from skrample.scheduling import Beta, Exponential, FlowShift, Karras, Linear, Scaled diff --git a/tests/diffusers_samplers.py b/tests/diffusers_samplers.py index 30efb43..ed0a62e 100644 --- a/tests/diffusers_samplers.py +++ b/tests/diffusers_samplers.py @@ -12,7 +12,7 @@ from skrample.common import predict_epsilon as EPSILON from skrample.common import predict_flow as FLOW from skrample.common import predict_velocity as VELOCITY -from skrample.sampling import DPM, Euler, SkrampleSampler, SKSamples, UniPC +from skrample.sampling.structured import DPM, Euler, SKSamples, StructuredSampler, UniPC DiffusersScheduler = ( EulerDiscreteScheduler | DPMSolverMultistepScheduler | FlowMatchEulerDiscreteScheduler | UniPCMultistepScheduler @@ -25,7 +25,7 @@ def fake_model(t: torch.Tensor) -> torch.Tensor: def dual_sample( - a: SkrampleSampler, + a: StructuredSampler, b: DiffusersScheduler, predictor: Predictor, steps: range, @@ -83,7 +83,7 @@ def dual_sample( def compare_samplers( - a: SkrampleSampler, + a: StructuredSampler, b: DiffusersScheduler, p: Predictor = EPSILON, mu: float | None = None, diff --git a/tests/miscellaneous.py b/tests/miscellaneous.py index dc6d4ac..9d9cdb5 100644 --- a/tests/miscellaneous.py +++ b/tests/miscellaneous.py @@ -8,15 +8,16 @@ from skrample.common import MergeStrategy, bashforth, sigma_complement, sigmoid, softmax, spowf from skrample.diffusers import SkrampleWrapperScheduler -from skrample.sampling import ( +from skrample.sampling.interface import StructuredFunctionalAdapter +from skrample.sampling.structured import ( DPM, SPC, Adams, Euler, - HighOrderSampler, - SkrampleSampler, SKSamples, - StochasticSampler, + StructuredMultistep, + StructuredSampler, + StructuredStochastic, UniPC, ) from skrample.scheduling import Beta, FlowShift, Karras, Linear, Scaled, SigmoidCDF @@ -53,7 +54,7 @@ def test_sampler_generics() -> None: eps = 1e-12 for sampler in [ *(cls() for cls in ALL_SAMPLERS), - *(cls(order=cls.max_order()) for cls in ALL_SAMPLERS if issubclass(cls, HighOrderSampler)), + *(cls(order=cls.max_order()) for cls in ALL_SAMPLERS if issubclass(cls, StructuredMultistep)), ]: for schedule in Scaled(), FlowShift(Linear()): i, o = random.random(), random.random() @@ -94,9 +95,9 @@ def test_mu_set() -> None: def test_require_previous() -> None: - samplers: list[SkrampleSampler] = [] + samplers: list[StructuredSampler] = [] for cls in ALL_SAMPLERS: - if issubclass(cls, HighOrderSampler): + if issubclass(cls, StructuredMultistep): samplers.extend([cls(order=o + 1) for o in range(cls.min_order(), cls.max_order())]) else: samplers.append(cls()) @@ -134,9 +135,9 @@ def test_require_previous() -> None: def test_require_noise() -> None: - samplers: list[SkrampleSampler] = [] + samplers: list[StructuredSampler] = [] for cls in ALL_SAMPLERS: - if issubclass(cls, StochasticSampler): + if issubclass(cls, StructuredStochastic): samplers.extend([cls(add_noise=n) for n in (False, True)]) else: samplers.append(cls()) @@ -177,6 +178,41 @@ def test_require_noise() -> None: assert a == b, (sampler, sampler.require_noise) +def test_functional_adapter() -> None: + def fake_model(x: float, _: float, s: float) -> float: + return x + math.sin(x) * s + + samplers: list[StructuredSampler] = [DPM(n, o) for o in range(1, 4) for n in [False, True]] + for schedule in Linear(), Scaled(): + for sampler in samplers: + for steps in [1, 3, 4, 9, 512, 999]: + sample = 1.5 + adapter = StructuredFunctionalAdapter(schedule, sampler) + noise = [random.random() for _ in range(steps)] + + rng = iter(noise) + sample_f = adapter.sample_model(sample, fake_model, steps, rng=lambda: next(rng)) + + rng = iter(noise) + schedule_np = schedule.schedule(steps) + sample_s = sample + previous: list[SKSamples[float]] = [] + for n, (t, s) in enumerate(schedule_np.tolist()): + results = sampler.sample( + sample_s, + fake_model(sample_s, t, s), + n, + schedule_np[:, 1], + schedule.sigma_transform, + next(rng), + tuple(previous), + ) + previous.append(results) + sample_s = results.final + + assert sample_s == sample_f, (sample_s, sample_f, sampler, schedule, steps) + + def test_bashforth() -> None: for n, coeffs in enumerate( np.array(c) for c in ((1,), (3 / 2, -1 / 2), (23 / 12, -4 / 3, 5 / 12), (55 / 24, -59 / 24, 37 / 24, -3 / 8)) From bc15d7475941842a37e943133d9c7f656ee67cab Mon Sep 17 00:00:00 2001 From: Beinsezii Date: Sun, 21 Sep 2025 02:06:57 -0700 Subject: [PATCH 02/59] Split FunctionalHigher from FunctionalSinglestep --- skrample/sampling/functional.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/skrample/sampling/functional.py b/skrample/sampling/functional.py index 9320c56..d5147f3 100644 --- a/skrample/sampling/functional.py +++ b/skrample/sampling/functional.py @@ -31,7 +31,7 @@ def sample_model[T: common.Sample]( @dataclasses.dataclass(frozen=True) -class FunctionalSinglestep(FunctionalSampler): +class FunctionalHigher(FunctionalSampler): order: int = 1 @staticmethod @@ -46,6 +46,9 @@ def adjust_steps(self, steps: int) -> int: "Adjust the steps to approximate an equal amount of model calls" return round(steps / self.order) + +@dataclasses.dataclass(frozen=True) +class FunctionalSinglestep(FunctionalSampler): @abstractmethod def step[T: common.Sample]( self, @@ -77,7 +80,7 @@ def sample_model[T: common.Sample]( @dataclasses.dataclass(frozen=True) -class Heun(FunctionalSinglestep): +class Heun(FunctionalHigher, FunctionalSinglestep): order: int = 2 @staticmethod From a6a0923921c5e3fb4653ee20e8daabc7e2c8ee75 Mon Sep 17 00:00:00 2001 From: Beinsezii Date: Mon, 22 Sep 2025 00:12:50 -0700 Subject: [PATCH 03/59] Plot functional samplers, upgrade Heun -> RungeKutta --- examples/functional.py | 2 +- scripts/plot_skrample.py | 73 ++++++++++++++++++----------- skrample/sampling/functional.py | 82 ++++++++++++++++++++++++--------- 3 files changed, 107 insertions(+), 50 deletions(-) diff --git a/examples/functional.py b/examples/functional.py index 7912b7d..aafccd6 100755 --- a/examples/functional.py +++ b/examples/functional.py @@ -26,7 +26,7 @@ # Equivalent to structured example sampler = StructuredFunctionalAdapter(schedule, structured.DPM(order=2, add_noise=True)) # Native functional example - sampler = functional.Heun(schedule) + sampler = functional.RungeKutta(schedule, 4) tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(url, subfolder="tokenizer") text_encoder: CLIPTextModel = CLIPTextModel.from_pretrained( diff --git a/scripts/plot_skrample.py b/scripts/plot_skrample.py index 64914b9..0bd939c 100755 --- a/scripts/plot_skrample.py +++ b/scripts/plot_skrample.py @@ -12,9 +12,10 @@ import numpy as np from numpy.typing import NDArray -import skrample.sampling.structured as sampling import skrample.scheduling as scheduling from skrample.common import SigmaTransform, sigma_complement, sigma_polar, spowf +from skrample.sampling import functional, structured +from skrample.sampling.interface import StructuredFunctionalAdapter OKLAB_XYZ_M1 = np.array( [ @@ -61,16 +62,17 @@ def colors(hue_steps: int) -> Generator[list[float]]: "polar": sigma_polar, "complement": sigma_complement, } -SAMPLERS: dict[str, sampling.StructuredSampler] = { - "euler": sampling.Euler(), - "adams": sampling.Adams(), - "dpm": sampling.DPM(), - "unip": sampling.UniP(), - "unipc": sampling.UniPC(), - "spc": sampling.SPC(), +SAMPLERS: dict[str, structured.StructuredSampler | functional.FunctionalSampler] = { + "euler": structured.Euler(), + "adams": structured.Adams(), + "dpm": structured.DPM(), + "unip": structured.UniP(), + "unipc": structured.UniPC(), + "spc": structured.SPC(), + "rk": functional.RungeKutta(scheduling.Linear()), } for k, v in list(SAMPLERS.items()): - if isinstance(v, sampling.StructuredMultistep): + if isinstance(v, structured.StructuredMultistep | functional.FunctionalHigher): for o in range(1, v.max_order() + 1): if o != v.order: SAMPLERS[k + str(o)] = replace(v, order=o) @@ -154,30 +156,42 @@ def colors(hue_steps: int) -> Generator[list[float]]: plt.ylabel("Sample") plt.title("Skrample Samplers") - schedule = scheduling.Linear(base_timesteps=10_000) + schedule = scheduling.Linear(base_timesteps=10_000, custom_transform=TRANSFORMS[args.transform]) + + def sample_model(sampler: structured.StructuredSampler | functional.FunctionalSampler, steps: int) -> list[float]: + if isinstance(sampler, structured.StructuredSampler): + sampler = StructuredFunctionalAdapter(schedule, sampler) + else: + sampler = replace(sampler, schedule=schedule) - def sample_model(sampler: sampling.StructuredSampler, schedule: NDArray[np.float64]) -> list[float]: - previous: list[sampling.SKSamples] = [] sample = 1.0 sampled_values = [sample] - for step, sigma in enumerate(schedule): - result = sampler.sample( - sample=sample, - prediction=math.sin(sigma * args.curve), - step=step, - sigma_schedule=schedule, - sigma_transform=TRANSFORMS[args.transform], - previous=tuple(previous), - noise=random(), - ) - previous.append(result) - sample = result.final - sampled_values.append(sample) + + if isinstance(sampler, functional.FunctionalHigher) and False: + adjusted = sampler.adjust_steps(steps) + else: + adjusted = steps + + sampler.sample_model( + sample=sample, + model=lambda sample, timestep, sigma: math.sin(sigma * args.curve), + steps=adjusted, + rng=random, + callback=lambda x: sampled_values.append(x), + ) + + # if isinstance(sampler, functional.FunctionalHigher): + # sampled_values = np.interp( + # np.linspace(0, 1, steps + 1), + # np.linspace(0, 1, len(sampled_values)), + # np.array(sampled_values), + # ).tolist() + return sampled_values plt.plot( [*schedule.sigmas(schedule.base_timesteps), 0], - sample_model(sampling.Euler(), schedule.sigmas(schedule.base_timesteps)), + sample_model(structured.Euler(), schedule.base_timesteps), label="Reference", color=next(COLORS), ) @@ -185,9 +199,12 @@ def sample_model(sampler: sampling.StructuredSampler, schedule: NDArray[np.float for sampler in [SAMPLERS[s] for s in args.sampler]: sigmas = schedule.sigmas(args.steps) label = type(sampler).__name__ - if isinstance(sampler, sampling.StructuredMultistep) and sampler.order != type(sampler).order: + if ( + isinstance(sampler, structured.StructuredMultistep | functional.FunctionalHigher) + and sampler.order != type(sampler).order + ): label += " " + str(sampler.order) - plt.plot([*sigmas, 0], sample_model(sampler, sigmas), label=label, color=next(COLORS), linestyle="--") + plt.plot([*sigmas, 0], sample_model(sampler, args.steps), label=label, color=next(COLORS), linestyle="--") elif args.command == "schedules": plt.xlabel("Step") diff --git a/skrample/sampling/functional.py b/skrample/sampling/functional.py index d5147f3..8bb68c2 100644 --- a/skrample/sampling/functional.py +++ b/skrample/sampling/functional.py @@ -49,6 +49,9 @@ def adjust_steps(self, steps: int) -> int: @dataclasses.dataclass(frozen=True) class FunctionalSinglestep(FunctionalSampler): + def step_increment(self) -> int: + return 1 + @abstractmethod def step[T: common.Sample]( self, @@ -68,10 +71,10 @@ def sample_model[T: common.Sample]( rng: FunctionalSampler.RNG[T] | None = None, callback: FunctionalSampler.SampleCallback | None = None, ) -> T: - schedule: list[tuple[float, float]] = self.schedule.schedule(steps).tolist() + schedule: list[tuple[float, float]] = self.schedule.schedule(steps * self.step_increment()).tolist() - for n in list(range(len(schedule)))[include]: - sample = self.step(sample, model, n, schedule, rng) + for n in list(range(steps))[include]: + sample = self.step(sample, model, n * self.step_increment(), schedule, rng) if callback: callback(sample) @@ -80,16 +83,19 @@ def sample_model[T: common.Sample]( @dataclasses.dataclass(frozen=True) -class Heun(FunctionalHigher, FunctionalSinglestep): +class RungeKutta(FunctionalHigher, FunctionalSinglestep): order: int = 2 @staticmethod def max_order() -> int: - return 2 + return 4 def adjust_steps(self, steps: int) -> int: return math.ceil(steps / self.order) # since we skip a call on final step + def step_increment(self) -> int: + return 2 if self.order > 2 else 1 + def step[T: common.Sample]( self, sample: T, @@ -98,23 +104,57 @@ def step[T: common.Sample]( schedule: list[tuple[float, float]], rng: FunctionalSampler.RNG[T] | None = None, ) -> T: - prediction = model(sample, *schedule[step]) - sample_next = common.euler( - sample, - prediction, - schedule[step][1], - schedule[step + 1][1] if step + 1 < len(schedule) else 0, - self.schedule.sigma_transform, - ) - - if step + 1 < len(schedule) and self.order > 1: - prediction_next = model(sample_next, *schedule[step + 1]) - return common.euler( + step_next = step + self.step_increment() + + def euler_kt2(k: T, t2: int) -> T: + return common.euler(sample, k, schedule[step][1], schedule[t2][1], self.schedule.sigma_transform) + + K1: T = model(sample, *schedule[step]) + + if self.order > 2 and step_next < len(schedule): + assert (step + step_next) % 2 == 0 + step_mid = (step + step_next) // 2 + + S1: T = euler_kt2(K1, step_mid) + + K2: T = model(S1, *schedule[step_mid]) + + if self.order > 3: + S2: T = euler_kt2(K2, step_mid) + + K3: T = model(S2, *schedule[step_mid]) + S3: T = euler_kt2(K3, step_next) + + K4: T = model(S3, *schedule[step_next]) + return euler_kt2( + (K1 + 2 * K2 + 2 * K3 + K4) / 6, # type: ignore + step_next, + ) + else: + S2: T = euler_kt2( + -K1 + 2 * K2, # type: ignore + step_next, + ) + + K3: T = model(S2, *schedule[step_next]) + return euler_kt2( + (K1 + 4 * K2 + K3) / 6, # type: ignore + step_next, + ) + else: + S1: T = common.euler( sample, - (prediction + prediction_next) / 2, # type: ignore + K1, schedule[step][1], - schedule[step + 1][1], + schedule[step_next][1] if step_next < len(schedule) else 0, self.schedule.sigma_transform, ) - else: - return sample_next + + if step_next < len(schedule) and self.order > 1: + K2: T = model(S1, *schedule[step_next]) + return euler_kt2( + (K1 + K2) / 2, # type: ignore + step_next, + ) + else: + return S1 From d9394ea864a210afbfda283e9b7d3829270a7a4e Mon Sep 17 00:00:00 2001 From: Beinsezii Date: Mon, 22 Sep 2025 01:28:57 -0700 Subject: [PATCH 04/59] Rewrite RungeKutta to deduplicate code --- skrample/sampling/functional.py | 74 +++++++++++++++------------------ 1 file changed, 33 insertions(+), 41 deletions(-) diff --git a/skrample/sampling/functional.py b/skrample/sampling/functional.py index 8bb68c2..01e483c 100644 --- a/skrample/sampling/functional.py +++ b/skrample/sampling/functional.py @@ -106,55 +106,47 @@ def step[T: common.Sample]( ) -> T: step_next = step + self.step_increment() - def euler_kt2(k: T, t2: int) -> T: - return common.euler(sample, k, schedule[step][1], schedule[t2][1], self.schedule.sigma_transform) - - K1: T = model(sample, *schedule[step]) - - if self.order > 2 and step_next < len(schedule): + stages: tuple[tuple[None | tuple[float, ...], int, int], ...] + effective_order = self.order if step_next < len(schedule) else 1 + if effective_order >= 3: assert (step + step_next) % 2 == 0 step_mid = (step + step_next) // 2 - - S1: T = euler_kt2(K1, step_mid) - - K2: T = model(S1, *schedule[step_mid]) - - if self.order > 3: - S2: T = euler_kt2(K2, step_mid) - - K3: T = model(S2, *schedule[step_mid]) - S3: T = euler_kt2(K3, step_next) - - K4: T = model(S3, *schedule[step_next]) - return euler_kt2( - (K1 + 2 * K2 + 2 * K3 + K4) / 6, # type: ignore - step_next, + if effective_order >= 4: # RK4 + stages = ( + (None, step, step_mid), + (None, step_mid, step_mid), + (None, step_mid, step_next), + ((1 / 6, 2 / 6, 2 / 6, 1 / 6), step_next, step_next), ) - else: - S2: T = euler_kt2( - -K1 + 2 * K2, # type: ignore - step_next, - ) - - K3: T = model(S2, *schedule[step_next]) - return euler_kt2( - (K1 + 4 * K2 + K3) / 6, # type: ignore - step_next, + else: # RK3 + stages = ( + (None, step, step_mid), + ((-1, 2), step_mid, step_next), + ((1 / 6, 4 / 6, 1 / 6), step_next, step_next), ) + elif effective_order >= 2: # Heun / RK2 + stages = ( + (None, step, step_next), + ((1 / 2, 1 / 2), step_next, step_next), + ) else: - S1: T = common.euler( + return common.euler( sample, - K1, + model(sample, *schedule[step]), schedule[step][1], schedule[step_next][1] if step_next < len(schedule) else 0, self.schedule.sigma_transform, ) - if step_next < len(schedule) and self.order > 1: - K2: T = model(S1, *schedule[step_next]) - return euler_kt2( - (K1 + K2) / 2, # type: ignore - step_next, - ) - else: - return S1 + Xn: T = sample + k_terms: list[T] = [] + for coeffs, model_t, sample_t in stages: + k_terms.append(model(Xn, *schedule[model_t])) + Xn = common.euler( + sample, + math.sumprod(k_terms, coeffs) if coeffs else k_terms[-1], # type: ignore + schedule[step][1], + schedule[sample_t][1] if step_next < len(schedule) else 0, + self.schedule.sigma_transform, + ) + return Xn From 99315bff1f314fdb52a0f850d87800cb9efffd81 Mon Sep 17 00:00:00 2001 From: Beinsezii Date: Thu, 25 Sep 2025 19:37:33 -0700 Subject: [PATCH 05/59] Rewrite RungeKutta again to use tableau format for stages --- skrample/sampling/functional.py | 81 ++++++++++++++++++++------------- 1 file changed, 50 insertions(+), 31 deletions(-) diff --git a/skrample/sampling/functional.py b/skrample/sampling/functional.py index 01e483c..eb55f05 100644 --- a/skrample/sampling/functional.py +++ b/skrample/sampling/functional.py @@ -84,6 +84,10 @@ def sample_model[T: common.Sample]( @dataclasses.dataclass(frozen=True) class RungeKutta(FunctionalHigher, FunctionalSinglestep): + type Stage = tuple[int, tuple[float, ...]] + type Final = tuple[float, ...] + type Tableau = tuple[tuple[Stage, ...], Final] + order: int = 2 @staticmethod @@ -106,47 +110,62 @@ def step[T: common.Sample]( ) -> T: step_next = step + self.step_increment() - stages: tuple[tuple[None | tuple[float, ...], int, int], ...] + tableau: RungeKutta.Tableau effective_order = self.order if step_next < len(schedule) else 1 if effective_order >= 3: assert (step + step_next) % 2 == 0 step_mid = (step + step_next) // 2 if effective_order >= 4: # RK4 - stages = ( - (None, step, step_mid), - (None, step_mid, step_mid), - (None, step_mid, step_next), - ((1 / 6, 2 / 6, 2 / 6, 1 / 6), step_next, step_next), + tableau = ( + ( + (step, ()), + (step_mid, (1,)), + (step_mid, (0, 1)), + (step_next, (0, 0, 1)), + ), + (1 / 6, 2 / 6, 2 / 6, 1 / 6), ) else: # RK3 - stages = ( - (None, step, step_mid), - ((-1, 2), step_mid, step_next), - ((1 / 6, 4 / 6, 1 / 6), step_next, step_next), + tableau = ( + ( + (step, ()), + (step_mid, (1,)), + (step_next, (-1, 2)), + ), + (1 / 6, 4 / 6, 1 / 6), ) elif effective_order >= 2: # Heun / RK2 - stages = ( - (None, step, step_next), - ((1 / 2, 1 / 2), step_next, step_next), + tableau = ( + ( + (step, ()), + (step_next, (1,)), + ), + (1 / 2, 1 / 2), ) - else: - return common.euler( - sample, - model(sample, *schedule[step]), - schedule[step][1], - schedule[step_next][1] if step_next < len(schedule) else 0, - self.schedule.sigma_transform, + else: # Euler / RK1 + tableau = ( + ((step, ()),), + (1,), ) - Xn: T = sample k_terms: list[T] = [] - for coeffs, model_t, sample_t in stages: - k_terms.append(model(Xn, *schedule[model_t])) - Xn = common.euler( - sample, - math.sumprod(k_terms, coeffs) if coeffs else k_terms[-1], # type: ignore - schedule[step][1], - schedule[sample_t][1] if step_next < len(schedule) else 0, - self.schedule.sigma_transform, - ) - return Xn + for istep, icoeffs in tableau[0]: + if icoeffs: + combined: T = common.euler( + sample, + math.sumprod(k_terms, icoeffs), # type: ignore + schedule[step][1], + schedule[istep][1] if istep < len(schedule) else 0, + self.schedule.sigma_transform, + ) + else: + combined = sample + k_terms.append(model(combined, *schedule[istep])) + + return common.euler( + sample, + math.sumprod(k_terms, tableau[1]), # type: ignore + schedule[step][1], + schedule[step_next][1] if step_next < len(schedule) else 0, + self.schedule.sigma_transform, + ) From 9c98a3cc48861426d3083221b93883f072042af1 Mon Sep 17 00:00:00 2001 From: Beinsezii Date: Thu, 25 Sep 2025 22:42:30 -0700 Subject: [PATCH 06/59] plot_skrample.py: improve fake model algorithm --- scripts/plot_skrample.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/scripts/plot_skrample.py b/scripts/plot_skrample.py index 0bd939c..5561cbf 100755 --- a/scripts/plot_skrample.py +++ b/scripts/plot_skrample.py @@ -107,7 +107,7 @@ def colors(hue_steps: int) -> Generator[list[float]]: # Samplers parser_sampler = subparsers.add_parser("samplers") -parser_sampler.add_argument("--curve", "-k", type=int, default=10) +parser_sampler.add_argument("--curve", "-k", type=int, default=30) parser_sampler.add_argument("--transform", "-t", type=str, choices=list(TRANSFORMS.keys()), default="polar") parser_sampler.add_argument( "--sampler", @@ -174,7 +174,7 @@ def sample_model(sampler: structured.StructuredSampler | functional.FunctionalSa sampler.sample_model( sample=sample, - model=lambda sample, timestep, sigma: math.sin(sigma * args.curve), + model=lambda sample, timestep, sigma: sample + math.sin(sigma * args.curve), steps=adjusted, rng=random, callback=lambda x: sampled_values.append(x), From 857f96764c9ea2a38d07db1ff566619a2ddf28b7 Mon Sep 17 00:00:00 2001 From: Beinsezii Date: Thu, 25 Sep 2025 22:42:49 -0700 Subject: [PATCH 07/59] Use real tableau values normalized by sum --- skrample/sampling/functional.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/skrample/sampling/functional.py b/skrample/sampling/functional.py index eb55f05..c3efa2e 100644 --- a/skrample/sampling/functional.py +++ b/skrample/sampling/functional.py @@ -119,8 +119,8 @@ def step[T: common.Sample]( tableau = ( ( (step, ()), - (step_mid, (1,)), - (step_mid, (0, 1)), + (step_mid, (1 / 2,)), + (step_mid, (0, 1 / 2)), (step_next, (0, 0, 1)), ), (1 / 6, 2 / 6, 2 / 6, 1 / 6), @@ -129,7 +129,7 @@ def step[T: common.Sample]( tableau = ( ( (step, ()), - (step_mid, (1,)), + (step_mid, (1 / 2,)), (step_next, (-1, 2)), ), (1 / 6, 4 / 6, 1 / 6), @@ -153,9 +153,9 @@ def step[T: common.Sample]( if icoeffs: combined: T = common.euler( sample, - math.sumprod(k_terms, icoeffs), # type: ignore + math.sumprod(k_terms, icoeffs) / math.fsum(icoeffs), # type: ignore schedule[step][1], - schedule[istep][1] if istep < len(schedule) else 0, + schedule[istep][1], self.schedule.sigma_transform, ) else: From 00573109b2da20f503bba654610d873c76c45f1a Mon Sep 17 00:00:00 2001 From: Beinsezii Date: Fri, 26 Sep 2025 01:04:34 -0700 Subject: [PATCH 08/59] RungeKutta use linear schedule interpolation for steps --- skrample/sampling/functional.py | 66 +++++++++++++++++++-------------- 1 file changed, 39 insertions(+), 27 deletions(-) diff --git a/skrample/sampling/functional.py b/skrample/sampling/functional.py index c3efa2e..48ed302 100644 --- a/skrample/sampling/functional.py +++ b/skrample/sampling/functional.py @@ -4,6 +4,8 @@ from collections.abc import Callable from typing import Any +import numpy as np + from skrample import common, scheduling @@ -49,9 +51,6 @@ def adjust_steps(self, steps: int) -> int: @dataclasses.dataclass(frozen=True) class FunctionalSinglestep(FunctionalSampler): - def step_increment(self) -> int: - return 1 - @abstractmethod def step[T: common.Sample]( self, @@ -71,10 +70,10 @@ def sample_model[T: common.Sample]( rng: FunctionalSampler.RNG[T] | None = None, callback: FunctionalSampler.SampleCallback | None = None, ) -> T: - schedule: list[tuple[float, float]] = self.schedule.schedule(steps * self.step_increment()).tolist() + schedule: list[tuple[float, float]] = self.schedule.schedule(steps).tolist() for n in list(range(steps))[include]: - sample = self.step(sample, model, n * self.step_increment(), schedule, rng) + sample = self.step(sample, model, n, schedule, rng) if callback: callback(sample) @@ -84,7 +83,7 @@ def sample_model[T: common.Sample]( @dataclasses.dataclass(frozen=True) class RungeKutta(FunctionalHigher, FunctionalSinglestep): - type Stage = tuple[int, tuple[float, ...]] + type Stage = tuple[float, tuple[float, ...]] type Final = tuple[float, ...] type Tableau = tuple[tuple[Stage, ...], Final] @@ -97,8 +96,24 @@ def max_order() -> int: def adjust_steps(self, steps: int) -> int: return math.ceil(steps / self.order) # since we skip a call on final step - def step_increment(self) -> int: - return 2 if self.order > 2 else 1 + @staticmethod + def fractional_step( + schedule: list[tuple[float, float]], + current: int, + idx: tuple[float, ...], + ) -> tuple[tuple[float, float], ...]: + schedule_np = np.array([*schedule, (0, 0)]) + idx_np = np.array(idx) / len(schedule) + current / len(schedule) + scale = np.linspace(0, 1, len(schedule_np)) + + # TODO (beinszeii): better 2d interpolate + result = tuple( + zip( + (np.interp(idx_np, scale, schedule_np[:, 0])).tolist(), + (np.interp(idx_np, scale, schedule_np[:, 1])).tolist(), + ) + ) + return result def step[T: common.Sample]( self, @@ -108,64 +123,61 @@ def step[T: common.Sample]( schedule: list[tuple[float, float]], rng: FunctionalSampler.RNG[T] | None = None, ) -> T: - step_next = step + self.step_increment() - tableau: RungeKutta.Tableau - effective_order = self.order if step_next < len(schedule) else 1 + effective_order = self.order if step + 1 < len(schedule) else 1 if effective_order >= 3: - assert (step + step_next) % 2 == 0 - step_mid = (step + step_next) // 2 if effective_order >= 4: # RK4 tableau = ( ( - (step, ()), - (step_mid, (1 / 2,)), - (step_mid, (0, 1 / 2)), - (step_next, (0, 0, 1)), + (0, ()), + (1 / 2, (1 / 2,)), + (1 / 2, (0, 1 / 2)), + (1, (0, 0, 1)), ), (1 / 6, 2 / 6, 2 / 6, 1 / 6), ) else: # RK3 tableau = ( ( - (step, ()), - (step_mid, (1 / 2,)), - (step_next, (-1, 2)), + (0, ()), + (1 / 2, (1 / 2,)), + (1, (-1, 2)), ), (1 / 6, 4 / 6, 1 / 6), ) elif effective_order >= 2: # Heun / RK2 tableau = ( ( - (step, ()), - (step_next, (1,)), + (0, ()), + (1, (1,)), ), (1 / 2, 1 / 2), ) else: # Euler / RK1 tableau = ( - ((step, ()),), + ((0, ()),), (1,), ) k_terms: list[T] = [] - for istep, icoeffs in tableau[0]: + fractions = self.fractional_step(schedule, step, tuple(f[0] for f in tableau[0])) + for frac_sc, icoeffs in zip(fractions, (t[1] for t in tableau[0]), strict=True): if icoeffs: combined: T = common.euler( sample, math.sumprod(k_terms, icoeffs) / math.fsum(icoeffs), # type: ignore schedule[step][1], - schedule[istep][1], + frac_sc[1], self.schedule.sigma_transform, ) else: combined = sample - k_terms.append(model(combined, *schedule[istep])) + k_terms.append(model(combined, *frac_sc)) return common.euler( sample, math.sumprod(k_terms, tableau[1]), # type: ignore schedule[step][1], - schedule[step_next][1] if step_next < len(schedule) else 0, + schedule[step + 1][1] if step + 1 < len(schedule) else 0, self.schedule.sigma_transform, ) From 3291c732406b5d6f528fc3b9cfd8a79e4ec326d9 Mon Sep 17 00:00:00 2001 From: Beinsezii Date: Fri, 26 Sep 2025 02:09:52 -0700 Subject: [PATCH 09/59] Add every RungeKutta tableau from wikipedia --- examples/functional.py | 2 +- scripts/plot_skrample.py | 4 +- skrample/sampling/functional.py | 200 ++++++++++++++++++++++++++------ 3 files changed, 170 insertions(+), 36 deletions(-) diff --git a/examples/functional.py b/examples/functional.py index aafccd6..917af09 100755 --- a/examples/functional.py +++ b/examples/functional.py @@ -26,7 +26,7 @@ # Equivalent to structured example sampler = StructuredFunctionalAdapter(schedule, structured.DPM(order=2, add_noise=True)) # Native functional example - sampler = functional.RungeKutta(schedule, 4) + sampler = functional.RKUltra(schedule, 4) tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(url, subfolder="tokenizer") text_encoder: CLIPTextModel = CLIPTextModel.from_pretrained( diff --git a/scripts/plot_skrample.py b/scripts/plot_skrample.py index 5561cbf..5d7a5d0 100755 --- a/scripts/plot_skrample.py +++ b/scripts/plot_skrample.py @@ -69,7 +69,7 @@ def colors(hue_steps: int) -> Generator[list[float]]: "unip": structured.UniP(), "unipc": structured.UniPC(), "spc": structured.SPC(), - "rk": functional.RungeKutta(scheduling.Linear()), + "rk": functional.RKUltra(scheduling.Linear()), } for k, v in list(SAMPLERS.items()): if isinstance(v, structured.StructuredMultistep | functional.FunctionalHigher): @@ -102,7 +102,7 @@ def colors(hue_steps: int) -> Generator[list[float]]: # Common parser = ArgumentParser() parser.add_argument("file", type=Path) -parser.add_argument("--steps", "-s", type=int, default=20) +parser.add_argument("--steps", "-s", type=int, default=25) subparsers = parser.add_subparsers(dest="command") # Samplers diff --git a/skrample/sampling/functional.py b/skrample/sampling/functional.py index 48ed302..b3ffd57 100644 --- a/skrample/sampling/functional.py +++ b/skrample/sampling/functional.py @@ -1,4 +1,5 @@ import dataclasses +import enum import math from abc import ABC, abstractmethod from collections.abc import Callable @@ -82,12 +83,168 @@ def sample_model[T: common.Sample]( @dataclasses.dataclass(frozen=True) -class RungeKutta(FunctionalHigher, FunctionalSinglestep): - type Stage = tuple[float, tuple[float, ...]] - type Final = tuple[float, ...] - type Tableau = tuple[tuple[Stage, ...], Final] +class RKUltra(FunctionalHigher, FunctionalSinglestep): + "Implements almost every single method from https://en.wikipedia.org/wiki/List_of_Runge–Kutta_methods" # noqa: RUF002 + + type Tableau = tuple[ + tuple[ + tuple[float, tuple[float, ...]], + ..., + ], + tuple[float, ...], + ] + + @enum.unique + class RK2(enum.StrEnum): + Heun = enum.auto() + Mid = enum.auto() + Ralston = enum.auto() + + def tableau(self) -> "RKUltra.Tableau": + match self: + case self.Heun: + return ( + ( + (0, ()), + (1, (1,)), + ), + (1 / 2, 1 / 2), + ) + case self.Mid: + return ( + ( + (0, ()), + (1 / 2, (1 / 2,)), + ), + (0, 1), + ) + case self.Ralston: + return ( + ( + (0, ()), + (2 / 3, (2 / 3,)), + ), + (1 / 4, 3 / 4), + ) + + @enum.unique + class RK3(enum.StrEnum): + Kutta = enum.auto() + Heun = enum.auto() + Ralston = enum.auto() + Wray = enum.auto() + SSPRK3 = enum.auto() + + def tableau(self) -> "RKUltra.Tableau": + match self: + case self.Kutta: + return ( + ( + (0, ()), + (1 / 2, (1 / 2,)), + (1, (-1, 2)), + ), + (1 / 6, 2 / 3, 1 / 6), + ) + case self.Heun: + return ( + ( + (0, ()), + (1 / 3, (1 / 3,)), + (2 / 3, (0, 2 / 3)), + ), + (1 / 4, 0, 3 / 4), + ) + case self.Ralston: + return ( + ( + (0, ()), + (1 / 2, (1 / 2,)), + (3 / 4, (0, 3 / 4)), + ), + (2 / 9, 1 / 3, 4 / 9), + ) + case self.Wray: + return ( + ( + (0, ()), + (8 / 15, (8 / 15,)), + (2 / 3, (1 / 4, 5 / 12)), + ), + (1 / 4, 0, 3 / 4), + ) + case self.SSPRK3: + return ( + ( + (0, ()), + (1, (1,)), + (1 / 2, (1 / 4, 1 / 4)), + ), + (1 / 6, 1 / 6, 2 / 3), + ) + + @enum.unique + class RK4(enum.StrEnum): + Classic = enum.auto() + Eighth = enum.auto() + Ralston = enum.auto() + + def tableau(self) -> "RKUltra.Tableau": + match self: + case self.Classic: + return ( + ( + (0, ()), + (1 / 2, (1 / 2,)), + (1 / 2, (0, 1 / 2)), + (1, (0, 0, 1)), + ), + (1 / 6, 1 / 3, 1 / 3, 1 / 6), + ) + case self.Eighth: + return ( + ( + (0, ()), + (1 / 3, (1 / 3,)), + (2 / 3, (-1 / 3, 1)), + (1, (1, -1, 1)), + ), + (1 / 8, 3 / 8, 3 / 8, 1 / 8), + ) + case self.Ralston: + sq5: float = math.sqrt(5) + return ( + ( + (0, ()), + (2 / 5, (2 / 5,)), + ( + (14 - 3 * sq5) / 16, + ( + (-2889 + 1428 * sq5) / 1024, + (3785 - 1620 * sq5) / 1024, + ), + ), + ( + 1, + ( + (-3365 + 2094 * sq5) / 6040, + (-975 - 3046 * sq5) / 2552, + (467040 + 203968 * sq5) / 240845, + ), + ), + ), + ( + (263 + 24 * sq5) / 1812, + (125 - 1000 * sq5) / 3828, + (3426304 + 1661952 * sq5) / 5924787, + (30 - 4 * sq5) / 123, + ), + ) order: int = 2 + rk2: RK2 = RK2.Ralston + rk3: RK3 = RK3.Ralston + rk4: RK4 = RK4.Ralston @staticmethod def max_order() -> int: @@ -123,36 +280,13 @@ def step[T: common.Sample]( schedule: list[tuple[float, float]], rng: FunctionalSampler.RNG[T] | None = None, ) -> T: - tableau: RungeKutta.Tableau effective_order = self.order if step + 1 < len(schedule) else 1 - if effective_order >= 3: - if effective_order >= 4: # RK4 - tableau = ( - ( - (0, ()), - (1 / 2, (1 / 2,)), - (1 / 2, (0, 1 / 2)), - (1, (0, 0, 1)), - ), - (1 / 6, 2 / 6, 2 / 6, 1 / 6), - ) - else: # RK3 - tableau = ( - ( - (0, ()), - (1 / 2, (1 / 2,)), - (1, (-1, 2)), - ), - (1 / 6, 4 / 6, 1 / 6), - ) - elif effective_order >= 2: # Heun / RK2 - tableau = ( - ( - (0, ()), - (1, (1,)), - ), - (1 / 2, 1 / 2), - ) + if effective_order >= 4: + tableau = self.rk4.tableau() + elif effective_order >= 3: + tableau = self.rk3.tableau() + elif effective_order >= 2: + tableau = self.rk2.tableau() else: # Euler / RK1 tableau = ( ((0, ()),), From 7f85797bd28a95e79e8d30e0746e803c985c0dff Mon Sep 17 00:00:00 2001 From: Beinsezii Date: Fri, 26 Sep 2025 02:19:37 -0700 Subject: [PATCH 10/59] Add RKUltra tableau test --- tests/miscellaneous.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/tests/miscellaneous.py b/tests/miscellaneous.py index 9d9cdb5..8ebed8d 100644 --- a/tests/miscellaneous.py +++ b/tests/miscellaneous.py @@ -8,6 +8,7 @@ from skrample.common import MergeStrategy, bashforth, sigma_complement, sigmoid, softmax, spowf from skrample.diffusers import SkrampleWrapperScheduler +from skrample.sampling.functional import RKUltra from skrample.sampling.interface import StructuredFunctionalAdapter from skrample.sampling.structured import ( DPM, @@ -220,6 +221,20 @@ def test_bashforth() -> None: assert np.allclose(coeffs, np.array(bashforth(n + 1)), atol=1e-12, rtol=1e-12) +def test_tableau() -> None: + for order in [RKUltra.RK2, RKUltra.RK3, RKUltra.RK4]: + variant: RKUltra.RK2 | RKUltra.RK3 | RKUltra.RK4 + for variant in order: + tab = variant.tableau() + + for stage in tab[0]: + stage_err = abs(stage[0] - math.fsum(stage[1])) + assert stage_err < 1e-15, (variant, stage) + + final_err = abs(1 - math.fsum(variant.tableau()[1])) + assert final_err < 1e-15, variant + + def test_sigmoid() -> None: items = spowf(torch.linspace(-2, 2, 9, dtype=torch.float64), 2) a = torch.sigmoid(items) From c82e35e1187f20171cb2e1a7012b0a86ac8bec36 Mon Sep 17 00:00:00 2001 From: Beinsezii Date: Fri, 26 Sep 2025 02:40:19 -0700 Subject: [PATCH 11/59] RKUltra do not skip non-zero intermediates --- skrample/sampling/functional.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/skrample/sampling/functional.py b/skrample/sampling/functional.py index b3ffd57..93e06b2 100644 --- a/skrample/sampling/functional.py +++ b/skrample/sampling/functional.py @@ -280,12 +280,11 @@ def step[T: common.Sample]( schedule: list[tuple[float, float]], rng: FunctionalSampler.RNG[T] | None = None, ) -> T: - effective_order = self.order if step + 1 < len(schedule) else 1 - if effective_order >= 4: + if self.order >= 4: tableau = self.rk4.tableau() - elif effective_order >= 3: + elif self.order >= 3: tableau = self.rk3.tableau() - elif effective_order >= 2: + elif self.order >= 2: tableau = self.rk2.tableau() else: # Euler / RK1 tableau = ( @@ -306,7 +305,9 @@ def step[T: common.Sample]( ) else: combined = sample - k_terms.append(model(combined, *frac_sc)) + + # Do not call model on timestep = 0 or sigma = 0 + k_terms.append(model(combined, *frac_sc) if not any(abs(v) < 1e-8 for v in frac_sc) else combined) return common.euler( sample, From 059c87449a8ad0529e53752f1c7416b95cca0776 Mon Sep 17 00:00:00 2001 From: Beinsezii Date: Sun, 28 Sep 2025 15:44:22 -0700 Subject: [PATCH 12/59] Add RKUltra.RK5.Nystrom tableau --- skrample/sampling/functional.py | 26 ++++++++++++++++++++++++-- tests/miscellaneous.py | 4 ++-- 2 files changed, 26 insertions(+), 4 deletions(-) diff --git a/skrample/sampling/functional.py b/skrample/sampling/functional.py index 93e06b2..8107c4d 100644 --- a/skrample/sampling/functional.py +++ b/skrample/sampling/functional.py @@ -241,14 +241,34 @@ def tableau(self) -> "RKUltra.Tableau": ), ) + @enum.unique + class RK5(enum.StrEnum): + Nystrom = enum.auto() + + def tableau(self) -> "RKUltra.Tableau": + match self: + case self.Nystrom: + return ( + ( + (0, ()), + (1 / 3, (1 / 3,)), + (2 / 5, (4 / 25, 6 / 25)), + (1, (1 / 4, -3, 15 / 4)), + (2 / 3, (2 / 27, 10 / 9, -50 / 81, 8 / 81)), + (4 / 5, (2 / 25, 12 / 25, 2 / 15, 8 / 75, 0)), + ), + (23 / 192, 0, 125 / 192, 0, -27 / 64, 125 / 192), + ) + order: int = 2 rk2: RK2 = RK2.Ralston rk3: RK3 = RK3.Ralston rk4: RK4 = RK4.Ralston + rk5: RK5 = RK5.Nystrom @staticmethod def max_order() -> int: - return 4 + return 5 def adjust_steps(self, steps: int) -> int: return math.ceil(steps / self.order) # since we skip a call on final step @@ -280,7 +300,9 @@ def step[T: common.Sample]( schedule: list[tuple[float, float]], rng: FunctionalSampler.RNG[T] | None = None, ) -> T: - if self.order >= 4: + if self.order >= 5: + tableau = self.rk5.tableau() + elif self.order >= 4: tableau = self.rk4.tableau() elif self.order >= 3: tableau = self.rk3.tableau() diff --git a/tests/miscellaneous.py b/tests/miscellaneous.py index 8ebed8d..7c1effd 100644 --- a/tests/miscellaneous.py +++ b/tests/miscellaneous.py @@ -222,8 +222,8 @@ def test_bashforth() -> None: def test_tableau() -> None: - for order in [RKUltra.RK2, RKUltra.RK3, RKUltra.RK4]: - variant: RKUltra.RK2 | RKUltra.RK3 | RKUltra.RK4 + for order in [RKUltra.RK2, RKUltra.RK3, RKUltra.RK4, RKUltra.RK5]: + variant: RKUltra.RK2 | RKUltra.RK3 | RKUltra.RK4 | RKUltra.RK5 for variant in order: tab = variant.tableau() From 0a0935254f37a1ef36af7579ea8c0243511726f2 Mon Sep 17 00:00:00 2001 From: Beinsezii Date: Sun, 28 Sep 2025 16:31:08 -0700 Subject: [PATCH 13/59] Tweak RKUltra.adjust_steps() to better handle different tableaus --- skrample/sampling/functional.py | 48 ++++++++++++++++++++------------- 1 file changed, 30 insertions(+), 18 deletions(-) diff --git a/skrample/sampling/functional.py b/skrample/sampling/functional.py index 8107c4d..0cff4d9 100644 --- a/skrample/sampling/functional.py +++ b/skrample/sampling/functional.py @@ -270,8 +270,32 @@ def tableau(self) -> "RKUltra.Tableau": def max_order() -> int: return 5 + def tableau(self, order: int | None = None) -> Tableau: + if order is None: + order = self.order + + if order >= 5: + return self.rk5.tableau() + elif order >= 4: + return self.rk4.tableau() + elif order >= 3: + return self.rk3.tableau() + elif order >= 2: + return self.rk2.tableau() + else: # Euler / RK1 + return ( + ((0, ()),), + (1,), + ) + def adjust_steps(self, steps: int) -> int: - return math.ceil(steps / self.order) # since we skip a call on final step + stages = self.tableau()[0] + calls = len(stages) + + # Add back the skipped calls on penultimate T + adjusted = steps / calls + sum(abs(1 - f[0]) < 1e-8 for f in stages) / calls + + return max(round(adjusted), 1) @staticmethod def fractional_step( @@ -300,23 +324,11 @@ def step[T: common.Sample]( schedule: list[tuple[float, float]], rng: FunctionalSampler.RNG[T] | None = None, ) -> T: - if self.order >= 5: - tableau = self.rk5.tableau() - elif self.order >= 4: - tableau = self.rk4.tableau() - elif self.order >= 3: - tableau = self.rk3.tableau() - elif self.order >= 2: - tableau = self.rk2.tableau() - else: # Euler / RK1 - tableau = ( - ((0, ()),), - (1,), - ) - + stages, composite = self.tableau() k_terms: list[T] = [] - fractions = self.fractional_step(schedule, step, tuple(f[0] for f in tableau[0])) - for frac_sc, icoeffs in zip(fractions, (t[1] for t in tableau[0]), strict=True): + fractions = self.fractional_step(schedule, step, tuple(f[0] for f in stages)) + + for frac_sc, icoeffs in zip(fractions, (t[1] for t in stages), strict=True): if icoeffs: combined: T = common.euler( sample, @@ -333,7 +345,7 @@ def step[T: common.Sample]( return common.euler( sample, - math.sumprod(k_terms, tableau[1]), # type: ignore + math.sumprod(k_terms, composite), # type: ignore schedule[step][1], schedule[step + 1][1] if step + 1 < len(schedule) else 0, self.schedule.sigma_transform, From 9f7cab9db5a9200af0be5470dd3de91e47d47481 Mon Sep 17 00:00:00 2001 From: Beinsezii Date: Sun, 28 Sep 2025 16:42:10 -0700 Subject: [PATCH 14/59] RKUltra add `custom_tableau` field --- skrample/sampling/functional.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/skrample/sampling/functional.py b/skrample/sampling/functional.py index 0cff4d9..54dbf1b 100644 --- a/skrample/sampling/functional.py +++ b/skrample/sampling/functional.py @@ -261,17 +261,27 @@ def tableau(self) -> "RKUltra.Tableau": ) order: int = 2 + rk2: RK2 = RK2.Ralston + "2nd order methods" rk3: RK3 = RK3.Ralston + "3rd order methods" rk4: RK4 = RK4.Ralston + "4th order methods" rk5: RK5 = RK5.Nystrom + "5th order methods" + + custom_tableau: Tableau | None = None + "If set, will use this Butcher tableau instead of picking method based on `RKUltra.order`" @staticmethod def max_order() -> int: return 5 def tableau(self, order: int | None = None) -> Tableau: - if order is None: + if self.custom_tableau is not None: + return self.custom_tableau + elif order is None: order = self.order if order >= 5: From b692567014b5da84cb481ffef2b202915da961c0 Mon Sep 17 00:00:00 2001 From: Beinsezii Date: Sun, 28 Sep 2025 20:54:53 -0700 Subject: [PATCH 15/59] Add examples/diffusers/functional.py --- examples/diffusers/functional.py | 96 ++++++++++++++++++++++++++++++++ 1 file changed, 96 insertions(+) create mode 100755 examples/diffusers/functional.py diff --git a/examples/diffusers/functional.py b/examples/diffusers/functional.py new file mode 100755 index 0000000..6316fbc --- /dev/null +++ b/examples/diffusers/functional.py @@ -0,0 +1,96 @@ +#! /usr/bin/env python + +from typing import ClassVar + +import torch +from diffusers.modular_pipelines.components_manager import ComponentsManager +from diffusers.modular_pipelines.flux.denoise import FluxDenoiseStep, FluxLoopDenoiser +from diffusers.modular_pipelines.flux.modular_blocks import TEXT2IMAGE_BLOCKS +from diffusers.modular_pipelines.flux.modular_pipeline import FluxModularPipeline +from diffusers.modular_pipelines.modular_pipeline import ModularPipelineBlocks, PipelineState, SequentialPipelineBlocks +from tqdm import tqdm + +import skrample.sampling.functional as sampling +import skrample.scheduling as scheduling +from skrample.common import predict_flow +from skrample.diffusers import SkrampleWrapperScheduler +from skrample.sampling.structured import Euler + +model_id = "black-forest-labs/FLUX.1-dev" + +blocks = SequentialPipelineBlocks.from_blocks_dict(TEXT2IMAGE_BLOCKS) + +schedule = scheduling.FlowShift(scheduling.Linear(), shift=2) +sampler = sampling.RKUltra(schedule, order=4) +wrapper = SkrampleWrapperScheduler(sampler=Euler(), schedule=schedule, predictor=predict_flow, allow_dynamic=False) + + +class FunctionalDenoise(FluxDenoiseStep): + # Exclude the after_denoise block + block_classes: ClassVar[list[type[ModularPipelineBlocks]]] = [FluxLoopDenoiser] + block_names: ClassVar[list[str]] = ["denoiser"] + + @torch.no_grad() + def __call__(self, components: FluxModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + block_state["num_inference_steps"] = sampler.adjust_steps(block_state["num_inference_steps"]) + progress = tqdm(total=block_state["num_inference_steps"]) + + i = 0 + + def call_model(sample: torch.Tensor, timestep: float, sigma: float) -> torch.Tensor: + nonlocal i, components, block_state, progress + block_state["latents"] = sample + components, block_state = self.loop_step( + components, + block_state, # type: ignore + i=i, + t=sample.new_tensor([timestep] * len(sample)), + ) + return wrapper.predictor( + sample, + block_state["noise_pred"], # type: ignore + sigma, + schedule.sigma_transform, + ) + + def sample_callback(_: torch.Tensor) -> None: + nonlocal i + i += 1 + progress.update() + + block_state["latents"] = sampler.sample_model( + sample=block_state["latents"], + model=call_model, + steps=block_state["num_inference_steps"], + callback=sample_callback, + ) + + self.set_block_state(state, block_state) # type: ignore + return components, state # type: ignore + + +blocks.sub_blocks["denoise"] = FunctionalDenoise() + +cm = ComponentsManager() +cm.enable_auto_cpu_offload() +pipe = blocks.init_pipeline(components_manager=cm) +pipe.load_components(["text_encoder"], repo=model_id, subfolder="text_encoder", torch_dtype=torch.bfloat16) +pipe.load_components(["tokenizer"], repo=model_id, subfolder="tokenizer") +pipe.load_components(["text_encoder_2"], repo=model_id, subfolder="text_encoder_2", torch_dtype=torch.bfloat16) +pipe.load_components(["tokenizer_2"], repo=model_id, subfolder="tokenizer_2") +pipe.load_components(["transformer"], repo=model_id, subfolder="transformer", torch_dtype=torch.bfloat16) +pipe.load_components(["vae"], repo=model_id, subfolder="vae", torch_dtype=torch.bfloat16) + +pipe.register_components(scheduler=wrapper) + + +pipe( # type: ignore + prompt="sharp, high dynamic range photograph of a kitten on a beach of rainbow pebbles", + generator=torch.Generator("cpu").manual_seed(42), + width=1024, + height=1024, + num_inference_steps=20, + guidance_scale=2.5, +).get("images")[0].save("diffusers_functional.png") From 805f44ab17f46f95a3ba39b74815c77bd30cbde4 Mon Sep 17 00:00:00 2001 From: Beinsezii Date: Sun, 28 Sep 2025 20:59:14 -0700 Subject: [PATCH 16/59] Min diffusers == 0.35 for examples/diffusers/functional.py --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 2a32793..e82ebc7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,7 +29,7 @@ scripts = ["skrample[all]", "matplotlib>=3.10.1"] test = [ "skrample[scripts]", "accelerate>=1.3", - "diffusers>=0.32", + "diffusers>=0.35", "protobuf>=5.29", "pyright>=1.1.400", "pytest-xdist>=3.6.1", From 355021d18df4e9f434662372ffdee27386fb025a Mon Sep 17 00:00:00 2001 From: Beinsezii Date: Sat, 4 Oct 2025 19:38:26 -0700 Subject: [PATCH 17/59] Add FunctionalSampler.generate_model(), some cleanup --- examples/functional.py | 12 ++++----- skrample/common.py | 5 ++++ skrample/sampling/functional.py | 46 ++++++++++++++++++++++++++++----- skrample/sampling/interface.py | 7 +++-- skrample/sampling/structured.py | 5 ++-- skrample/scheduling.py | 2 +- 6 files changed, 57 insertions(+), 20 deletions(-) diff --git a/examples/functional.py b/examples/functional.py index 917af09..2872912 100755 --- a/examples/functional.py +++ b/examples/functional.py @@ -56,16 +56,14 @@ def call_model(x: torch.Tensor, t: float, s: float) -> torch.Tensor: p = conditioned + (cfg - 1) * (conditioned - unconditioned) return skrample.common.predict_epsilon(x, p, s, schedule.sigma_transform) - if isinstance(sampler, functional.FunctionalSinglestep): + if isinstance(sampler, functional.FunctionalHigher): steps = sampler.adjust_steps(steps) - sample = torch.randn([1, 4, 80, 80], generator=seed).to(dtype=dtype, device=device) - rng = noise.Random.from_inputs(sample.shape, seed) + rng = noise.Random.from_inputs((1, 4, 80, 80), seed) bar = tqdm(total=steps) - sample = sampler.sample_model( - sample, - call_model, - steps, + sample = sampler.generate_model( + model=call_model, + steps=steps, rng=lambda: rng.generate().to(dtype=dtype, device=device), callback=lambda _: bar.update(), ) diff --git a/skrample/common.py b/skrample/common.py index 4e81343..488c81e 100644 --- a/skrample/common.py +++ b/skrample/common.py @@ -99,6 +99,11 @@ def euler[T: Sample](sample: T, prediction: T, sigma: float, sigma_next: float, return sample * scale + prediction * delta # type: ignore +def merge_noise[T: Sample](sample: T, noise: T, sigma: float, sigma_transform: SigmaTransform) -> T: + sigma_u, sigma_v = sigma_transform(sigma) + return sample * sigma_v + noise * sigma_u # type: ignore + + def safe_log(x: float) -> float: "Returns inf rather than throw an err" try: diff --git a/skrample/sampling/functional.py b/skrample/sampling/functional.py index 54dbf1b..cfeb31d 100644 --- a/skrample/sampling/functional.py +++ b/skrample/sampling/functional.py @@ -8,21 +8,25 @@ import numpy as np from skrample import common, scheduling +from skrample.common import Sample, SigmaTransform @dataclasses.dataclass(frozen=True) class FunctionalSampler(ABC): - type SampleCallback[T: common.Sample] = Callable[[T], Any] + type SampleCallback[T: Sample] = Callable[[T], Any] "Return is ignored" - type SampleableModel[T: common.Sample] = Callable[[T, float, float], T] + type SampleableModel[T: Sample] = Callable[[T, float, float], T] "sample, timestep, sigma" - type RNG[T: common.Sample] = Callable[[], T] + type RNG[T: Sample] = Callable[[], T] "Distribution should match model, typically normal" schedule: scheduling.SkrampleSchedule + def merge_noise[T: Sample](self, sample: T, noise: T, sigma: float, sigma_transform: SigmaTransform) -> T: + return common.merge_noise(sample, noise, sigma, sigma_transform) + @abstractmethod - def sample_model[T: common.Sample]( + def sample_model[T: Sample]( self, sample: T, model: SampleableModel[T], @@ -31,6 +35,34 @@ def sample_model[T: common.Sample]( rng: RNG[T] | None = None, callback: SampleCallback | None = None, ) -> T: ... + """Runs the noisy sample through the model for a given range `include` of total steps. + Calls callback every step with sampled value.""" + + def generate_model[T: Sample]( + self, + model: SampleableModel[T], + rng: RNG[T], + steps: int, + include: slice = slice(None), + initial: T | None = None, + callback: SampleCallback | None = None, + ) -> T: + """Equivalent to `sample_model` except the noise is handled automatically + rather than being pre-added to the initial value""" + + if initial is None and include.start is None: # Short circuit for common case + sample: T = rng() + else: + sigmas = scheduling.schedule_lru(self.schedule, steps)[:, 1] + sample: T = self.merge_noise( + 0 if initial is None else initial, # type: ignore + rng(), + sigmas[include.start or 0].item(), + self.schedule.sigma_transform, + ) / self.merge_noise(0.0, 1.0, sigmas[0].item(), self.schedule.sigma_transform) + # Rescale sample by initial sigma. Mostly just to handle quirks with Scaled + + return self.sample_model(sample, model, steps, include, rng, callback) @dataclasses.dataclass(frozen=True) @@ -53,7 +85,7 @@ def adjust_steps(self, steps: int) -> int: @dataclasses.dataclass(frozen=True) class FunctionalSinglestep(FunctionalSampler): @abstractmethod - def step[T: common.Sample]( + def step[T: Sample]( self, sample: T, model: FunctionalSampler.SampleableModel[T], @@ -62,7 +94,7 @@ def step[T: common.Sample]( rng: FunctionalSampler.RNG[T] | None = None, ) -> T: ... - def sample_model[T: common.Sample]( + def sample_model[T: Sample]( self, sample: T, model: FunctionalSampler.SampleableModel[T], @@ -326,7 +358,7 @@ def fractional_step( ) return result - def step[T: common.Sample]( + def step[T: Sample]( self, sample: T, model: FunctionalSampler.SampleableModel[T], diff --git a/skrample/sampling/interface.py b/skrample/sampling/interface.py index 792208b..9a96d94 100644 --- a/skrample/sampling/interface.py +++ b/skrample/sampling/interface.py @@ -1,6 +1,6 @@ import dataclasses -from skrample import common +from skrample.common import Sample, SigmaTransform from skrample.sampling import functional, structured @@ -8,7 +8,10 @@ class StructuredFunctionalAdapter(functional.FunctionalSampler): sampler: structured.StructuredSampler - def sample_model[T: common.Sample]( + def merge_noise[T: Sample](self, sample: T, noise: T, sigma: float, sigma_transform: SigmaTransform) -> T: + return self.sampler.merge_noise(sample, noise, sigma, sigma_transform) + + def sample_model[T: Sample]( self, sample: T, model: functional.FunctionalSampler.SampleableModel[T], diff --git a/skrample/sampling/structured.py b/skrample/sampling/structured.py index bf98901..2d1b0d8 100644 --- a/skrample/sampling/structured.py +++ b/skrample/sampling/structured.py @@ -5,7 +5,7 @@ import numpy as np from numpy.typing import NDArray -from skrample.common import Sample, SigmaTransform, bashforth, euler, safe_log, softmax, spowf +from skrample.common import Sample, SigmaTransform, bashforth, euler, merge_noise, safe_log, softmax, spowf @dataclass(frozen=True) @@ -75,8 +75,7 @@ def scale_input[T: Sample](self, sample: T, sigma: float, sigma_transform: Sigma return sample def merge_noise[T: Sample](self, sample: T, noise: T, sigma: float, sigma_transform: SigmaTransform) -> T: - sigma_u, sigma_v = sigma_transform(sigma) - return sample * sigma_v + noise * sigma_u # type: ignore + return merge_noise(sample, noise, sigma, sigma_transform) def __call__[T: Sample]( self, diff --git a/skrample/scheduling.py b/skrample/scheduling.py index affc9fa..3672b6e 100644 --- a/skrample/scheduling.py +++ b/skrample/scheduling.py @@ -104,7 +104,7 @@ def sigmas_to_timesteps(self, sigmas: NDArray[np.float64]) -> NDArray[np.float64 def timesteps(self, steps: int) -> NDArray[np.float64]: # # https://arxiv.org/abs/2305.08891 Table 2 if self.uniform: - return np.linspace(self.base_timesteps - 1, 0, steps + 1, dtype=np.float64).round()[:-1] + return np.linspace(self.base_timesteps - 1, 0, steps, endpoint=False, dtype=np.float64).round() else: # They use a truncated ratio for ...reasons? return np.flip(np.arange(0, steps, dtype=np.float64) * (self.base_timesteps // steps)).round() From a2d5ebb04ac6ac2ba30abfd63543bf4451bf4764 Mon Sep 17 00:00:00 2001 From: Beinsezii Date: Mon, 6 Oct 2025 03:46:23 -0700 Subject: [PATCH 18/59] Basic adaptive samplers --- examples/diffusers/functional.py | 29 ++++-- examples/functional.py | 6 +- scripts/plot_skrample.py | 42 ++++----- skrample/sampling/functional.py | 152 ++++++++++++++++++++++++++++++- skrample/sampling/interface.py | 2 +- 5 files changed, 197 insertions(+), 34 deletions(-) diff --git a/examples/diffusers/functional.py b/examples/diffusers/functional.py index 6316fbc..cb43416 100755 --- a/examples/diffusers/functional.py +++ b/examples/diffusers/functional.py @@ -10,19 +10,29 @@ from diffusers.modular_pipelines.modular_pipeline import ModularPipelineBlocks, PipelineState, SequentialPipelineBlocks from tqdm import tqdm -import skrample.sampling.functional as sampling import skrample.scheduling as scheduling from skrample.common import predict_flow from skrample.diffusers import SkrampleWrapperScheduler -from skrample.sampling.structured import Euler +from skrample.sampling import functional, structured +from skrample.sampling.interface import StructuredFunctionalAdapter model_id = "black-forest-labs/FLUX.1-dev" blocks = SequentialPipelineBlocks.from_blocks_dict(TEXT2IMAGE_BLOCKS) schedule = scheduling.FlowShift(scheduling.Linear(), shift=2) -sampler = sampling.RKUltra(schedule, order=4) -wrapper = SkrampleWrapperScheduler(sampler=Euler(), schedule=schedule, predictor=predict_flow, allow_dynamic=False) +wrapper = SkrampleWrapperScheduler( + sampler=structured.Euler(), schedule=schedule, predictor=predict_flow, allow_dynamic=False +) + +# Equivalent to structured example +sampler = StructuredFunctionalAdapter(schedule, structured.DPM(order=2, add_noise=True)) +# Native functional example +sampler = functional.RKUltra(schedule, 4) +# Dynamic model calls +sampler = functional.FastHeun(schedule) +# Dynamic step sizes +sampler = functional.AdaptiveHeun(schedule) class FunctionalDenoise(FluxDenoiseStep): @@ -34,7 +44,8 @@ class FunctionalDenoise(FluxDenoiseStep): def __call__(self, components: FluxModularPipeline, state: PipelineState) -> PipelineState: block_state = self.get_block_state(state) - block_state["num_inference_steps"] = sampler.adjust_steps(block_state["num_inference_steps"]) + if isinstance(sampler, functional.FunctionalHigher): + block_state["num_inference_steps"] = sampler.adjust_steps(block_state["num_inference_steps"]) progress = tqdm(total=block_state["num_inference_steps"]) i = 0 @@ -55,10 +66,10 @@ def call_model(sample: torch.Tensor, timestep: float, sigma: float) -> torch.Ten schedule.sigma_transform, ) - def sample_callback(_: torch.Tensor) -> None: + def sample_callback(x: torch.Tensor, n: int, t: float, s: float) -> None: nonlocal i - i += 1 - progress.update() + progress.update(n + 1 - progress.n) + i = n + 1 block_state["latents"] = sampler.sample_model( sample=block_state["latents"], @@ -91,6 +102,6 @@ def sample_callback(_: torch.Tensor) -> None: generator=torch.Generator("cpu").manual_seed(42), width=1024, height=1024, - num_inference_steps=20, + num_inference_steps=25, guidance_scale=2.5, ).get("images")[0].save("diffusers_functional.png") diff --git a/examples/functional.py b/examples/functional.py index 2872912..49ebc31 100755 --- a/examples/functional.py +++ b/examples/functional.py @@ -27,6 +27,10 @@ sampler = StructuredFunctionalAdapter(schedule, structured.DPM(order=2, add_noise=True)) # Native functional example sampler = functional.RKUltra(schedule, 4) + # Dynamic model calls + sampler = functional.FastHeun(schedule) + # Dynamic step sizes + sampler = functional.AdaptiveHeun(schedule) tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(url, subfolder="tokenizer") text_encoder: CLIPTextModel = CLIPTextModel.from_pretrained( @@ -65,7 +69,7 @@ def call_model(x: torch.Tensor, t: float, s: float) -> torch.Tensor: model=call_model, steps=steps, rng=lambda: rng.generate().to(dtype=dtype, device=device), - callback=lambda _: bar.update(), + callback=lambda x, n, t, s: bar.update(n + 1 - bar.n), ) image: torch.Tensor = image_encoder.decode(sample / image_encoder.config.scaling_factor).sample[0] # type: ignore diff --git a/scripts/plot_skrample.py b/scripts/plot_skrample.py index 5d7a5d0..a208e82 100755 --- a/scripts/plot_skrample.py +++ b/scripts/plot_skrample.py @@ -1,7 +1,7 @@ #! /usr/bin/env python import math -from argparse import ArgumentParser +from argparse import ArgumentParser, BooleanOptionalAction from collections.abc import Generator from dataclasses import replace from pathlib import Path @@ -70,10 +70,12 @@ def colors(hue_steps: int) -> Generator[list[float]]: "unipc": structured.UniPC(), "spc": structured.SPC(), "rk": functional.RKUltra(scheduling.Linear()), + "fheun": functional.FastHeun(scheduling.Linear()), + "aheun": functional.AdaptiveHeun(scheduling.Linear()), } for k, v in list(SAMPLERS.items()): if isinstance(v, structured.StructuredMultistep | functional.FunctionalHigher): - for o in range(1, v.max_order() + 1): + for o in range(v.min_order(), v.max_order() + 1): if o != v.order: SAMPLERS[k + str(o)] = replace(v, order=o) @@ -107,6 +109,7 @@ def colors(hue_steps: int) -> Generator[list[float]]: # Samplers parser_sampler = subparsers.add_parser("samplers") +parser_sampler.add_argument("--adjust", type=bool, default=True, action=BooleanOptionalAction) parser_sampler.add_argument("--curve", "-k", type=int, default=30) parser_sampler.add_argument("--transform", "-t", type=str, choices=list(TRANSFORMS.keys()), default="polar") parser_sampler.add_argument( @@ -158,7 +161,9 @@ def colors(hue_steps: int) -> Generator[list[float]]: schedule = scheduling.Linear(base_timesteps=10_000, custom_transform=TRANSFORMS[args.transform]) - def sample_model(sampler: structured.StructuredSampler | functional.FunctionalSampler, steps: int) -> list[float]: + def sample_model( + sampler: structured.StructuredSampler | functional.FunctionalSampler, steps: int + ) -> tuple[list[float], list[float]]: if isinstance(sampler, structured.StructuredSampler): sampler = StructuredFunctionalAdapter(schedule, sampler) else: @@ -166,8 +171,16 @@ def sample_model(sampler: structured.StructuredSampler | functional.FunctionalSa sample = 1.0 sampled_values = [sample] + sigmas = [0.0] - if isinstance(sampler, functional.FunctionalHigher) and False: + def callback(x: float, n: int, t: float, s: float) -> None: + nonlocal sampled_values, sigmas + sampled_values.append(x) + sigmas.insert(-1, s) + + if isinstance(sampler, functional.AdaptiveHeun) and args.adjust: + adjusted = schedule.base_timesteps + elif isinstance(sampler, functional.FunctionalHigher) and args.adjust: adjusted = sampler.adjust_steps(steps) else: adjusted = steps @@ -177,34 +190,21 @@ def sample_model(sampler: structured.StructuredSampler | functional.FunctionalSa model=lambda sample, timestep, sigma: sample + math.sin(sigma * args.curve), steps=adjusted, rng=random, - callback=lambda x: sampled_values.append(x), + callback=callback, ) - # if isinstance(sampler, functional.FunctionalHigher): - # sampled_values = np.interp( - # np.linspace(0, 1, steps + 1), - # np.linspace(0, 1, len(sampled_values)), - # np.array(sampled_values), - # ).tolist() - - return sampled_values + return sigmas, sampled_values - plt.plot( - [*schedule.sigmas(schedule.base_timesteps), 0], - sample_model(structured.Euler(), schedule.base_timesteps), - label="Reference", - color=next(COLORS), - ) + plt.plot(*sample_model(structured.Euler(), schedule.base_timesteps), label="Reference", color=next(COLORS)) for sampler in [SAMPLERS[s] for s in args.sampler]: - sigmas = schedule.sigmas(args.steps) label = type(sampler).__name__ if ( isinstance(sampler, structured.StructuredMultistep | functional.FunctionalHigher) and sampler.order != type(sampler).order ): label += " " + str(sampler.order) - plt.plot([*sigmas, 0], sample_model(sampler, args.steps), label=label, color=next(COLORS), linestyle="--") + plt.plot(*sample_model(sampler, args.steps), label=label, color=next(COLORS), linestyle="--") elif args.command == "schedules": plt.xlabel("Step") diff --git a/skrample/sampling/functional.py b/skrample/sampling/functional.py index cfeb31d..5f4c819 100644 --- a/skrample/sampling/functional.py +++ b/skrample/sampling/functional.py @@ -13,7 +13,7 @@ @dataclasses.dataclass(frozen=True) class FunctionalSampler(ABC): - type SampleCallback[T: Sample] = Callable[[T], Any] + type SampleCallback[T: Sample] = Callable[[T, int, float, float], Any] "Return is ignored" type SampleableModel[T: Sample] = Callable[[T, float, float], T] "sample, timestep, sigma" @@ -35,6 +35,7 @@ def sample_model[T: Sample]( rng: RNG[T] | None = None, callback: SampleCallback | None = None, ) -> T: ... + """Runs the noisy sample through the model for a given range `include` of total steps. Calls callback every step with sampled value.""" @@ -109,11 +110,27 @@ def sample_model[T: Sample]( sample = self.step(sample, model, n, schedule, rng) if callback: - callback(sample) + callback(sample, n, *schedule[n] if n < len(schedule) else (0, 0)) return sample +@dataclasses.dataclass(frozen=True) +class FunctionalAdaptive(FunctionalSampler): + type Evaluator[T: Sample] = Callable[[T, T], float] + + @staticmethod + def mse[T: Sample](a: T, b: T) -> float: + error: T = abs(a - b) ** 2 # type: ignore + if isinstance(error, float | int): + return error + else: + return error.mean().item() + + evaluator: Evaluator = mse + threshold: float = 1e-2 + + @dataclasses.dataclass(frozen=True) class RKUltra(FunctionalHigher, FunctionalSinglestep): "Implements almost every single method from https://en.wikipedia.org/wiki/List_of_Runge–Kutta_methods" # noqa: RUF002 @@ -392,3 +409,134 @@ def step[T: Sample]( schedule[step + 1][1] if step + 1 < len(schedule) else 0, self.schedule.sigma_transform, ) + + +@dataclasses.dataclass(frozen=True) +class FastHeun(FunctionalAdaptive, FunctionalSinglestep, FunctionalHigher): + order: int = 2 + + threshold: float = 5e-2 + + @staticmethod + def min_order() -> int: + return 2 + + @staticmethod + def max_order() -> int: + return 2 + + def adjust_steps(self, steps: int) -> int: + return round(steps * 0.75 + 0.25) + + def step[T: Sample]( + self, + sample: T, + model: FunctionalSampler.SampleableModel[T], + step: int, + schedule: list[tuple[float, float]], + rng: FunctionalSampler.RNG[T] | None = None, + ) -> T: + sigma = schedule[step][1] + sigma_next = schedule[step + 1][1] if step + 1 < len(schedule) else 0 + + sigma_u, sigma_v = self.schedule.sigma_transform(sigma) + sigma_u_next, sigma_v_next = self.schedule.sigma_transform(sigma_next) + + scale = sigma_u_next / sigma_u + dt = sigma_v_next - sigma_v * scale + + k1 = model(sample, *schedule[step]) + result: T = sample * scale + k1 * dt # type: ignore + + # Multiplying by step size here is kind of an asspull, but so is this whole solver so... + if ( + step + 1 < len(schedule) + and self.evaluator(sample, result) / max(self.evaluator(0, result), 1e-16) > self.threshold * dt + ): + k2 = model(result, *schedule[step + 1]) + result: T = sample * scale + (k1 + k2) / 2 * dt # type: ignore + + return result + + +@dataclasses.dataclass(frozen=True) +class AdaptiveHeun(FunctionalAdaptive, FunctionalHigher): + order: int = 2 + + threshold: float = 1e-3 + + initial: float = 1 / 50 + maximum: float = 1 / 4 + adaption: float = 0.3 + + @staticmethod + def min_order() -> int: + return 2 + + @staticmethod + def max_order() -> int: + return 2 + + def adjust_steps(self, steps: int) -> int: + return steps + + def sample_model[T: Sample]( + self, + sample: T, + model: FunctionalSampler.SampleableModel[T], + steps: int, + include: slice = slice(None), + rng: FunctionalSampler.RNG[T] | None = None, + callback: FunctionalSampler.SampleCallback | None = None, + ) -> T: + epsilon: float = 1e-16 # lgtm + step_size: int = max(round(steps * self.initial), 1) + + schedule: list[tuple[float, float]] = self.schedule.schedule(steps).tolist() + + indices: list[int] = list(range(steps))[include] + step: int = indices[0] + + while step <= indices[-1]: + step_next = min(step + step_size, indices[-1] + 1) + + sigma = schedule[step][1] + sigma_next = schedule[step_next][1] if step_next < len(schedule) else 0 + + sigma_u, sigma_v = self.schedule.sigma_transform(sigma) + sigma_u_next, sigma_v_next = self.schedule.sigma_transform(sigma_next) + + scale = sigma_u_next / sigma_u + dt = sigma_v_next - sigma_v * scale + + k1 = model(sample, *schedule[step]) + euler_step: T = sample * scale + k1 * dt # type: ignore + + if step_next < len(schedule): + k2 = model(euler_step, *schedule[step_next]) + heun_step: T = sample * scale + (k1 + k2) / 2 * dt # type: ignore + + sample = heun_step + + # The schedule is *also* trying to predict the step size, so we need the next step size too + sigma_next2 = schedule[step_next + step_size][1] if step_next + step_size < len(schedule) else 0 + sigma_u_next2, sigma_v_next2 = self.schedule.sigma_transform(sigma_next2) + dt2 = sigma_v_next2 - sigma_v_next * (sigma_u_next2 / sigma_u_next) + + # Normalize against pure error + error = self.evaluator(euler_step, heun_step) / max(self.evaluator(0, heun_step), epsilon) + # Offset adjustment by dt2 / dt to account for non-linearity + # Basically if we want a 50% larger step but the next dt will already be 25% larger, + # we should only set a 20% larger step ie 1.5 / 1.25 + # Really this could be iterated to contrast dt2/dt and thresh/error until they're 100% matched but eh + adjustment: float = (self.threshold / max(error, epsilon)) ** self.adaption / (dt2 / dt) + step_size = max(round(min(step_size * adjustment, steps * self.maximum)), 1) + else: + sample = euler_step + + if callback: + callback(sample, step, *schedule[step] if step < len(schedule) else (0, 0)) + + step = step_next + + return sample diff --git a/skrample/sampling/interface.py b/skrample/sampling/interface.py index 9a96d94..8291bc4 100644 --- a/skrample/sampling/interface.py +++ b/skrample/sampling/interface.py @@ -48,6 +48,6 @@ def sample_model[T: Sample]( sample = sksamples.final if callback: - callback(sample) + callback(sample, n, *schedule[n] if n < len(schedule) else (0, 0)) return sample From 8bcc08b3eb24bdf91d8a4643ea1c4c8190c6e12b Mon Sep 17 00:00:00 2001 From: Beinsezii Date: Tue, 7 Oct 2025 23:41:16 -0700 Subject: [PATCH 19/59] Rework AdaptiveHeun -> RKMoire --- examples/diffusers/functional.py | 2 +- examples/functional.py | 2 +- scripts/plot_skrample.py | 6 +- skrample/sampling/functional.py | 182 +++++++++++++++++++++++-------- tests/miscellaneous.py | 13 ++- 5 files changed, 152 insertions(+), 53 deletions(-) diff --git a/examples/diffusers/functional.py b/examples/diffusers/functional.py index cb43416..e8d56d7 100755 --- a/examples/diffusers/functional.py +++ b/examples/diffusers/functional.py @@ -32,7 +32,7 @@ # Dynamic model calls sampler = functional.FastHeun(schedule) # Dynamic step sizes -sampler = functional.AdaptiveHeun(schedule) +sampler = functional.RKMoire(schedule) class FunctionalDenoise(FluxDenoiseStep): diff --git a/examples/functional.py b/examples/functional.py index 49ebc31..3bb7152 100755 --- a/examples/functional.py +++ b/examples/functional.py @@ -30,7 +30,7 @@ # Dynamic model calls sampler = functional.FastHeun(schedule) # Dynamic step sizes - sampler = functional.AdaptiveHeun(schedule) + sampler = functional.RKMoire(schedule) tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(url, subfolder="tokenizer") text_encoder: CLIPTextModel = CLIPTextModel.from_pretrained( diff --git a/scripts/plot_skrample.py b/scripts/plot_skrample.py index a208e82..9501090 100755 --- a/scripts/plot_skrample.py +++ b/scripts/plot_skrample.py @@ -69,9 +69,9 @@ def colors(hue_steps: int) -> Generator[list[float]]: "unip": structured.UniP(), "unipc": structured.UniPC(), "spc": structured.SPC(), - "rk": functional.RKUltra(scheduling.Linear()), + "rku": functional.RKUltra(scheduling.Linear()), + "rkm": functional.RKMoire(scheduling.Linear()), "fheun": functional.FastHeun(scheduling.Linear()), - "aheun": functional.AdaptiveHeun(scheduling.Linear()), } for k, v in list(SAMPLERS.items()): if isinstance(v, structured.StructuredMultistep | functional.FunctionalHigher): @@ -178,7 +178,7 @@ def callback(x: float, n: int, t: float, s: float) -> None: sampled_values.append(x) sigmas.insert(-1, s) - if isinstance(sampler, functional.AdaptiveHeun) and args.adjust: + if isinstance(sampler, functional.RKMoire) and args.adjust: adjusted = schedule.base_timesteps elif isinstance(sampler, functional.FunctionalHigher) and args.adjust: adjusted = sampler.adjust_steps(steps) diff --git a/skrample/sampling/functional.py b/skrample/sampling/functional.py index 5f4c819..270219c 100644 --- a/skrample/sampling/functional.py +++ b/skrample/sampling/functional.py @@ -11,6 +11,25 @@ from skrample.common import Sample, SigmaTransform +def fractional_step( + schedule: list[tuple[float, float]], + current: int, + idx: tuple[float, ...], +) -> tuple[tuple[float, float], ...]: + schedule_np = np.array([*schedule, (0, 0)]) + idx_np = np.array(idx) / len(schedule) + current / len(schedule) + scale = np.linspace(0, 1, len(schedule_np)) + + # TODO (beinszeii): better 2d interpolate + result = tuple( + zip( + (np.interp(idx_np, scale, schedule_np[:, 0])).tolist(), + (np.interp(idx_np, scale, schedule_np[:, 1])).tolist(), + ) + ) + return result + + @dataclasses.dataclass(frozen=True) class FunctionalSampler(ABC): type SampleCallback[T: Sample] = Callable[[T, int, float, float], Any] @@ -356,25 +375,6 @@ def adjust_steps(self, steps: int) -> int: return max(round(adjusted), 1) - @staticmethod - def fractional_step( - schedule: list[tuple[float, float]], - current: int, - idx: tuple[float, ...], - ) -> tuple[tuple[float, float], ...]: - schedule_np = np.array([*schedule, (0, 0)]) - idx_np = np.array(idx) / len(schedule) + current / len(schedule) - scale = np.linspace(0, 1, len(schedule_np)) - - # TODO (beinszeii): better 2d interpolate - result = tuple( - zip( - (np.interp(idx_np, scale, schedule_np[:, 0])).tolist(), - (np.interp(idx_np, scale, schedule_np[:, 1])).tolist(), - ) - ) - return result - def step[T: Sample]( self, sample: T, @@ -385,7 +385,7 @@ def step[T: Sample]( ) -> T: stages, composite = self.tableau() k_terms: list[T] = [] - fractions = self.fractional_step(schedule, step, tuple(f[0] for f in stages)) + fractions = fractional_step(schedule, step, tuple(f[0] for f in stages)) for frac_sc, icoeffs in zip(fractions, (t[1] for t in stages), strict=True): if icoeffs: @@ -460,7 +460,16 @@ def step[T: Sample]( @dataclasses.dataclass(frozen=True) -class AdaptiveHeun(FunctionalAdaptive, FunctionalHigher): +class RKMoire(FunctionalAdaptive, FunctionalHigher): + type ExtendedTableau = tuple[ + tuple[ + tuple[float, tuple[float, ...]], + ..., + ], + tuple[float, ...], + tuple[float, ...], + ] + order: int = 2 threshold: float = 1e-3 @@ -469,17 +478,74 @@ class AdaptiveHeun(FunctionalAdaptive, FunctionalHigher): maximum: float = 1 / 4 adaption: float = 0.3 + rescale_init: bool = True + "Scale initial by a tableau's model evals." + + @enum.unique + class RKE2(enum.StrEnum): + Heun = enum.auto() + # Fehlberg = enum.auto() + + def tableau(self) -> "RKMoire.ExtendedTableau": + match self: + case self.Heun: + return ( + ( + (0, ()), + (1, (1,)), + ), + (1 / 2, 1 / 2), + (1, 0), + ) + + @enum.unique + class RKE5(enum.StrEnum): + Fehlberg = enum.auto() + # CashKarp = enum.auto() + # DormandPrince = enum.auto() + + def tableau(self) -> "RKMoire.ExtendedTableau": + match self: + case self.Fehlberg: + return ( + ( + (0, ()), + (1 / 4, (1 / 4,)), + (3 / 8, (3 / 32, 9 / 32)), + (12 / 13, (1932 / 2197, -7200 / 2197, 7296 / 2197)), + (1, (439 / 216, -8, 3680 / 513, -845 / 4104)), + (1 / 2, (-8 / 27, 2, -3544 / 2565, 1859 / 4104, -11 / 40)), + ), + (16 / 135, 0, 6656 / 12825, 28561 / 56430, -9 / 50, 2 / 55), + (25 / 216, 0, 1408 / 2565, 2197 / 4104, -1 / 5, 0), + ) + + custom_tableau: ExtendedTableau | None = None + rke2: RKE2 = RKE2.Heun + rke5: RKE5 = RKE5.Fehlberg + @staticmethod def min_order() -> int: return 2 @staticmethod def max_order() -> int: - return 2 + return 5 def adjust_steps(self, steps: int) -> int: return steps + def tableau(self, order: int | None = None) -> ExtendedTableau: + if self.custom_tableau is not None: + return self.custom_tableau + elif order is None: + order = self.order + + if order >= 5: + return self.rke5.tableau() + else: + return self.rke2.tableau() + def sample_model[T: Sample]( self, sample: T, @@ -489,8 +555,14 @@ def sample_model[T: Sample]( rng: FunctionalSampler.RNG[T] | None = None, callback: FunctionalSampler.SampleCallback | None = None, ) -> T: + stages, comp_high, comp_low = self.tableau() + + initial = self.initial + if self.rescale_init: + initial *= len(stages) / 2 # Heun is base so / 2 + + step_size: int = max(round(steps * initial), 1) epsilon: float = 1e-16 # lgtm - step_size: int = max(round(steps * self.initial), 1) schedule: list[tuple[float, float]] = self.schedule.schedule(steps).tolist() @@ -500,42 +572,66 @@ def sample_model[T: Sample]( while step <= indices[-1]: step_next = min(step + step_size, indices[-1] + 1) - sigma = schedule[step][1] - sigma_next = schedule[step_next][1] if step_next < len(schedule) else 0 - - sigma_u, sigma_v = self.schedule.sigma_transform(sigma) - sigma_u_next, sigma_v_next = self.schedule.sigma_transform(sigma_next) + k_terms: list[T] = [] + fractions = fractional_step(schedule, step, tuple(f[0] * step_size for f in stages)) + + for frac_sc, icoeffs in zip(fractions, (t[1] for t in stages), strict=True): + if icoeffs: + combined: T = common.euler( + sample, + math.sumprod(k_terms, icoeffs) / math.fsum(icoeffs), # type: ignore + schedule[step][1], + frac_sc[1], + self.schedule.sigma_transform, + ) + else: + combined = sample - scale = sigma_u_next / sigma_u - dt = sigma_v_next - sigma_v * scale + # Do not call model on timestep = 0 or sigma = 0 + k_terms.append(model(combined, *frac_sc) if not any(abs(v) < 1e-8 for v in frac_sc) else combined) - k1 = model(sample, *schedule[step]) - euler_step: T = sample * scale + k1 * dt # type: ignore + sample_high: T = common.euler( + sample, + math.sumprod(k_terms, comp_high), # type: ignore + schedule[step][1], + schedule[step_next][1] if step_next < len(schedule) else 0, + self.schedule.sigma_transform, + ) if step_next < len(schedule): - k2 = model(euler_step, *schedule[step_next]) - heun_step: T = sample * scale + (k1 + k2) / 2 * dt # type: ignore - - sample = heun_step - - # The schedule is *also* trying to predict the step size, so we need the next step size too + sigma = schedule[step][1] + sigma_next = schedule[step_next][1] if step_next < len(schedule) else 0 sigma_next2 = schedule[step_next + step_size][1] if step_next + step_size < len(schedule) else 0 + + sigma_u, sigma_v = self.schedule.sigma_transform(sigma) + sigma_u_next, sigma_v_next = self.schedule.sigma_transform(sigma_next) sigma_u_next2, sigma_v_next2 = self.schedule.sigma_transform(sigma_next2) + + dt = sigma_v_next - sigma_v * (sigma_u_next / sigma_u) dt2 = sigma_v_next2 - sigma_v_next * (sigma_u_next2 / sigma_u_next) + dt1x2 = dt2 / dt + + sample_low: T = common.euler( + sample, + math.sumprod(k_terms, comp_low), # type: ignore + schedule[step][1], + schedule[step_next][1] if step_next < len(schedule) else 0, + self.schedule.sigma_transform, + ) # Normalize against pure error - error = self.evaluator(euler_step, heun_step) / max(self.evaluator(0, heun_step), epsilon) + error = self.evaluator(sample_low, sample_high) / max(self.evaluator(0, sample_high), epsilon) # Offset adjustment by dt2 / dt to account for non-linearity # Basically if we want a 50% larger step but the next dt will already be 25% larger, # we should only set a 20% larger step ie 1.5 / 1.25 # Really this could be iterated to contrast dt2/dt and thresh/error until they're 100% matched but eh - adjustment: float = (self.threshold / max(error, epsilon)) ** self.adaption / (dt2 / dt) + adjustment: float = (self.threshold / max(error, epsilon)) ** self.adaption / dt1x2 step_size = max(round(min(step_size * adjustment, steps * self.maximum)), 1) - else: - sample = euler_step + + sample = sample_high if callback: - callback(sample, step, *schedule[step] if step < len(schedule) else (0, 0)) + callback(sample, step_next - 1, *schedule[step] if step < len(schedule) else (0, 0)) step = step_next diff --git a/tests/miscellaneous.py b/tests/miscellaneous.py index 7c1effd..fb2790a 100644 --- a/tests/miscellaneous.py +++ b/tests/miscellaneous.py @@ -8,7 +8,7 @@ from skrample.common import MergeStrategy, bashforth, sigma_complement, sigmoid, softmax, spowf from skrample.diffusers import SkrampleWrapperScheduler -from skrample.sampling.functional import RKUltra +from skrample.sampling.functional import RKMoire, RKUltra from skrample.sampling.interface import StructuredFunctionalAdapter from skrample.sampling.structured import ( DPM, @@ -222,18 +222,21 @@ def test_bashforth() -> None: def test_tableau() -> None: - for order in [RKUltra.RK2, RKUltra.RK3, RKUltra.RK4, RKUltra.RK5]: - variant: RKUltra.RK2 | RKUltra.RK3 | RKUltra.RK4 | RKUltra.RK5 + for order in [RKUltra.RK2, RKUltra.RK3, RKUltra.RK4, RKUltra.RK5, RKMoire.RKE2, RKMoire.RKE5]: for variant in order: - tab = variant.tableau() + tab: RKUltra.Tableau | RKMoire.ExtendedTableau = variant.tableau() for stage in tab[0]: stage_err = abs(stage[0] - math.fsum(stage[1])) assert stage_err < 1e-15, (variant, stage) - final_err = abs(1 - math.fsum(variant.tableau()[1])) + final_err = abs(1 - math.fsum(tab[1])) assert final_err < 1e-15, variant + if len(tab) > 2: + low_err = abs(1 - math.fsum(tab[2])) + assert low_err < 1e-15, variant + def test_sigmoid() -> None: items = spowf(torch.linspace(-2, 2, 9, dtype=torch.float64), 2) From 16f27605900c3effc28b2504da7cdf3a6a7988ba Mon Sep 17 00:00:00 2001 From: Beinsezii Date: Sat, 11 Oct 2025 17:56:51 -0700 Subject: [PATCH 20/59] Move tableaux definitions into separate module --- skrample/sampling/functional.py | 282 ++++---------------------------- skrample/sampling/tableaux.py | 249 ++++++++++++++++++++++++++++ tests/miscellaneous.py | 8 +- 3 files changed, 282 insertions(+), 257 deletions(-) create mode 100644 skrample/sampling/tableaux.py diff --git a/skrample/sampling/functional.py b/skrample/sampling/functional.py index 270219c..1eefe13 100644 --- a/skrample/sampling/functional.py +++ b/skrample/sampling/functional.py @@ -1,5 +1,4 @@ import dataclasses -import enum import math from abc import ABC, abstractmethod from collections.abc import Callable @@ -10,6 +9,8 @@ from skrample import common, scheduling from skrample.common import Sample, SigmaTransform +from . import tableaux + def fractional_step( schedule: list[tuple[float, float]], @@ -154,217 +155,34 @@ def mse[T: Sample](a: T, b: T) -> float: class RKUltra(FunctionalHigher, FunctionalSinglestep): "Implements almost every single method from https://en.wikipedia.org/wiki/List_of_Runge–Kutta_methods" # noqa: RUF002 - type Tableau = tuple[ - tuple[ - tuple[float, tuple[float, ...]], - ..., - ], - tuple[float, ...], - ] - - @enum.unique - class RK2(enum.StrEnum): - Heun = enum.auto() - Mid = enum.auto() - Ralston = enum.auto() - - def tableau(self) -> "RKUltra.Tableau": - match self: - case self.Heun: - return ( - ( - (0, ()), - (1, (1,)), - ), - (1 / 2, 1 / 2), - ) - case self.Mid: - return ( - ( - (0, ()), - (1 / 2, (1 / 2,)), - ), - (0, 1), - ) - case self.Ralston: - return ( - ( - (0, ()), - (2 / 3, (2 / 3,)), - ), - (1 / 4, 3 / 4), - ) - - @enum.unique - class RK3(enum.StrEnum): - Kutta = enum.auto() - Heun = enum.auto() - Ralston = enum.auto() - Wray = enum.auto() - SSPRK3 = enum.auto() - - def tableau(self) -> "RKUltra.Tableau": - match self: - case self.Kutta: - return ( - ( - (0, ()), - (1 / 2, (1 / 2,)), - (1, (-1, 2)), - ), - (1 / 6, 2 / 3, 1 / 6), - ) - case self.Heun: - return ( - ( - (0, ()), - (1 / 3, (1 / 3,)), - (2 / 3, (0, 2 / 3)), - ), - (1 / 4, 0, 3 / 4), - ) - case self.Ralston: - return ( - ( - (0, ()), - (1 / 2, (1 / 2,)), - (3 / 4, (0, 3 / 4)), - ), - (2 / 9, 1 / 3, 4 / 9), - ) - case self.Wray: - return ( - ( - (0, ()), - (8 / 15, (8 / 15,)), - (2 / 3, (1 / 4, 5 / 12)), - ), - (1 / 4, 0, 3 / 4), - ) - case self.SSPRK3: - return ( - ( - (0, ()), - (1, (1,)), - (1 / 2, (1 / 4, 1 / 4)), - ), - (1 / 6, 1 / 6, 2 / 3), - ) - - @enum.unique - class RK4(enum.StrEnum): - Classic = enum.auto() - Eighth = enum.auto() - Ralston = enum.auto() - - def tableau(self) -> "RKUltra.Tableau": - match self: - case self.Classic: - return ( - ( - (0, ()), - (1 / 2, (1 / 2,)), - (1 / 2, (0, 1 / 2)), - (1, (0, 0, 1)), - ), - (1 / 6, 1 / 3, 1 / 3, 1 / 6), - ) - case self.Eighth: - return ( - ( - (0, ()), - (1 / 3, (1 / 3,)), - (2 / 3, (-1 / 3, 1)), - (1, (1, -1, 1)), - ), - (1 / 8, 3 / 8, 3 / 8, 1 / 8), - ) - case self.Ralston: - sq5: float = math.sqrt(5) - return ( - ( - (0, ()), - (2 / 5, (2 / 5,)), - ( - (14 - 3 * sq5) / 16, - ( - (-2889 + 1428 * sq5) / 1024, - (3785 - 1620 * sq5) / 1024, - ), - ), - ( - 1, - ( - (-3365 + 2094 * sq5) / 6040, - (-975 - 3046 * sq5) / 2552, - (467040 + 203968 * sq5) / 240845, - ), - ), - ), - ( - (263 + 24 * sq5) / 1812, - (125 - 1000 * sq5) / 3828, - (3426304 + 1661952 * sq5) / 5924787, - (30 - 4 * sq5) / 123, - ), - ) - - @enum.unique - class RK5(enum.StrEnum): - Nystrom = enum.auto() - - def tableau(self) -> "RKUltra.Tableau": - match self: - case self.Nystrom: - return ( - ( - (0, ()), - (1 / 3, (1 / 3,)), - (2 / 5, (4 / 25, 6 / 25)), - (1, (1 / 4, -3, 15 / 4)), - (2 / 3, (2 / 27, 10 / 9, -50 / 81, 8 / 81)), - (4 / 5, (2 / 25, 12 / 25, 2 / 15, 8 / 75, 0)), - ), - (23 / 192, 0, 125 / 192, 0, -27 / 64, 125 / 192), - ) - order: int = 2 - rk2: RK2 = RK2.Ralston - "2nd order methods" - rk3: RK3 = RK3.Ralston - "3rd order methods" - rk4: RK4 = RK4.Ralston - "4th order methods" - rk5: RK5 = RK5.Nystrom - "5th order methods" + providers: tuple[tableaux.TableauProvider | tableaux.ExtendedTableauProvider, ...] = ( + tableaux.RK2.Ralston, + tableaux.RK3.Ralston, + tableaux.RK4.Ralston, + tableaux.RK5.Nystrom, + ) + """Providers for a given order, starting from 2. + Order 1 is always the Euler method.""" - custom_tableau: Tableau | None = None + custom_tableau: tableaux.Tableau | tableaux.ExtendedTableau | None = None "If set, will use this Butcher tableau instead of picking method based on `RKUltra.order`" @staticmethod def max_order() -> int: return 5 - def tableau(self, order: int | None = None) -> Tableau: + def tableau(self, order: int | None = None) -> tableaux.Tableau: if self.custom_tableau is not None: - return self.custom_tableau + return self.custom_tableau[:2] elif order is None: order = self.order - if order >= 5: - return self.rk5.tableau() - elif order >= 4: - return self.rk4.tableau() - elif order >= 3: - return self.rk3.tableau() - elif order >= 2: - return self.rk2.tableau() + if order >= 2 and (morder := len(self.providers)): + return self.providers[min(order - 2, morder - 1)].tableau()[:2] else: # Euler / RK1 - return ( - ((0, ()),), - (1,), - ) + return tableaux.RK1 def adjust_steps(self, steps: int) -> int: stages = self.tableau()[0] @@ -461,17 +279,15 @@ def step[T: Sample]( @dataclasses.dataclass(frozen=True) class RKMoire(FunctionalAdaptive, FunctionalHigher): - type ExtendedTableau = tuple[ - tuple[ - tuple[float, tuple[float, ...]], - ..., - ], - tuple[float, ...], - tuple[float, ...], - ] - order: int = 2 + providers: tuple[tableaux.ExtendedTableauProvider, ...] = ( + tableaux.RKE2.Heun, + tableaux.RKE2.Heun, + tableaux.RKE2.Heun, + tableaux.RKE5.Fehlberg, + ) + threshold: float = 1e-3 initial: float = 1 / 50 @@ -481,48 +297,8 @@ class RKMoire(FunctionalAdaptive, FunctionalHigher): rescale_init: bool = True "Scale initial by a tableau's model evals." - @enum.unique - class RKE2(enum.StrEnum): - Heun = enum.auto() - # Fehlberg = enum.auto() - - def tableau(self) -> "RKMoire.ExtendedTableau": - match self: - case self.Heun: - return ( - ( - (0, ()), - (1, (1,)), - ), - (1 / 2, 1 / 2), - (1, 0), - ) - - @enum.unique - class RKE5(enum.StrEnum): - Fehlberg = enum.auto() - # CashKarp = enum.auto() - # DormandPrince = enum.auto() - - def tableau(self) -> "RKMoire.ExtendedTableau": - match self: - case self.Fehlberg: - return ( - ( - (0, ()), - (1 / 4, (1 / 4,)), - (3 / 8, (3 / 32, 9 / 32)), - (12 / 13, (1932 / 2197, -7200 / 2197, 7296 / 2197)), - (1, (439 / 216, -8, 3680 / 513, -845 / 4104)), - (1 / 2, (-8 / 27, 2, -3544 / 2565, 1859 / 4104, -11 / 40)), - ), - (16 / 135, 0, 6656 / 12825, 28561 / 56430, -9 / 50, 2 / 55), - (25 / 216, 0, 1408 / 2565, 2197 / 4104, -1 / 5, 0), - ) - - custom_tableau: ExtendedTableau | None = None - rke2: RKE2 = RKE2.Heun - rke5: RKE5 = RKE5.Fehlberg + custom_tableau: tableaux.ExtendedTableau | None = None + "If set, will use this Butcher tableau instead of picking method based on `RKUltra.order`" @staticmethod def min_order() -> int: @@ -535,16 +311,16 @@ def max_order() -> int: def adjust_steps(self, steps: int) -> int: return steps - def tableau(self, order: int | None = None) -> ExtendedTableau: + def tableau(self, order: int | None = None) -> tableaux.ExtendedTableau: if self.custom_tableau is not None: return self.custom_tableau elif order is None: order = self.order - if order >= 5: - return self.rke5.tableau() + if order >= 2 and (morder := len(self.providers)): + return self.providers[min(order - 2, morder - 1)].tableau() else: - return self.rke2.tableau() + return tableaux.RKE2.Heun.tableau() def sample_model[T: Sample]( self, diff --git a/skrample/sampling/tableaux.py b/skrample/sampling/tableaux.py new file mode 100644 index 0000000..983bf03 --- /dev/null +++ b/skrample/sampling/tableaux.py @@ -0,0 +1,249 @@ +import abc +import enum +import math +from typing import Protocol + +type TabNode = tuple[float, tuple[float, ...]] +type TabWeight = tuple[float, ...] + + +type Tableau = tuple[ + tuple[TabNode, ...], + TabWeight, +] + +type ExtendedTableau = tuple[ + tuple[TabNode, ...], + TabWeight, + TabWeight, +] + + +class TableauProvider(Protocol): + @abc.abstractmethod + def tableau(self) -> Tableau: + raise NotImplementedError + + +class ExtendedTableauProvider(Protocol): + @abc.abstractmethod + def tableau(self) -> ExtendedTableau: + raise NotImplementedError + + +RK1: Tableau = ( + ((0, ()),), + (1,), +) +"Euler method" + + +@enum.unique +class RK2(enum.StrEnum): + Heun = enum.auto() + Mid = enum.auto() + Ralston = enum.auto() + + def tableau(self) -> Tableau: + match self: + case self.Heun: + return ( + ( + (0, ()), + (1, (1,)), + ), + (1 / 2, 1 / 2), + ) + case self.Mid: + return ( + ( + (0, ()), + (1 / 2, (1 / 2,)), + ), + (0, 1), + ) + case self.Ralston: + return ( + ( + (0, ()), + (2 / 3, (2 / 3,)), + ), + (1 / 4, 3 / 4), + ) + + +@enum.unique +class RK3(enum.StrEnum): + Kutta = enum.auto() + Heun = enum.auto() + Ralston = enum.auto() + Wray = enum.auto() + SSPRK3 = enum.auto() + + def tableau(self) -> Tableau: + match self: + case self.Kutta: + return ( + ( + (0, ()), + (1 / 2, (1 / 2,)), + (1, (-1, 2)), + ), + (1 / 6, 2 / 3, 1 / 6), + ) + case self.Heun: + return ( + ( + (0, ()), + (1 / 3, (1 / 3,)), + (2 / 3, (0, 2 / 3)), + ), + (1 / 4, 0, 3 / 4), + ) + case self.Ralston: + return ( + ( + (0, ()), + (1 / 2, (1 / 2,)), + (3 / 4, (0, 3 / 4)), + ), + (2 / 9, 1 / 3, 4 / 9), + ) + case self.Wray: + return ( + ( + (0, ()), + (8 / 15, (8 / 15,)), + (2 / 3, (1 / 4, 5 / 12)), + ), + (1 / 4, 0, 3 / 4), + ) + case self.SSPRK3: + return ( + ( + (0, ()), + (1, (1,)), + (1 / 2, (1 / 4, 1 / 4)), + ), + (1 / 6, 1 / 6, 2 / 3), + ) + + +@enum.unique +class RK4(enum.StrEnum): + Classic = enum.auto() + Eighth = enum.auto() + Ralston = enum.auto() + + def tableau(self) -> Tableau: + match self: + case self.Classic: + return ( + ( + (0, ()), + (1 / 2, (1 / 2,)), + (1 / 2, (0, 1 / 2)), + (1, (0, 0, 1)), + ), + (1 / 6, 1 / 3, 1 / 3, 1 / 6), + ) + case self.Eighth: + return ( + ( + (0, ()), + (1 / 3, (1 / 3,)), + (2 / 3, (-1 / 3, 1)), + (1, (1, -1, 1)), + ), + (1 / 8, 3 / 8, 3 / 8, 1 / 8), + ) + case self.Ralston: + sq5: float = math.sqrt(5) + return ( + ( + (0, ()), + (2 / 5, (2 / 5,)), + ( + (14 - 3 * sq5) / 16, + ( + (-2889 + 1428 * sq5) / 1024, + (3785 - 1620 * sq5) / 1024, + ), + ), + ( + 1, + ( + (-3365 + 2094 * sq5) / 6040, + (-975 - 3046 * sq5) / 2552, + (467040 + 203968 * sq5) / 240845, + ), + ), + ), + ( + (263 + 24 * sq5) / 1812, + (125 - 1000 * sq5) / 3828, + (3426304 + 1661952 * sq5) / 5924787, + (30 - 4 * sq5) / 123, + ), + ) + + +@enum.unique +class RK5(enum.StrEnum): + Nystrom = enum.auto() + + def tableau(self) -> Tableau: + match self: + case self.Nystrom: + return ( + ( + (0, ()), + (1 / 3, (1 / 3,)), + (2 / 5, (4 / 25, 6 / 25)), + (1, (1 / 4, -3, 15 / 4)), + (2 / 3, (2 / 27, 10 / 9, -50 / 81, 8 / 81)), + (4 / 5, (2 / 25, 12 / 25, 2 / 15, 8 / 75, 0)), + ), + (23 / 192, 0, 125 / 192, 0, -27 / 64, 125 / 192), + ) + + +@enum.unique +class RKE2(enum.StrEnum): + Heun = enum.auto() + # Fehlberg = enum.auto() + + def tableau(self) -> ExtendedTableau: + match self: + case self.Heun: + return ( + ( + (0, ()), + (1, (1,)), + ), + (1 / 2, 1 / 2), + (1, 0), + ) + + +@enum.unique +class RKE5(enum.StrEnum): + Fehlberg = enum.auto() + # CashKarp = enum.auto() + # DormandPrince = enum.auto() + + def tableau(self) -> ExtendedTableau: + match self: + case self.Fehlberg: + return ( + ( + (0, ()), + (1 / 4, (1 / 4,)), + (3 / 8, (3 / 32, 9 / 32)), + (12 / 13, (1932 / 2197, -7200 / 2197, 7296 / 2197)), + (1, (439 / 216, -8, 3680 / 513, -845 / 4104)), + (1 / 2, (-8 / 27, 2, -3544 / 2565, 1859 / 4104, -11 / 40)), + ), + (16 / 135, 0, 6656 / 12825, 28561 / 56430, -9 / 50, 2 / 55), + (25 / 216, 0, 1408 / 2565, 2197 / 4104, -1 / 5, 0), + ) diff --git a/tests/miscellaneous.py b/tests/miscellaneous.py index fb2790a..facff5a 100644 --- a/tests/miscellaneous.py +++ b/tests/miscellaneous.py @@ -8,7 +8,7 @@ from skrample.common import MergeStrategy, bashforth, sigma_complement, sigmoid, softmax, spowf from skrample.diffusers import SkrampleWrapperScheduler -from skrample.sampling.functional import RKMoire, RKUltra +from skrample.sampling import tableaux from skrample.sampling.interface import StructuredFunctionalAdapter from skrample.sampling.structured import ( DPM, @@ -222,9 +222,9 @@ def test_bashforth() -> None: def test_tableau() -> None: - for order in [RKUltra.RK2, RKUltra.RK3, RKUltra.RK4, RKUltra.RK5, RKMoire.RKE2, RKMoire.RKE5]: - for variant in order: - tab: RKUltra.Tableau | RKMoire.ExtendedTableau = variant.tableau() + for provider in [tableaux.RK2, tableaux.RK3, tableaux.RK4, tableaux.RK5, tableaux.RKE2, tableaux.RKE5]: + for variant in provider: + tab: tableaux.Tableau | tableaux.ExtendedTableau = variant.tableau() for stage in tab[0]: stage_err = abs(stage[0] - math.fsum(stage[1])) From 3c841d0f9a8d82ac6789cb3d851136d77a5c4f77 Mon Sep 17 00:00:00 2001 From: Beinsezii Date: Sat, 11 Oct 2025 18:29:19 -0700 Subject: [PATCH 21/59] Deduplicate tableau sampling code --- skrample/sampling/functional.py | 114 +++++++++++++++----------------- 1 file changed, 52 insertions(+), 62 deletions(-) diff --git a/skrample/sampling/functional.py b/skrample/sampling/functional.py index 1eefe13..c38a69e 100644 --- a/skrample/sampling/functional.py +++ b/skrample/sampling/functional.py @@ -31,6 +31,46 @@ def fractional_step( return result +def step_tableau[T: Sample]( + tableau: tableaux.Tableau | tableaux.ExtendedTableau, + sample: T, + model: "FunctionalSampler.SampleableModel[T]", + step: int, + schedule: list[tuple[float, float]], + transform: SigmaTransform, + step_size: int = 1, +) -> tuple[T, ...]: + nodes, weights = tableau[0], tableau[1:] + k_terms: list[T] = [] + fractions = fractional_step(schedule, step, tuple(f[0] * step_size for f in nodes)) + + for frac_sc, icoeffs in zip(fractions, (t[1] for t in nodes), strict=True): + if icoeffs: + combined: T = common.euler( + sample, + math.sumprod(k_terms, icoeffs) / math.fsum(icoeffs), # type: ignore + schedule[step][1], + frac_sc[1], + transform, + ) + else: + combined = sample + + # Do not call model on timestep = 0 or sigma = 0 + k_terms.append(model(combined, *frac_sc) if not any(abs(v) < 1e-8 for v in frac_sc) else combined) + + return tuple( + common.euler( + sample, + math.sumprod(k_terms, w), # type: ignore + schedule[step][1], + schedule[step + step_size][1] if step + step_size < len(schedule) else 0, + transform, + ) + for w in weights + ) + + @dataclasses.dataclass(frozen=True) class FunctionalSampler(ABC): type SampleCallback[T: Sample] = Callable[[T, int, float, float], Any] @@ -201,32 +241,7 @@ def step[T: Sample]( schedule: list[tuple[float, float]], rng: FunctionalSampler.RNG[T] | None = None, ) -> T: - stages, composite = self.tableau() - k_terms: list[T] = [] - fractions = fractional_step(schedule, step, tuple(f[0] for f in stages)) - - for frac_sc, icoeffs in zip(fractions, (t[1] for t in stages), strict=True): - if icoeffs: - combined: T = common.euler( - sample, - math.sumprod(k_terms, icoeffs) / math.fsum(icoeffs), # type: ignore - schedule[step][1], - frac_sc[1], - self.schedule.sigma_transform, - ) - else: - combined = sample - - # Do not call model on timestep = 0 or sigma = 0 - k_terms.append(model(combined, *frac_sc) if not any(abs(v) < 1e-8 for v in frac_sc) else combined) - - return common.euler( - sample, - math.sumprod(k_terms, composite), # type: ignore - schedule[step][1], - schedule[step + 1][1] if step + 1 < len(schedule) else 0, - self.schedule.sigma_transform, - ) + return step_tableau(self.tableau(), sample, model, step, schedule, self.schedule.sigma_transform)[0] @dataclasses.dataclass(frozen=True) @@ -331,11 +346,11 @@ def sample_model[T: Sample]( rng: FunctionalSampler.RNG[T] | None = None, callback: FunctionalSampler.SampleCallback | None = None, ) -> T: - stages, comp_high, comp_low = self.tableau() + tab = self.tableau() initial = self.initial if self.rescale_init: - initial *= len(stages) / 2 # Heun is base so / 2 + initial *= len(tab[0]) / 2 # Heun is base so / 2 step_size: int = max(round(steps * initial), 1) epsilon: float = 1e-16 # lgtm @@ -348,33 +363,11 @@ def sample_model[T: Sample]( while step <= indices[-1]: step_next = min(step + step_size, indices[-1] + 1) - k_terms: list[T] = [] - fractions = fractional_step(schedule, step, tuple(f[0] * step_size for f in stages)) - - for frac_sc, icoeffs in zip(fractions, (t[1] for t in stages), strict=True): - if icoeffs: - combined: T = common.euler( - sample, - math.sumprod(k_terms, icoeffs) / math.fsum(icoeffs), # type: ignore - schedule[step][1], - frac_sc[1], - self.schedule.sigma_transform, - ) - else: - combined = sample - - # Do not call model on timestep = 0 or sigma = 0 - k_terms.append(model(combined, *frac_sc) if not any(abs(v) < 1e-8 for v in frac_sc) else combined) - - sample_high: T = common.euler( - sample, - math.sumprod(k_terms, comp_high), # type: ignore - schedule[step][1], - schedule[step_next][1] if step_next < len(schedule) else 0, - self.schedule.sigma_transform, - ) - if step_next < len(schedule): + sample_high, sample_low = step_tableau( + tab, sample, model, step, schedule, self.schedule.sigma_transform, step_size + ) + sigma = schedule[step][1] sigma_next = schedule[step_next][1] if step_next < len(schedule) else 0 sigma_next2 = schedule[step_next + step_size][1] if step_next + step_size < len(schedule) else 0 @@ -387,14 +380,6 @@ def sample_model[T: Sample]( dt2 = sigma_v_next2 - sigma_v_next * (sigma_u_next2 / sigma_u_next) dt1x2 = dt2 / dt - sample_low: T = common.euler( - sample, - math.sumprod(k_terms, comp_low), # type: ignore - schedule[step][1], - schedule[step_next][1] if step_next < len(schedule) else 0, - self.schedule.sigma_transform, - ) - # Normalize against pure error error = self.evaluator(sample_low, sample_high) / max(self.evaluator(0, sample_high), epsilon) # Offset adjustment by dt2 / dt to account for non-linearity @@ -404,6 +389,11 @@ def sample_model[T: Sample]( adjustment: float = (self.threshold / max(error, epsilon)) ** self.adaption / dt1x2 step_size = max(round(min(step_size * adjustment, steps * self.maximum)), 1) + else: # Save the extra euler call since the 2nd weight isn't used + sample_high = step_tableau( + tab[:2], sample, model, step, schedule, self.schedule.sigma_transform, step_size + )[0] + sample = sample_high if callback: From 8ebf7be570c8bf5115f232bd2197b9a492ff2f55 Mon Sep 17 00:00:00 2001 From: Beinsezii Date: Sat, 11 Oct 2025 18:30:05 -0700 Subject: [PATCH 22/59] Fix a docstring --- skrample/sampling/functional.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/skrample/sampling/functional.py b/skrample/sampling/functional.py index c38a69e..8dbe356 100644 --- a/skrample/sampling/functional.py +++ b/skrample/sampling/functional.py @@ -313,7 +313,7 @@ class RKMoire(FunctionalAdaptive, FunctionalHigher): "Scale initial by a tableau's model evals." custom_tableau: tableaux.ExtendedTableau | None = None - "If set, will use this Butcher tableau instead of picking method based on `RKUltra.order`" + "If set, will use this Butcher tableau instead of picking method based on `RKMoire.order`" @staticmethod def min_order() -> int: From 835dd995ad51c52968729f275ce4923d54039ad0 Mon Sep 17 00:00:00 2001 From: Beinsezii Date: Sat, 11 Oct 2025 18:56:20 -0700 Subject: [PATCH 23/59] Make tableau validations standalone && more comprehensive --- skrample/sampling/tableaux.py | 14 ++++++++++++++ tests/miscellaneous.py | 14 ++------------ 2 files changed, 16 insertions(+), 12 deletions(-) diff --git a/skrample/sampling/tableaux.py b/skrample/sampling/tableaux.py index 983bf03..29cd490 100644 --- a/skrample/sampling/tableaux.py +++ b/skrample/sampling/tableaux.py @@ -19,6 +19,20 @@ ] +def validate_tableau(tab: Tableau | ExtendedTableau, tolerance: float = 1e-15) -> None | IndexError | ValueError: + for index, node in enumerate(tab[0]): + if index != (node_len := len(node[1])): + return IndexError(f"{index=}, {node_len=}, {node=}") + if tolerance < (node_err := abs(node[0] - math.fsum(node[1]))): + return ValueError(f"{tolerance=}, {node_err=}, {node=}") + + for weight in tab[1:]: + if (node_count := len(tab[0])) != (weight_len := len(weight)): + return IndexError(f"{node_count=}, {weight_len=}, {weight=}") + if tolerance < (weight_err := abs(1 - math.fsum(weight))): + return ValueError(f"{tolerance=}, {weight_err=}, {weight=}") + + class TableauProvider(Protocol): @abc.abstractmethod def tableau(self) -> Tableau: diff --git a/tests/miscellaneous.py b/tests/miscellaneous.py index facff5a..39f4b31 100644 --- a/tests/miscellaneous.py +++ b/tests/miscellaneous.py @@ -224,18 +224,8 @@ def test_bashforth() -> None: def test_tableau() -> None: for provider in [tableaux.RK2, tableaux.RK3, tableaux.RK4, tableaux.RK5, tableaux.RKE2, tableaux.RKE5]: for variant in provider: - tab: tableaux.Tableau | tableaux.ExtendedTableau = variant.tableau() - - for stage in tab[0]: - stage_err = abs(stage[0] - math.fsum(stage[1])) - assert stage_err < 1e-15, (variant, stage) - - final_err = abs(1 - math.fsum(tab[1])) - assert final_err < 1e-15, variant - - if len(tab) > 2: - low_err = abs(1 - math.fsum(tab[2])) - assert low_err < 1e-15, variant + if error := tableaux.validate_tableau(variant.tableau()): + raise error def test_sigmoid() -> None: From fbc69909c324d77e38fd40c64bd4fa17e59e0f75 Mon Sep 17 00:00:00 2001 From: Beinsezii Date: Sat, 11 Oct 2025 20:03:22 -0700 Subject: [PATCH 24/59] Add generic RK2 and RK3 tableau generators --- skrample/sampling/tableaux.py | 104 ++++++++++++++-------------------- tests/miscellaneous.py | 43 +++++++++++++- 2 files changed, 85 insertions(+), 62 deletions(-) diff --git a/skrample/sampling/tableaux.py b/skrample/sampling/tableaux.py index 29cd490..2ad635a 100644 --- a/skrample/sampling/tableaux.py +++ b/skrample/sampling/tableaux.py @@ -33,6 +33,41 @@ def validate_tableau(tab: Tableau | ExtendedTableau, tolerance: float = 1e-15) - return ValueError(f"{tolerance=}, {weight_err=}, {weight=}") +def rk2_tableau(alpha: float) -> Tableau: + "Create a generic 2nd order Tableau from a given alpha value." + alpha_w = 1 / (2 * alpha) + return ( + ( + (0.0, ()), + (alpha, (alpha,)), + ), + (1 - alpha_w, alpha_w), + ) + + +def rk3_tableau(alpha: float, beta: float) -> Tableau: + "Create a generic 3rd order Tableau from a given alpha and beta values." + + return ( + ( + (0.0, ()), + (alpha, (alpha,)), + ( + beta, + ( + beta / alpha * ((beta - 3 * alpha * (1 - alpha)) / (3 * alpha - 2)), + -beta / alpha * ((beta - alpha) / (3 * alpha - 2)), + ), + ), + ), + ( + 1 - (3 * alpha + 3 * beta - 2) / (6 * alpha * beta), + (3 * beta - 2) / (6 * alpha * (beta - alpha)), + (2 - 3 * alpha) / (6 * beta * (beta - alpha)), + ), + ) + + class TableauProvider(Protocol): @abc.abstractmethod def tableau(self) -> Tableau: @@ -61,29 +96,11 @@ class RK2(enum.StrEnum): def tableau(self) -> Tableau: match self: case self.Heun: - return ( - ( - (0, ()), - (1, (1,)), - ), - (1 / 2, 1 / 2), - ) + return rk2_tableau(1) case self.Mid: - return ( - ( - (0, ()), - (1 / 2, (1 / 2,)), - ), - (0, 1), - ) + return rk2_tableau(1 / 2) case self.Ralston: - return ( - ( - (0, ()), - (2 / 3, (2 / 3,)), - ), - (1 / 4, 3 / 4), - ) + return rk2_tableau(2 / 3) @enum.unique @@ -97,50 +114,15 @@ class RK3(enum.StrEnum): def tableau(self) -> Tableau: match self: case self.Kutta: - return ( - ( - (0, ()), - (1 / 2, (1 / 2,)), - (1, (-1, 2)), - ), - (1 / 6, 2 / 3, 1 / 6), - ) + return rk3_tableau(1 / 2, 1) case self.Heun: - return ( - ( - (0, ()), - (1 / 3, (1 / 3,)), - (2 / 3, (0, 2 / 3)), - ), - (1 / 4, 0, 3 / 4), - ) + return rk3_tableau(1 / 3, 2 / 3) case self.Ralston: - return ( - ( - (0, ()), - (1 / 2, (1 / 2,)), - (3 / 4, (0, 3 / 4)), - ), - (2 / 9, 1 / 3, 4 / 9), - ) + return rk3_tableau(1 / 2, 3 / 4) case self.Wray: - return ( - ( - (0, ()), - (8 / 15, (8 / 15,)), - (2 / 3, (1 / 4, 5 / 12)), - ), - (1 / 4, 0, 3 / 4), - ) + return rk3_tableau(8 / 15, 2 / 3) case self.SSPRK3: - return ( - ( - (0, ()), - (1, (1,)), - (1 / 2, (1 / 4, 1 / 4)), - ), - (1 / 6, 1 / 6, 2 / 3), - ) + return rk3_tableau(1, 1 / 2) @enum.unique diff --git a/tests/miscellaneous.py b/tests/miscellaneous.py index 39f4b31..f5e51a1 100644 --- a/tests/miscellaneous.py +++ b/tests/miscellaneous.py @@ -221,13 +221,54 @@ def test_bashforth() -> None: assert np.allclose(coeffs, np.array(bashforth(n + 1)), atol=1e-12, rtol=1e-12) -def test_tableau() -> None: +def test_tableau_providers() -> None: for provider in [tableaux.RK2, tableaux.RK3, tableaux.RK4, tableaux.RK5, tableaux.RKE2, tableaux.RKE5]: for variant in provider: if error := tableaux.validate_tableau(variant.tableau()): raise error +def flat_tableau(t: tuple[float | tuple[float | tuple[float | tuple[float, ...], ...], ...], ...]) -> tuple[float, ...]: + return tuple(z for y in (flat_tableau(x) if isinstance(x, tuple) else (x,) for x in t) for z in y) + + +def tableau_distance(a: tableaux.Tableau, b: tableaux.Tableau) -> float: + return abs(np.subtract(flat_tableau(a), flat_tableau(b))).max().item() + + +def test_rk2_tableau() -> None: + assert ( + tableau_distance( + ( # Ralston + ( + (0.0, ()), + (2 / 3, (2 / 3,)), + ), + (1 / 4, 3 / 4), + ), + tableaux.rk2_tableau(2 / 3), + ) + < 1e-20 + ) + + +def test_rk3_tableau() -> None: + assert ( + tableau_distance( + ( # Wray + ( + (0.0, ()), + (8 / 15, (8 / 15,)), + (2 / 3, (1 / 4, 5 / 12)), + ), + (1 / 4, 0.0, 3 / 4), + ), + tableaux.rk3_tableau(8 / 15, 2 / 3), + ) + < 1e-15 + ) + + def test_sigmoid() -> None: items = spowf(torch.linspace(-2, 2, 9, dtype=torch.float64), 2) a = torch.sigmoid(items) From 6d4337306e64906488f8b82d54f9a44881e240f2 Mon Sep 17 00:00:00 2001 From: Beinsezii Date: Sun, 12 Oct 2025 13:42:37 -0700 Subject: [PATCH 25/59] Make step_tableau() epsilon customizable --- skrample/sampling/functional.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/skrample/sampling/functional.py b/skrample/sampling/functional.py index 8dbe356..4d6094b 100644 --- a/skrample/sampling/functional.py +++ b/skrample/sampling/functional.py @@ -39,6 +39,7 @@ def step_tableau[T: Sample]( schedule: list[tuple[float, float]], transform: SigmaTransform, step_size: int = 1, + epsilon: float = 1e-8 ) -> tuple[T, ...]: nodes, weights = tableau[0], tableau[1:] k_terms: list[T] = [] @@ -57,7 +58,7 @@ def step_tableau[T: Sample]( combined = sample # Do not call model on timestep = 0 or sigma = 0 - k_terms.append(model(combined, *frac_sc) if not any(abs(v) < 1e-8 for v in frac_sc) else combined) + k_terms.append(model(combined, *frac_sc) if not any(abs(v) < epsilon for v in frac_sc) else combined) return tuple( common.euler( From 28d0cd5d428f9f70e13130f1a2612201997ac728 Mon Sep 17 00:00:00 2001 From: Beinsezii Date: Sun, 12 Oct 2025 16:13:17 -0700 Subject: [PATCH 26/59] Streamling higher order tableau logic --- scripts/plot_skrample.py | 2 +- skrample/sampling/functional.py | 53 ++++++++++++++++----------------- 2 files changed, 27 insertions(+), 28 deletions(-) diff --git a/scripts/plot_skrample.py b/scripts/plot_skrample.py index 9501090..61884d4 100755 --- a/scripts/plot_skrample.py +++ b/scripts/plot_skrample.py @@ -75,7 +75,7 @@ def colors(hue_steps: int) -> Generator[list[float]]: } for k, v in list(SAMPLERS.items()): if isinstance(v, structured.StructuredMultistep | functional.FunctionalHigher): - for o in range(v.min_order(), v.max_order() + 1): + for o in range(v.min_order(), min(v.max_order() + 1, 9)): if o != v.order: SAMPLERS[k + str(o)] = replace(v, order=o) diff --git a/skrample/sampling/functional.py b/skrample/sampling/functional.py index 4d6094b..a93b125 100644 --- a/skrample/sampling/functional.py +++ b/skrample/sampling/functional.py @@ -39,7 +39,7 @@ def step_tableau[T: Sample]( schedule: list[tuple[float, float]], transform: SigmaTransform, step_size: int = 1, - epsilon: float = 1e-8 + epsilon: float = 1e-8, ) -> tuple[T, ...]: nodes, weights = tableau[0], tableau[1:] k_terms: list[T] = [] @@ -189,7 +189,9 @@ def mse[T: Sample](a: T, b: T) -> float: return error.mean().item() evaluator: Evaluator = mse + "Function used to measure error of two samples" threshold: float = 1e-2 + "Target error threshold for a given evaluation" @dataclasses.dataclass(frozen=True) @@ -198,30 +200,27 @@ class RKUltra(FunctionalHigher, FunctionalSinglestep): order: int = 2 - providers: tuple[tableaux.TableauProvider | tableaux.ExtendedTableauProvider, ...] = ( - tableaux.RK2.Ralston, - tableaux.RK3.Ralston, - tableaux.RK4.Ralston, - tableaux.RK5.Nystrom, + providers: dict[int, tableaux.TableauProvider | tableaux.ExtendedTableauProvider] = dataclasses.field( + default_factory=lambda: { + 2: tableaux.RK2.Ralston, + 3: tableaux.RK3.Ralston, + 4: tableaux.RK4.Ralston, + 5: tableaux.RK5.Nystrom, + } ) """Providers for a given order, starting from 2. Order 1 is always the Euler method.""" - custom_tableau: tableaux.Tableau | tableaux.ExtendedTableau | None = None - "If set, will use this Butcher tableau instead of picking method based on `RKUltra.order`" - @staticmethod def max_order() -> int: - return 5 + return 99 def tableau(self, order: int | None = None) -> tableaux.Tableau: - if self.custom_tableau is not None: - return self.custom_tableau[:2] - elif order is None: + if order is None: order = self.order - if order >= 2 and (morder := len(self.providers)): - return self.providers[min(order - 2, morder - 1)].tableau()[:2] + if order >= 2 and (morder := max(o for o in self.providers.keys() if o <= order)): + return self.providers[morder].tableau()[:2] else: # Euler / RK1 return tableaux.RK1 @@ -297,40 +296,40 @@ def step[T: Sample]( class RKMoire(FunctionalAdaptive, FunctionalHigher): order: int = 2 - providers: tuple[tableaux.ExtendedTableauProvider, ...] = ( - tableaux.RKE2.Heun, - tableaux.RKE2.Heun, - tableaux.RKE2.Heun, - tableaux.RKE5.Fehlberg, + providers: dict[int, tableaux.ExtendedTableauProvider] = dataclasses.field( + default_factory=lambda: { + 2: tableaux.RKE2.Heun, + 5: tableaux.RKE5.Fehlberg, + } ) + """Providers for a given order, starting from 2. + Falls back to RKE2.Heun""" threshold: float = 1e-3 initial: float = 1 / 50 + "Percent of schedule to take as an initial step." maximum: float = 1 / 4 + "Percent of schedule to take as a maximum step." adaption: float = 0.3 + "How fast to adjust step size in relation to error" rescale_init: bool = True "Scale initial by a tableau's model evals." - custom_tableau: tableaux.ExtendedTableau | None = None - "If set, will use this Butcher tableau instead of picking method based on `RKMoire.order`" - @staticmethod def min_order() -> int: return 2 @staticmethod def max_order() -> int: - return 5 + return 99 def adjust_steps(self, steps: int) -> int: return steps def tableau(self, order: int | None = None) -> tableaux.ExtendedTableau: - if self.custom_tableau is not None: - return self.custom_tableau - elif order is None: + if order is None: order = self.order if order >= 2 and (morder := len(self.providers)): From 31306452663eb86c964c65a52914238b79842a65 Mon Sep 17 00:00:00 2001 From: Beinsezii Date: Sun, 12 Oct 2025 16:16:10 -0700 Subject: [PATCH 27/59] Fix RKMoire tableau provider fetch --- skrample/sampling/functional.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/skrample/sampling/functional.py b/skrample/sampling/functional.py index a93b125..ec1feeb 100644 --- a/skrample/sampling/functional.py +++ b/skrample/sampling/functional.py @@ -332,8 +332,8 @@ def tableau(self, order: int | None = None) -> tableaux.ExtendedTableau: if order is None: order = self.order - if order >= 2 and (morder := len(self.providers)): - return self.providers[min(order - 2, morder - 1)].tableau() + if order >= 2 and (morder := max(o for o in self.providers.keys() if o <= order)): + return self.providers[morder].tableau() else: return tableaux.RKE2.Heun.tableau() From eace1ad256c00798e6304a944617dffc6b4ec1fe Mon Sep 17 00:00:00 2001 From: Beinsezii Date: Sun, 12 Oct 2025 16:58:56 -0700 Subject: [PATCH 28/59] Simplify tableaux providers --- skrample/sampling/functional.py | 4 +- skrample/sampling/tableaux.py | 251 +++++++++++++++----------------- 2 files changed, 120 insertions(+), 135 deletions(-) diff --git a/skrample/sampling/functional.py b/skrample/sampling/functional.py index ec1feeb..dbaa5d9 100644 --- a/skrample/sampling/functional.py +++ b/skrample/sampling/functional.py @@ -200,7 +200,7 @@ class RKUltra(FunctionalHigher, FunctionalSinglestep): order: int = 2 - providers: dict[int, tableaux.TableauProvider | tableaux.ExtendedTableauProvider] = dataclasses.field( + providers: dict[int, tableaux.TableauProvider[tableaux.Tableau | tableaux.ExtendedTableau]] = dataclasses.field( default_factory=lambda: { 2: tableaux.RK2.Ralston, 3: tableaux.RK3.Ralston, @@ -296,7 +296,7 @@ def step[T: Sample]( class RKMoire(FunctionalAdaptive, FunctionalHigher): order: int = 2 - providers: dict[int, tableaux.ExtendedTableauProvider] = dataclasses.field( + providers: dict[int, tableaux.TableauProvider[tableaux.ExtendedTableau]] = dataclasses.field( default_factory=lambda: { 2: tableaux.RKE2.Heun, 5: tableaux.RKE5.Fehlberg, diff --git a/skrample/sampling/tableaux.py b/skrample/sampling/tableaux.py index 2ad635a..2643ffa 100644 --- a/skrample/sampling/tableaux.py +++ b/skrample/sampling/tableaux.py @@ -1,4 +1,5 @@ import abc +import dataclasses import enum import math from typing import Protocol @@ -68,15 +69,9 @@ def rk3_tableau(alpha: float, beta: float) -> Tableau: ) -class TableauProvider(Protocol): +class TableauProvider[T: Tableau | ExtendedTableau](Protocol): @abc.abstractmethod - def tableau(self) -> Tableau: - raise NotImplementedError - - -class ExtendedTableauProvider(Protocol): - @abc.abstractmethod - def tableau(self) -> ExtendedTableau: + def tableau(self) -> T: raise NotImplementedError @@ -87,159 +82,149 @@ def tableau(self) -> ExtendedTableau: "Euler method" +@dataclasses.dataclass(frozen=True) +class CustomTableau[T: Tableau | ExtendedTableau](TableauProvider[T]): + custom: T + + def tableau(self) -> T: + return self.custom + + +@dataclasses.dataclass(frozen=True) +class RK2Custom(TableauProvider): + alpha: float = 1.0 + + def tableau(self) -> Tableau: + return rk2_tableau(self.alpha) + + +@dataclasses.dataclass(frozen=True) +class RK3Custom(TableauProvider): + alpha: float = 1 / 2 + beta: float = 1.0 + + def tableau(self) -> Tableau: + return rk3_tableau(self.alpha, self.beta) + + @enum.unique -class RK2(enum.StrEnum): - Heun = enum.auto() - Mid = enum.auto() - Ralston = enum.auto() +class RK2(enum.Enum): + Heun = rk2_tableau(1) + Mid = rk2_tableau(1 / 2) + Ralston = rk2_tableau(2 / 3) def tableau(self) -> Tableau: - match self: - case self.Heun: - return rk2_tableau(1) - case self.Mid: - return rk2_tableau(1 / 2) - case self.Ralston: - return rk2_tableau(2 / 3) + return self.value @enum.unique -class RK3(enum.StrEnum): - Kutta = enum.auto() - Heun = enum.auto() - Ralston = enum.auto() - Wray = enum.auto() - SSPRK3 = enum.auto() +class RK3(enum.Enum): + Kutta = rk3_tableau(1 / 2, 1) + Heun = rk3_tableau(1 / 3, 2 / 3) + Ralston = rk3_tableau(1 / 2, 3 / 4) + Wray = rk3_tableau(8 / 15, 2 / 3) + SSPRK3 = rk3_tableau(1, 1 / 2) def tableau(self) -> Tableau: - match self: - case self.Kutta: - return rk3_tableau(1 / 2, 1) - case self.Heun: - return rk3_tableau(1 / 3, 2 / 3) - case self.Ralston: - return rk3_tableau(1 / 2, 3 / 4) - case self.Wray: - return rk3_tableau(8 / 15, 2 / 3) - case self.SSPRK3: - return rk3_tableau(1, 1 / 2) + return self.value @enum.unique -class RK4(enum.StrEnum): - Classic = enum.auto() - Eighth = enum.auto() - Ralston = enum.auto() +class RK4(enum.Enum): + Classic = ( + ( + (0, ()), + (1 / 2, (1 / 2,)), + (1 / 2, (0, 1 / 2)), + (1, (0, 0, 1)), + ), + (1 / 6, 1 / 3, 1 / 3, 1 / 6), + ) + Eighth = ( + ( + (0, ()), + (1 / 3, (1 / 3,)), + (2 / 3, (-1 / 3, 1)), + (1, (1, -1, 1)), + ), + (1 / 8, 3 / 8, 3 / 8, 1 / 8), + ) + Ralston = ( + ( + (0, ()), + (2 / 5, (2 / 5,)), + ((14 - 3 * math.sqrt(5)) / 16, ((-2889 + 1428 * math.sqrt(5)) / 1024, (3785 - 1620 * math.sqrt(5)) / 1024)), + ( + 1, + ( + (-3365 + 2094 * math.sqrt(5)) / 6040, + (-975 - 3046 * math.sqrt(5)) / 2552, + (467040 + 203968 * math.sqrt(5)) / 240845, + ), + ), + ), + ( + (263 + 24 * math.sqrt(5)) / 1812, + (125 - 1000 * math.sqrt(5)) / 3828, + (3426304 + 1661952 * math.sqrt(5)) / 5924787, + (30 - 4 * math.sqrt(5)) / 123, + ), + ) def tableau(self) -> Tableau: - match self: - case self.Classic: - return ( - ( - (0, ()), - (1 / 2, (1 / 2,)), - (1 / 2, (0, 1 / 2)), - (1, (0, 0, 1)), - ), - (1 / 6, 1 / 3, 1 / 3, 1 / 6), - ) - case self.Eighth: - return ( - ( - (0, ()), - (1 / 3, (1 / 3,)), - (2 / 3, (-1 / 3, 1)), - (1, (1, -1, 1)), - ), - (1 / 8, 3 / 8, 3 / 8, 1 / 8), - ) - case self.Ralston: - sq5: float = math.sqrt(5) - return ( - ( - (0, ()), - (2 / 5, (2 / 5,)), - ( - (14 - 3 * sq5) / 16, - ( - (-2889 + 1428 * sq5) / 1024, - (3785 - 1620 * sq5) / 1024, - ), - ), - ( - 1, - ( - (-3365 + 2094 * sq5) / 6040, - (-975 - 3046 * sq5) / 2552, - (467040 + 203968 * sq5) / 240845, - ), - ), - ), - ( - (263 + 24 * sq5) / 1812, - (125 - 1000 * sq5) / 3828, - (3426304 + 1661952 * sq5) / 5924787, - (30 - 4 * sq5) / 123, - ), - ) + return self.value @enum.unique -class RK5(enum.StrEnum): - Nystrom = enum.auto() +class RK5(enum.Enum): + Nystrom = ( + ( + (0, ()), + (1 / 3, (1 / 3,)), + (2 / 5, (4 / 25, 6 / 25)), + (1, (1 / 4, -3, 15 / 4)), + (2 / 3, (2 / 27, 10 / 9, -50 / 81, 8 / 81)), + (4 / 5, (2 / 25, 12 / 25, 2 / 15, 8 / 75, 0)), + ), + (23 / 192, 0, 125 / 192, 0, -27 / 64, 125 / 192), + ) def tableau(self) -> Tableau: - match self: - case self.Nystrom: - return ( - ( - (0, ()), - (1 / 3, (1 / 3,)), - (2 / 5, (4 / 25, 6 / 25)), - (1, (1 / 4, -3, 15 / 4)), - (2 / 3, (2 / 27, 10 / 9, -50 / 81, 8 / 81)), - (4 / 5, (2 / 25, 12 / 25, 2 / 15, 8 / 75, 0)), - ), - (23 / 192, 0, 125 / 192, 0, -27 / 64, 125 / 192), - ) + return self.value @enum.unique -class RKE2(enum.StrEnum): - Heun = enum.auto() +class RKE2(enum.Enum): + Heun = ( + ( + (0, ()), + (1, (1,)), + ), + (1 / 2, 1 / 2), + (1, 0), + ) # Fehlberg = enum.auto() def tableau(self) -> ExtendedTableau: - match self: - case self.Heun: - return ( - ( - (0, ()), - (1, (1,)), - ), - (1 / 2, 1 / 2), - (1, 0), - ) + return self.value @enum.unique -class RKE5(enum.StrEnum): - Fehlberg = enum.auto() +class RKE5(enum.Enum): + Fehlberg = ( + ( + (0, ()), + (1 / 4, (1 / 4,)), + (3 / 8, (3 / 32, 9 / 32)), + (12 / 13, (1932 / 2197, -7200 / 2197, 7296 / 2197)), + (1, (439 / 216, -8, 3680 / 513, -845 / 4104)), + (1 / 2, (-8 / 27, 2, -3544 / 2565, 1859 / 4104, -11 / 40)), + ), + (16 / 135, 0, 6656 / 12825, 28561 / 56430, -9 / 50, 2 / 55), + (25 / 216, 0, 1408 / 2565, 2197 / 4104, -1 / 5, 0), + ) # CashKarp = enum.auto() # DormandPrince = enum.auto() def tableau(self) -> ExtendedTableau: - match self: - case self.Fehlberg: - return ( - ( - (0, ()), - (1 / 4, (1 / 4,)), - (3 / 8, (3 / 32, 9 / 32)), - (12 / 13, (1932 / 2197, -7200 / 2197, 7296 / 2197)), - (1, (439 / 216, -8, 3680 / 513, -845 / 4104)), - (1 / 2, (-8 / 27, 2, -3544 / 2565, 1859 / 4104, -11 / 40)), - ), - (16 / 135, 0, 6656 / 12825, 28561 / 56430, -9 / 50, 2 / 55), - (25 / 216, 0, 1408 / 2565, 2197 / 4104, -1 / 5, 0), - ) + return self.value From 00bd28bf3ce8e40648d9f464585f91061dd01927 Mon Sep 17 00:00:00 2001 From: Beinsezii Date: Sun, 12 Oct 2025 17:00:44 -0700 Subject: [PATCH 29/59] clear whitespace --- skrample/sampling/tableaux.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/skrample/sampling/tableaux.py b/skrample/sampling/tableaux.py index 2643ffa..9ac0211 100644 --- a/skrample/sampling/tableaux.py +++ b/skrample/sampling/tableaux.py @@ -7,12 +7,10 @@ type TabNode = tuple[float, tuple[float, ...]] type TabWeight = tuple[float, ...] - type Tableau = tuple[ tuple[TabNode, ...], TabWeight, ] - type ExtendedTableau = tuple[ tuple[TabNode, ...], TabWeight, @@ -36,19 +34,17 @@ def validate_tableau(tab: Tableau | ExtendedTableau, tolerance: float = 1e-15) - def rk2_tableau(alpha: float) -> Tableau: "Create a generic 2nd order Tableau from a given alpha value." - alpha_w = 1 / (2 * alpha) return ( ( (0.0, ()), (alpha, (alpha,)), ), - (1 - alpha_w, alpha_w), + (1 - 1 / (2 * alpha), 1 / (2 * alpha)), ) def rk3_tableau(alpha: float, beta: float) -> Tableau: "Create a generic 3rd order Tableau from a given alpha and beta values." - return ( ( (0.0, ()), From 9dca150b6809940cbc7198e09699385098db43a6 Mon Sep 17 00:00:00 2001 From: Beinsezii Date: Sun, 12 Oct 2025 17:40:20 -0700 Subject: [PATCH 30/59] Fix `providers` fields not being accessible at a class level --- skrample/common.py | 3 +++ skrample/sampling/functional.py | 23 +++++++++++++---------- 2 files changed, 16 insertions(+), 10 deletions(-) diff --git a/skrample/common.py b/skrample/common.py index 488c81e..7668d1e 100644 --- a/skrample/common.py +++ b/skrample/common.py @@ -3,6 +3,7 @@ from collections.abc import Callable from functools import lru_cache from itertools import repeat +from types import MappingProxyType from typing import TYPE_CHECKING import numpy as np @@ -23,6 +24,8 @@ type Predictor[S: Sample] = Callable[[S, S, float, SigmaTransform], S] "sample, output, sigma, sigma_transform" +type DictOrProxy[T, U] = MappingProxyType[T, U] | dict[T, U] # Mapping does not implement __or__ + @enum.unique class MergeStrategy(enum.StrEnum): # str for easy UI options diff --git a/skrample/sampling/functional.py b/skrample/sampling/functional.py index dbaa5d9..245f33f 100644 --- a/skrample/sampling/functional.py +++ b/skrample/sampling/functional.py @@ -2,12 +2,13 @@ import math from abc import ABC, abstractmethod from collections.abc import Callable +from types import MappingProxyType from typing import Any import numpy as np from skrample import common, scheduling -from skrample.common import Sample, SigmaTransform +from skrample.common import DictOrProxy, Sample, SigmaTransform from . import tableaux @@ -200,13 +201,15 @@ class RKUltra(FunctionalHigher, FunctionalSinglestep): order: int = 2 - providers: dict[int, tableaux.TableauProvider[tableaux.Tableau | tableaux.ExtendedTableau]] = dataclasses.field( - default_factory=lambda: { - 2: tableaux.RK2.Ralston, - 3: tableaux.RK3.Ralston, - 4: tableaux.RK4.Ralston, - 5: tableaux.RK5.Nystrom, - } + providers: DictOrProxy[int, tableaux.TableauProvider[tableaux.Tableau | tableaux.ExtendedTableau]] = ( + MappingProxyType( + { + 2: tableaux.RK2.Ralston, + 3: tableaux.RK3.Ralston, + 4: tableaux.RK4.Ralston, + 5: tableaux.RK5.Nystrom, + } + ) ) """Providers for a given order, starting from 2. Order 1 is always the Euler method.""" @@ -296,8 +299,8 @@ def step[T: Sample]( class RKMoire(FunctionalAdaptive, FunctionalHigher): order: int = 2 - providers: dict[int, tableaux.TableauProvider[tableaux.ExtendedTableau]] = dataclasses.field( - default_factory=lambda: { + providers: DictOrProxy[int, tableaux.TableauProvider[tableaux.ExtendedTableau]] = MappingProxyType( + { 2: tableaux.RKE2.Heun, 5: tableaux.RKE5.Fehlberg, } From b4dafdb5fcabef2d801af55f0aed4e99e6e940ac Mon Sep 17 00:00:00 2001 From: Beinsezii Date: Sun, 12 Oct 2025 18:09:53 -0700 Subject: [PATCH 31/59] Add RKMoire.rescale_max --- skrample/sampling/functional.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/skrample/sampling/functional.py b/skrample/sampling/functional.py index 245f33f..a4a24db 100644 --- a/skrample/sampling/functional.py +++ b/skrample/sampling/functional.py @@ -319,6 +319,8 @@ class RKMoire(FunctionalAdaptive, FunctionalHigher): rescale_init: bool = True "Scale initial by a tableau's model evals." + rescale_max: bool = False + "Scale maximum by a tableau's model evals." @staticmethod def min_order() -> int: @@ -352,8 +354,11 @@ def sample_model[T: Sample]( tab = self.tableau() initial = self.initial + maximum = self.maximum if self.rescale_init: initial *= len(tab[0]) / 2 # Heun is base so / 2 + if self.rescale_max: + maximum *= len(tab[0]) / 2 # Heun is base so / 2 step_size: int = max(round(steps * initial), 1) epsilon: float = 1e-16 # lgtm @@ -390,7 +395,7 @@ def sample_model[T: Sample]( # we should only set a 20% larger step ie 1.5 / 1.25 # Really this could be iterated to contrast dt2/dt and thresh/error until they're 100% matched but eh adjustment: float = (self.threshold / max(error, epsilon)) ** self.adaption / dt1x2 - step_size = max(round(min(step_size * adjustment, steps * self.maximum)), 1) + step_size = max(round(min(step_size * adjustment, steps * maximum)), 1) else: # Save the extra euler call since the 2nd weight isn't used sample_high = step_tableau( From 530e58227e1d7178446cb7c3904e10fc5cb84902 Mon Sep 17 00:00:00 2001 From: Beinsezii Date: Sun, 12 Oct 2025 19:10:44 -0700 Subject: [PATCH 32/59] Deduplicate lots of scheduling / euler code --- skrample/common.py | 36 ++++++++++++-- skrample/sampling/functional.py | 85 +++++++++++++-------------------- skrample/sampling/interface.py | 8 ++-- 3 files changed, 71 insertions(+), 58 deletions(-) diff --git a/skrample/common.py b/skrample/common.py index 7668d1e..fa24811 100644 --- a/skrample/common.py +++ b/skrample/common.py @@ -1,6 +1,6 @@ import enum import math -from collections.abc import Callable +from collections.abc import Callable, Sequence from functools import lru_cache from itertools import repeat from types import MappingProxyType @@ -25,6 +25,13 @@ "sample, output, sigma, sigma_transform" type DictOrProxy[T, U] = MappingProxyType[T, U] | dict[T, U] # Mapping does not implement __or__ +"Simple union type for a possibly immutable dictionary" + +type FloatSchedule = Sequence[tuple[float, float]] +"Sequence of timestep, sigma" + +type RNG[T: Sample] = Callable[[], T] +"Distribution should match model, typically normal" @enum.unique @@ -93,13 +100,36 @@ def predict_flow[T: Sample](sample: T, output: T, sigma: float, sigma_transform: return sample - sigma * output # type: ignore -def euler[T: Sample](sample: T, prediction: T, sigma: float, sigma_next: float, sigma_transform: SigmaTransform) -> T: +def scaled_delta(sigma: float, sigma_next: float, sigma_transform: SigmaTransform) -> tuple[float, float]: + "Returns delta (h) and scale factor to perform the euler method." sigma_u, sigma_v = sigma_transform(sigma) sigma_u_next, sigma_v_next = sigma_transform(sigma_next) scale = sigma_u_next / sigma_u delta = sigma_v_next - sigma_v * scale # aka `h` or `dt` - return sample * scale + prediction * delta # type: ignore + return delta, scale + + +def euler[T: Sample](sample: T, prediction: T, sigma: float, sigma_next: float, sigma_transform: SigmaTransform) -> T: + "Perform the euler method using scaled_delta" + # Returns delta, scale so prediction is first + return math.sumprod((prediction, sample), scaled_delta(sigma, sigma_next, sigma_transform)) # type: ignore + + +def scaled_delta_step( + step: int, schedule: FloatSchedule, sigma_transform: SigmaTransform, step_size: int = 1 +) -> tuple[float, float]: + """Returns delta (h) and scale factor to perform the euler method. + If step + step_size > len(schedule), assumes the next timestep and sigma are zero""" + step_next = step + step_size + return scaled_delta(schedule[step][1], schedule[step_next][1] if step_next < len(schedule) else 0, sigma_transform) + + +def euler_step[T: Sample]( + sample: T, prediction: T, step: int, schedule: FloatSchedule, sigma_transform: SigmaTransform, step_size: int = 1 +) -> T: + "Perform the euler method using scaled_delta_step" + return math.sumprod((prediction, sample), scaled_delta_step(step, schedule, sigma_transform, step_size)) # type: ignore def merge_noise[T: Sample](sample: T, noise: T, sigma: float, sigma_transform: SigmaTransform) -> T: diff --git a/skrample/sampling/functional.py b/skrample/sampling/functional.py index a4a24db..e006365 100644 --- a/skrample/sampling/functional.py +++ b/skrample/sampling/functional.py @@ -8,13 +8,18 @@ import numpy as np from skrample import common, scheduling -from skrample.common import DictOrProxy, Sample, SigmaTransform +from skrample.common import RNG, DictOrProxy, FloatSchedule, Sample, SigmaTransform from . import tableaux +type SampleCallback[T: Sample] = Callable[[T, int, float, float], Any] +"Return is ignored" +type SampleableModel[T: Sample] = Callable[[T, float, float], T] +"sample, timestep, sigma" + def fractional_step( - schedule: list[tuple[float, float]], + schedule: FloatSchedule, current: int, idx: tuple[float, ...], ) -> tuple[tuple[float, float], ...]: @@ -35,9 +40,9 @@ def fractional_step( def step_tableau[T: Sample]( tableau: tableaux.Tableau | tableaux.ExtendedTableau, sample: T, - model: "FunctionalSampler.SampleableModel[T]", + model: SampleableModel[T], step: int, - schedule: list[tuple[float, float]], + schedule: FloatSchedule, transform: SigmaTransform, step_size: int = 1, epsilon: float = 1e-8, @@ -62,12 +67,13 @@ def step_tableau[T: Sample]( k_terms.append(model(combined, *frac_sc) if not any(abs(v) < epsilon for v in frac_sc) else combined) return tuple( - common.euler( + common.euler_step( sample, math.sumprod(k_terms, w), # type: ignore - schedule[step][1], - schedule[step + step_size][1] if step + step_size < len(schedule) else 0, + step, + schedule, transform, + step_size, ) for w in weights ) @@ -75,13 +81,6 @@ def step_tableau[T: Sample]( @dataclasses.dataclass(frozen=True) class FunctionalSampler(ABC): - type SampleCallback[T: Sample] = Callable[[T, int, float, float], Any] - "Return is ignored" - type SampleableModel[T: Sample] = Callable[[T, float, float], T] - "sample, timestep, sigma" - type RNG[T: Sample] = Callable[[], T] - "Distribution should match model, typically normal" - schedule: scheduling.SkrampleSchedule def merge_noise[T: Sample](self, sample: T, noise: T, sigma: float, sigma_transform: SigmaTransform) -> T: @@ -151,22 +150,22 @@ class FunctionalSinglestep(FunctionalSampler): def step[T: Sample]( self, sample: T, - model: FunctionalSampler.SampleableModel[T], + model: SampleableModel[T], step: int, - schedule: list[tuple[float, float]], - rng: FunctionalSampler.RNG[T] | None = None, + schedule: FloatSchedule, + rng: RNG[T] | None = None, ) -> T: ... def sample_model[T: Sample]( self, sample: T, - model: FunctionalSampler.SampleableModel[T], + model: SampleableModel[T], steps: int, include: slice = slice(None), - rng: FunctionalSampler.RNG[T] | None = None, - callback: FunctionalSampler.SampleCallback | None = None, + rng: RNG[T] | None = None, + callback: SampleCallback | None = None, ) -> T: - schedule: list[tuple[float, float]] = self.schedule.schedule(steps).tolist() + schedule: FloatSchedule = self.schedule.schedule(steps).tolist() for n in list(range(steps))[include]: sample = self.step(sample, model, n, schedule, rng) @@ -239,10 +238,10 @@ def adjust_steps(self, steps: int) -> int: def step[T: Sample]( self, sample: T, - model: FunctionalSampler.SampleableModel[T], + model: SampleableModel[T], step: int, - schedule: list[tuple[float, float]], - rng: FunctionalSampler.RNG[T] | None = None, + schedule: FloatSchedule, + rng: RNG[T] | None = None, ) -> T: return step_tableau(self.tableau(), sample, model, step, schedule, self.schedule.sigma_transform)[0] @@ -267,19 +266,12 @@ def adjust_steps(self, steps: int) -> int: def step[T: Sample]( self, sample: T, - model: FunctionalSampler.SampleableModel[T], + model: SampleableModel[T], step: int, - schedule: list[tuple[float, float]], - rng: FunctionalSampler.RNG[T] | None = None, + schedule: FloatSchedule, + rng: RNG[T] | None = None, ) -> T: - sigma = schedule[step][1] - sigma_next = schedule[step + 1][1] if step + 1 < len(schedule) else 0 - - sigma_u, sigma_v = self.schedule.sigma_transform(sigma) - sigma_u_next, sigma_v_next = self.schedule.sigma_transform(sigma_next) - - scale = sigma_u_next / sigma_u - dt = sigma_v_next - sigma_v * scale + dt, scale = common.scaled_delta_step(step, schedule, self.schedule.sigma_transform) k1 = model(sample, *schedule[step]) result: T = sample * scale + k1 * dt # type: ignore @@ -345,11 +337,11 @@ def tableau(self, order: int | None = None) -> tableaux.ExtendedTableau: def sample_model[T: Sample]( self, sample: T, - model: FunctionalSampler.SampleableModel[T], + model: SampleableModel[T], steps: int, include: slice = slice(None), - rng: FunctionalSampler.RNG[T] | None = None, - callback: FunctionalSampler.SampleCallback | None = None, + rng: RNG[T] | None = None, + callback: SampleCallback | None = None, ) -> T: tab = self.tableau() @@ -363,7 +355,7 @@ def sample_model[T: Sample]( step_size: int = max(round(steps * initial), 1) epsilon: float = 1e-16 # lgtm - schedule: list[tuple[float, float]] = self.schedule.schedule(steps).tolist() + schedule: FloatSchedule = self.schedule.schedule(steps).tolist() indices: list[int] = list(range(steps))[include] step: int = indices[0] @@ -376,17 +368,8 @@ def sample_model[T: Sample]( tab, sample, model, step, schedule, self.schedule.sigma_transform, step_size ) - sigma = schedule[step][1] - sigma_next = schedule[step_next][1] if step_next < len(schedule) else 0 - sigma_next2 = schedule[step_next + step_size][1] if step_next + step_size < len(schedule) else 0 - - sigma_u, sigma_v = self.schedule.sigma_transform(sigma) - sigma_u_next, sigma_v_next = self.schedule.sigma_transform(sigma_next) - sigma_u_next2, sigma_v_next2 = self.schedule.sigma_transform(sigma_next2) - - dt = sigma_v_next - sigma_v * (sigma_u_next / sigma_u) - dt2 = sigma_v_next2 - sigma_v_next * (sigma_u_next2 / sigma_u_next) - dt1x2 = dt2 / dt + delta = common.scaled_delta_step(step, schedule, self.schedule.sigma_transform, step_size)[0] + delta_next = common.scaled_delta_step(step_next, schedule, self.schedule.sigma_transform, step_size)[0] # Normalize against pure error error = self.evaluator(sample_low, sample_high) / max(self.evaluator(0, sample_high), epsilon) @@ -394,7 +377,7 @@ def sample_model[T: Sample]( # Basically if we want a 50% larger step but the next dt will already be 25% larger, # we should only set a 20% larger step ie 1.5 / 1.25 # Really this could be iterated to contrast dt2/dt and thresh/error until they're 100% matched but eh - adjustment: float = (self.threshold / max(error, epsilon)) ** self.adaption / dt1x2 + adjustment: float = (self.threshold / max(error, epsilon)) ** self.adaption / (delta_next / delta) step_size = max(round(min(step_size * adjustment, steps * maximum)), 1) else: # Save the extra euler call since the 2nd weight isn't used diff --git a/skrample/sampling/interface.py b/skrample/sampling/interface.py index 8291bc4..7f632a9 100644 --- a/skrample/sampling/interface.py +++ b/skrample/sampling/interface.py @@ -1,6 +1,6 @@ import dataclasses -from skrample.common import Sample, SigmaTransform +from skrample.common import RNG, Sample, SigmaTransform from skrample.sampling import functional, structured @@ -14,11 +14,11 @@ def merge_noise[T: Sample](self, sample: T, noise: T, sigma: float, sigma_transf def sample_model[T: Sample]( self, sample: T, - model: functional.FunctionalSampler.SampleableModel[T], + model: functional.SampleableModel[T], steps: int, include: slice = slice(None), - rng: functional.FunctionalSampler.RNG[T] | None = None, - callback: functional.FunctionalSampler.SampleCallback | None = None, + rng: RNG[T] | None = None, + callback: functional.SampleCallback | None = None, ) -> T: previous: list[structured.SKSamples[T]] = [] schedule_np = self.schedule.schedule(steps) From 2d859925d7c01f70f42a6013bdbf667dac843e1c Mon Sep 17 00:00:00 2001 From: Beinsezii Date: Thu, 16 Oct 2025 20:16:15 -0700 Subject: [PATCH 33/59] Unify multiple types across structured, functional, scheduling --- examples/structured.py | 5 +- scripts/plot_skrample.py | 2 +- scripts/spc.py | 12 ++--- skrample/common.py | 6 +++ skrample/diffusers.py | 20 ++++++-- skrample/sampling/functional.py | 6 +-- skrample/sampling/interface.py | 9 ++-- skrample/sampling/structured.py | 84 +++++++++++++-------------------- skrample/scheduling.py | 46 +++++++++++------- tests/diffusers_samplers.py | 2 +- tests/miscellaneous.py | 20 ++++---- 11 files changed, 111 insertions(+), 101 deletions(-) diff --git a/examples/structured.py b/examples/structured.py index 7b84538..09f2d9f 100755 --- a/examples/structured.py +++ b/examples/structured.py @@ -44,8 +44,9 @@ sample: torch.Tensor = torch.randn([1, 4, 80, 80], generator=seed).to(dtype=dtype, device=device) previous: list[structured.SKSamples[torch.Tensor]] = [] + float_schedule = schedule.schedule(steps) - for n, (timestep, sigma) in enumerate(tqdm(schedule.schedule(steps))): + for n, (timestep, sigma) in enumerate(tqdm(float_schedule)): conditioned, unconditioned = model( sample.expand([sample.shape[0] * 2, *sample.shape[1:]]), timestep, @@ -59,7 +60,7 @@ sample=sample, prediction=prediction, step=n, - sigma_schedule=schedule.sigmas(steps), + schedule=float_schedule, sigma_transform=schedule.sigma_transform, noise=torch.randn(sample.shape, generator=seed).to(dtype=sample.dtype, device=sample.device), previous=tuple(previous), diff --git a/scripts/plot_skrample.py b/scripts/plot_skrample.py index 61884d4..d6364c2 100755 --- a/scripts/plot_skrample.py +++ b/scripts/plot_skrample.py @@ -227,7 +227,7 @@ def callback(x: float, n: int, t: float, s: float) -> None: label = " ".join([s.capitalize() for s in label.split("_")]) - data = np.concatenate([composed.schedule(args.steps), [[0, 0]]], dtype=np.float64) + data = np.concatenate([composed.schedule_np(args.steps), [[0, 0]]], dtype=np.float64) timesteps = data[:, 0] / composed.base_timesteps sigmas = data[:, 1] / data[:, 1].max() diff --git a/scripts/spc.py b/scripts/spc.py index 948d390..991a9fe 100755 --- a/scripts/spc.py +++ b/scripts/spc.py @@ -11,7 +11,7 @@ import skrample.sampling.structured as sampling import skrample.scheduling as scheduling -from skrample.common import SigmaTransform, sigma_complement, sigma_polar +from skrample.common import FloatSchedule, SigmaTransform, sigma_complement, sigma_polar parser = ArgumentParser() parser.add_argument("out", type=FileType("w")) @@ -33,17 +33,17 @@ class Row: def sample_model( - sampler: sampling.StructuredSampler, schedule: NDArray[np.float64], curve: int, transform: SigmaTransform + sampler: sampling.StructuredSampler, schedule: FloatSchedule, curve: int, transform: SigmaTransform ) -> NDArray: previous: list[sampling.SKSamples] = [] sample = 1.0 sampled_values = [sample] - for step, sigma in enumerate(schedule): + for step, (timestep, sigma) in enumerate(schedule): result = sampler.sample( sample=sample, prediction=math.sin(sigma * curve), step=step, - sigma_schedule=schedule, + schedule=schedule, sigma_transform=transform, previous=tuple(previous), noise=random(), @@ -65,13 +65,13 @@ def sample_model( table: list[Row] = [] for t in [sigma_polar, sigma_complement]: for k in args.curves: - reference = sample_model(sampling.Euler(), schedule.sigmas(schedule.base_timesteps), k, t) + reference = sample_model(sampling.Euler(), schedule.schedule(schedule.base_timesteps), k, t) for h in args.steps: reference_aliased = np.interp(np.linspace(0, 1, h + 1), np.linspace(0, 1, len(reference)), reference) for pe in samplers: for ce in samplers: spc = sampling.SPC(predictor=pe, corrector=ce) - sampled = sample_model(spc, schedule.sigmas(h), k, t) + sampled = sample_model(spc, schedule.schedule(h), k, t) table.append( Row( type(pe).__name__ + (str(pe.order) if isinstance(pe, sampling.StructuredMultistep) else ""), diff --git a/skrample/common.py b/skrample/common.py index fa24811..adfcbc1 100644 --- a/skrample/common.py +++ b/skrample/common.py @@ -100,6 +100,12 @@ def predict_flow[T: Sample](sample: T, output: T, sigma: float, sigma_transform: return sample - sigma * output # type: ignore +def get_sigma_uv(step: int, schedule: FloatSchedule, sigma_transform: SigmaTransform) -> tuple[float, float]: + """Gets sigma u/v with bounds check. + If step >= len(schedule), the sigma is assumed to be zero.""" + return sigma_transform(schedule[step][1] if step < len(schedule) else 0) + + def scaled_delta(sigma: float, sigma_next: float, sigma_transform: SigmaTransform) -> tuple[float, float]: "Returns delta (h) and scale factor to perform the euler method." sigma_u, sigma_v = sigma_transform(sigma) diff --git a/skrample/diffusers.py b/skrample/diffusers.py index fad58b1..30235bf 100644 --- a/skrample/diffusers.py +++ b/skrample/diffusers.py @@ -11,7 +11,15 @@ import skrample.sampling.structured as sampling from skrample import scheduling -from skrample.common import MergeStrategy, Predictor, predict_epsilon, predict_flow, predict_sample, predict_velocity +from skrample.common import ( + FloatSchedule, + MergeStrategy, + Predictor, + predict_epsilon, + predict_flow, + predict_sample, + predict_velocity, +) from skrample.pytorch.noise import ( BatchTensorNoise, Random, @@ -272,9 +280,13 @@ def from_diffusers_config[N: TensorNoiseProps | None]( # pyright fails if you u ) @property - def schedule_np(self) -> NDArray[np.float64]: + def schedule_float(self) -> FloatSchedule: return scheduling.schedule_lru(self.schedule, self._steps) + @property + def schedule_np(self) -> NDArray[np.float64]: + return scheduling.np_schedule_lru(self.schedule, self._steps) + @property def schedule_pt(self) -> Tensor: return torch.from_numpy(self.schedule_np).to(self._device) @@ -291,7 +303,7 @@ def sigmas(self) -> Tensor: @property def init_noise_sigma(self) -> float: - return self.sampler.scale_input(1, self.schedule_np[0, 1].item(), sigma_transform=self.schedule.sigma_transform) + return self.sampler.scale_input(1, self.schedule_float[0][1], sigma_transform=self.schedule.sigma_transform) @property def order(self) -> int: @@ -413,7 +425,7 @@ def step( sampled = self.sampler.sample( sample=sample_cast, prediction=prediction, - sigma_schedule=schedule[:, 1], + schedule=self.schedule_float, step=step, noise=noise, previous=tuple(self._previous), diff --git a/skrample/sampling/functional.py b/skrample/sampling/functional.py index e006365..6bfb0ed 100644 --- a/skrample/sampling/functional.py +++ b/skrample/sampling/functional.py @@ -115,7 +115,7 @@ def generate_model[T: Sample]( if initial is None and include.start is None: # Short circuit for common case sample: T = rng() else: - sigmas = scheduling.schedule_lru(self.schedule, steps)[:, 1] + sigmas = self.schedule.sigmas(steps) sample: T = self.merge_noise( 0 if initial is None else initial, # type: ignore rng(), @@ -165,7 +165,7 @@ def sample_model[T: Sample]( rng: RNG[T] | None = None, callback: SampleCallback | None = None, ) -> T: - schedule: FloatSchedule = self.schedule.schedule(steps).tolist() + schedule: FloatSchedule = self.schedule.schedule(steps) for n in list(range(steps))[include]: sample = self.step(sample, model, n, schedule, rng) @@ -355,7 +355,7 @@ def sample_model[T: Sample]( step_size: int = max(round(steps * initial), 1) epsilon: float = 1e-16 # lgtm - schedule: FloatSchedule = self.schedule.schedule(steps).tolist() + schedule: FloatSchedule = self.schedule.schedule(steps) indices: list[int] = list(range(steps))[include] step: int = indices[0] diff --git a/skrample/sampling/interface.py b/skrample/sampling/interface.py index 7f632a9..379f11e 100644 --- a/skrample/sampling/interface.py +++ b/skrample/sampling/interface.py @@ -1,6 +1,6 @@ import dataclasses -from skrample.common import RNG, Sample, SigmaTransform +from skrample.common import RNG, FloatSchedule, Sample, SigmaTransform from skrample.sampling import functional, structured @@ -21,10 +21,7 @@ def sample_model[T: Sample]( callback: functional.SampleCallback | None = None, ) -> T: previous: list[structured.SKSamples[T]] = [] - schedule_np = self.schedule.schedule(steps) - schedule: list[tuple[float, float]] = schedule_np.tolist() - sigmas = schedule_np[:, 1] - del schedule_np + schedule: FloatSchedule = self.schedule.schedule(steps) for n in list(range(len(schedule)))[include]: timestep, sigma = schedule[n] @@ -35,7 +32,7 @@ def sample_model[T: Sample]( sample, prediction, n, - sigmas, + schedule, self.schedule.sigma_transform, noise=rng() if rng else None, previous=tuple(previous), diff --git a/skrample/sampling/structured.py b/skrample/sampling/structured.py index 2d1b0d8..3314315 100644 --- a/skrample/sampling/structured.py +++ b/skrample/sampling/structured.py @@ -3,9 +3,9 @@ from dataclasses import dataclass, replace import numpy as np -from numpy.typing import NDArray -from skrample.common import Sample, SigmaTransform, bashforth, euler, merge_noise, safe_log, softmax, spowf +from skrample import common +from skrample.common import FloatSchedule, Sample, SigmaTransform, merge_noise, safe_log, softmax, spowf @dataclass(frozen=True) @@ -44,18 +44,13 @@ def require_previous(self) -> int: "How many prior samples the sampler needs in `previous: list[T]`" return 0 - @staticmethod - def get_sigma(step: int, sigma_schedule: NDArray) -> float: - "Just returns zero if step > len" - return sigma_schedule[step].item() if step < len(sigma_schedule) else 0 - @abstractmethod def sample[T: Sample]( self, sample: T, prediction: T, step: int, - sigma_schedule: NDArray, + schedule: FloatSchedule, sigma_transform: SigmaTransform, noise: T | None = None, previous: tuple[SKSamples[T], ...] = (), @@ -82,7 +77,7 @@ def __call__[T: Sample]( sample: T, prediction: T, step: int, - sigma_schedule: NDArray, + schedule: FloatSchedule, sigma_transform: SigmaTransform, noise: T | None = None, previous: tuple[SKSamples[T], ...] = (), @@ -91,7 +86,7 @@ def __call__[T: Sample]( sample=sample, prediction=prediction, step=step, - sigma_schedule=sigma_schedule, + schedule=schedule, sigma_transform=sigma_transform, noise=noise, previous=previous, @@ -118,7 +113,7 @@ def max_order() -> int: def require_previous(self) -> int: return max(min(self.order, self.max_order()), self.min_order()) - 1 - def effective_order(self, step: int, schedule: NDArray, previous: tuple[SKSamples, ...]) -> int: + def effective_order(self, step: int, schedule: FloatSchedule, previous: tuple[SKSamples, ...]) -> int: "The order used in calculation given a step, schedule length, and previous sample count" return max( 1, # not min_order because previous may be < min. samplers should check effective >= min @@ -151,15 +146,13 @@ def sample[T: Sample]( sample: T, prediction: T, step: int, - sigma_schedule: NDArray, + schedule: FloatSchedule, sigma_transform: SigmaTransform, noise: T | None = None, previous: tuple[SKSamples[T], ...] = (), ) -> SKSamples[T]: - sigma = self.get_sigma(step, sigma_schedule) - sigma_next = self.get_sigma(step + 1, sigma_schedule) return SKSamples( - final=euler(sample, prediction, sigma, sigma_next, sigma_transform), + final=common.euler_step(sample, prediction, step, schedule, sigma_transform), prediction=prediction, sample=sample, ) @@ -182,16 +175,13 @@ def sample[T: Sample]( sample: T, prediction: T, step: int, - sigma_schedule: NDArray, + schedule: FloatSchedule, sigma_transform: SigmaTransform, noise: T | None = None, previous: tuple[SKSamples[T], ...] = (), ) -> SKSamples[T]: - sigma = self.get_sigma(step, sigma_schedule) - sigma_next = self.get_sigma(step + 1, sigma_schedule) - - sigma_u, sigma_v = sigma_transform(sigma) - sigma_u_next, sigma_v_next = sigma_transform(sigma_next) + sigma_u, sigma_v = common.get_sigma_uv(step, schedule, sigma_transform) + sigma_u_next, sigma_v_next = common.get_sigma_uv(step + 1, schedule, sigma_transform) lambda_ = safe_log(sigma_v) - safe_log(sigma_u) lambda_next = safe_log(sigma_v_next) - safe_log(sigma_u_next) @@ -213,11 +203,8 @@ def sample[T: Sample]( # 1st order final -= (sigma_v_next * exp2) * prediction - effective_order = self.effective_order(step, sigma_schedule, previous) - - if effective_order >= 2: - sigma_prev = self.get_sigma(step - 1, sigma_schedule) - sigma_u_prev, sigma_v_prev = sigma_transform(sigma_prev) + if (effective_order := self.effective_order(step, schedule, previous)) >= 2: + sigma_u_prev, sigma_v_prev = common.get_sigma_uv(step - 1, schedule, sigma_transform) lambda_prev = safe_log(sigma_v_prev) - safe_log(sigma_u_prev) h_prev = lambda_ - lambda_prev @@ -228,8 +215,7 @@ def sample[T: Sample]( D1_0 = (1.0 / r) * (prediction - prediction_prev) if effective_order >= 3: - sigma_prev2 = self.get_sigma(step - 2, sigma_schedule) - sigma_u_prev2, sigma_v_prev2 = sigma_transform(sigma_prev2) + sigma_u_prev2, sigma_v_prev2 = common.get_sigma_uv(step - 2, schedule, sigma_transform) lambda_prev2 = safe_log(sigma_v_prev2) - safe_log(sigma_u_prev2) h_prev2 = lambda_prev - lambda_prev2 r_prev2 = h_prev2 / h @@ -268,21 +254,21 @@ def sample[T: Sample]( sample: T, prediction: T, step: int, - sigma_schedule: NDArray, + schedule: FloatSchedule, sigma_transform: SigmaTransform, noise: T | None = None, previous: tuple[SKSamples[T], ...] = (), ) -> SKSamples[T]: - effective_order = self.effective_order(step, sigma_schedule, previous) + effective_order = self.effective_order(step, schedule, previous) predictions = [prediction, *reversed([p.prediction for p in previous[-effective_order + 1 :]])] weighted_prediction: T = math.sumprod( predictions[:effective_order], # type: ignore - bashforth(effective_order), + common.bashforth(effective_order), ) return replace( - super().sample(sample, weighted_prediction, step, sigma_schedule, sigma_transform, noise, previous), + super().sample(sample, weighted_prediction, step, schedule, sigma_transform, noise, previous), prediction=prediction, ) @@ -305,23 +291,18 @@ def unisolve[T: Sample]( sample: T, prediction: T, step: int, - sigma_schedule: NDArray, + schedule: FloatSchedule, sigma_transform: SigmaTransform, noise: T | None = None, previous: tuple[SKSamples[T], ...] = (), prediction_next: Sample | None = None, ) -> T: "Passing `prediction_next` is equivalent to UniC, otherwise behaves as UniP" - sigma = self.get_sigma(step, sigma_schedule) - sigma = self.get_sigma(step, sigma_schedule) - sigma_u, sigma_v = sigma_transform(sigma) - lambda_ = safe_log(sigma_v) - safe_log(sigma_u) - - effective_order = self.effective_order(step, sigma_schedule, previous) + sigma_u, sigma_v = common.get_sigma_uv(step, schedule, sigma_transform) + sigma_u_next, sigma_v_next = common.get_sigma_uv(step + 1, schedule, sigma_transform) - sigma_next = self.get_sigma(step + 1, sigma_schedule) - sigma_u_next, sigma_v_next = sigma_transform(sigma_next) + lambda_ = safe_log(sigma_v) - safe_log(sigma_u) lambda_next = safe_log(sigma_v_next) - safe_log(sigma_u_next) h = abs(lambda_next - lambda_) @@ -336,10 +317,11 @@ def unisolve[T: Sample]( rks: list[float] = [] D1s: list[Sample] = [] + effective_order = self.effective_order(step, schedule, previous) for n in range(1, effective_order): step_prev_N = step - n prediction_prev_N = previous[-n].prediction - sigma_u_prev_N, sigma_v_prev_N = sigma_transform(self.get_sigma(step_prev_N, sigma_schedule)) + sigma_u_prev_N, sigma_v_prev_N = common.get_sigma_uv(step_prev_N, schedule, sigma_transform) lambda_pO = safe_log(sigma_v_prev_N) - safe_log(sigma_u_prev_N) rk = (lambda_pO - lambda_) / h if math.isfinite(rk): # for subnormal @@ -392,13 +374,13 @@ def sample[T: Sample]( sample: T, prediction: T, step: int, - sigma_schedule: NDArray, + schedule: FloatSchedule, sigma_transform: SigmaTransform, noise: T | None = None, previous: tuple[SKSamples[T], ...] = (), ) -> SKSamples[T]: return SKSamples( # type: ignore - final=self.unisolve(sample, prediction, step, sigma_schedule, sigma_transform, noise, previous), + final=self.unisolve(sample, prediction, step, schedule, sigma_transform, noise, previous), prediction=prediction, sample=sample, ) @@ -433,7 +415,7 @@ def sample[T: Sample]( sample: T, prediction: T, step: int, - sigma_schedule: NDArray, + schedule: FloatSchedule, sigma_transform: SigmaTransform, noise: T | None = None, previous: tuple[SKSamples[T], ...] = (), @@ -443,7 +425,7 @@ def sample[T: Sample]( previous[-1].sample, previous[-1].prediction, step - 1, - sigma_schedule, + schedule, sigma_transform, noise, previous[:-1], @@ -454,7 +436,7 @@ def sample[T: Sample]( sample, prediction, step, - sigma_schedule, + schedule, sigma_transform, noise, previous, @@ -493,7 +475,7 @@ def sample[T: Sample]( sample: T, prediction: T, step: int, - sigma_schedule: NDArray, + schedule: FloatSchedule, sigma_transform: SigmaTransform, noise: T | None = None, previous: tuple[SKSamples[T], ...] = (), @@ -509,14 +491,14 @@ def sample[T: Sample]( prior.sample, prior.prediction, step - 1, - sigma_schedule, + schedule, sigma_transform, prior.noise, offset_previous[:-1], ).final if self.adaptive: - p, c = sigma_transform(self.get_sigma(step, sigma_schedule)) + p, c = common.get_sigma_uv(step, schedule, sigma_transform) else: p, c = 0, 0 @@ -534,6 +516,6 @@ def sample[T: Sample]( sample = sample * p + corrected * c # type: ignore return replace( - self.predictor.sample(sample, prediction, step, sigma_schedule, sigma_transform, noise, previous), + self.predictor.sample(sample, prediction, step, schedule, sigma_transform, noise, previous), noise=noise, # the corrector may or may not need noise so we always store ) diff --git a/skrample/scheduling.py b/skrample/scheduling.py index 3672b6e..4a07b98 100644 --- a/skrample/scheduling.py +++ b/skrample/scheduling.py @@ -6,40 +6,52 @@ import numpy as np from numpy.typing import NDArray -from skrample.common import SigmaTransform, normalize, regularize, sigma_complement, sigma_polar, sigmoid +from skrample.common import FloatSchedule, SigmaTransform, normalize, regularize, sigma_complement, sigma_polar, sigmoid @lru_cache -def schedule_lru(schedule: "SkrampleSchedule", steps: int) -> NDArray[np.float64]: +def np_schedule_lru(schedule: "SkrampleSchedule", steps: int) -> NDArray[np.float64]: """Globally cached function for SkrampleSchedule.schedule(steps). Prefer moving SkrampleScheudle.schedule() outside of any loops if possible.""" - return schedule.schedule(steps) + return schedule.schedule_np(steps) + + +@lru_cache +def schedule_lru(schedule: "SkrampleSchedule", steps: int) -> FloatSchedule: + """Globally cached function for SkrampleSchedule.schedule(steps). + Prefer moving SkrampleScheudle.schedule() outside of any loops if possible.""" + return tuple(map(tuple, np_schedule_lru(schedule, steps).tolist())) @dataclass(frozen=True) class SkrampleSchedule(ABC): "Abstract class defining the bare minimum for a noise schedule" + def schedule(self, steps: int) -> FloatSchedule: + """Return the full noise schedule, timesteps stacked on top of sigmas. + Excludes the trailing zero""" + return tuple(map(tuple, self.schedule_np(steps).tolist())) + @property @abstractmethod def sigma_transform(self) -> SigmaTransform: "SigmaTransform required for a given noise schedule" @abstractmethod - def schedule(self, steps: int) -> NDArray[np.float64]: + def schedule_np(self, steps: int) -> NDArray[np.float64]: """Return the full noise schedule, timesteps stacked on top of sigmas. Excludes the trailing zero""" def timesteps(self, steps: int) -> NDArray[np.float64]: "Just the timesteps component as a 1-d array" - return self.schedule(steps)[:, 0] + return self.schedule_np(steps)[:, 0] def sigmas(self, steps: int) -> NDArray[np.float64]: "Just the sigmas component as a 1-d array" - return self.schedule(steps)[:, 1] + return self.schedule_np(steps)[:, 1] def __call__(self, steps: int) -> NDArray[np.float64]: - return self.schedule(steps) + return self.schedule_np(steps) @dataclass(frozen=True) @@ -126,7 +138,7 @@ def alphas_cumprod(self, betas: NDArray[np.float64]) -> NDArray[np.float64]: def scaled_sigmas(self, alphas_cumprod: NDArray[np.float64]) -> NDArray[np.float64]: return ((1 - alphas_cumprod) / alphas_cumprod) ** 0.5 - def schedule(self, steps: int) -> NDArray[np.float64]: + def schedule_np(self, steps: int) -> NDArray[np.float64]: sigmas = self.scaled_sigmas(self.alphas_cumprod(self.betas())) timesteps = self.timesteps(steps) sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) @@ -191,7 +203,7 @@ def sigmas_to_timesteps(self, sigmas: NDArray[np.float64]) -> NDArray[np.float64 def sigmas(self, steps: int) -> NDArray[np.float64]: return np.linspace(self.sigma_start, 0, steps, endpoint=False, dtype=np.float64) - def schedule(self, steps: int) -> NDArray[np.float64]: + def schedule_np(self, steps: int) -> NDArray[np.float64]: sigmas = self.sigmas(steps) timesteps = self.sigmas_to_timesteps(sigmas) @@ -303,8 +315,8 @@ def find_split[T: "ScheduleModifier"]( class NoMod(ScheduleModifier): "Does nothing. For generic programming against ScheduleModifier" - def schedule(self, steps: int) -> NDArray[np.float64]: - return self.base.schedule(steps) + def schedule_np(self, steps: int) -> NDArray[np.float64]: + return self.base.schedule_np(steps) @dataclass(frozen=True) @@ -312,7 +324,7 @@ class FlowShift(ScheduleModifier): shift: float = 3.0 """Amount to shift noise schedule by.""" - def schedule(self, steps: int) -> NDArray[np.float64]: + def schedule_np(self, steps: int) -> NDArray[np.float64]: sigmas = self.base.sigmas(steps) start = sigmas.max().item() @@ -330,7 +342,7 @@ class Karras(ScheduleModifier): rho: float = 7.0 "Ramp power" - def schedule(self, steps: int) -> NDArray[np.float64]: + def schedule_np(self, steps: int) -> NDArray[np.float64]: sigmas = self.base.sigmas(steps) sigma_min = sigmas[-1].item() @@ -351,7 +363,7 @@ class Exponential(ScheduleModifier): rho: float = 1.0 "Ramp power" - def schedule(self, steps: int) -> NDArray[np.float64]: + def schedule_np(self, steps: int) -> NDArray[np.float64]: sigmas = self.base.sigmas(steps) sigma_min = sigmas[-1].item() sigma_max = sigmas[0].item() @@ -371,7 +383,7 @@ class Beta(ScheduleModifier): alpha: float = 0.6 beta: float = 0.6 - def schedule(self, steps: int) -> NDArray[np.float64]: + def schedule_np(self, steps: int) -> NDArray[np.float64]: import scipy sigmas = self.base.sigmas(steps) @@ -397,9 +409,9 @@ class Hyper(ScheduleModifier): tail: bool = True "Include the trailing end to make an S curve" - def schedule(self, steps: int) -> NDArray[np.float64]: + def schedule_np(self, steps: int) -> NDArray[np.float64]: if abs(self.scale) <= 1e-8: - return self.base.schedule(steps) + return self.base.schedule_np(steps) sigmas = self.base.sigmas(steps) start = sigmas[0].item() diff --git a/tests/diffusers_samplers.py b/tests/diffusers_samplers.py index ed0a62e..1231673 100644 --- a/tests/diffusers_samplers.py +++ b/tests/diffusers_samplers.py @@ -65,7 +65,7 @@ def dual_sample( a_output = predictor( a_sample, fake_model(a.scale_input(a_sample, sigma.item(), sigma_transform)), sigma.item(), sigma_transform ) - sampled = a.sample(a_sample, a_output, step, schedule[:, 1].numpy(), sigma_transform, noise, prior_steps) + sampled = a.sample(a_sample, a_output, step, schedule.numpy().tolist(), sigma_transform, noise, prior_steps) a_sample = sampled.final prior_steps.append(sampled) diff --git a/tests/miscellaneous.py b/tests/miscellaneous.py index f5e51a1..b7e628d 100644 --- a/tests/miscellaneous.py +++ b/tests/miscellaneous.py @@ -61,14 +61,14 @@ def test_sampler_generics() -> None: i, o = random.random(), random.random() prev = [SKSamples(random.random(), random.random(), random.random()) for _ in range(9)] - scalar = sampler.sample(i, o, 4, schedule.sigmas(10), schedule.sigma_transform, previous=prev).final + scalar = sampler.sample(i, o, 4, schedule.schedule(10), schedule.sigma_transform, previous=prev).final # Enforce FP64 as that should be equivalent to python scalar ndarr = sampler.sample( np.array([i], dtype=np.float64), np.array([o], dtype=np.float64), 4, - schedule.sigmas(10), + schedule.schedule(10), schedule.sigma_transform, previous=prev, # type: ignore ).final.item() @@ -77,7 +77,7 @@ def test_sampler_generics() -> None: torch.tensor([i], dtype=torch.float64), torch.tensor([o], dtype=torch.float64), 4, - schedule.sigmas(10), + schedule.schedule(10), schedule.sigma_transform, previous=prev, # type: ignore ).final.item() @@ -117,7 +117,7 @@ def test_require_previous() -> None: sample, prediction, 31, - Linear().sigmas(100), + Linear().schedule(100), sigma_complement, None, previous, @@ -126,7 +126,7 @@ def test_require_previous() -> None: sample, prediction, 31, - Linear().sigmas(100), + Linear().schedule(100), sigma_complement, None, previous[len(previous) - sampler.require_previous :], @@ -158,7 +158,7 @@ def test_require_noise() -> None: sample, prediction, 31, - Linear().sigmas(100), + Linear().schedule(100), sigma_complement, noise, previous, @@ -167,7 +167,7 @@ def test_require_noise() -> None: sample, prediction, 31, - Linear().sigmas(100), + Linear().schedule(100), sigma_complement, noise if sampler.require_noise else None, previous, @@ -195,15 +195,15 @@ def fake_model(x: float, _: float, s: float) -> float: sample_f = adapter.sample_model(sample, fake_model, steps, rng=lambda: next(rng)) rng = iter(noise) - schedule_np = schedule.schedule(steps) + float_schedule = schedule.schedule(steps) sample_s = sample previous: list[SKSamples[float]] = [] - for n, (t, s) in enumerate(schedule_np.tolist()): + for n, (t, s) in enumerate(float_schedule): results = sampler.sample( sample_s, fake_model(sample_s, t, s), n, - schedule_np[:, 1], + float_schedule, schedule.sigma_transform, next(rng), tuple(previous), From 2b6c0d4ba0dca9b76fd5903e067ec3fc72242f24 Mon Sep 17 00:00:00 2001 From: Beinsezii Date: Thu, 16 Oct 2025 20:46:19 -0700 Subject: [PATCH 34/59] Add common.mean --- skrample/common.py | 8 ++++++++ skrample/sampling/functional.py | 5 +---- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/skrample/common.py b/skrample/common.py index adfcbc1..a75eca9 100644 --- a/skrample/common.py +++ b/skrample/common.py @@ -181,6 +181,14 @@ def spowf[T: Sample](x: T, f: float) -> T: return abs(x) ** f * (-1 * (x < 0) | 1) # type: ignore +def mean(x: Sample) -> float: + "For an array this returns mean().item(). For a float this returns x" + if isinstance(x, float | int): + return x + else: + return x.mean().item() + + @lru_cache def bashforth(order: int) -> tuple[float, ...]: # tuple return so lru isnt mutable "Bashforth coefficients for a given order" diff --git a/skrample/sampling/functional.py b/skrample/sampling/functional.py index 6bfb0ed..cbf701f 100644 --- a/skrample/sampling/functional.py +++ b/skrample/sampling/functional.py @@ -183,10 +183,7 @@ class FunctionalAdaptive(FunctionalSampler): @staticmethod def mse[T: Sample](a: T, b: T) -> float: error: T = abs(a - b) ** 2 # type: ignore - if isinstance(error, float | int): - return error - else: - return error.mean().item() + return common.mean(error) evaluator: Evaluator = mse "Function used to measure error of two samples" From 00a4902ac66a1630691bdbb20c54fb6e3affa310 Mon Sep 17 00:00:00 2001 From: Beinsezii Date: Thu, 16 Oct 2025 21:05:09 -0700 Subject: [PATCH 35/59] Change signature of FunctionalSampler.merge_noise to use internal schedule --- skrample/sampling/functional.py | 13 +++++++------ skrample/sampling/interface.py | 9 ++++++--- 2 files changed, 13 insertions(+), 9 deletions(-) diff --git a/skrample/sampling/functional.py b/skrample/sampling/functional.py index cbf701f..3aa5587 100644 --- a/skrample/sampling/functional.py +++ b/skrample/sampling/functional.py @@ -83,8 +83,10 @@ def step_tableau[T: Sample]( class FunctionalSampler(ABC): schedule: scheduling.SkrampleSchedule - def merge_noise[T: Sample](self, sample: T, noise: T, sigma: float, sigma_transform: SigmaTransform) -> T: - return common.merge_noise(sample, noise, sigma, sigma_transform) + def merge_noise[T: Sample](self, sample: T, noise: T, steps: int, start: int) -> T: + schedule = scheduling.schedule_lru(self.schedule, steps) + sigma = schedule[start][1] if start < len(schedule) else 0 + return common.merge_noise(sample, noise, sigma, self.schedule.sigma_transform) @abstractmethod def sample_model[T: Sample]( @@ -115,13 +117,12 @@ def generate_model[T: Sample]( if initial is None and include.start is None: # Short circuit for common case sample: T = rng() else: - sigmas = self.schedule.sigmas(steps) sample: T = self.merge_noise( 0 if initial is None else initial, # type: ignore rng(), - sigmas[include.start or 0].item(), - self.schedule.sigma_transform, - ) / self.merge_noise(0.0, 1.0, sigmas[0].item(), self.schedule.sigma_transform) + steps, + include.start or 0, + ) / self.merge_noise(0.0, 1.0, steps, 0) # Rescale sample by initial sigma. Mostly just to handle quirks with Scaled return self.sample_model(sample, model, steps, include, rng, callback) diff --git a/skrample/sampling/interface.py b/skrample/sampling/interface.py index 379f11e..31a13bc 100644 --- a/skrample/sampling/interface.py +++ b/skrample/sampling/interface.py @@ -1,6 +1,7 @@ import dataclasses -from skrample.common import RNG, FloatSchedule, Sample, SigmaTransform +from skrample import scheduling +from skrample.common import RNG, FloatSchedule, Sample from skrample.sampling import functional, structured @@ -8,8 +9,10 @@ class StructuredFunctionalAdapter(functional.FunctionalSampler): sampler: structured.StructuredSampler - def merge_noise[T: Sample](self, sample: T, noise: T, sigma: float, sigma_transform: SigmaTransform) -> T: - return self.sampler.merge_noise(sample, noise, sigma, sigma_transform) + def merge_noise[T: Sample](self, sample: T, noise: T, steps: int, start: int) -> T: + schedule = scheduling.schedule_lru(self.schedule, steps) + sigma = schedule[start][1] if start < len(schedule) else 0 + return self.sampler.merge_noise(sample, noise, sigma, self.schedule.sigma_transform) def sample_model[T: Sample]( self, From 895870cc2433c6499c72282a52aebba70e23cca3 Mon Sep 17 00:00:00 2001 From: Beinsezii Date: Thu, 16 Oct 2025 21:06:42 -0700 Subject: [PATCH 36/59] StructuredFunctionalAdapter only call rng() if sampler requires --- skrample/sampling/interface.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/skrample/sampling/interface.py b/skrample/sampling/interface.py index 31a13bc..8362f68 100644 --- a/skrample/sampling/interface.py +++ b/skrample/sampling/interface.py @@ -37,7 +37,7 @@ def sample_model[T: Sample]( n, schedule, self.schedule.sigma_transform, - noise=rng() if rng else None, + noise=rng() if rng and self.sampler.require_noise else None, previous=tuple(previous), ) From 6358c1bbc5d8385a7ce49d558224cd89f47dab71 Mon Sep 17 00:00:00 2001 From: Beinsezii Date: Thu, 16 Oct 2025 21:44:45 -0700 Subject: [PATCH 37/59] Add float versions of SkrampleSchedule methods, FunctionalSampler.model_with_predictor --- examples/diffusers/functional.py | 9 ++----- skrample/diffusers.py | 2 +- skrample/sampling/functional.py | 16 +++++++++--- skrample/sampling/interface.py | 5 ++-- skrample/scheduling.py | 43 +++++++++++++++++++------------- tests/diffusers_schedules.py | 4 +-- tests/miscellaneous.py | 4 +-- 7 files changed, 48 insertions(+), 35 deletions(-) diff --git a/examples/diffusers/functional.py b/examples/diffusers/functional.py index e8d56d7..350736f 100755 --- a/examples/diffusers/functional.py +++ b/examples/diffusers/functional.py @@ -59,12 +59,7 @@ def call_model(sample: torch.Tensor, timestep: float, sigma: float) -> torch.Ten i=i, t=sample.new_tensor([timestep] * len(sample)), ) - return wrapper.predictor( - sample, - block_state["noise_pred"], # type: ignore - sigma, - schedule.sigma_transform, - ) + return block_state["noise_pred"] # type: ignore def sample_callback(x: torch.Tensor, n: int, t: float, s: float) -> None: nonlocal i @@ -73,7 +68,7 @@ def sample_callback(x: torch.Tensor, n: int, t: float, s: float) -> None: block_state["latents"] = sampler.sample_model( sample=block_state["latents"], - model=call_model, + model=sampler.model_with_predictor(call_model, wrapper.predictor), steps=block_state["num_inference_steps"], callback=sample_callback, ) diff --git a/skrample/diffusers.py b/skrample/diffusers.py index 30235bf..5be6190 100644 --- a/skrample/diffusers.py +++ b/skrample/diffusers.py @@ -153,7 +153,7 @@ def parse_diffusers_config( scaled_keys = [f.name for f in dataclasses.fields(scheduling.Scaled)] # non-uniform misses a whole timestep scaled = scheduling.Scaled(**{k: v for k, v in remapped.items() if k in scaled_keys} | {"uniform": True}) - sigma_start: float = scaled.sigmas(1).item() + sigma_start: float = scaled.sigmas(1)[0] remapped["sigma_start"] = math.sqrt(sigma_start) schedule_modifiers: list[tuple[type[ScheduleModifier], dict[str, Any]]] = [] diff --git a/skrample/sampling/functional.py b/skrample/sampling/functional.py index 3aa5587..39148ea 100644 --- a/skrample/sampling/functional.py +++ b/skrample/sampling/functional.py @@ -2,13 +2,14 @@ import math from abc import ABC, abstractmethod from collections.abc import Callable +from functools import wraps from types import MappingProxyType from typing import Any import numpy as np from skrample import common, scheduling -from skrample.common import RNG, DictOrProxy, FloatSchedule, Sample, SigmaTransform +from skrample.common import RNG, DictOrProxy, FloatSchedule, Predictor, Sample, SigmaTransform from . import tableaux @@ -83,9 +84,18 @@ def step_tableau[T: Sample]( class FunctionalSampler(ABC): schedule: scheduling.SkrampleSchedule + def model_with_predictor(self, model: SampleableModel, predictor: Predictor) -> SampleableModel: + "Wraps the output of `model` with `predictor` using schedule.sigma_transform" + + @wraps(model) + def model_with_predictor[T: Sample](x: T, t: float, s: float) -> T: + return predictor(x, model(x, t, s), s, self.schedule.sigma_transform) + + return model_with_predictor + def merge_noise[T: Sample](self, sample: T, noise: T, steps: int, start: int) -> T: - schedule = scheduling.schedule_lru(self.schedule, steps) - sigma = schedule[start][1] if start < len(schedule) else 0 + sigmas = self.schedule.sigmas(steps) + sigma = sigmas[start] if start < len(sigmas) else 0 return common.merge_noise(sample, noise, sigma, self.schedule.sigma_transform) @abstractmethod diff --git a/skrample/sampling/interface.py b/skrample/sampling/interface.py index 8362f68..cfb39d3 100644 --- a/skrample/sampling/interface.py +++ b/skrample/sampling/interface.py @@ -1,6 +1,5 @@ import dataclasses -from skrample import scheduling from skrample.common import RNG, FloatSchedule, Sample from skrample.sampling import functional, structured @@ -10,8 +9,8 @@ class StructuredFunctionalAdapter(functional.FunctionalSampler): sampler: structured.StructuredSampler def merge_noise[T: Sample](self, sample: T, noise: T, steps: int, start: int) -> T: - schedule = scheduling.schedule_lru(self.schedule, steps) - sigma = schedule[start][1] if start < len(schedule) else 0 + sigmas = self.schedule.sigmas(steps) + sigma = sigmas[start] if start < len(sigmas) else 0 return self.sampler.merge_noise(sample, noise, sigma, self.schedule.sigma_transform) def sample_model[T: Sample]( diff --git a/skrample/scheduling.py b/skrample/scheduling.py index 4a07b98..515f5aa 100644 --- a/skrample/scheduling.py +++ b/skrample/scheduling.py @@ -1,5 +1,6 @@ import math from abc import ABC, abstractmethod +from collections.abc import Sequence from dataclasses import dataclass, replace from functools import lru_cache @@ -27,11 +28,6 @@ def schedule_lru(schedule: "SkrampleSchedule", steps: int) -> FloatSchedule: class SkrampleSchedule(ABC): "Abstract class defining the bare minimum for a noise schedule" - def schedule(self, steps: int) -> FloatSchedule: - """Return the full noise schedule, timesteps stacked on top of sigmas. - Excludes the trailing zero""" - return tuple(map(tuple, self.schedule_np(steps).tolist())) - @property @abstractmethod def sigma_transform(self) -> SigmaTransform: @@ -42,14 +38,27 @@ def schedule_np(self, steps: int) -> NDArray[np.float64]: """Return the full noise schedule, timesteps stacked on top of sigmas. Excludes the trailing zero""" - def timesteps(self, steps: int) -> NDArray[np.float64]: + def timesteps_np(self, steps: int) -> NDArray[np.float64]: "Just the timesteps component as a 1-d array" return self.schedule_np(steps)[:, 0] - def sigmas(self, steps: int) -> NDArray[np.float64]: + def sigmas_np(self, steps: int) -> NDArray[np.float64]: "Just the sigmas component as a 1-d array" return self.schedule_np(steps)[:, 1] + def schedule(self, steps: int) -> FloatSchedule: + """Return the full noise schedule, [(timestep, sigma), ...) + Excludes the trailing zero""" + return tuple(map(tuple, self.schedule_np(steps).tolist())) + + def timesteps(self, steps: int) -> Sequence[float]: + "Just the timesteps component" + return self.timesteps_np(steps).tolist() + + def sigmas(self, steps: int) -> Sequence[float]: + "Just the sigmas component" + return self.sigmas_np(steps).tolist() + def __call__(self, steps: int) -> NDArray[np.float64]: return self.schedule_np(steps) @@ -113,7 +122,7 @@ def sigmas_to_timesteps(self, sigmas: NDArray[np.float64]) -> NDArray[np.float64 t = (1 - w) * low_idx + w * high_idx return t - def timesteps(self, steps: int) -> NDArray[np.float64]: + def timesteps_np(self, steps: int) -> NDArray[np.float64]: # # https://arxiv.org/abs/2305.08891 Table 2 if self.uniform: return np.linspace(self.base_timesteps - 1, 0, steps, endpoint=False, dtype=np.float64).round() @@ -140,7 +149,7 @@ def scaled_sigmas(self, alphas_cumprod: NDArray[np.float64]) -> NDArray[np.float def schedule_np(self, steps: int) -> NDArray[np.float64]: sigmas = self.scaled_sigmas(self.alphas_cumprod(self.betas())) - timesteps = self.timesteps(steps) + timesteps = self.timesteps_np(steps) sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) return np.stack([timesteps, sigmas], axis=1) @@ -200,11 +209,11 @@ def sigma_transform(self) -> SigmaTransform: def sigmas_to_timesteps(self, sigmas: NDArray[np.float64]) -> NDArray[np.float64]: return normalize(sigmas, self.sigma_start) * self.base_timesteps - def sigmas(self, steps: int) -> NDArray[np.float64]: + def sigmas_np(self, steps: int) -> NDArray[np.float64]: return np.linspace(self.sigma_start, 0, steps, endpoint=False, dtype=np.float64) def schedule_np(self, steps: int) -> NDArray[np.float64]: - sigmas = self.sigmas(steps) + sigmas = self.sigmas_np(steps) timesteps = self.sigmas_to_timesteps(sigmas) return np.stack([timesteps, sigmas], axis=1) @@ -219,7 +228,7 @@ class SigmoidCDF(Linear): cdf_scale: float = 3 "Multiply the inverse CDF output before the sigmoid function is applied" - def sigmas(self, steps: int) -> NDArray[np.float64]: + def sigmas_np(self, steps: int) -> NDArray[np.float64]: from scipy.stats import norm step_peak = 1 / (steps * math.pi / 2) @@ -325,7 +334,7 @@ class FlowShift(ScheduleModifier): """Amount to shift noise schedule by.""" def schedule_np(self, steps: int) -> NDArray[np.float64]: - sigmas = self.base.sigmas(steps) + sigmas = self.base.sigmas_np(steps) start = sigmas.max().item() sigmas = self.shift / (self.shift + (start / sigmas - 1)) * start @@ -343,7 +352,7 @@ class Karras(ScheduleModifier): "Ramp power" def schedule_np(self, steps: int) -> NDArray[np.float64]: - sigmas = self.base.sigmas(steps) + sigmas = self.base.sigmas_np(steps) sigma_min = sigmas[-1].item() sigma_max = sigmas[0].item() @@ -364,7 +373,7 @@ class Exponential(ScheduleModifier): "Ramp power" def schedule_np(self, steps: int) -> NDArray[np.float64]: - sigmas = self.base.sigmas(steps) + sigmas = self.base.sigmas_np(steps) sigma_min = sigmas[-1].item() sigma_max = sigmas[0].item() @@ -386,7 +395,7 @@ class Beta(ScheduleModifier): def schedule_np(self, steps: int) -> NDArray[np.float64]: import scipy - sigmas = self.base.sigmas(steps) + sigmas = self.base.sigmas_np(steps) sigma_min = sigmas[-1].item() sigma_max = sigmas[0].item() @@ -413,7 +422,7 @@ def schedule_np(self, steps: int) -> NDArray[np.float64]: if abs(self.scale) <= 1e-8: return self.base.schedule_np(steps) - sigmas = self.base.sigmas(steps) + sigmas = self.base.sigmas_np(steps) start = sigmas[0].item() sigmas = normalize(sigmas, start) # Base -> 1..0 diff --git a/tests/diffusers_schedules.py b/tests/diffusers_schedules.py index a3220c2..41c0a23 100644 --- a/tests/diffusers_schedules.py +++ b/tests/diffusers_schedules.py @@ -24,13 +24,13 @@ def compare_schedules( b.set_timesteps(num_inference_steps=steps) compare_tensors( - torch.from_numpy(a.timesteps(steps)), + torch.from_numpy(a.timesteps_np(steps)), b.timesteps, f"TIMESTEPS @ {steps}", margin=ts_margin, ) compare_tensors( - torch.from_numpy(a.sigmas(steps)), + torch.from_numpy(a.sigmas_np(steps)), b.sigmas[:-1], f"SIGMAS @ {steps}", margin=sig_margin, diff --git a/tests/miscellaneous.py b/tests/miscellaneous.py index b7e628d..6b02360 100644 --- a/tests/miscellaneous.py +++ b/tests/miscellaneous.py @@ -46,8 +46,8 @@ def test_sigmas_to_timesteps() -> None: for schedule in [*(cls() for cls in ALL_SCHEDULES), Scaled(beta_scale=1)]: # base schedules - timesteps = schedule.timesteps(123) - timesteps_inv = schedule.sigmas_to_timesteps(schedule.sigmas(123)) + timesteps = schedule.timesteps_np(123) + timesteps_inv = schedule.sigmas_to_timesteps(schedule.sigmas_np(123)) compare_tensors(torch.tensor(timesteps), torch.tensor(timesteps_inv), margin=0) # shocked this rounds good From d2514fdaaa7382d2debdd497a257dc25e7c0dc44 Mon Sep 17 00:00:00 2001 From: Beinsezii Date: Thu, 16 Oct 2025 23:31:28 -0700 Subject: [PATCH 38/59] Add RKMoire.discard --- skrample/sampling/functional.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/skrample/sampling/functional.py b/skrample/sampling/functional.py index 39148ea..cb4eb50 100644 --- a/skrample/sampling/functional.py +++ b/skrample/sampling/functional.py @@ -316,6 +316,8 @@ class RKMoire(FunctionalAdaptive, FunctionalHigher): "Percent of schedule to take as a maximum step." adaption: float = 0.3 "How fast to adjust step size in relation to error" + discard: float = float("inf") + "If the final adjustment down is more than this, the entire previous step is discarded." rescale_init: bool = True "Scale initial by a tableau's model evals." @@ -388,6 +390,10 @@ def sample_model[T: Sample]( adjustment: float = (self.threshold / max(error, epsilon)) ** self.adaption / (delta_next / delta) step_size = max(round(min(step_size * adjustment, steps * maximum)), 1) + # Only discard if it will actually decrease step size + if step_next - step > step_size and 1 / max(adjustment, epsilon) > self.discard: + continue + else: # Save the extra euler call since the 2nd weight isn't used sample_high = step_tableau( tab[:2], sample, model, step, schedule, self.schedule.sigma_transform, step_size From 4752995fb843dd7b7bfdddacfefd1867b14ca327 Mon Sep 17 00:00:00 2001 From: Beinsezii Date: Sat, 1 Nov 2025 20:19:35 -0700 Subject: [PATCH 39/59] Further improve plot_skrample.py fake model --- scripts/plot_skrample.py | 26 +++++++++++++++++--------- 1 file changed, 17 insertions(+), 9 deletions(-) diff --git a/scripts/plot_skrample.py b/scripts/plot_skrample.py index d6364c2..e51c5a9 100755 --- a/scripts/plot_skrample.py +++ b/scripts/plot_skrample.py @@ -58,9 +58,9 @@ def colors(hue_steps: int) -> Generator[list[float]]: yield oklch_to_srgb(np.array([lighness_actual, chroma_actual, hue], dtype=np.float64)) -TRANSFORMS: dict[str, SigmaTransform] = { - "polar": sigma_polar, - "complement": sigma_complement, +TRANSFORMS: dict[str, tuple[float, SigmaTransform]] = { + "polar": (14.6, sigma_polar), + "complement": (1.0, sigma_complement), } SAMPLERS: dict[str, structured.StructuredSampler | functional.FunctionalSampler] = { "euler": structured.Euler(), @@ -159,7 +159,15 @@ def colors(hue_steps: int) -> Generator[list[float]]: plt.ylabel("Sample") plt.title("Skrample Samplers") - schedule = scheduling.Linear(base_timesteps=10_000, custom_transform=TRANSFORMS[args.transform]) + schedule = scheduling.Hyper( + scheduling.Linear( + sigma_start=TRANSFORMS[args.transform][0], + base_timesteps=10_000, + custom_transform=TRANSFORMS[args.transform][1], + ), + -2, + False, + ) def sample_model( sampler: structured.StructuredSampler | functional.FunctionalSampler, steps: int @@ -171,12 +179,12 @@ def sample_model( sample = 1.0 sampled_values = [sample] - sigmas = [0.0] + timesteps = [0.0] def callback(x: float, n: int, t: float, s: float) -> None: - nonlocal sampled_values, sigmas + nonlocal sampled_values, timesteps sampled_values.append(x) - sigmas.insert(-1, s) + timesteps.insert(-1, t / schedule.base_timesteps) if isinstance(sampler, functional.RKMoire) and args.adjust: adjusted = schedule.base_timesteps @@ -187,13 +195,13 @@ def callback(x: float, n: int, t: float, s: float) -> None: sampler.sample_model( sample=sample, - model=lambda sample, timestep, sigma: sample + math.sin(sigma * args.curve), + model=lambda x, t, s: x + math.sin(t / schedule.base_timesteps * args.curve) * (s + 1), steps=adjusted, rng=random, callback=callback, ) - return sigmas, sampled_values + return timesteps, sampled_values plt.plot(*sample_model(structured.Euler(), schedule.base_timesteps), label="Reference", color=next(COLORS)) From 9f38650316a52641e04ae45cd4c3cbe2f87431ca Mon Sep 17 00:00:00 2001 From: Beinsezii Date: Wed, 5 Nov 2025 02:14:01 -0800 Subject: [PATCH 40/59] Add Heun tests, rough impl of derivative-based tableau solver Relies on branching for transforms currently, but results are substantially better than the raw prediction based solver --- skrample/sampling/functional.py | 44 ++++++++++++++ tests/diffusers_samplers.py | 100 +++++++++++++++++++++++++++++++- 2 files changed, 142 insertions(+), 2 deletions(-) diff --git a/skrample/sampling/functional.py b/skrample/sampling/functional.py index cb4eb50..8ba9719 100644 --- a/skrample/sampling/functional.py +++ b/skrample/sampling/functional.py @@ -38,6 +38,41 @@ def fractional_step( return result +def step_tableau_derive[T: Sample]( + tableau: tableaux.Tableau | tableaux.ExtendedTableau, + sample: T, + model: SampleableModel[T], + step: int, + schedule: FloatSchedule, + step_size: int = 1, + epsilon: float = 1e-8, +) -> tuple[T, ...]: + nodes, weights = tableau[0], tableau[1:] + + derivatives: list[T] = [] + S0 = schedule[step][1] + S1 = schedule[step + step_size][1] if step + step_size < len(schedule) else 0 + H = S1 - S0 + + fractions = fractional_step(schedule, step, tuple(f[0] * step_size for f in nodes)) + + for frac_sc, icoeffs in zip(fractions, (t[1] for t in nodes), strict=True): + Sn = frac_sc[1] + if icoeffs: + X: T = sample + math.sumprod(derivatives, icoeffs) / math.fsum(icoeffs) * (Sn - S0) # type: ignore + else: + X = sample + + # Do not call model on timestep = 0 or sigma = 0 + if any(abs(v) < epsilon for v in frac_sc): + derivatives.append((sample - X) / S0) # type: ignore + else: + P: T = model(X, *frac_sc) if not any(abs(v) < epsilon for v in frac_sc) else X + derivatives.append((X - P) / Sn) # type: ignore + + return tuple(sample + math.sumprod(derivatives, w) * H for w in weights) # type: ignore + + def step_tableau[T: Sample]( tableau: tableaux.Tableau | tableaux.ExtendedTableau, sample: T, @@ -48,6 +83,15 @@ def step_tableau[T: Sample]( step_size: int = 1, epsilon: float = 1e-8, ) -> tuple[T, ...]: + if transform is common.sigma_complement: + return step_tableau_derive(tableau, sample, model, step, schedule, step_size, epsilon) + elif transform is common.sigma_polar: + if step == 0: # TODO (beinsezii): can't have this + sample = sample * schedule[step][1] # type: ignore + return step_tableau_derive( + tableau, sample, lambda x, t, s: model(x / (s**2 + 1) ** 0.5, t, s), step, schedule, step_size, epsilon + ) + nodes, weights = tableau[0], tableau[1:] k_terms: list[T] = [] fractions = fractional_step(schedule, step, tuple(f[0] * step_size for f in nodes)) diff --git a/tests/diffusers_samplers.py b/tests/diffusers_samplers.py index 1231673..d05f0f7 100644 --- a/tests/diffusers_samplers.py +++ b/tests/diffusers_samplers.py @@ -1,24 +1,48 @@ +import dataclasses from inspect import signature +import numpy as np import torch from diffusers.schedulers.scheduling_dpmsolver_multistep import DPMSolverMultistepScheduler from diffusers.schedulers.scheduling_euler_ancestral_discrete import EulerAncestralDiscreteScheduler from diffusers.schedulers.scheduling_euler_discrete import EulerDiscreteScheduler from diffusers.schedulers.scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler +from diffusers.schedulers.scheduling_flow_match_heun_discrete import FlowMatchHeunDiscreteScheduler +from diffusers.schedulers.scheduling_heun_discrete import HeunDiscreteScheduler from diffusers.schedulers.scheduling_unipc_multistep import UniPCMultistepScheduler from testing_common import FLOW_CONFIG, SCALED_CONFIG, compare_tensors -from skrample.common import Predictor, sigma_complement, sigma_polar +from skrample.common import FloatSchedule, Predictor, SigmaTransform, sigma_complement, sigma_polar from skrample.common import predict_epsilon as EPSILON from skrample.common import predict_flow as FLOW from skrample.common import predict_velocity as VELOCITY +from skrample.sampling.functional import RKUltra from skrample.sampling.structured import DPM, Euler, SKSamples, StructuredSampler, UniPC +from skrample.sampling.tableaux import RK2 +from skrample.scheduling import SkrampleSchedule DiffusersScheduler = ( - EulerDiscreteScheduler | DPMSolverMultistepScheduler | FlowMatchEulerDiscreteScheduler | UniPCMultistepScheduler + EulerDiscreteScheduler + | DPMSolverMultistepScheduler + | FlowMatchEulerDiscreteScheduler + | FlowMatchHeunDiscreteScheduler + | UniPCMultistepScheduler ) +@dataclasses.dataclass(frozen=True) +class FixedSchedule(SkrampleSchedule): + fixed_schedule: FloatSchedule + transform: SigmaTransform + + def schedule_np(self, steps: int) -> np.typing.NDArray[np.float64]: + return np.array(self.fixed_schedule, dtype=np.float64) + + @property + def sigma_transform(self) -> SigmaTransform: + return self.transform + + def fake_model(t: torch.Tensor) -> torch.Tensor: t @= torch.randn(t.shape, generator=torch.Generator(t.device).manual_seed(-1), dtype=t.dtype) return t / t.std() # keep values in sane range @@ -168,3 +192,75 @@ def test_unipc() -> None: predictor[0], message=f"{predictor[0].__name__} o{order}", ) + + +def test_heun_scaled() -> None: + margin = 1e-8 + predictor: Predictor = EPSILON + sigma_transform = sigma_polar + for steps in 30, 31, 200, 201: + df: HeunDiscreteScheduler = HeunDiscreteScheduler.from_config(SCALED_CONFIG) # type: ignore + + df.set_timesteps(steps) + + fixed: list[tuple[float, float]] = [] + for t in zip(df.timesteps.tolist(), df.sigmas.tolist()): + if t not in fixed: + fixed.append(t) + + sk = RKUltra(FixedSchedule(fixed, sigma_transform), order=2, providers=RKUltra.providers | {2: RK2.Heun}) + + sk_sample = torch.zeros([1, 4, 128, 128], dtype=torch.float32) + seed = torch.manual_seed(0) + + df_noise = torch.randn(sk_sample.shape, generator=seed.clone_state(), dtype=sk_sample.dtype) + df_sample = df.add_noise(sk_sample.clone(), df_noise, df.timesteps[0:1]) + for t in df.timesteps: + df_sample: torch.Tensor = df.step( + fake_model(df.scale_model_input(df_sample, timestep=t)), + sample=df_sample, + timestep=t, + )[0] + + sk_sample = sk.generate_model( + sk.model_with_predictor(lambda x, t, s: fake_model(x), predictor), + lambda: torch.randn(sk_sample.shape, generator=seed, dtype=sk_sample.dtype), + steps, + initial=sk_sample, + ) + + compare_tensors(df_sample, sk_sample, message=f"{steps}", margin=margin) + + +def test_heun_flow() -> None: + margin = 1e-8 + predictor: Predictor = FLOW + sigma_transform = sigma_complement + for steps in 30, 31, 200, 201: + df: FlowMatchHeunDiscreteScheduler = FlowMatchHeunDiscreteScheduler.from_config(FLOW_CONFIG) # type: ignore + + df.set_timesteps(steps) + + fixed: list[tuple[float, float]] = [] + for t in zip(df.timesteps.tolist(), df.sigmas.tolist()): + if t not in fixed: + fixed.append(t) + + sk = RKUltra(FixedSchedule(fixed, sigma_transform), order=2, providers=RKUltra.providers | {2: RK2.Heun}) + + sk_sample = torch.zeros([1, 4, 128, 128], dtype=torch.float32) + seed = torch.manual_seed(0) + + df_noise = torch.randn(sk_sample.shape, generator=seed.clone_state(), dtype=sk_sample.dtype) + df_sample = df.scale_noise(sk_sample.clone(), df.timesteps[0], df_noise) # type: ignore + for t in df.timesteps: + df_sample: torch.Tensor = df.step(fake_model(df_sample), sample=df_sample, timestep=t)[0] # type: ignore + + sk_sample = sk.generate_model( + sk.model_with_predictor(lambda x, t, s: fake_model(x), predictor), + lambda: torch.randn(sk_sample.shape, generator=seed, dtype=sk_sample.dtype), + steps, + initial=sk_sample, + ) + + compare_tensors(df_sample, sk_sample, message=f"{steps}", margin=margin) From e4935cb8ec3ada61e20df859f0a49581e1193a58 Mon Sep 17 00:00:00 2001 From: Beinsezii Date: Mon, 10 Nov 2025 03:14:20 -0800 Subject: [PATCH 41/59] Rework step_tableau_derive to better work with polar models The polar to/from derive functions are *so close* to working with flow match but it's just not quite there. So I basically have two options: 1. fix the derives so they work for any transform 2. accept defeat and replace SigmaTransform with some generic ModelTransforms unit struct that just has static methods for the sigma and derivative transforms --- skrample/sampling/functional.py | 86 ++++++++++++++++++++++++++++----- tests/diffusers_samplers.py | 3 +- 2 files changed, 77 insertions(+), 12 deletions(-) diff --git a/skrample/sampling/functional.py b/skrample/sampling/functional.py index 8ba9719..177636c 100644 --- a/skrample/sampling/functional.py +++ b/skrample/sampling/functional.py @@ -38,39 +38,87 @@ def fractional_step( return result +def to_derivative_polar[T: Sample](sample: T, prediction: T, sigma: float, transform: SigmaTransform) -> T: + sigma_u, sigma_v = transform(sigma) + return (sample - (sigma_v * prediction)) / sigma_u # pyright: ignore [reportReturnType] + + +def from_derivative_polar[T: Sample](sample: T, derivative: T, sigma: float, transform: SigmaTransform) -> T: + sigma_u, sigma_v = transform(sigma) + return (sample - derivative * sigma_u) / sigma_v # pyright: ignore [reportReturnType] + + +def to_derivative_complement[T: Sample](sample: T, prediction: T, sigma: float, transform: SigmaTransform) -> T: + return (sample - prediction) / sigma # pyright: ignore [reportReturnType] + + +def from_derivative_complement[T: Sample](sample: T, derivative: T, sigma: float, transform: SigmaTransform) -> T: + return sample - derivative * sigma # pyright: ignore [reportReturnType] + + +type DerivativeTransform[T: Sample] = Callable[[T, T, float, SigmaTransform], T] + + def step_tableau_derive[T: Sample]( tableau: tableaux.Tableau | tableaux.ExtendedTableau, sample: T, model: SampleableModel[T], step: int, schedule: FloatSchedule, + transform: SigmaTransform, + derivative_io: tuple[DerivativeTransform[T], DerivativeTransform[T]], step_size: int = 1, epsilon: float = 1e-8, ) -> tuple[T, ...]: + to_d, from_d = derivative_io + nodes, weights = tableau[0], tableau[1:] derivatives: list[T] = [] S0 = schedule[step][1] S1 = schedule[step + step_size][1] if step + step_size < len(schedule) else 0 - H = S1 - S0 fractions = fractional_step(schedule, step, tuple(f[0] * step_size for f in nodes)) for frac_sc, icoeffs in zip(fractions, (t[1] for t in nodes), strict=True): - Sn = frac_sc[1] + sigma_i = frac_sc[1] if icoeffs: - X: T = sample + math.sumprod(derivatives, icoeffs) / math.fsum(icoeffs) * (Sn - S0) # type: ignore + X: T = common.euler( # pyright: ignore [reportAssignmentType] + sample, + from_d( + sample, + math.sumprod(derivatives, icoeffs) / math.fsum(icoeffs), # pyright: ignore [reportArgumentType] + S0, + transform, + ), + S0, + sigma_i, + transform, + ) else: X = sample # Do not call model on timestep = 0 or sigma = 0 if any(abs(v) < epsilon for v in frac_sc): - derivatives.append((sample - X) / S0) # type: ignore + derivatives.append(to_d(sample, X, S0, transform)) else: - P: T = model(X, *frac_sc) if not any(abs(v) < epsilon for v in frac_sc) else X - derivatives.append((X - P) / Sn) # type: ignore + derivatives.append(to_d(X, model(X, *frac_sc), sigma_i, transform)) - return tuple(sample + math.sumprod(derivatives, w) * H for w in weights) # type: ignore + return tuple( # pyright: ignore [reportReturnType] + common.euler( + sample, + from_d( + sample, + math.sumprod(derivatives, w), # pyright: ignore [reportArgumentType] + S0, + transform, + ), + S0, + S1, + transform, + ) + for w in weights + ) def step_tableau[T: Sample]( @@ -84,12 +132,28 @@ def step_tableau[T: Sample]( epsilon: float = 1e-8, ) -> tuple[T, ...]: if transform is common.sigma_complement: - return step_tableau_derive(tableau, sample, model, step, schedule, step_size, epsilon) + return step_tableau_derive( + tableau, + sample, + model, + step, + schedule, + transform, + (to_derivative_complement, from_derivative_complement), + step_size, + epsilon, + ) elif transform is common.sigma_polar: - if step == 0: # TODO (beinsezii): can't have this - sample = sample * schedule[step][1] # type: ignore return step_tableau_derive( - tableau, sample, lambda x, t, s: model(x / (s**2 + 1) ** 0.5, t, s), step, schedule, step_size, epsilon + tableau, + sample, + model, + step, + schedule, + transform, + (to_derivative_polar, from_derivative_polar), + step_size, + epsilon, ) nodes, weights = tableau[0], tableau[1:] diff --git a/tests/diffusers_samplers.py b/tests/diffusers_samplers.py index d05f0f7..742b549 100644 --- a/tests/diffusers_samplers.py +++ b/tests/diffusers_samplers.py @@ -214,7 +214,8 @@ def test_heun_scaled() -> None: seed = torch.manual_seed(0) df_noise = torch.randn(sk_sample.shape, generator=seed.clone_state(), dtype=sk_sample.dtype) - df_sample = df.add_noise(sk_sample.clone(), df_noise, df.timesteps[0:1]) + # df_sample = df.add_noise(sk_sample.clone(), df_noise, df.timesteps[0:1]) + df_sample = df_noise * df.init_noise_sigma for t in df.timesteps: df_sample: torch.Tensor = df.step( fake_model(df.scale_model_input(df_sample, timestep=t)), From 2a440bf183efbb0d1ca7363dab96404c4caf05f8 Mon Sep 17 00:00:00 2001 From: Beinsezii Date: Mon, 10 Nov 2025 03:34:58 -0800 Subject: [PATCH 42/59] Deduplicate step_tableau with step_tableau_derive --- skrample/sampling/functional.py | 39 +++++++++------------------------ 1 file changed, 10 insertions(+), 29 deletions(-) diff --git a/skrample/sampling/functional.py b/skrample/sampling/functional.py index 177636c..ecb1343 100644 --- a/skrample/sampling/functional.py +++ b/skrample/sampling/functional.py @@ -156,35 +156,16 @@ def step_tableau[T: Sample]( epsilon, ) - nodes, weights = tableau[0], tableau[1:] - k_terms: list[T] = [] - fractions = fractional_step(schedule, step, tuple(f[0] * step_size for f in nodes)) - - for frac_sc, icoeffs in zip(fractions, (t[1] for t in nodes), strict=True): - if icoeffs: - combined: T = common.euler( - sample, - math.sumprod(k_terms, icoeffs) / math.fsum(icoeffs), # type: ignore - schedule[step][1], - frac_sc[1], - transform, - ) - else: - combined = sample - - # Do not call model on timestep = 0 or sigma = 0 - k_terms.append(model(combined, *frac_sc) if not any(abs(v) < epsilon for v in frac_sc) else combined) - - return tuple( - common.euler_step( - sample, - math.sumprod(k_terms, w), # type: ignore - step, - schedule, - transform, - step_size, - ) - for w in weights + return step_tableau_derive( + tableau, + sample, + model, + step, + schedule, + transform, + ((lambda x, p, s, t: p), (lambda x, p, s, t: p)), + step_size, + epsilon, ) From 3139fc2f09759729ba6feb7c23690efe692f1b36 Mon Sep 17 00:00:00 2001 From: Beinsezii Date: Mon, 10 Nov 2025 03:40:11 -0800 Subject: [PATCH 43/59] Add low step test cases for heun --- tests/diffusers_samplers.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/diffusers_samplers.py b/tests/diffusers_samplers.py index 742b549..0deace5 100644 --- a/tests/diffusers_samplers.py +++ b/tests/diffusers_samplers.py @@ -198,7 +198,7 @@ def test_heun_scaled() -> None: margin = 1e-8 predictor: Predictor = EPSILON sigma_transform = sigma_polar - for steps in 30, 31, 200, 201: + for steps in 2, 3, 30, 31, 200, 201: df: HeunDiscreteScheduler = HeunDiscreteScheduler.from_config(SCALED_CONFIG) # type: ignore df.set_timesteps(steps) @@ -237,7 +237,7 @@ def test_heun_flow() -> None: margin = 1e-8 predictor: Predictor = FLOW sigma_transform = sigma_complement - for steps in 30, 31, 200, 201: + for steps in 2, 3, 30, 31, 200, 201: df: FlowMatchHeunDiscreteScheduler = FlowMatchHeunDiscreteScheduler.from_config(FLOW_CONFIG) # type: ignore df.set_timesteps(steps) From a92acf085a0830a0aba7d4c088068fa4b4eed309 Mon Sep 17 00:00:00 2001 From: Beinsezii Date: Mon, 10 Nov 2025 03:44:08 -0800 Subject: [PATCH 44/59] from_d complement is equivalent to predict_flow --- skrample/sampling/functional.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/skrample/sampling/functional.py b/skrample/sampling/functional.py index ecb1343..988bbe2 100644 --- a/skrample/sampling/functional.py +++ b/skrample/sampling/functional.py @@ -52,10 +52,6 @@ def to_derivative_complement[T: Sample](sample: T, prediction: T, sigma: float, return (sample - prediction) / sigma # pyright: ignore [reportReturnType] -def from_derivative_complement[T: Sample](sample: T, derivative: T, sigma: float, transform: SigmaTransform) -> T: - return sample - derivative * sigma # pyright: ignore [reportReturnType] - - type DerivativeTransform[T: Sample] = Callable[[T, T, float, SigmaTransform], T] @@ -139,7 +135,7 @@ def step_tableau[T: Sample]( step, schedule, transform, - (to_derivative_complement, from_derivative_complement), + (to_derivative_complement, common.predict_flow), step_size, epsilon, ) From 53a230cc24c0d128b2576661c6e33661407ca223 Mon Sep 17 00:00:00 2001 From: Beinsezii Date: Mon, 10 Nov 2025 03:58:23 -0800 Subject: [PATCH 45/59] from_d polar is equivalent to predict_epsilon, add velocity heun test --- skrample/sampling/functional.py | 7 +--- tests/diffusers_samplers.py | 65 +++++++++++++++++---------------- 2 files changed, 34 insertions(+), 38 deletions(-) diff --git a/skrample/sampling/functional.py b/skrample/sampling/functional.py index 988bbe2..4abb882 100644 --- a/skrample/sampling/functional.py +++ b/skrample/sampling/functional.py @@ -43,11 +43,6 @@ def to_derivative_polar[T: Sample](sample: T, prediction: T, sigma: float, trans return (sample - (sigma_v * prediction)) / sigma_u # pyright: ignore [reportReturnType] -def from_derivative_polar[T: Sample](sample: T, derivative: T, sigma: float, transform: SigmaTransform) -> T: - sigma_u, sigma_v = transform(sigma) - return (sample - derivative * sigma_u) / sigma_v # pyright: ignore [reportReturnType] - - def to_derivative_complement[T: Sample](sample: T, prediction: T, sigma: float, transform: SigmaTransform) -> T: return (sample - prediction) / sigma # pyright: ignore [reportReturnType] @@ -147,7 +142,7 @@ def step_tableau[T: Sample]( step, schedule, transform, - (to_derivative_polar, from_derivative_polar), + (to_derivative_polar, common.predict_epsilon), step_size, epsilon, ) diff --git a/tests/diffusers_samplers.py b/tests/diffusers_samplers.py index 0deace5..4d322a3 100644 --- a/tests/diffusers_samplers.py +++ b/tests/diffusers_samplers.py @@ -196,41 +196,42 @@ def test_unipc() -> None: def test_heun_scaled() -> None: margin = 1e-8 - predictor: Predictor = EPSILON sigma_transform = sigma_polar - for steps in 2, 3, 30, 31, 200, 201: - df: HeunDiscreteScheduler = HeunDiscreteScheduler.from_config(SCALED_CONFIG) # type: ignore - - df.set_timesteps(steps) - - fixed: list[tuple[float, float]] = [] - for t in zip(df.timesteps.tolist(), df.sigmas.tolist()): - if t not in fixed: - fixed.append(t) - - sk = RKUltra(FixedSchedule(fixed, sigma_transform), order=2, providers=RKUltra.providers | {2: RK2.Heun}) - - sk_sample = torch.zeros([1, 4, 128, 128], dtype=torch.float32) - seed = torch.manual_seed(0) - df_noise = torch.randn(sk_sample.shape, generator=seed.clone_state(), dtype=sk_sample.dtype) - # df_sample = df.add_noise(sk_sample.clone(), df_noise, df.timesteps[0:1]) - df_sample = df_noise * df.init_noise_sigma - for t in df.timesteps: - df_sample: torch.Tensor = df.step( - fake_model(df.scale_model_input(df_sample, timestep=t)), - sample=df_sample, - timestep=t, - )[0] - - sk_sample = sk.generate_model( - sk.model_with_predictor(lambda x, t, s: fake_model(x), predictor), - lambda: torch.randn(sk_sample.shape, generator=seed, dtype=sk_sample.dtype), - steps, - initial=sk_sample, - ) + for predictor in [(EPSILON, "epsilon"), (VELOCITY, "v_prediction")]: + for steps in 2, 3, 30, 31, 200, 201: + df: HeunDiscreteScheduler = HeunDiscreteScheduler.from_config(SCALED_CONFIG, prediction_type=predictor[1]) # type: ignore + + df.set_timesteps(steps) + + fixed: list[tuple[float, float]] = [] + for t in zip(df.timesteps.tolist(), df.sigmas.tolist()): + if t not in fixed: + fixed.append(t) + + sk = RKUltra(FixedSchedule(fixed, sigma_transform), order=2, providers=RKUltra.providers | {2: RK2.Heun}) + + sk_sample = torch.zeros([1, 4, 128, 128], dtype=torch.float32) + seed = torch.manual_seed(0) + + df_noise = torch.randn(sk_sample.shape, generator=seed.clone_state(), dtype=sk_sample.dtype) + # df_sample = df.add_noise(sk_sample.clone(), df_noise, df.timesteps[0:1]) + df_sample = df_noise * df.init_noise_sigma + for t in df.timesteps: + df_sample: torch.Tensor = df.step( + fake_model(df.scale_model_input(df_sample, timestep=t)), + sample=df_sample, + timestep=t, + )[0] + + sk_sample = sk.generate_model( + sk.model_with_predictor(lambda x, t, s: fake_model(x), predictor[0]), + lambda: torch.randn(sk_sample.shape, generator=seed, dtype=sk_sample.dtype), + steps, + initial=sk_sample, + ) - compare_tensors(df_sample, sk_sample, message=f"{steps}", margin=margin) + compare_tensors(df_sample, sk_sample, message=f"{steps}", margin=margin) def test_heun_flow() -> None: From 3e1bedbaf78de2fed442bf3064bc2e05b5a928b5 Mon Sep 17 00:00:00 2001 From: Beinsezii Date: Mon, 17 Nov 2025 01:14:11 -0800 Subject: [PATCH 46/59] Add unified set of Diffusion model transforms MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit "Once I add velocity this can trivially go into the diffusers wrapper" yeah pretty sure it's actually impossible to write v-prediction as z̃ₛ = z̃ₜ + output · (ηₛ - ηₜ) so I rewrote it again into a custom parameterization space that actually for real works with everything and am wiping out all of the commits trying to make DiffusionFlow work Squashed commit of the following: commit 6b83bf297e87b888a65733e1ed0c95b30f21bbd0 Author: Beinsezii Date: Mon Nov 17 01:11:55 2025 -0800 Rewrite ModelTransform to use custom parameterization commit 49137440fb3a2dd4a30cde7f1ab4dd73d3f55970 Author: Beinsezii Date: Sun Nov 16 16:45:17 2025 -0800 Velocity yes depend on X commit 4373aa1bf1de80b12080a5e9a4c1b4deb256bdc9 Author: Beinsezii Date: Fri Nov 14 22:56:11 2025 -0800 Velocity don't depend on Epsilon commit 0d0a4c055fd6c027818a674f105293a4acdb25bc Author: Beinsezii Date: Fri Nov 14 22:22:55 2025 -0800 Fix FlowModel to_x/from_x with VariancePreserving schedule commit d6b5971e3d932bb77d8793c4e2e1e14ccd0bb8f3 Author: Beinsezii Date: Fri Nov 14 03:49:50 2025 -0800 Make DiffusionModel the X impl commit 73a73d6f1fb916dfdf6a86727f706210725ed82d Author: Beinsezii Date: Fri Nov 14 03:22:07 2025 -0800 models: Add XModel, VelocityModel, .to_x(), .from_x(), .to_h() commit 26b012a13eeb6d7dde58bfafcb0d69482ad71243 Author: Beinsezii Date: Tue Nov 11 01:38:37 2025 -0800 Add unified set of Diffusion model transforms Currently just for sampling.functional Once I add velocity this can trivially go into the diffusers wrapper instead of a raw prediction function, and then moved through structured Possibly I could associate the schedule and therefore also the sigma transform `type base_schedule = Linear` or whatever. Will have to play around see if FlowModel can be made to handle a Variance-Preserving schedule and vice-versa. Just so fucking glad I don't have to switch on sigma_transform anymore for the tableau solver. --- examples/diffusers/functional.py | 5 +- examples/functional.py | 7 +- examples/predictions.py | 126 ++++++++++++++++++++ scripts/plot_skrample.py | 11 +- skrample/common.py | 10 ++ skrample/sampling/functional.py | 190 +++++++++++++------------------ skrample/sampling/interface.py | 7 +- skrample/sampling/models.py | 174 ++++++++++++++++++++++++++++ tests/diffusers_samplers.py | 29 +++-- tests/miscellaneous.py | 42 ++++++- 10 files changed, 469 insertions(+), 132 deletions(-) create mode 100755 examples/predictions.py create mode 100644 skrample/sampling/models.py diff --git a/examples/diffusers/functional.py b/examples/diffusers/functional.py index 350736f..f5880d8 100755 --- a/examples/diffusers/functional.py +++ b/examples/diffusers/functional.py @@ -13,7 +13,7 @@ import skrample.scheduling as scheduling from skrample.common import predict_flow from skrample.diffusers import SkrampleWrapperScheduler -from skrample.sampling import functional, structured +from skrample.sampling import functional, models, structured from skrample.sampling.interface import StructuredFunctionalAdapter model_id = "black-forest-labs/FLUX.1-dev" @@ -68,7 +68,8 @@ def sample_callback(x: torch.Tensor, n: int, t: float, s: float) -> None: block_state["latents"] = sampler.sample_model( sample=block_state["latents"], - model=sampler.model_with_predictor(call_model, wrapper.predictor), + model=call_model, + model_transform=models.FlowModel, steps=block_state["num_inference_steps"], callback=sample_callback, ) diff --git a/examples/functional.py b/examples/functional.py index 3bb7152..f4559b3 100755 --- a/examples/functional.py +++ b/examples/functional.py @@ -7,10 +7,9 @@ from tqdm import tqdm from transformers.models.clip import CLIPTextModel, CLIPTokenizer -import skrample.common import skrample.pytorch.noise as noise import skrample.scheduling as scheduling -from skrample.sampling import functional, structured +from skrample.sampling import functional, models, structured from skrample.sampling.interface import StructuredFunctionalAdapter with torch.inference_mode(): @@ -57,8 +56,7 @@ def call_model(x: torch.Tensor, t: float, s: float) -> torch.Tensor: t, torch.cat([text_embeds, torch.zeros_like(text_embeds)]), ).sample.chunk(2) - p = conditioned + (cfg - 1) * (conditioned - unconditioned) - return skrample.common.predict_epsilon(x, p, s, schedule.sigma_transform) + return conditioned + (cfg - 1) * (conditioned - unconditioned) if isinstance(sampler, functional.FunctionalHigher): steps = sampler.adjust_steps(steps) @@ -67,6 +65,7 @@ def call_model(x: torch.Tensor, t: float, s: float) -> torch.Tensor: bar = tqdm(total=steps) sample = sampler.generate_model( model=call_model, + model_transform=models.EpsilonModel, steps=steps, rng=lambda: rng.generate().to(dtype=dtype, device=device), callback=lambda x, n, t, s: bar.update(n + 1 - bar.n), diff --git a/examples/predictions.py b/examples/predictions.py new file mode 100755 index 0000000..c0e2f44 --- /dev/null +++ b/examples/predictions.py @@ -0,0 +1,126 @@ +#! /usr/bin/env python + +import json + +import huggingface_hub as hf +import torch +from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL +from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel +from PIL import Image +from safetensors.torch import load_file +from tqdm import tqdm +from transformers.models.clip import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer + +import skrample.pytorch.noise as noise +import skrample.scheduling as scheduling +from skrample.common import Predictor, predict_epsilon, predict_sample, predict_velocity +from skrample.sampling import functional, models, structured +from skrample.sampling.interface import StructuredFunctionalAdapter + +with torch.inference_mode(): + device: torch.device = torch.device("cuda") + dtype: torch.dtype = torch.float16 + steps: int = 15 + cfg: float = 8 + seed = torch.Generator("cpu").manual_seed(0) + prompts = ["dreamy analog photograph of a kitten in a stained glass church", "blurry, noisy, cropped"] + + schedule = scheduling.Scaled() + + sampler_snr = StructuredFunctionalAdapter(schedule, structured.DPM(order=1)) + sampler_df = functional.RKUltra(schedule, order=1) + + base = "stabilityai/stable-diffusion-xl-base-1.0" + + tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(base, subfolder="tokenizer") + tokenizer_2: CLIPTokenizer = CLIPTokenizer.from_pretrained(base, subfolder="tokenizer_2") + text_encoder: CLIPTextModelWithProjection = CLIPTextModelWithProjection.from_pretrained( + base, subfolder="text_encoder", device_map=device, torch_dtype=dtype + ) + text_encoder_2: CLIPTextModel = CLIPTextModel.from_pretrained( + base, subfolder="text_encoder_2", device_map=device, torch_dtype=dtype + ) + image_encoder: AutoencoderKL = AutoencoderKL.from_pretrained( # type: ignore + base, subfolder="vae", device_map=device, torch_dtype=torch.float32 + ) + + text_embeds: torch.Tensor = text_encoder( + tokenizer(prompts, padding="max_length", return_tensors="pt").input_ids.to(device=device), + output_hidden_states=True, + ).hidden_states[-2] + te2_out = text_encoder_2( + tokenizer_2(prompts, padding="max_length", return_tensors="pt").input_ids.to(device=device), + output_hidden_states=True, + ) + text_embeds = torch.cat([text_embeds, te2_out.hidden_states[-2]], dim=-1) + pooled_embeds: torch.Tensor = te2_out.pooler_output + + time_embeds = text_embeds.new([[4096, 4096, 0, 0, 4096, 4096]]).repeat(2, 1) + + configs: tuple[tuple[models.ModelTransform, Predictor, str, str], ...] = ( + (models.EpsilonModel, predict_epsilon, base, ""), + (models.VelocityModel, predict_velocity, "terminusresearch/terminus-xl-velocity-v2", ""), + (models.XModel, predict_sample, "ByteDance/SDXL-Lightning", "sdxl_lightning_1step_unet_x0.safetensors"), + ) + + for transform, predictor, url, weights in configs: + model_steps = 1 if transform is models.XModel else steps + model_cfg = 1 if transform is models.XModel else cfg + + if weights: + model: UNet2DConditionModel = UNet2DConditionModel.from_config( # type: ignore + json.load(open(hf.hf_hub_download(base, "config.json", subfolder="unet"))), + device_map=device, + torch_dtype=dtype, + ) + model.load_state_dict(load_file(hf.hf_hub_download(url, weights))) + model = model.to(device=device, dtype=dtype) # pyright: ignore [reportCallIssue] + else: + model: UNet2DConditionModel = UNet2DConditionModel.from_pretrained( # type: ignore + url, subfolder="unet", device_map=device, torch_dtype=dtype + ) + + def call_model(x: torch.Tensor, t: float, s: float) -> torch.Tensor: + conditioned, unconditioned = model( + x.expand([x.shape[0] * 2, *x.shape[1:]]), + t, + text_embeds, + added_cond_kwargs={"text_embeds": pooled_embeds, "time_ids": time_embeds}, + ).sample.chunk(2) + return conditioned + (model_cfg - 1) * (conditioned - unconditioned) + + rng = noise.Random.from_inputs((1, 4, 128, 128), seed.clone_state()) + bar = tqdm(total=model_steps) + sample = sampler_snr.generate_model( + model=call_model, + model_transform=transform, + steps=model_steps, + rng=lambda: rng.generate().to(dtype=dtype, device=device), + callback=lambda x, n, t, s: bar.update(n + 1 - bar.n), + ) + + image: torch.Tensor = image_encoder.decode( + sample.to(dtype=image_encoder.dtype) / image_encoder.config.scaling_factor # type: ignore + ).sample[0] # type: ignore + Image.fromarray( + ((image + 1) * (255 / 2)).clamp(0, 255).permute(1, 2, 0).to(device="cpu", dtype=torch.uint8).numpy() + ).save(f"{predictor.__name__}.png") + + rng = noise.Random.from_inputs((1, 4, 128, 128), seed.clone_state()) + bar = tqdm(total=sampler_df.adjust_steps(model_steps)) + sample = sampler_df.generate_model( + model=call_model, + model_transform=transform, + steps=sampler_df.adjust_steps(model_steps), + rng=lambda: rng.generate().to(dtype=dtype, device=device), + callback=lambda x, n, t, s: bar.update(n + 1 - bar.n), + ) + + image: torch.Tensor = image_encoder.decode( + sample.to(dtype=image_encoder.dtype) / image_encoder.config.scaling_factor # type: ignore + ).sample[0] # type: ignore + Image.fromarray( + ((image + 1) * (255 / 2)).clamp(0, 255).permute(1, 2, 0).to(device="cpu", dtype=torch.uint8).numpy() + ).save(f"{transform.__name__}.png") + + model = model.to(device="meta") # pyright: ignore [reportCallIssue] diff --git a/scripts/plot_skrample.py b/scripts/plot_skrample.py index e51c5a9..97c4b2b 100755 --- a/scripts/plot_skrample.py +++ b/scripts/plot_skrample.py @@ -14,7 +14,7 @@ import skrample.scheduling as scheduling from skrample.common import SigmaTransform, sigma_complement, sigma_polar, spowf -from skrample.sampling import functional, structured +from skrample.sampling import functional, models, structured from skrample.sampling.interface import StructuredFunctionalAdapter OKLAB_XYZ_M1 = np.array( @@ -58,9 +58,9 @@ def colors(hue_steps: int) -> Generator[list[float]]: yield oklch_to_srgb(np.array([lighness_actual, chroma_actual, hue], dtype=np.float64)) -TRANSFORMS: dict[str, tuple[float, SigmaTransform]] = { - "polar": (14.6, sigma_polar), - "complement": (1.0, sigma_complement), +TRANSFORMS: dict[str, tuple[float, SigmaTransform, models.ModelTransform]] = { + "polar": (1.0, sigma_polar, models.EpsilonModel), + "complement": (1.0, sigma_complement, models.FlowModel), } SAMPLERS: dict[str, structured.StructuredSampler | functional.FunctionalSampler] = { "euler": structured.Euler(), @@ -195,7 +195,8 @@ def callback(x: float, n: int, t: float, s: float) -> None: sampler.sample_model( sample=sample, - model=lambda x, t, s: x + math.sin(t / schedule.base_timesteps * args.curve) * (s + 1), + model=lambda x, t, s: x - math.sin(t / schedule.base_timesteps * args.curve), + model_transform=TRANSFORMS[args.transform][2], steps=adjusted, rng=random, callback=callback, diff --git a/skrample/common.py b/skrample/common.py index a75eca9..f8984c7 100644 --- a/skrample/common.py +++ b/skrample/common.py @@ -143,6 +143,16 @@ def merge_noise[T: Sample](sample: T, noise: T, sigma: float, sigma_transform: S return sample * sigma_v + noise * sigma_u # type: ignore +def divf(lhs: float, rhs: float) -> float: + "Float division with infinity" + if rhs != 0: + return lhs / rhs + elif lhs == 0: + raise ZeroDivisionError + else: + return math.copysign(math.inf, lhs) + + def safe_log(x: float) -> float: "Returns inf rather than throw an err" try: diff --git a/skrample/sampling/functional.py b/skrample/sampling/functional.py index 4abb882..824287a 100644 --- a/skrample/sampling/functional.py +++ b/skrample/sampling/functional.py @@ -11,7 +11,8 @@ from skrample import common, scheduling from skrample.common import RNG, DictOrProxy, FloatSchedule, Predictor, Sample, SigmaTransform -from . import tableaux +from . import models, tableaux +from .models import ModelTransform type SampleCallback[T: Sample] = Callable[[T, int, float, float], Any] "Return is ignored" @@ -38,33 +39,24 @@ def fractional_step( return result -def to_derivative_polar[T: Sample](sample: T, prediction: T, sigma: float, transform: SigmaTransform) -> T: - sigma_u, sigma_v = transform(sigma) - return (sample - (sigma_v * prediction)) / sigma_u # pyright: ignore [reportReturnType] - - -def to_derivative_complement[T: Sample](sample: T, prediction: T, sigma: float, transform: SigmaTransform) -> T: - return (sample - prediction) / sigma # pyright: ignore [reportReturnType] - - -type DerivativeTransform[T: Sample] = Callable[[T, T, float, SigmaTransform], T] - - -def step_tableau_derive[T: Sample]( +def step_tableau[T: Sample]( tableau: tableaux.Tableau | tableaux.ExtendedTableau, sample: T, model: SampleableModel[T], + model_transform: ModelTransform, step: int, schedule: FloatSchedule, - transform: SigmaTransform, - derivative_io: tuple[DerivativeTransform[T], DerivativeTransform[T]], + sigma_transform: SigmaTransform, + derivative_transform: models.ModelTransform | None = models.DiffusionModel, step_size: int = 1, epsilon: float = 1e-8, ) -> tuple[T, ...]: - to_d, from_d = derivative_io - nodes, weights = tableau[0], tableau[1:] + if len(nodes) > 1 and derivative_transform: + model = models.ModelConvert(model_transform, derivative_transform).wrap_model_call(model, sigma_transform) + model_transform = derivative_transform + derivatives: list[T] = [] S0 = schedule[step][1] S1 = schedule[step + step_size][1] if step + step_size < len(schedule) else 0 @@ -74,92 +66,34 @@ def step_tableau_derive[T: Sample]( for frac_sc, icoeffs in zip(fractions, (t[1] for t in nodes), strict=True): sigma_i = frac_sc[1] if icoeffs: - X: T = common.euler( # pyright: ignore [reportAssignmentType] + X: T = model_transform.forward( # pyright: ignore [reportAssignmentType] sample, - from_d( - sample, - math.sumprod(derivatives, icoeffs) / math.fsum(icoeffs), # pyright: ignore [reportArgumentType] - S0, - transform, - ), + math.sumprod(derivatives, icoeffs) / math.fsum(icoeffs), # pyright: ignore [reportArgumentType] S0, sigma_i, - transform, + sigma_transform, ) else: X = sample # Do not call model on timestep = 0 or sigma = 0 if any(abs(v) < epsilon for v in frac_sc): - derivatives.append(to_d(sample, X, S0, transform)) + derivatives.append(model_transform.backward(sample, X, S0, S1, sigma_transform)) else: - derivatives.append(to_d(X, model(X, *frac_sc), sigma_i, transform)) + derivatives.append(model(X, *frac_sc)) return tuple( # pyright: ignore [reportReturnType] - common.euler( + model_transform.forward( sample, - from_d( - sample, - math.sumprod(derivatives, w), # pyright: ignore [reportArgumentType] - S0, - transform, - ), + math.sumprod(derivatives, w), # pyright: ignore [reportArgumentType] S0, S1, - transform, + sigma_transform, ) for w in weights ) -def step_tableau[T: Sample]( - tableau: tableaux.Tableau | tableaux.ExtendedTableau, - sample: T, - model: SampleableModel[T], - step: int, - schedule: FloatSchedule, - transform: SigmaTransform, - step_size: int = 1, - epsilon: float = 1e-8, -) -> tuple[T, ...]: - if transform is common.sigma_complement: - return step_tableau_derive( - tableau, - sample, - model, - step, - schedule, - transform, - (to_derivative_complement, common.predict_flow), - step_size, - epsilon, - ) - elif transform is common.sigma_polar: - return step_tableau_derive( - tableau, - sample, - model, - step, - schedule, - transform, - (to_derivative_polar, common.predict_epsilon), - step_size, - epsilon, - ) - - return step_tableau_derive( - tableau, - sample, - model, - step, - schedule, - transform, - ((lambda x, p, s, t: p), (lambda x, p, s, t: p)), - step_size, - epsilon, - ) - - @dataclasses.dataclass(frozen=True) class FunctionalSampler(ABC): schedule: scheduling.SkrampleSchedule @@ -183,6 +117,7 @@ def sample_model[T: Sample]( self, sample: T, model: SampleableModel[T], + model_transform: ModelTransform, steps: int, include: slice = slice(None), rng: RNG[T] | None = None, @@ -195,6 +130,7 @@ def sample_model[T: Sample]( def generate_model[T: Sample]( self, model: SampleableModel[T], + model_transform: ModelTransform, rng: RNG[T], steps: int, include: slice = slice(None), @@ -215,7 +151,7 @@ def generate_model[T: Sample]( ) / self.merge_noise(0.0, 1.0, steps, 0) # Rescale sample by initial sigma. Mostly just to handle quirks with Scaled - return self.sample_model(sample, model, steps, include, rng, callback) + return self.sample_model(sample, model, model_transform, steps, include, rng, callback) @dataclasses.dataclass(frozen=True) @@ -235,6 +171,12 @@ def adjust_steps(self, steps: int) -> int: return round(steps / self.order) +@dataclasses.dataclass(frozen=True) +class FunctionalDerivative(FunctionalHigher): + derivative_transform: models.ModelTransform | None = models.DiffusionModel + "Transform model to this space when computing higher order samples." + + @dataclasses.dataclass(frozen=True) class FunctionalSinglestep(FunctionalSampler): @abstractmethod @@ -242,6 +184,7 @@ def step[T: Sample]( self, sample: T, model: SampleableModel[T], + model_transform: ModelTransform, step: int, schedule: FloatSchedule, rng: RNG[T] | None = None, @@ -251,6 +194,7 @@ def sample_model[T: Sample]( self, sample: T, model: SampleableModel[T], + model_transform: ModelTransform, steps: int, include: slice = slice(None), rng: RNG[T] | None = None, @@ -259,7 +203,7 @@ def sample_model[T: Sample]( schedule: FloatSchedule = self.schedule.schedule(steps) for n in list(range(steps))[include]: - sample = self.step(sample, model, n, schedule, rng) + sample = self.step(sample, model, model_transform, n, schedule, rng) if callback: callback(sample, n, *schedule[n] if n < len(schedule) else (0, 0)) @@ -283,7 +227,7 @@ def mse[T: Sample](a: T, b: T) -> float: @dataclasses.dataclass(frozen=True) -class RKUltra(FunctionalHigher, FunctionalSinglestep): +class RKUltra(FunctionalDerivative, FunctionalSinglestep): "Implements almost every single method from https://en.wikipedia.org/wiki/List_of_Runge–Kutta_methods" # noqa: RUF002 order: int = 2 @@ -291,7 +235,7 @@ class RKUltra(FunctionalHigher, FunctionalSinglestep): providers: DictOrProxy[int, tableaux.TableauProvider[tableaux.Tableau | tableaux.ExtendedTableau]] = ( MappingProxyType( { - 2: tableaux.RK2.Ralston, + 2: tableaux.RK2.Heun, 3: tableaux.RK3.Ralston, 4: tableaux.RK4.Ralston, 5: tableaux.RK5.Nystrom, @@ -327,11 +271,21 @@ def step[T: Sample]( self, sample: T, model: SampleableModel[T], + model_transform: ModelTransform, step: int, schedule: FloatSchedule, rng: RNG[T] | None = None, ) -> T: - return step_tableau(self.tableau(), sample, model, step, schedule, self.schedule.sigma_transform)[0] + return step_tableau( + self.tableau(), + sample, + model, + model_transform, + step, + schedule, + self.schedule.sigma_transform, + derivative_transform=self.derivative_transform, + )[0] @dataclasses.dataclass(frozen=True) @@ -355,28 +309,28 @@ def step[T: Sample]( self, sample: T, model: SampleableModel[T], + model_transform: ModelTransform, step: int, schedule: FloatSchedule, rng: RNG[T] | None = None, ) -> T: - dt, scale = common.scaled_delta_step(step, schedule, self.schedule.sigma_transform) - k1 = model(sample, *schedule[step]) - result: T = sample * scale + k1 * dt # type: ignore + sigma_from = schedule[step][1] + sigma_to = schedule[step + 1][1] if step + 1 < len(schedule) else 0 + result: T = model_transform.forward(sample, k1, sigma_from, sigma_to, self.schedule.sigma_transform) # Multiplying by step size here is kind of an asspull, but so is this whole solver so... - if ( - step + 1 < len(schedule) - and self.evaluator(sample, result) / max(self.evaluator(0, result), 1e-16) > self.threshold * dt - ): - k2 = model(result, *schedule[step + 1]) - result: T = sample * scale + (k1 + k2) / 2 * dt # type: ignore + if step + 1 < len(schedule) and self.evaluator(sample, result) / max( + self.evaluator(0, result), 1e-16 + ) > self.threshold * abs(sigma_from - sigma_to): + k2: T = (k1 + model(result, *schedule[step + 1])) / 2 # pyright: ignore [reportAssignmentType] + result: T = model_transform.forward(sample, k2, sigma_from, sigma_to, self.schedule.sigma_transform) return result @dataclasses.dataclass(frozen=True) -class RKMoire(FunctionalAdaptive, FunctionalHigher): +class RKMoire(FunctionalAdaptive, FunctionalDerivative): order: int = 2 providers: DictOrProxy[int, tableaux.TableauProvider[tableaux.ExtendedTableau]] = MappingProxyType( @@ -388,7 +342,7 @@ class RKMoire(FunctionalAdaptive, FunctionalHigher): """Providers for a given order, starting from 2. Falls back to RKE2.Heun""" - threshold: float = 1e-3 + threshold: float = 1e-4 initial: float = 1 / 50 "Percent of schedule to take as an initial step." @@ -428,6 +382,7 @@ def sample_model[T: Sample]( self, sample: T, model: SampleableModel[T], + model_transform: ModelTransform, steps: int, include: slice = slice(None), rng: RNG[T] | None = None, @@ -455,19 +410,28 @@ def sample_model[T: Sample]( if step_next < len(schedule): sample_high, sample_low = step_tableau( - tab, sample, model, step, schedule, self.schedule.sigma_transform, step_size + tab, + sample, + model, + model_transform, + step, + schedule, + self.schedule.sigma_transform, + self.derivative_transform, + step_size, ) - delta = common.scaled_delta_step(step, schedule, self.schedule.sigma_transform, step_size)[0] - delta_next = common.scaled_delta_step(step_next, schedule, self.schedule.sigma_transform, step_size)[0] - - # Normalize against pure error - error = self.evaluator(sample_low, sample_high) / max(self.evaluator(0, sample_high), epsilon) + sigma0 = schedule[step][1] + sigma1 = schedule[step_next][1] + sigma2 = schedule[step_next + step_size][1] if step_next + step_size < len(schedule) else 0 # Offset adjustment by dt2 / dt to account for non-linearity # Basically if we want a 50% larger step but the next dt will already be 25% larger, # we should only set a 20% larger step ie 1.5 / 1.25 - # Really this could be iterated to contrast dt2/dt and thresh/error until they're 100% matched but eh - adjustment: float = (self.threshold / max(error, epsilon)) ** self.adaption / (delta_next / delta) + slope = abs(sigma0 - sigma1) / abs(sigma1 - sigma2) + + # Normalize against pure error + error = self.evaluator(sample_low, sample_high) / max(self.evaluator(0, sample_high), epsilon) + adjustment: float = (self.threshold / max(error, epsilon)) ** self.adaption / slope step_size = max(round(min(step_size * adjustment, steps * maximum)), 1) # Only discard if it will actually decrease step size @@ -476,7 +440,15 @@ def sample_model[T: Sample]( else: # Save the extra euler call since the 2nd weight isn't used sample_high = step_tableau( - tab[:2], sample, model, step, schedule, self.schedule.sigma_transform, step_size + tab[:2], + sample, + model, + model_transform, + step, + schedule, + self.schedule.sigma_transform, + self.derivative_transform, + step_size, )[0] sample = sample_high diff --git a/skrample/sampling/interface.py b/skrample/sampling/interface.py index cfb39d3..d2f83ce 100644 --- a/skrample/sampling/interface.py +++ b/skrample/sampling/interface.py @@ -1,7 +1,8 @@ import dataclasses from skrample.common import RNG, FloatSchedule, Sample -from skrample.sampling import functional, structured + +from . import functional, models, structured @dataclasses.dataclass(frozen=True) @@ -17,6 +18,7 @@ def sample_model[T: Sample]( self, sample: T, model: functional.SampleableModel[T], + model_transform: models.ModelTransform, steps: int, include: slice = slice(None), rng: RNG[T] | None = None, @@ -28,7 +30,8 @@ def sample_model[T: Sample]( for n in list(range(len(schedule)))[include]: timestep, sigma = schedule[n] - prediction = model(self.sampler.scale_input(sample, sigma, self.schedule.sigma_transform), timestep, sigma) + output = model(self.sampler.scale_input(sample, sigma, self.schedule.sigma_transform), timestep, sigma) + prediction = model_transform.to_x(sample, output, sigma, self.schedule.sigma_transform) sksamples = self.sampler.sample( sample, diff --git a/skrample/sampling/models.py b/skrample/sampling/models.py new file mode 100644 index 0000000..e3b438c --- /dev/null +++ b/skrample/sampling/models.py @@ -0,0 +1,174 @@ +import dataclasses +import math +from collections.abc import Callable +from functools import wraps + +from skrample.common import Sample, SigmaTransform + +type ModelTransform = type[DiffusionModel] + + +class DiffusionModel(type): + """Common framework for diffusion model sampling. + Intermediate format is X̂ or sample prediction""" + + @classmethod + def to_x[T: Sample](cls, sample: T, output: T, sigma: float, sigma_transform: SigmaTransform) -> T: + "output -> X̂" + return output + + @classmethod + def from_x[T: Sample](cls, sample: T, x: T, sigma: float, sigma_transform: SigmaTransform) -> T: + "X̂ -> output" + return x + + @classmethod + def gamma(cls, sigma_from: float, sigma_to: float, sigma_transform: SigmaTransform) -> float: + "σₜ, σₛ -> Γ" + sigma_t, _alpha_t = sigma_transform(sigma_from) + sigma_s, _alpha_s = sigma_transform(sigma_to) + return sigma_s / sigma_t + + @classmethod + def delta(cls, sigma_from: float, sigma_to: float, sigma_transform: SigmaTransform) -> float: + "σₜ, σₛ -> Δ" + sigma_t, alpha_t = sigma_transform(sigma_from) + sigma_s, alpha_s = sigma_transform(sigma_to) + return alpha_s - (alpha_t * sigma_s) / sigma_t + + @classmethod + def forward[T: Sample]( + cls, sample: T, output: T, sigma_from: float, sigma_to: float, sigma_transform: SigmaTransform + ) -> T: + "sample * Γ + output * Δ" + gamma = cls.gamma(sigma_from, sigma_to, sigma_transform) + delta = cls.delta(sigma_from, sigma_to, sigma_transform) + return math.sumprod((sample, output), (gamma, delta)) # pyright: ignore [reportReturnType, reportArgumentType] + + @classmethod + def backward[T: Sample]( + cls, sample: T, result: T, sigma_from: float, sigma_to: float, sigma_transform: SigmaTransform + ) -> T: + "(output - sample * Γ) / Δ" + gamma = cls.gamma(sigma_from, sigma_to, sigma_transform) + delta = cls.delta(sigma_from, sigma_to, sigma_transform) + return (result - sample * gamma) / delta # pyright: ignore [reportReturnType] + + +class XModel(DiffusionModel): + "Equivalent to DiffusionModel, for type checking" + + +class EpsilonModel(DiffusionModel): + "Ε-Prediction" # noqa: RUF002 + + @classmethod + def to_x[T: Sample](cls, sample: T, output: T, sigma: float, sigma_transform: SigmaTransform) -> T: + sigma_t, alpha_t = sigma_transform(sigma) + return (sample - sigma_t * output) / alpha_t # pyright: ignore [reportReturnType] + + @classmethod + def from_x[T: Sample](cls, sample: T, x: T, sigma: float, sigma_transform: SigmaTransform) -> T: + sigma_t, alpha_t = sigma_transform(sigma) + return (sample - alpha_t * x) / sigma_t # pyright: ignore [reportReturnType] + + @classmethod + def gamma(cls, sigma_from: float, sigma_to: float, sigma_transform: SigmaTransform) -> float: + _sigma_t, alpha_t = sigma_transform(sigma_from) + _sigma_s, alpha_s = sigma_transform(sigma_to) + return alpha_s / alpha_t + + @classmethod + def delta(cls, sigma_from: float, sigma_to: float, sigma_transform: SigmaTransform) -> float: + sigma_t, alpha_t = sigma_transform(sigma_from) + sigma_s, alpha_s = sigma_transform(sigma_to) + return sigma_s - (alpha_s * sigma_t) / alpha_t + + +class FlowModel(DiffusionModel): + "U-Prediction" + + @classmethod + def to_x[T: Sample](cls, sample: T, output: T, sigma: float, sigma_transform: SigmaTransform) -> T: + sigma_t, alpha_t = sigma_transform(sigma) + return (sample - sigma_t * output) / (alpha_t + sigma_t) # pyright: ignore [reportReturnType] + + @classmethod + def from_x[T: Sample](cls, sample: T, x: T, sigma: float, sigma_transform: SigmaTransform) -> T: + sigma_t, alpha_t = sigma_transform(sigma) + return (sample - (alpha_t + sigma_t) * x) / sigma_t # pyright: ignore [reportReturnType] + + @classmethod + def gamma(cls, sigma_from: float, sigma_to: float, sigma_transform: SigmaTransform) -> float: + sigma_t, alpha_t = sigma_transform(sigma_from) + sigma_s, alpha_s = sigma_transform(sigma_to) + return (sigma_s + alpha_s) / (sigma_t + alpha_t) + + @classmethod + def delta(cls, sigma_from: float, sigma_to: float, sigma_transform: SigmaTransform) -> float: + sigma_t, alpha_t = sigma_transform(sigma_from) + sigma_s, alpha_s = sigma_transform(sigma_to) + return (alpha_t * sigma_s - alpha_s * sigma_t) / (alpha_t + sigma_t) + + +class VelocityModel(DiffusionModel): + "V-Prediction" + + @classmethod + def to_x[T: Sample](cls, sample: T, output: T, sigma: float, sigma_transform: SigmaTransform) -> T: + sigma_t, alpha_t = sigma_transform(sigma) + return alpha_t * sample - sigma_t * output # pyright: ignore [reportReturnType] + + @classmethod + def from_x[T: Sample](cls, sample: T, x: T, sigma: float, sigma_transform: SigmaTransform) -> T: + sigma_t, alpha_t = sigma_transform(sigma) + return (alpha_t * sample - x) / sigma_t # pyright: ignore [reportReturnType] + + @classmethod + def gamma(cls, sigma_from: float, sigma_to: float, sigma_transform: SigmaTransform) -> float: + sigma_t, alpha_t = sigma_transform(sigma_from) + sigma_s, alpha_s = sigma_transform(sigma_to) + return (sigma_s / sigma_t) * (1 - alpha_t * alpha_t) + alpha_s * alpha_t + + @classmethod + def delta(cls, sigma_from: float, sigma_to: float, sigma_transform: SigmaTransform) -> float: + sigma_t, alpha_t = sigma_transform(sigma_from) + sigma_s, alpha_s = sigma_transform(sigma_to) + return alpha_t * sigma_s - alpha_s * sigma_t + + +@dataclasses.dataclass(frozen=True) +class ModelConvert: + transform_from: ModelTransform + transform_to: ModelTransform + + def output_to[T: Sample](self, sample: T, output_from: T, sigma: float, sigma_transform: SigmaTransform) -> T: + if self.transform_to is self.transform_from: + return output_from + else: + return self.transform_to.from_x( + sample, + self.transform_from.to_x(sample, output_from, sigma, sigma_transform), + sigma, + sigma_transform, + ) + + def output_from[T: Sample](self, sample: T, output_to: T, sigma: float, sigma_transform: SigmaTransform) -> T: + if self.transform_from is self.transform_to: + return output_to + else: + return self.transform_from.from_x( + sample, + self.transform_to.to_x(sample, output_to, sigma, sigma_transform), + sigma, + sigma_transform, + ) + + def wrap_model_call[T: Sample]( + self, model: Callable[[T, float, float], T], sigma_transform: SigmaTransform + ) -> Callable[[T, float, float], T]: + @wraps(model) + def converted(sample: T, timestep: float, sigma: float) -> T: + return self.output_to(sample, model(sample, timestep, sigma), sigma, sigma_transform) + + return converted diff --git a/tests/diffusers_samplers.py b/tests/diffusers_samplers.py index 4d322a3..5141343 100644 --- a/tests/diffusers_samplers.py +++ b/tests/diffusers_samplers.py @@ -17,6 +17,7 @@ from skrample.common import predict_flow as FLOW from skrample.common import predict_velocity as VELOCITY from skrample.sampling.functional import RKUltra +from skrample.sampling.models import EpsilonModel, FlowModel, VelocityModel from skrample.sampling.structured import DPM, Euler, SKSamples, StructuredSampler, UniPC from skrample.sampling.tableaux import RK2 from skrample.scheduling import SkrampleSchedule @@ -198,7 +199,10 @@ def test_heun_scaled() -> None: margin = 1e-8 sigma_transform = sigma_polar - for predictor in [(EPSILON, "epsilon"), (VELOCITY, "v_prediction")]: + for predictor in [ + (EpsilonModel, "epsilon"), + (VelocityModel, "v_prediction"), + ]: for steps in 2, 3, 30, 31, 200, 201: df: HeunDiscreteScheduler = HeunDiscreteScheduler.from_config(SCALED_CONFIG, prediction_type=predictor[1]) # type: ignore @@ -209,7 +213,12 @@ def test_heun_scaled() -> None: if t not in fixed: fixed.append(t) - sk = RKUltra(FixedSchedule(fixed, sigma_transform), order=2, providers=RKUltra.providers | {2: RK2.Heun}) + sk = RKUltra( + FixedSchedule(fixed, sigma_transform), + order=2, + providers=RKUltra.providers | {2: RK2.Heun}, + derivative_transform=EpsilonModel, + ) sk_sample = torch.zeros([1, 4, 128, 128], dtype=torch.float32) seed = torch.manual_seed(0) @@ -225,18 +234,18 @@ def test_heun_scaled() -> None: )[0] sk_sample = sk.generate_model( - sk.model_with_predictor(lambda x, t, s: fake_model(x), predictor[0]), + lambda x, t, s: fake_model(x), + predictor[0], lambda: torch.randn(sk_sample.shape, generator=seed, dtype=sk_sample.dtype), steps, initial=sk_sample, ) - compare_tensors(df_sample, sk_sample, message=f"{steps}", margin=margin) + compare_tensors(df_sample, sk_sample, message=f"{steps} {predictor[1]}", margin=margin) def test_heun_flow() -> None: margin = 1e-8 - predictor: Predictor = FLOW sigma_transform = sigma_complement for steps in 2, 3, 30, 31, 200, 201: df: FlowMatchHeunDiscreteScheduler = FlowMatchHeunDiscreteScheduler.from_config(FLOW_CONFIG) # type: ignore @@ -248,7 +257,12 @@ def test_heun_flow() -> None: if t not in fixed: fixed.append(t) - sk = RKUltra(FixedSchedule(fixed, sigma_transform), order=2, providers=RKUltra.providers | {2: RK2.Heun}) + sk = RKUltra( + FixedSchedule(fixed, sigma_transform), + order=2, + providers=RKUltra.providers | {2: RK2.Heun}, + derivative_transform=None, + ) sk_sample = torch.zeros([1, 4, 128, 128], dtype=torch.float32) seed = torch.manual_seed(0) @@ -259,7 +273,8 @@ def test_heun_flow() -> None: df_sample: torch.Tensor = df.step(fake_model(df_sample), sample=df_sample, timestep=t)[0] # type: ignore sk_sample = sk.generate_model( - sk.model_with_predictor(lambda x, t, s: fake_model(x), predictor), + lambda x, t, s: fake_model(x), + FlowModel, lambda: torch.randn(sk_sample.shape, generator=seed, dtype=sk_sample.dtype), steps, initial=sk_sample, diff --git a/tests/miscellaneous.py b/tests/miscellaneous.py index 6b02360..0dda5e5 100644 --- a/tests/miscellaneous.py +++ b/tests/miscellaneous.py @@ -1,15 +1,29 @@ +import itertools import math import random from dataclasses import replace import numpy as np +import pytest import torch from testing_common import compare_tensors -from skrample.common import MergeStrategy, bashforth, sigma_complement, sigmoid, softmax, spowf +from skrample.common import ( + MergeStrategy, + SigmaTransform, + bashforth, + euler, + predict_flow, + sigma_complement, + sigma_polar, + sigmoid, + softmax, + spowf, +) from skrample.diffusers import SkrampleWrapperScheduler from skrample.sampling import tableaux from skrample.sampling.interface import StructuredFunctionalAdapter +from skrample.sampling.models import EpsilonModel, FlowModel, ModelTransform, VelocityModel, XModel from skrample.sampling.structured import ( DPM, SPC, @@ -51,6 +65,28 @@ def test_sigmas_to_timesteps() -> None: compare_tensors(torch.tensor(timesteps), torch.tensor(timesteps_inv), margin=0) # shocked this rounds good +@pytest.mark.parametrize( + ("model_transform", "sigma_transform"), + itertools.product([EpsilonModel, FlowModel, VelocityModel, XModel], [sigma_complement, sigma_polar]), +) +def test_model_transforms(model_transform: ModelTransform, sigma_transform: SigmaTransform) -> None: + sample = 0.8 + output = 0.3 + sigma = 0.2 + + x = model_transform.to_x(sample, output, sigma, sigma_transform) + o = model_transform.from_x(sample, x, sigma, sigma_transform) + assert abs(output - o) < 1e-12 + + sigma_next = 0.05 + for sigma_next in 0.05, 0: # extra 0 to validate X̂ + snr = euler( + sample, model_transform.to_x(sample, output, sigma, sigma_transform), sigma, sigma_next, sigma_transform + ) + df = model_transform.forward(sample, output, sigma, sigma_next, sigma_transform) + assert abs(snr - df) < 1e-12 + + def test_sampler_generics() -> None: eps = 1e-12 for sampler in [ @@ -192,7 +228,7 @@ def fake_model(x: float, _: float, s: float) -> float: noise = [random.random() for _ in range(steps)] rng = iter(noise) - sample_f = adapter.sample_model(sample, fake_model, steps, rng=lambda: next(rng)) + sample_f = adapter.sample_model(sample, fake_model, FlowModel, steps, rng=lambda: next(rng)) rng = iter(noise) float_schedule = schedule.schedule(steps) @@ -201,7 +237,7 @@ def fake_model(x: float, _: float, s: float) -> float: for n, (t, s) in enumerate(float_schedule): results = sampler.sample( sample_s, - fake_model(sample_s, t, s), + predict_flow(sample_s, fake_model(sample_s, t, s), s, schedule.sigma_transform), n, float_schedule, schedule.sigma_transform, From d2ababa3e0c911018d793c8d6b08e851b06a8c50 Mon Sep 17 00:00:00 2001 From: Beinsezii Date: Sat, 22 Nov 2025 18:39:56 -0800 Subject: [PATCH 47/59] Change ModelTransfrom from Type -> Dataclass, add ScaleX transform MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Still not 100% certain on how to best parameterize the derivative transform. It needs to be per-step for plans on making RKUltra its own diffusers wrapper, so ig adding yet another parameter™ is hot. --- examples/diffusers/functional.py | 2 +- examples/functional.py | 2 +- examples/predictions.py | 17 ++-- scripts/plot_skrample.py | 4 +- skrample/sampling/functional.py | 21 ++-- skrample/sampling/models.py | 166 ++++++++++++++++++++----------- tests/diffusers_samplers.py | 162 +++++++++++++++--------------- tests/miscellaneous.py | 40 +++++--- 8 files changed, 238 insertions(+), 176 deletions(-) diff --git a/examples/diffusers/functional.py b/examples/diffusers/functional.py index f5880d8..7b5276c 100755 --- a/examples/diffusers/functional.py +++ b/examples/diffusers/functional.py @@ -69,7 +69,7 @@ def sample_callback(x: torch.Tensor, n: int, t: float, s: float) -> None: block_state["latents"] = sampler.sample_model( sample=block_state["latents"], model=call_model, - model_transform=models.FlowModel, + model_transform=models.FlowModel(), steps=block_state["num_inference_steps"], callback=sample_callback, ) diff --git a/examples/functional.py b/examples/functional.py index f4559b3..5a67fb9 100755 --- a/examples/functional.py +++ b/examples/functional.py @@ -65,7 +65,7 @@ def call_model(x: torch.Tensor, t: float, s: float) -> torch.Tensor: bar = tqdm(total=steps) sample = sampler.generate_model( model=call_model, - model_transform=models.EpsilonModel, + model_transform=models.EpsilonModel(), steps=steps, rng=lambda: rng.generate().to(dtype=dtype, device=device), callback=lambda x, n, t, s: bar.update(n + 1 - bar.n), diff --git a/examples/predictions.py b/examples/predictions.py index c0e2f44..ad12c9b 100755 --- a/examples/predictions.py +++ b/examples/predictions.py @@ -58,14 +58,19 @@ time_embeds = text_embeds.new([[4096, 4096, 0, 0, 4096, 4096]]).repeat(2, 1) configs: tuple[tuple[models.ModelTransform, Predictor, str, str], ...] = ( - (models.EpsilonModel, predict_epsilon, base, ""), - (models.VelocityModel, predict_velocity, "terminusresearch/terminus-xl-velocity-v2", ""), - (models.XModel, predict_sample, "ByteDance/SDXL-Lightning", "sdxl_lightning_1step_unet_x0.safetensors"), + (models.EpsilonModel(), predict_epsilon, base, ""), + (models.VelocityModel(), predict_velocity, "terminusresearch/terminus-xl-velocity-v2", ""), + ( + models.DiffusionModel(), + predict_sample, + "ByteDance/SDXL-Lightning", + "sdxl_lightning_1step_unet_x0.safetensors", + ), ) for transform, predictor, url, weights in configs: - model_steps = 1 if transform is models.XModel else steps - model_cfg = 1 if transform is models.XModel else cfg + model_steps = 1 if isinstance(transform, models.DiffusionModel) else steps + model_cfg = 1 if isinstance(transform, models.DiffusionModel) else cfg if weights: model: UNet2DConditionModel = UNet2DConditionModel.from_config( # type: ignore @@ -121,6 +126,6 @@ def call_model(x: torch.Tensor, t: float, s: float) -> torch.Tensor: ).sample[0] # type: ignore Image.fromarray( ((image + 1) * (255 / 2)).clamp(0, 255).permute(1, 2, 0).to(device="cpu", dtype=torch.uint8).numpy() - ).save(f"{transform.__name__}.png") + ).save(f"{type(transform).__name__}.png") model = model.to(device="meta") # pyright: ignore [reportCallIssue] diff --git a/scripts/plot_skrample.py b/scripts/plot_skrample.py index 97c4b2b..0a73248 100755 --- a/scripts/plot_skrample.py +++ b/scripts/plot_skrample.py @@ -59,8 +59,8 @@ def colors(hue_steps: int) -> Generator[list[float]]: TRANSFORMS: dict[str, tuple[float, SigmaTransform, models.ModelTransform]] = { - "polar": (1.0, sigma_polar, models.EpsilonModel), - "complement": (1.0, sigma_complement, models.FlowModel), + "polar": (1.0, sigma_polar, models.EpsilonModel()), + "complement": (1.0, sigma_complement, models.FlowModel()), } SAMPLERS: dict[str, structured.StructuredSampler | functional.FunctionalSampler] = { "euler": structured.Euler(), diff --git a/skrample/sampling/functional.py b/skrample/sampling/functional.py index 824287a..89fa6a2 100644 --- a/skrample/sampling/functional.py +++ b/skrample/sampling/functional.py @@ -12,7 +12,6 @@ from skrample.common import RNG, DictOrProxy, FloatSchedule, Predictor, Sample, SigmaTransform from . import models, tableaux -from .models import ModelTransform type SampleCallback[T: Sample] = Callable[[T, int, float, float], Any] "Return is ignored" @@ -43,11 +42,11 @@ def step_tableau[T: Sample]( tableau: tableaux.Tableau | tableaux.ExtendedTableau, sample: T, model: SampleableModel[T], - model_transform: ModelTransform, + model_transform: models.ModelTransform, step: int, schedule: FloatSchedule, sigma_transform: SigmaTransform, - derivative_transform: models.ModelTransform | None = models.DiffusionModel, + derivative_transform: models.ModelTransform | None = None, step_size: int = 1, epsilon: float = 1e-8, ) -> tuple[T, ...]: @@ -117,7 +116,7 @@ def sample_model[T: Sample]( self, sample: T, model: SampleableModel[T], - model_transform: ModelTransform, + model_transform: models.ModelTransform, steps: int, include: slice = slice(None), rng: RNG[T] | None = None, @@ -130,7 +129,7 @@ def sample_model[T: Sample]( def generate_model[T: Sample]( self, model: SampleableModel[T], - model_transform: ModelTransform, + model_transform: models.ModelTransform, rng: RNG[T], steps: int, include: slice = slice(None), @@ -173,7 +172,7 @@ def adjust_steps(self, steps: int) -> int: @dataclasses.dataclass(frozen=True) class FunctionalDerivative(FunctionalHigher): - derivative_transform: models.ModelTransform | None = models.DiffusionModel + derivative_transform: models.ModelTransform | None = models.DiffusionModel() # noqa: RUF009 # is immutable "Transform model to this space when computing higher order samples." @@ -184,7 +183,7 @@ def step[T: Sample]( self, sample: T, model: SampleableModel[T], - model_transform: ModelTransform, + model_transform: models.ModelTransform, step: int, schedule: FloatSchedule, rng: RNG[T] | None = None, @@ -194,7 +193,7 @@ def sample_model[T: Sample]( self, sample: T, model: SampleableModel[T], - model_transform: ModelTransform, + model_transform: models.ModelTransform, steps: int, include: slice = slice(None), rng: RNG[T] | None = None, @@ -271,7 +270,7 @@ def step[T: Sample]( self, sample: T, model: SampleableModel[T], - model_transform: ModelTransform, + model_transform: models.ModelTransform, step: int, schedule: FloatSchedule, rng: RNG[T] | None = None, @@ -309,7 +308,7 @@ def step[T: Sample]( self, sample: T, model: SampleableModel[T], - model_transform: ModelTransform, + model_transform: models.ModelTransform, step: int, schedule: FloatSchedule, rng: RNG[T] | None = None, @@ -382,7 +381,7 @@ def sample_model[T: Sample]( self, sample: T, model: SampleableModel[T], - model_transform: ModelTransform, + model_transform: models.ModelTransform, steps: int, include: slice = slice(None), rng: RNG[T] | None = None, diff --git a/skrample/sampling/models.py b/skrample/sampling/models.py index e3b438c..98ed882 100644 --- a/skrample/sampling/models.py +++ b/skrample/sampling/models.py @@ -1,3 +1,4 @@ +import abc import dataclasses import math from collections.abc import Callable @@ -5,138 +6,185 @@ from skrample.common import Sample, SigmaTransform -type ModelTransform = type[DiffusionModel] +@dataclasses.dataclass(frozen=True) +class ModelTransform(abc.ABC): + """Common framework for diffusion model sampling.""" -class DiffusionModel(type): - """Common framework for diffusion model sampling. - Intermediate format is X̂ or sample prediction""" - - @classmethod - def to_x[T: Sample](cls, sample: T, output: T, sigma: float, sigma_transform: SigmaTransform) -> T: + @abc.abstractmethod + def to_x[T: Sample](self, sample: T, output: T, sigma: float, sigma_transform: SigmaTransform) -> T: "output -> X̂" - return output - @classmethod - def from_x[T: Sample](cls, sample: T, x: T, sigma: float, sigma_transform: SigmaTransform) -> T: + @abc.abstractmethod + def from_x[T: Sample](self, sample: T, x: T, sigma: float, sigma_transform: SigmaTransform) -> T: "X̂ -> output" - return x - @classmethod - def gamma(cls, sigma_from: float, sigma_to: float, sigma_transform: SigmaTransform) -> float: + @abc.abstractmethod + def gamma(self, sigma_from: float, sigma_to: float, sigma_transform: SigmaTransform) -> float: "σₜ, σₛ -> Γ" - sigma_t, _alpha_t = sigma_transform(sigma_from) - sigma_s, _alpha_s = sigma_transform(sigma_to) - return sigma_s / sigma_t - @classmethod - def delta(cls, sigma_from: float, sigma_to: float, sigma_transform: SigmaTransform) -> float: + @abc.abstractmethod + def delta(self, sigma_from: float, sigma_to: float, sigma_transform: SigmaTransform) -> float: "σₜ, σₛ -> Δ" - sigma_t, alpha_t = sigma_transform(sigma_from) - sigma_s, alpha_s = sigma_transform(sigma_to) - return alpha_s - (alpha_t * sigma_s) / sigma_t - @classmethod def forward[T: Sample]( - cls, sample: T, output: T, sigma_from: float, sigma_to: float, sigma_transform: SigmaTransform + self, sample: T, output: T, sigma_from: float, sigma_to: float, sigma_transform: SigmaTransform ) -> T: "sample * Γ + output * Δ" - gamma = cls.gamma(sigma_from, sigma_to, sigma_transform) - delta = cls.delta(sigma_from, sigma_to, sigma_transform) + gamma = self.gamma(sigma_from, sigma_to, sigma_transform) + delta = self.delta(sigma_from, sigma_to, sigma_transform) return math.sumprod((sample, output), (gamma, delta)) # pyright: ignore [reportReturnType, reportArgumentType] - @classmethod def backward[T: Sample]( - cls, sample: T, result: T, sigma_from: float, sigma_to: float, sigma_transform: SigmaTransform + self, sample: T, result: T, sigma_from: float, sigma_to: float, sigma_transform: SigmaTransform ) -> T: "(output - sample * Γ) / Δ" - gamma = cls.gamma(sigma_from, sigma_to, sigma_transform) - delta = cls.delta(sigma_from, sigma_to, sigma_transform) + gamma = self.gamma(sigma_from, sigma_to, sigma_transform) + delta = self.delta(sigma_from, sigma_to, sigma_transform) return (result - sample * gamma) / delta # pyright: ignore [reportReturnType] -class XModel(DiffusionModel): - "Equivalent to DiffusionModel, for type checking" +@dataclasses.dataclass(frozen=True) +class DiffusionModel(ModelTransform): + """X-Prediction + Predicts the clean image""" + def to_x[T: Sample](self, sample: T, output: T, sigma: float, sigma_transform: SigmaTransform) -> T: + "output -> X̂" + return output -class EpsilonModel(DiffusionModel): - "Ε-Prediction" # noqa: RUF002 + def from_x[T: Sample](self, sample: T, x: T, sigma: float, sigma_transform: SigmaTransform) -> T: + "X̂ -> output" + return x + + def gamma(self, sigma_from: float, sigma_to: float, sigma_transform: SigmaTransform) -> float: + "σₜ, σₛ -> Γ" + sigma_t, _alpha_t = sigma_transform(sigma_from) + sigma_s, _alpha_s = sigma_transform(sigma_to) + return sigma_s / sigma_t + + def delta(self, sigma_from: float, sigma_to: float, sigma_transform: SigmaTransform) -> float: + "σₜ, σₛ -> Δ" + sigma_t, alpha_t = sigma_transform(sigma_from) + sigma_s, alpha_s = sigma_transform(sigma_to) + return alpha_s - (alpha_t * sigma_s) / sigma_t - @classmethod - def to_x[T: Sample](cls, sample: T, output: T, sigma: float, sigma_transform: SigmaTransform) -> T: + +@dataclasses.dataclass(frozen=True) +class EpsilonModel(ModelTransform): + """Ε-Prediction + Predicts the added noise""" # noqa: RUF002 + + def to_x[T: Sample](self, sample: T, output: T, sigma: float, sigma_transform: SigmaTransform) -> T: sigma_t, alpha_t = sigma_transform(sigma) return (sample - sigma_t * output) / alpha_t # pyright: ignore [reportReturnType] - @classmethod - def from_x[T: Sample](cls, sample: T, x: T, sigma: float, sigma_transform: SigmaTransform) -> T: + def from_x[T: Sample](self, sample: T, x: T, sigma: float, sigma_transform: SigmaTransform) -> T: sigma_t, alpha_t = sigma_transform(sigma) return (sample - alpha_t * x) / sigma_t # pyright: ignore [reportReturnType] - @classmethod - def gamma(cls, sigma_from: float, sigma_to: float, sigma_transform: SigmaTransform) -> float: + def gamma(self, sigma_from: float, sigma_to: float, sigma_transform: SigmaTransform) -> float: _sigma_t, alpha_t = sigma_transform(sigma_from) _sigma_s, alpha_s = sigma_transform(sigma_to) return alpha_s / alpha_t - @classmethod - def delta(cls, sigma_from: float, sigma_to: float, sigma_transform: SigmaTransform) -> float: + def delta(self, sigma_from: float, sigma_to: float, sigma_transform: SigmaTransform) -> float: sigma_t, alpha_t = sigma_transform(sigma_from) sigma_s, alpha_s = sigma_transform(sigma_to) return sigma_s - (alpha_s * sigma_t) / alpha_t -class FlowModel(DiffusionModel): +@dataclasses.dataclass(frozen=True) +class FlowModel(ModelTransform): "U-Prediction" - @classmethod - def to_x[T: Sample](cls, sample: T, output: T, sigma: float, sigma_transform: SigmaTransform) -> T: + def to_x[T: Sample](self, sample: T, output: T, sigma: float, sigma_transform: SigmaTransform) -> T: sigma_t, alpha_t = sigma_transform(sigma) return (sample - sigma_t * output) / (alpha_t + sigma_t) # pyright: ignore [reportReturnType] - @classmethod - def from_x[T: Sample](cls, sample: T, x: T, sigma: float, sigma_transform: SigmaTransform) -> T: + def from_x[T: Sample](self, sample: T, x: T, sigma: float, sigma_transform: SigmaTransform) -> T: sigma_t, alpha_t = sigma_transform(sigma) return (sample - (alpha_t + sigma_t) * x) / sigma_t # pyright: ignore [reportReturnType] - @classmethod - def gamma(cls, sigma_from: float, sigma_to: float, sigma_transform: SigmaTransform) -> float: + def gamma(self, sigma_from: float, sigma_to: float, sigma_transform: SigmaTransform) -> float: sigma_t, alpha_t = sigma_transform(sigma_from) sigma_s, alpha_s = sigma_transform(sigma_to) return (sigma_s + alpha_s) / (sigma_t + alpha_t) - @classmethod - def delta(cls, sigma_from: float, sigma_to: float, sigma_transform: SigmaTransform) -> float: + def delta(self, sigma_from: float, sigma_to: float, sigma_transform: SigmaTransform) -> float: sigma_t, alpha_t = sigma_transform(sigma_from) sigma_s, alpha_s = sigma_transform(sigma_to) return (alpha_t * sigma_s - alpha_s * sigma_t) / (alpha_t + sigma_t) -class VelocityModel(DiffusionModel): +@dataclasses.dataclass(frozen=True) +class VelocityModel(ModelTransform): "V-Prediction" - @classmethod - def to_x[T: Sample](cls, sample: T, output: T, sigma: float, sigma_transform: SigmaTransform) -> T: + def to_x[T: Sample](self, sample: T, output: T, sigma: float, sigma_transform: SigmaTransform) -> T: sigma_t, alpha_t = sigma_transform(sigma) return alpha_t * sample - sigma_t * output # pyright: ignore [reportReturnType] - @classmethod - def from_x[T: Sample](cls, sample: T, x: T, sigma: float, sigma_transform: SigmaTransform) -> T: + def from_x[T: Sample](self, sample: T, x: T, sigma: float, sigma_transform: SigmaTransform) -> T: sigma_t, alpha_t = sigma_transform(sigma) return (alpha_t * sample - x) / sigma_t # pyright: ignore [reportReturnType] - @classmethod - def gamma(cls, sigma_from: float, sigma_to: float, sigma_transform: SigmaTransform) -> float: + def gamma(self, sigma_from: float, sigma_to: float, sigma_transform: SigmaTransform) -> float: sigma_t, alpha_t = sigma_transform(sigma_from) sigma_s, alpha_s = sigma_transform(sigma_to) return (sigma_s / sigma_t) * (1 - alpha_t * alpha_t) + alpha_s * alpha_t - @classmethod - def delta(cls, sigma_from: float, sigma_to: float, sigma_transform: SigmaTransform) -> float: + def delta(self, sigma_from: float, sigma_to: float, sigma_transform: SigmaTransform) -> float: sigma_t, alpha_t = sigma_transform(sigma_from) sigma_s, alpha_s = sigma_transform(sigma_to) return alpha_t * sigma_s - alpha_s * sigma_t +@dataclasses.dataclass(frozen=True) +class FakeModel(ModelTransform): + "Marker for transforms that are only used for alternative sampling of other models." + + +@dataclasses.dataclass(frozen=True) +class ScaleX(FakeModel): + "X / Sample prediction with sampling bias" + + bias: float = 3 + """Bias for sample prediction. + Higher values create a stronger image.""" + + def x_scale(self, sigma_t: float, alpha_t: float) -> float: + # Remap -∞ → 0 → ∞ » 0 → 1 → log(∞) + if self.bias < 0: + # -∞ → 0⁻ » 0⁺ → 1⁻ + factor = 1 / math.log(math.e - self.bias) + else: + # 0 → ∞ » 1 → log(∞) + factor = math.log(math.e + self.bias) + + # Rescale sigma_t to average bias scale on VP and NV schedules + sigma_mean = sigma_t / (sigma_t + alpha_t) + return factor**sigma_mean + + def to_x[T: Sample](self, sample: T, output: T, sigma: float, sigma_transform: SigmaTransform) -> T: + sigma_t, alpha_t = sigma_transform(sigma) + return output * self.x_scale(sigma_t, alpha_t) # pyright: ignore [reportReturnType] + + def from_x[T: Sample](self, sample: T, x: T, sigma: float, sigma_transform: SigmaTransform) -> T: + sigma_t, alpha_t = sigma_transform(sigma) + return x / self.x_scale(sigma_t, alpha_t) # pyright: ignore [reportReturnType] + + def gamma(self, sigma_from: float, sigma_to: float, sigma_transform: SigmaTransform) -> float: + sigma_t, _alpha_t = sigma_transform(sigma_from) + sigma_s, _alpha_s = sigma_transform(sigma_to) + return sigma_s / sigma_t + + def delta(self, sigma_from: float, sigma_to: float, sigma_transform: SigmaTransform) -> float: + sigma_t, alpha_t = sigma_transform(sigma_from) + sigma_s, alpha_s = sigma_transform(sigma_to) + return (alpha_s - alpha_t * sigma_s / sigma_t) * self.x_scale(sigma_t, alpha_t) + + @dataclasses.dataclass(frozen=True) class ModelConvert: transform_from: ModelTransform diff --git a/tests/diffusers_samplers.py b/tests/diffusers_samplers.py index 5141343..ecfe9b5 100644 --- a/tests/diffusers_samplers.py +++ b/tests/diffusers_samplers.py @@ -1,7 +1,9 @@ import dataclasses +import itertools from inspect import signature import numpy as np +import pytest import torch from diffusers.schedulers.scheduling_dpmsolver_multistep import DPMSolverMultistepScheduler from diffusers.schedulers.scheduling_euler_ancestral_discrete import EulerAncestralDiscreteScheduler @@ -17,7 +19,7 @@ from skrample.common import predict_flow as FLOW from skrample.common import predict_velocity as VELOCITY from skrample.sampling.functional import RKUltra -from skrample.sampling.models import EpsilonModel, FlowModel, VelocityModel +from skrample.sampling.models import EpsilonModel, FlowModel, ModelTransform, VelocityModel from skrample.sampling.structured import DPM, Euler, SKSamples, StructuredSampler, UniPC from skrample.sampling.tableaux import RK2 from skrample.scheduling import SkrampleSchedule @@ -195,89 +197,83 @@ def test_unipc() -> None: ) -def test_heun_scaled() -> None: - margin = 1e-8 - sigma_transform = sigma_polar - - for predictor in [ - (EpsilonModel, "epsilon"), - (VelocityModel, "v_prediction"), - ]: - for steps in 2, 3, 30, 31, 200, 201: - df: HeunDiscreteScheduler = HeunDiscreteScheduler.from_config(SCALED_CONFIG, prediction_type=predictor[1]) # type: ignore - - df.set_timesteps(steps) - - fixed: list[tuple[float, float]] = [] - for t in zip(df.timesteps.tolist(), df.sigmas.tolist()): - if t not in fixed: - fixed.append(t) - - sk = RKUltra( - FixedSchedule(fixed, sigma_transform), - order=2, - providers=RKUltra.providers | {2: RK2.Heun}, - derivative_transform=EpsilonModel, - ) - - sk_sample = torch.zeros([1, 4, 128, 128], dtype=torch.float32) - seed = torch.manual_seed(0) - - df_noise = torch.randn(sk_sample.shape, generator=seed.clone_state(), dtype=sk_sample.dtype) - # df_sample = df.add_noise(sk_sample.clone(), df_noise, df.timesteps[0:1]) - df_sample = df_noise * df.init_noise_sigma - for t in df.timesteps: - df_sample: torch.Tensor = df.step( - fake_model(df.scale_model_input(df_sample, timestep=t)), - sample=df_sample, - timestep=t, - )[0] - - sk_sample = sk.generate_model( - lambda x, t, s: fake_model(x), - predictor[0], - lambda: torch.randn(sk_sample.shape, generator=seed, dtype=sk_sample.dtype), - steps, - initial=sk_sample, - ) - - compare_tensors(df_sample, sk_sample, message=f"{steps} {predictor[1]}", margin=margin) - - -def test_heun_flow() -> None: - margin = 1e-8 - sigma_transform = sigma_complement - for steps in 2, 3, 30, 31, 200, 201: - df: FlowMatchHeunDiscreteScheduler = FlowMatchHeunDiscreteScheduler.from_config(FLOW_CONFIG) # type: ignore - - df.set_timesteps(steps) - - fixed: list[tuple[float, float]] = [] - for t in zip(df.timesteps.tolist(), df.sigmas.tolist()): - if t not in fixed: - fixed.append(t) - - sk = RKUltra( - FixedSchedule(fixed, sigma_transform), - order=2, - providers=RKUltra.providers | {2: RK2.Heun}, - derivative_transform=None, +@pytest.mark.parametrize( + ("model_transform", "derivative_transform", "sigma_transform", "diffusers_scheduler", "steps"), + ( + (mt, dt, st, ds, s) + for (mt, dt, st, ds), s in itertools.product( + ( + ( + EpsilonModel(), + EpsilonModel(), + sigma_polar, + HeunDiscreteScheduler.from_config(SCALED_CONFIG, prediction_type="epsilon"), + ), + ( + VelocityModel(), + EpsilonModel(), + sigma_polar, + HeunDiscreteScheduler.from_config(SCALED_CONFIG, prediction_type="v_prediction"), + ), + ( + FlowModel(), + FlowModel(), + sigma_complement, + FlowMatchHeunDiscreteScheduler.from_config(FLOW_CONFIG), + ), + ), + (2, 3, 30, 31, 200, 201), ) + ), +) +def test_heun( + model_transform: ModelTransform, + derivative_transform: ModelTransform, + sigma_transform: SigmaTransform, + diffusers_scheduler: HeunDiscreteScheduler | FlowMatchHeunDiscreteScheduler, + steps: int, +) -> None: + diffusers_scheduler.set_timesteps(steps) + + fixed: list[tuple[float, float]] = [] + for t in zip(diffusers_scheduler.timesteps.tolist(), diffusers_scheduler.sigmas.tolist()): + if t not in fixed: + fixed.append(t) + + skrample_sampler = RKUltra( + FixedSchedule(fixed, sigma_transform), + order=2, + providers=RKUltra.providers | {2: RK2.Heun}, + derivative_transform=derivative_transform, + ) - sk_sample = torch.zeros([1, 4, 128, 128], dtype=torch.float32) - seed = torch.manual_seed(0) - - df_noise = torch.randn(sk_sample.shape, generator=seed.clone_state(), dtype=sk_sample.dtype) - df_sample = df.scale_noise(sk_sample.clone(), df.timesteps[0], df_noise) # type: ignore - for t in df.timesteps: - df_sample: torch.Tensor = df.step(fake_model(df_sample), sample=df_sample, timestep=t)[0] # type: ignore + sk_sample = torch.zeros([1, 4, 128, 128], dtype=torch.float32) + seed = torch.manual_seed(0) - sk_sample = sk.generate_model( - lambda x, t, s: fake_model(x), - FlowModel, - lambda: torch.randn(sk_sample.shape, generator=seed, dtype=sk_sample.dtype), - steps, - initial=sk_sample, - ) + df_noise = torch.randn(sk_sample.shape, generator=seed.clone_state(), dtype=sk_sample.dtype) + # df_sample = df.add_noise(sk_sample.clone(), df_noise, df.timesteps[0:1]) + + df_sample = df_noise.clone() + if isinstance(diffusers_scheduler, HeunDiscreteScheduler): + df_sample *= diffusers_scheduler.init_noise_sigma + + for t in diffusers_scheduler.timesteps: + model_input: torch.Tensor = df_sample + if isinstance(diffusers_scheduler, HeunDiscreteScheduler): + model_input = diffusers_scheduler.scale_model_input(df_sample, timestep=t) + + df_sample: torch.Tensor = diffusers_scheduler.step( + fake_model(model_input), # pyright: ignore [reportArgumentType] + sample=df_sample, # pyright: ignore [reportArgumentType] + timestep=t, # pyright: ignore [reportArgumentType] + )[0] + + sk_sample = skrample_sampler.generate_model( + lambda x, t, s: fake_model(x), + model_transform, + lambda: torch.randn(sk_sample.shape, generator=seed, dtype=sk_sample.dtype), + steps, + initial=sk_sample, + ) - compare_tensors(df_sample, sk_sample, message=f"{steps}", margin=margin) + compare_tensors(df_sample, sk_sample, margin=1e-8) diff --git a/tests/miscellaneous.py b/tests/miscellaneous.py index 0dda5e5..f1508ee 100644 --- a/tests/miscellaneous.py +++ b/tests/miscellaneous.py @@ -1,6 +1,7 @@ import itertools import math import random +from collections.abc import Sequence from dataclasses import replace import numpy as np @@ -23,7 +24,7 @@ from skrample.diffusers import SkrampleWrapperScheduler from skrample.sampling import tableaux from skrample.sampling.interface import StructuredFunctionalAdapter -from skrample.sampling.models import EpsilonModel, FlowModel, ModelTransform, VelocityModel, XModel +from skrample.sampling.models import DiffusionModel, EpsilonModel, FlowModel, ModelTransform, VelocityModel from skrample.sampling.structured import ( DPM, SPC, @@ -35,9 +36,9 @@ StructuredStochastic, UniPC, ) -from skrample.scheduling import Beta, FlowShift, Karras, Linear, Scaled, SigmoidCDF +from skrample.scheduling import Beta, FlowShift, Karras, Linear, Scaled, ScheduleCommon, ScheduleModifier, SigmoidCDF -ALL_SAMPLERS = [ +ALL_STRUCTURED: Sequence[type[StructuredSampler]] = [ Adams, DPM, Euler, @@ -45,18 +46,30 @@ UniPC, ] -ALL_SCHEDULES = [ +ALL_SCHEDULES: Sequence[type[ScheduleCommon]] = [ Linear, Scaled, SigmoidCDF, ] -ALL_MODIFIERS = [ +ALL_MODIFIERS: Sequence[type[ScheduleModifier]] = [ Beta, FlowShift, Karras, ] +ALL_MODELS: Sequence[type[ModelTransform]] = [ + DiffusionModel, + EpsilonModel, + FlowModel, + VelocityModel, +] + +ALL_TRANSFROMS: Sequence[SigmaTransform] = [ + sigma_complement, + sigma_polar, +] + def test_sigmas_to_timesteps() -> None: for schedule in [*(cls() for cls in ALL_SCHEDULES), Scaled(beta_scale=1)]: # base schedules @@ -66,10 +79,11 @@ def test_sigmas_to_timesteps() -> None: @pytest.mark.parametrize( - ("model_transform", "sigma_transform"), - itertools.product([EpsilonModel, FlowModel, VelocityModel, XModel], [sigma_complement, sigma_polar]), + ("model_type", "sigma_transform"), + itertools.product(ALL_MODELS, ALL_TRANSFROMS), ) -def test_model_transforms(model_transform: ModelTransform, sigma_transform: SigmaTransform) -> None: +def test_model_transforms(model_type: type[ModelTransform], sigma_transform: SigmaTransform) -> None: + model_transform = model_type() sample = 0.8 output = 0.3 sigma = 0.2 @@ -90,8 +104,8 @@ def test_model_transforms(model_transform: ModelTransform, sigma_transform: Sigm def test_sampler_generics() -> None: eps = 1e-12 for sampler in [ - *(cls() for cls in ALL_SAMPLERS), - *(cls(order=cls.max_order()) for cls in ALL_SAMPLERS if issubclass(cls, StructuredMultistep)), + *(cls() for cls in ALL_STRUCTURED), + *(cls(order=cls.max_order()) for cls in ALL_STRUCTURED if issubclass(cls, StructuredMultistep)), ]: for schedule in Scaled(), FlowShift(Linear()): i, o = random.random(), random.random() @@ -133,7 +147,7 @@ def test_mu_set() -> None: def test_require_previous() -> None: samplers: list[StructuredSampler] = [] - for cls in ALL_SAMPLERS: + for cls in ALL_STRUCTURED: if issubclass(cls, StructuredMultistep): samplers.extend([cls(order=o + 1) for o in range(cls.min_order(), cls.max_order())]) else: @@ -173,7 +187,7 @@ def test_require_previous() -> None: def test_require_noise() -> None: samplers: list[StructuredSampler] = [] - for cls in ALL_SAMPLERS: + for cls in ALL_STRUCTURED: if issubclass(cls, StructuredStochastic): samplers.extend([cls(add_noise=n) for n in (False, True)]) else: @@ -228,7 +242,7 @@ def fake_model(x: float, _: float, s: float) -> float: noise = [random.random() for _ in range(steps)] rng = iter(noise) - sample_f = adapter.sample_model(sample, fake_model, FlowModel, steps, rng=lambda: next(rng)) + sample_f = adapter.sample_model(sample, fake_model, FlowModel(), steps, rng=lambda: next(rng)) rng = iter(noise) float_schedule = schedule.schedule(steps) From fe0a0baa335ee595371d0839f87c995ede70f289 Mon Sep 17 00:00:00 2001 From: Beinsezii Date: Sat, 22 Nov 2025 18:46:37 -0800 Subject: [PATCH 48/59] Replace common.safe_log() -> common.ln() Since we have common.divf() now this is more appropriate --- skrample/common.py | 12 +++++++----- skrample/sampling/structured.py | 16 ++++++++-------- 2 files changed, 15 insertions(+), 13 deletions(-) diff --git a/skrample/common.py b/skrample/common.py index f8984c7..4c0e069 100644 --- a/skrample/common.py +++ b/skrample/common.py @@ -153,12 +153,14 @@ def divf(lhs: float, rhs: float) -> float: return math.copysign(math.inf, lhs) -def safe_log(x: float) -> float: - "Returns inf rather than throw an err" - try: +def ln(x: float) -> float: + "Natural logarithm with infinity" + if x > 0: return math.log(x) - except ValueError: - return math.inf + elif x < 0: + raise ValueError + else: + return -math.inf def normalize(regular_array: NDArray[np.float64], start: float, end: float = 0) -> NDArray[np.float64]: diff --git a/skrample/sampling/structured.py b/skrample/sampling/structured.py index 3314315..6fc60db 100644 --- a/skrample/sampling/structured.py +++ b/skrample/sampling/structured.py @@ -5,7 +5,7 @@ import numpy as np from skrample import common -from skrample.common import FloatSchedule, Sample, SigmaTransform, merge_noise, safe_log, softmax, spowf +from skrample.common import FloatSchedule, Sample, SigmaTransform, divf, ln, merge_noise, softmax, spowf @dataclass(frozen=True) @@ -183,8 +183,8 @@ def sample[T: Sample]( sigma_u, sigma_v = common.get_sigma_uv(step, schedule, sigma_transform) sigma_u_next, sigma_v_next = common.get_sigma_uv(step + 1, schedule, sigma_transform) - lambda_ = safe_log(sigma_v) - safe_log(sigma_u) - lambda_next = safe_log(sigma_v_next) - safe_log(sigma_u_next) + lambda_ = ln(divf(sigma_v, sigma_u)) + lambda_next = ln(divf(sigma_v_next, sigma_u_next)) h = abs(lambda_next - lambda_) if noise is not None and self.add_noise: @@ -206,7 +206,7 @@ def sample[T: Sample]( if (effective_order := self.effective_order(step, schedule, previous)) >= 2: sigma_u_prev, sigma_v_prev = common.get_sigma_uv(step - 1, schedule, sigma_transform) - lambda_prev = safe_log(sigma_v_prev) - safe_log(sigma_u_prev) + lambda_prev = ln(divf(sigma_v_prev, sigma_u_prev)) h_prev = lambda_ - lambda_prev r = h_prev / h # math people and their var names... @@ -216,7 +216,7 @@ def sample[T: Sample]( if effective_order >= 3: sigma_u_prev2, sigma_v_prev2 = common.get_sigma_uv(step - 2, schedule, sigma_transform) - lambda_prev2 = safe_log(sigma_v_prev2) - safe_log(sigma_u_prev2) + lambda_prev2 = ln(divf(sigma_v_prev2, sigma_u_prev2)) h_prev2 = lambda_prev - lambda_prev2 r_prev2 = h_prev2 / h @@ -302,8 +302,8 @@ def unisolve[T: Sample]( sigma_u, sigma_v = common.get_sigma_uv(step, schedule, sigma_transform) sigma_u_next, sigma_v_next = common.get_sigma_uv(step + 1, schedule, sigma_transform) - lambda_ = safe_log(sigma_v) - safe_log(sigma_u) - lambda_next = safe_log(sigma_v_next) - safe_log(sigma_u_next) + lambda_ = ln(divf(sigma_v, sigma_u)) + lambda_next = ln(divf(sigma_v_next, sigma_u_next)) h = abs(lambda_next - lambda_) # hh = -h if self.predict_x0 else h @@ -322,7 +322,7 @@ def unisolve[T: Sample]( step_prev_N = step - n prediction_prev_N = previous[-n].prediction sigma_u_prev_N, sigma_v_prev_N = common.get_sigma_uv(step_prev_N, schedule, sigma_transform) - lambda_pO = safe_log(sigma_v_prev_N) - safe_log(sigma_u_prev_N) + lambda_pO = ln(divf(sigma_v_prev_N, sigma_u_prev_N)) rk = (lambda_pO - lambda_) / h if math.isfinite(rk): # for subnormal rks.append(rk) From 96169dc5459baa6ee5673835dc346e30fa637870 Mon Sep 17 00:00:00 2001 From: Beinsezii Date: Sat, 22 Nov 2025 18:49:41 -0800 Subject: [PATCH 49/59] Fix test_functional_adapter --- tests/miscellaneous.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/miscellaneous.py b/tests/miscellaneous.py index f1508ee..c18ec9a 100644 --- a/tests/miscellaneous.py +++ b/tests/miscellaneous.py @@ -14,7 +14,6 @@ SigmaTransform, bashforth, euler, - predict_flow, sigma_complement, sigma_polar, sigmoid, @@ -242,7 +241,8 @@ def fake_model(x: float, _: float, s: float) -> float: noise = [random.random() for _ in range(steps)] rng = iter(noise) - sample_f = adapter.sample_model(sample, fake_model, FlowModel(), steps, rng=lambda: next(rng)) + model_transform = FlowModel() + sample_f = adapter.sample_model(sample, fake_model, model_transform, steps, rng=lambda: next(rng)) rng = iter(noise) float_schedule = schedule.schedule(steps) @@ -251,7 +251,7 @@ def fake_model(x: float, _: float, s: float) -> float: for n, (t, s) in enumerate(float_schedule): results = sampler.sample( sample_s, - predict_flow(sample_s, fake_model(sample_s, t, s), s, schedule.sigma_transform), + model_transform.to_x(sample_s, fake_model(sample_s, t, s), s, schedule.sigma_transform), n, float_schedule, schedule.sigma_transform, From d4ca1ea4c7d5aee7e7aec41fcbbac5245b94f9a6 Mon Sep 17 00:00:00 2001 From: Beinsezii Date: Sat, 22 Nov 2025 19:30:24 -0800 Subject: [PATCH 50/59] Remove common.Predictor && functions --- examples/diffusers/functional.py | 3 +- examples/diffusers/wrapper.py | 4 +- examples/predictions.py | 131 ------------------------------- examples/structured.py | 6 +- skrample/common.py | 26 ------ skrample/diffusers.py | 53 ++++++------- skrample/sampling/functional.py | 12 +-- skrample/sampling/models.py | 12 ++- tests/diffusers_map.py | 8 +- tests/diffusers_samplers.py | 31 ++++---- 10 files changed, 62 insertions(+), 224 deletions(-) delete mode 100755 examples/predictions.py diff --git a/examples/diffusers/functional.py b/examples/diffusers/functional.py index 7b5276c..5820d81 100755 --- a/examples/diffusers/functional.py +++ b/examples/diffusers/functional.py @@ -11,7 +11,6 @@ from tqdm import tqdm import skrample.scheduling as scheduling -from skrample.common import predict_flow from skrample.diffusers import SkrampleWrapperScheduler from skrample.sampling import functional, models, structured from skrample.sampling.interface import StructuredFunctionalAdapter @@ -22,7 +21,7 @@ schedule = scheduling.FlowShift(scheduling.Linear(), shift=2) wrapper = SkrampleWrapperScheduler( - sampler=structured.Euler(), schedule=schedule, predictor=predict_flow, allow_dynamic=False + sampler=structured.Euler(), schedule=schedule, model=models.FlowModel(), allow_dynamic=False ) # Equivalent to structured example diff --git a/examples/diffusers/wrapper.py b/examples/diffusers/wrapper.py index 6af5184..1d064e7 100755 --- a/examples/diffusers/wrapper.py +++ b/examples/diffusers/wrapper.py @@ -6,8 +6,8 @@ import skrample.pytorch.noise as sknoise import skrample.sampling.structured as sampling import skrample.scheduling as scheduling -from skrample.common import predict_flow from skrample.diffusers import SkrampleWrapperScheduler +from skrample.sampling.models import FlowModel pipe: FluxPipeline = FluxPipeline.from_pretrained( # type: ignore "black-forest-labs/FLUX.1-dev", @@ -17,7 +17,7 @@ pipe.scheduler = scheduler = SkrampleWrapperScheduler( sampler=sampling.DPM(order=2, add_noise=True), schedule=scheduling.FlowShift(scheduling.Linear(), shift=2.0), - predictor=predict_flow, + model=FlowModel(), noise_type=sknoise.Brownian, allow_dynamic=False, ) diff --git a/examples/predictions.py b/examples/predictions.py deleted file mode 100755 index ad12c9b..0000000 --- a/examples/predictions.py +++ /dev/null @@ -1,131 +0,0 @@ -#! /usr/bin/env python - -import json - -import huggingface_hub as hf -import torch -from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL -from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel -from PIL import Image -from safetensors.torch import load_file -from tqdm import tqdm -from transformers.models.clip import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer - -import skrample.pytorch.noise as noise -import skrample.scheduling as scheduling -from skrample.common import Predictor, predict_epsilon, predict_sample, predict_velocity -from skrample.sampling import functional, models, structured -from skrample.sampling.interface import StructuredFunctionalAdapter - -with torch.inference_mode(): - device: torch.device = torch.device("cuda") - dtype: torch.dtype = torch.float16 - steps: int = 15 - cfg: float = 8 - seed = torch.Generator("cpu").manual_seed(0) - prompts = ["dreamy analog photograph of a kitten in a stained glass church", "blurry, noisy, cropped"] - - schedule = scheduling.Scaled() - - sampler_snr = StructuredFunctionalAdapter(schedule, structured.DPM(order=1)) - sampler_df = functional.RKUltra(schedule, order=1) - - base = "stabilityai/stable-diffusion-xl-base-1.0" - - tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(base, subfolder="tokenizer") - tokenizer_2: CLIPTokenizer = CLIPTokenizer.from_pretrained(base, subfolder="tokenizer_2") - text_encoder: CLIPTextModelWithProjection = CLIPTextModelWithProjection.from_pretrained( - base, subfolder="text_encoder", device_map=device, torch_dtype=dtype - ) - text_encoder_2: CLIPTextModel = CLIPTextModel.from_pretrained( - base, subfolder="text_encoder_2", device_map=device, torch_dtype=dtype - ) - image_encoder: AutoencoderKL = AutoencoderKL.from_pretrained( # type: ignore - base, subfolder="vae", device_map=device, torch_dtype=torch.float32 - ) - - text_embeds: torch.Tensor = text_encoder( - tokenizer(prompts, padding="max_length", return_tensors="pt").input_ids.to(device=device), - output_hidden_states=True, - ).hidden_states[-2] - te2_out = text_encoder_2( - tokenizer_2(prompts, padding="max_length", return_tensors="pt").input_ids.to(device=device), - output_hidden_states=True, - ) - text_embeds = torch.cat([text_embeds, te2_out.hidden_states[-2]], dim=-1) - pooled_embeds: torch.Tensor = te2_out.pooler_output - - time_embeds = text_embeds.new([[4096, 4096, 0, 0, 4096, 4096]]).repeat(2, 1) - - configs: tuple[tuple[models.ModelTransform, Predictor, str, str], ...] = ( - (models.EpsilonModel(), predict_epsilon, base, ""), - (models.VelocityModel(), predict_velocity, "terminusresearch/terminus-xl-velocity-v2", ""), - ( - models.DiffusionModel(), - predict_sample, - "ByteDance/SDXL-Lightning", - "sdxl_lightning_1step_unet_x0.safetensors", - ), - ) - - for transform, predictor, url, weights in configs: - model_steps = 1 if isinstance(transform, models.DiffusionModel) else steps - model_cfg = 1 if isinstance(transform, models.DiffusionModel) else cfg - - if weights: - model: UNet2DConditionModel = UNet2DConditionModel.from_config( # type: ignore - json.load(open(hf.hf_hub_download(base, "config.json", subfolder="unet"))), - device_map=device, - torch_dtype=dtype, - ) - model.load_state_dict(load_file(hf.hf_hub_download(url, weights))) - model = model.to(device=device, dtype=dtype) # pyright: ignore [reportCallIssue] - else: - model: UNet2DConditionModel = UNet2DConditionModel.from_pretrained( # type: ignore - url, subfolder="unet", device_map=device, torch_dtype=dtype - ) - - def call_model(x: torch.Tensor, t: float, s: float) -> torch.Tensor: - conditioned, unconditioned = model( - x.expand([x.shape[0] * 2, *x.shape[1:]]), - t, - text_embeds, - added_cond_kwargs={"text_embeds": pooled_embeds, "time_ids": time_embeds}, - ).sample.chunk(2) - return conditioned + (model_cfg - 1) * (conditioned - unconditioned) - - rng = noise.Random.from_inputs((1, 4, 128, 128), seed.clone_state()) - bar = tqdm(total=model_steps) - sample = sampler_snr.generate_model( - model=call_model, - model_transform=transform, - steps=model_steps, - rng=lambda: rng.generate().to(dtype=dtype, device=device), - callback=lambda x, n, t, s: bar.update(n + 1 - bar.n), - ) - - image: torch.Tensor = image_encoder.decode( - sample.to(dtype=image_encoder.dtype) / image_encoder.config.scaling_factor # type: ignore - ).sample[0] # type: ignore - Image.fromarray( - ((image + 1) * (255 / 2)).clamp(0, 255).permute(1, 2, 0).to(device="cpu", dtype=torch.uint8).numpy() - ).save(f"{predictor.__name__}.png") - - rng = noise.Random.from_inputs((1, 4, 128, 128), seed.clone_state()) - bar = tqdm(total=sampler_df.adjust_steps(model_steps)) - sample = sampler_df.generate_model( - model=call_model, - model_transform=transform, - steps=sampler_df.adjust_steps(model_steps), - rng=lambda: rng.generate().to(dtype=dtype, device=device), - callback=lambda x, n, t, s: bar.update(n + 1 - bar.n), - ) - - image: torch.Tensor = image_encoder.decode( - sample.to(dtype=image_encoder.dtype) / image_encoder.config.scaling_factor # type: ignore - ).sample[0] # type: ignore - Image.fromarray( - ((image + 1) * (255 / 2)).clamp(0, 255).permute(1, 2, 0).to(device="cpu", dtype=torch.uint8).numpy() - ).save(f"{type(transform).__name__}.png") - - model = model.to(device="meta") # pyright: ignore [reportCallIssue] diff --git a/examples/structured.py b/examples/structured.py index 09f2d9f..d560433 100755 --- a/examples/structured.py +++ b/examples/structured.py @@ -7,9 +7,9 @@ from tqdm import tqdm from transformers.models.clip import CLIPTextModel, CLIPTokenizer -import skrample.common import skrample.sampling.structured as structured import skrample.scheduling as scheduling +from skrample.sampling import models with torch.inference_mode(): device: torch.device = torch.device("cuda") @@ -21,7 +21,7 @@ schedule: scheduling.SkrampleSchedule = scheduling.Karras(scheduling.Scaled()) sampler: structured.StructuredSampler = structured.DPM(order=2, add_noise=True) - predictor: skrample.common.Predictor = skrample.common.predict_epsilon + transform: models.ModelTransform = models.EpsilonModel() tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(url, subfolder="tokenizer") text_encoder: CLIPTextModel = CLIPTextModel.from_pretrained( @@ -54,7 +54,7 @@ ).sample.chunk(2) model_output: torch.Tensor = conditioned + (cfg - 1) * (conditioned - unconditioned) - prediction = predictor(sample, model_output, sigma, schedule.sigma_transform) + prediction = transform.to_x(sample, model_output, sigma, schedule.sigma_transform) sampler_output = sampler.sample( sample=sample, diff --git a/skrample/common.py b/skrample/common.py index 4c0e069..137ec25 100644 --- a/skrample/common.py +++ b/skrample/common.py @@ -21,8 +21,6 @@ type SigmaTransform = Callable[[float], tuple[float, float]] "Transforms a single noise sigma into a pair" -type Predictor[S: Sample] = Callable[[S, S, float, SigmaTransform], S] -"sample, output, sigma, sigma_transform" type DictOrProxy[T, U] = MappingProxyType[T, U] | dict[T, U] # Mapping does not implement __or__ "Simple union type for a possibly immutable dictionary" @@ -76,30 +74,6 @@ def sigma_polar(sigma: float) -> tuple[float, float]: return math.sin(theta), math.cos(theta) -def predict_epsilon[T: Sample](sample: T, output: T, sigma: float, sigma_transform: SigmaTransform) -> T: - "If a model does not specify, this is usually what it needs." - sigma_u, sigma_v = sigma_transform(sigma) - return (sample - sigma_u * output) / sigma_v # type: ignore - - -def predict_sample[T: Sample](sample: T, output: T, sigma: float, sigma_transform: SigmaTransform) -> T: - "No prediction. Only for single step afaik." - return output - - -def predict_velocity[T: Sample](sample: T, output: T, sigma: float, sigma_transform: SigmaTransform) -> T: - "Rare, models will usually explicitly say they require velocity/vpred/zero terminal SNR" - sigma_u, sigma_v = sigma_transform(sigma) - return sigma_v * sample - sigma_u * output # type: ignore - - -def predict_flow[T: Sample](sample: T, output: T, sigma: float, sigma_transform: SigmaTransform) -> T: - "Flow matching models use this, notably FLUX.1 and SD3" - # TODO(beinsezii): this might need to be u * output. Don't trust diffusers - # Our tests will fail if we do so, leaving here for now. - return sample - sigma * output # type: ignore - - def get_sigma_uv(step: int, schedule: FloatSchedule, sigma_transform: SigmaTransform) -> tuple[float, float]: """Gets sigma u/v with bounds check. If step >= len(schedule), the sigma is assumed to be zero.""" diff --git a/skrample/diffusers.py b/skrample/diffusers.py index 5be6190..c6c89eb 100644 --- a/skrample/diffusers.py +++ b/skrample/diffusers.py @@ -11,15 +11,7 @@ import skrample.sampling.structured as sampling from skrample import scheduling -from skrample.common import ( - FloatSchedule, - MergeStrategy, - Predictor, - predict_epsilon, - predict_flow, - predict_sample, - predict_velocity, -) +from skrample.common import FloatSchedule, MergeStrategy from skrample.pytorch.noise import ( BatchTensorNoise, Random, @@ -27,6 +19,7 @@ TensorNoiseProps, schedule_to_ramp, ) +from skrample.sampling.models import DiffusionModel, EpsilonModel, FlowModel, ModelTransform, VelocityModel from skrample.sampling.structured import SKSamples, StructuredSampler from skrample.scheduling import ScheduleCommon, ScheduleModifier, SkrampleSchedule @@ -74,10 +67,10 @@ ("algorithm_type", "sde-dpmsolver"): ("add_noise", True), ("algorithm_type", "sde-dpmsolver++"): ("add_noise", True), # Complex types - ("prediction_type", "epsilon"): ("skrample_predictor", predict_epsilon), - ("prediction_type", "flow"): ("skrample_predictor", predict_flow), - ("prediction_type", "sample"): ("skrample_predictor", predict_sample), - ("prediction_type", "v_prediction"): ("skrample_predictor", predict_velocity), + ("prediction_type", "epsilon"): ("skrample_predictor", EpsilonModel()), + ("prediction_type", "flow"): ("skrample_predictor", FlowModel()), + ("prediction_type", "sample"): ("skrample_predictor", DiffusionModel()), + ("prediction_type", "v_prediction"): ("skrample_predictor", VelocityModel()), ("use_beta_sigmas", True): ("skrample_modifier", scheduling.Beta), ("use_exponential_sigmas", True): ("skrample_modifier", scheduling.Exponential), ("use_karras_sigmas", True): ("skrample_modifier", scheduling.Karras), @@ -107,7 +100,7 @@ class ParsedDiffusersConfig: schedule: type[SkrampleSchedule] schedule_props: dict[str, Any] schedule_modifiers: list[tuple[type[ScheduleModifier], dict[str, Any]]] - predictor: Predictor + model: ModelTransform def parse_diffusers_config( @@ -129,11 +122,11 @@ def parse_diffusers_config( } if "skrample_predictor" in remapped: - predictor: Predictor = remapped.pop("skrample_predictor") + model: ModelTransform = remapped.pop("skrample_predictor") elif "shift" in remapped: # should only be flow - predictor = predict_flow + model = FlowModel() else: - predictor = predict_epsilon + model = EpsilonModel() if not sampler: sampler, sampler_props = DIFFUSERS_CLASS_MAP.get(diffusers_class, (sampling.DPM, {})) @@ -141,7 +134,7 @@ def parse_diffusers_config( sampler_props = {} if not schedule: - if predictor is predict_flow: + if isinstance(model, FlowModel): schedule = scheduling.Linear elif remapped.get("rescale_betas_zero_snr", False): schedule = scheduling.ZSNR @@ -149,7 +142,7 @@ def parse_diffusers_config( schedule = scheduling.Scaled # Adjust sigma_start to match scaled beta for sd1/xl - if "sigma_start" not in remapped and predictor is not predict_flow and issubclass(schedule, scheduling.Linear): + if "sigma_start" not in remapped and not isinstance(model, FlowModel) and issubclass(schedule, scheduling.Linear): scaled_keys = [f.name for f in dataclasses.fields(scheduling.Scaled)] # non-uniform misses a whole timestep scaled = scheduling.Scaled(**{k: v for k, v in remapped.items() if k in scaled_keys} | {"uniform": True}) @@ -158,7 +151,7 @@ def parse_diffusers_config( schedule_modifiers: list[tuple[type[ScheduleModifier], dict[str, Any]]] = [] - if predictor is predict_flow: + if isinstance(model, FlowModel): flow_keys = [f.name for f in dataclasses.fields(scheduling.FlowShift)] schedule_modifiers.append((scheduling.FlowShift, {k: v for k, v in remapped.items() if k in flow_keys})) @@ -177,7 +170,7 @@ def parse_diffusers_config( schedule=schedule, schedule_props={k: v for k, v in remapped.items() if k in schedule_keys}, schedule_modifiers=schedule_modifiers, - predictor=predictor, + model=model, ) @@ -188,10 +181,14 @@ def attr_dict[T: Any](**kwargs: T) -> OrderedDict[str, T]: return od -def as_diffusers_config(sampler: StructuredSampler, schedule: SkrampleSchedule, predictor: Predictor) -> dict[str, Any]: +def as_diffusers_config( + sampler: StructuredSampler, + schedule: SkrampleSchedule, + model: ModelTransform, +) -> dict[str, Any]: "Converts skrample classes back into a diffusers-readable config. Not comprehensive" skrample_config = dataclasses.asdict(sampler) - skrample_config["skrample_predictor"] = predictor + skrample_config["skrample_predictor"] = model if isinstance(schedule, ScheduleModifier): for modifier in schedule.all: @@ -219,7 +216,7 @@ class SkrampleWrapperScheduler[T: TensorNoiseProps | None]: sampler: StructuredSampler schedule: SkrampleSchedule - predictor: Predictor[Tensor] = predict_epsilon + model: ModelTransform = EpsilonModel() # noqa: RUF009 # is immutable noise_type: type[TensorNoiseCommon[T]] = Random # type: ignore # Unsure why? noise_props: T | None = None compute_scale: torch.dtype | None = torch.float32 @@ -245,7 +242,7 @@ def from_diffusers_config[N: TensorNoiseProps | None]( # pyright fails if you u sampler: type[StructuredSampler] | None = None, schedule: type[SkrampleSchedule] | None = None, schedule_modifiers: list[tuple[type[ScheduleModifier], dict[str, Any]]] = [], - predictor: Predictor[Tensor] | None = None, + model: ModelTransform | None = None, noise_type: type[TensorNoiseCommon[N]] = Random, compute_scale: torch.dtype | None = torch.float32, sampler_props: dict[str, Any] = {}, @@ -271,7 +268,7 @@ def from_diffusers_config[N: TensorNoiseProps | None]( # pyright fails if you u return cls( built_sampler, built_schedule, - predictor or parsed.predictor, + model or parsed.model, noise_type=noise_type, # type: ignore # think these are weird because of the defaults? noise_props=noise_props, # type: ignore compute_scale=compute_scale, @@ -312,7 +309,7 @@ def order(self) -> int: @property def config(self) -> OrderedDict[str, Any]: # Diffusers expects the frozen shift value - return attr_dict(**(self.fake_config | as_diffusers_config(self.sampler, self._schedule, self.predictor))) + return attr_dict(**(self.fake_config | as_diffusers_config(self.sampler, self._schedule, self.model))) def time_shift(self, mu: float, sigma: float, t: Tensor) -> Tensor: return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) @@ -416,7 +413,7 @@ def step( noise = None sample_cast = sample.to(dtype=self.compute_scale) - prediction = self.predictor( + prediction = self.model.to_x( sample_cast, model_output.to(dtype=self.compute_scale), schedule[step, 1].item(), diff --git a/skrample/sampling/functional.py b/skrample/sampling/functional.py index 89fa6a2..1ed1761 100644 --- a/skrample/sampling/functional.py +++ b/skrample/sampling/functional.py @@ -2,14 +2,13 @@ import math from abc import ABC, abstractmethod from collections.abc import Callable -from functools import wraps from types import MappingProxyType from typing import Any import numpy as np from skrample import common, scheduling -from skrample.common import RNG, DictOrProxy, FloatSchedule, Predictor, Sample, SigmaTransform +from skrample.common import RNG, DictOrProxy, FloatSchedule, Sample, SigmaTransform from . import models, tableaux @@ -97,15 +96,6 @@ def step_tableau[T: Sample]( class FunctionalSampler(ABC): schedule: scheduling.SkrampleSchedule - def model_with_predictor(self, model: SampleableModel, predictor: Predictor) -> SampleableModel: - "Wraps the output of `model` with `predictor` using schedule.sigma_transform" - - @wraps(model) - def model_with_predictor[T: Sample](x: T, t: float, s: float) -> T: - return predictor(x, model(x, t, s), s, self.schedule.sigma_transform) - - return model_with_predictor - def merge_noise[T: Sample](self, sample: T, noise: T, steps: int, start: int) -> T: sigmas = self.schedule.sigmas(steps) sigma = sigmas[start] if start < len(sigmas) else 0 diff --git a/skrample/sampling/models.py b/skrample/sampling/models.py index 98ed882..733032f 100644 --- a/skrample/sampling/models.py +++ b/skrample/sampling/models.py @@ -47,7 +47,8 @@ def backward[T: Sample]( @dataclasses.dataclass(frozen=True) class DiffusionModel(ModelTransform): """X-Prediction - Predicts the clean image""" + Predicts the clean image. + Usually for single step models.""" def to_x[T: Sample](self, sample: T, output: T, sigma: float, sigma_transform: SigmaTransform) -> T: "output -> X̂" @@ -73,7 +74,8 @@ def delta(self, sigma_from: float, sigma_to: float, sigma_transform: SigmaTransf @dataclasses.dataclass(frozen=True) class EpsilonModel(ModelTransform): """Ε-Prediction - Predicts the added noise""" # noqa: RUF002 + Predicts the added noise. + If a model does not specify, this is usually what it needs.""" # noqa: RUF002 def to_x[T: Sample](self, sample: T, output: T, sigma: float, sigma_transform: SigmaTransform) -> T: sigma_t, alpha_t = sigma_transform(sigma) @@ -96,7 +98,8 @@ def delta(self, sigma_from: float, sigma_to: float, sigma_transform: SigmaTransf @dataclasses.dataclass(frozen=True) class FlowModel(ModelTransform): - "U-Prediction" + """U-Prediction. + Flow matching models use this, notably FLUX.1 and SD3""" def to_x[T: Sample](self, sample: T, output: T, sigma: float, sigma_transform: SigmaTransform) -> T: sigma_t, alpha_t = sigma_transform(sigma) @@ -119,7 +122,8 @@ def delta(self, sigma_from: float, sigma_to: float, sigma_transform: SigmaTransf @dataclasses.dataclass(frozen=True) class VelocityModel(ModelTransform): - "V-Prediction" + """V-Prediction. + Rare, models will usually explicitly say they require velocity/vpred/zero terminal SNR""" def to_x[T: Sample](self, sample: T, output: T, sigma: float, sigma_transform: SigmaTransform) -> T: sigma_t, alpha_t = sigma_transform(sigma) diff --git a/tests/diffusers_map.py b/tests/diffusers_map.py index 07ca2b8..9ad4824 100644 --- a/tests/diffusers_map.py +++ b/tests/diffusers_map.py @@ -10,13 +10,15 @@ from diffusers.schedulers.scheduling_unipc_multistep import UniPCMultistepScheduler from testing_common import FLOW_CONFIG, SCALED_CONFIG -from skrample.common import predict_epsilon as EPSILON -from skrample.common import predict_flow as FLOW -from skrample.common import predict_velocity as VELOCITY from skrample.diffusers import SkrampleWrapperScheduler +from skrample.sampling.models import EpsilonModel, FlowModel, VelocityModel from skrample.sampling.structured import DPM, Adams, Euler, UniPC from skrample.scheduling import Beta, Exponential, FlowShift, Karras, Linear, Scaled +EPSILON = EpsilonModel() +FLOW = FlowModel() +VELOCITY = VelocityModel() + def check_wrapper(wrapper: SkrampleWrapperScheduler, scheduler: ConfigMixin, params: list[str] = []) -> None: a, b = wrapper, SkrampleWrapperScheduler.from_diffusers_config(scheduler) diff --git a/tests/diffusers_samplers.py b/tests/diffusers_samplers.py index ecfe9b5..466b3a5 100644 --- a/tests/diffusers_samplers.py +++ b/tests/diffusers_samplers.py @@ -14,10 +14,7 @@ from diffusers.schedulers.scheduling_unipc_multistep import UniPCMultistepScheduler from testing_common import FLOW_CONFIG, SCALED_CONFIG, compare_tensors -from skrample.common import FloatSchedule, Predictor, SigmaTransform, sigma_complement, sigma_polar -from skrample.common import predict_epsilon as EPSILON -from skrample.common import predict_flow as FLOW -from skrample.common import predict_velocity as VELOCITY +from skrample.common import FloatSchedule, SigmaTransform, sigma_complement, sigma_polar from skrample.sampling.functional import RKUltra from skrample.sampling.models import EpsilonModel, FlowModel, ModelTransform, VelocityModel from skrample.sampling.structured import DPM, Euler, SKSamples, StructuredSampler, UniPC @@ -32,6 +29,10 @@ | UniPCMultistepScheduler ) +EPSILON = EpsilonModel() +FLOW = FlowModel() +VELOCITY = VelocityModel() + @dataclasses.dataclass(frozen=True) class FixedSchedule(SkrampleSchedule): @@ -54,7 +55,7 @@ def fake_model(t: torch.Tensor) -> torch.Tensor: def dual_sample( a: StructuredSampler, b: DiffusersScheduler, - predictor: Predictor, + model_transform: ModelTransform, steps: range, mu: float | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: @@ -75,10 +76,10 @@ def dual_sample( if isinstance(b, FlowMatchEulerDiscreteScheduler): b_sample = b.scale_noise(sample=b_sample, timestep=timestep.unsqueeze(0), noise=initial_noise) - sigma_transform = sigma_complement else: b_sample = b.add_noise(original_samples=b_sample, noise=initial_noise, timesteps=timestep.unsqueeze(0)) - sigma_transform = sigma_polar + + sigma_transform = sigma_complement if isinstance(model_transform, FlowModel) else sigma_polar a_sample = a.merge_noise(a_sample, initial_noise, sigma.item(), sigma_transform) @@ -89,7 +90,7 @@ def dual_sample( timestep, sigma = schedule[step] - a_output = predictor( + a_output = model_transform.to_x( a_sample, fake_model(a.scale_input(a_sample, sigma.item(), sigma_transform)), sigma.item(), sigma_transform ) sampled = a.sample(a_sample, a_output, step, schedule.numpy().tolist(), sigma_transform, noise, prior_steps) @@ -112,14 +113,14 @@ def dual_sample( def compare_samplers( a: StructuredSampler, b: DiffusersScheduler, - p: Predictor = EPSILON, + t: ModelTransform = EPSILON, mu: float | None = None, margin: float = 1e-8, message: str = "", ) -> None: for step_range in [range(0, 2), range(0, 11), range(0, 201), range(3, 6), range(2, 23), range(31, 200)]: compare_tensors( - *dual_sample(a, b, p, step_range, mu), + *dual_sample(a, b, t, step_range, mu), message=str(step_range) + (" | " + message if message else ""), margin=margin, ) @@ -134,7 +135,7 @@ def test_euler() -> None: prediction_type=predictor[1], ), predictor[0], - message=predictor[0].__name__, + message=type(predictor[0]).__name__, ) @@ -147,7 +148,7 @@ def test_euler_ancestral() -> None: prediction_type=predictor[1], ), predictor[0], - message=predictor[0].__name__, + message=type(predictor[0]).__name__, ) @@ -172,9 +173,10 @@ def test_dpm() -> None: final_sigmas_type="zero", solver_order=order, prediction_type=predictor[1], + use_flow_sigmas=predictor[0] == FLOW, ), predictor[0], - message=f"{predictor[0].__name__} o{order} s{stochastic}", + message=f"{type(predictor[0]).__name__} o{order} s{stochastic}", ) @@ -191,9 +193,10 @@ def test_unipc() -> None: final_sigmas_type="zero", solver_order=order, prediction_type=predictor[1], + use_flow_sigmas=predictor[0] == FLOW, ), predictor[0], - message=f"{predictor[0].__name__} o{order}", + message=f"{type(predictor[0]).__name__} o{order}", ) From 285c1325c897272b4a49990d95eac73191feef29 Mon Sep 17 00:00:00 2001 From: Beinsezii Date: Sat, 22 Nov 2025 19:37:51 -0800 Subject: [PATCH 51/59] test_model_transforms also test backward() --- tests/miscellaneous.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/miscellaneous.py b/tests/miscellaneous.py index c18ec9a..bf02077 100644 --- a/tests/miscellaneous.py +++ b/tests/miscellaneous.py @@ -99,6 +99,9 @@ def test_model_transforms(model_type: type[ModelTransform], sigma_transform: Sig df = model_transform.forward(sample, output, sigma, sigma_next, sigma_transform) assert abs(snr - df) < 1e-12 + ob = model_transform.backward(sample, df, sigma, sigma_next, sigma_transform) + assert abs(o - ob) < 1e-12 + def test_sampler_generics() -> None: eps = 1e-12 From db0d48a21402bd57523f4925ebf59ded9c93ba43 Mon Sep 17 00:00:00 2001 From: Beinsezii Date: Sat, 22 Nov 2025 19:50:12 -0800 Subject: [PATCH 52/59] add test_model_convert --- tests/miscellaneous.py | 45 +++++++++++++++++++++++++++++++++++++++++- 1 file changed, 44 insertions(+), 1 deletion(-) diff --git a/tests/miscellaneous.py b/tests/miscellaneous.py index bf02077..876ff62 100644 --- a/tests/miscellaneous.py +++ b/tests/miscellaneous.py @@ -23,7 +23,14 @@ from skrample.diffusers import SkrampleWrapperScheduler from skrample.sampling import tableaux from skrample.sampling.interface import StructuredFunctionalAdapter -from skrample.sampling.models import DiffusionModel, EpsilonModel, FlowModel, ModelTransform, VelocityModel +from skrample.sampling.models import ( + DiffusionModel, + EpsilonModel, + FlowModel, + ModelConvert, + ModelTransform, + VelocityModel, +) from skrample.sampling.structured import ( DPM, SPC, @@ -103,6 +110,42 @@ def test_model_transforms(model_type: type[ModelTransform], sigma_transform: Sig assert abs(o - ob) < 1e-12 +@pytest.mark.parametrize( + ("model_from", "model_to", "sigma_transform", "sigma_to"), + itertools.product(ALL_MODELS, ALL_MODELS, ALL_TRANSFROMS, (0.05, 0.0)), +) +def test_model_convert( + model_from: type[ModelTransform], + model_to: type[ModelTransform], + sigma_transform: SigmaTransform, + sigma_to: float, +) -> None: + convert = ModelConvert(model_from(), model_to()) + sample = 0.8 + output = 0.3 + sigma_from = 0.2 + + def model(x: float, t: float, s: float) -> float: + return output + + x_from = convert.transform_from.forward( + sample, + model(sample, sigma_from, sigma_from), + sigma_from, + sigma_to, + sigma_transform, + ) + x_to = convert.transform_to.forward( + sample, + convert.wrap_model_call(model, sigma_transform)(sample, sigma_from, sigma_from), + sigma_from, + sigma_to, + sigma_transform, + ) + + assert abs(x_from - x_to) < 1e-12 + + def test_sampler_generics() -> None: eps = 1e-12 for sampler in [ From aefe0e7a934034f54d12a41dafbf97805b00604e Mon Sep 17 00:00:00 2001 From: Beinsezii Date: Sat, 22 Nov 2025 19:58:12 -0800 Subject: [PATCH 53/59] Add ScaleX to test_model_convert --- tests/miscellaneous.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/tests/miscellaneous.py b/tests/miscellaneous.py index 876ff62..dcf6f5b 100644 --- a/tests/miscellaneous.py +++ b/tests/miscellaneous.py @@ -29,6 +29,7 @@ FlowModel, ModelConvert, ModelTransform, + ScaleX, VelocityModel, ) from skrample.sampling.structured import ( @@ -71,6 +72,10 @@ VelocityModel, ] +ALL_FAKE_MODELS: Sequence[type[ModelTransform]] = [ + ScaleX, +] + ALL_TRANSFROMS: Sequence[SigmaTransform] = [ sigma_complement, sigma_polar, @@ -112,7 +117,7 @@ def test_model_transforms(model_type: type[ModelTransform], sigma_transform: Sig @pytest.mark.parametrize( ("model_from", "model_to", "sigma_transform", "sigma_to"), - itertools.product(ALL_MODELS, ALL_MODELS, ALL_TRANSFROMS, (0.05, 0.0)), + itertools.product(ALL_MODELS, ALL_MODELS + ALL_FAKE_MODELS, ALL_TRANSFROMS, (0.05, 0.0)), ) def test_model_convert( model_from: type[ModelTransform], From da56008ba9d7b92b1a3d199043839176456711ce Mon Sep 17 00:00:00 2001 From: Beinsezii Date: Sun, 23 Nov 2025 14:22:40 -0800 Subject: [PATCH 54/59] Rename ModelTransform,DiffusionModel,EpsilonModel -> DiffusionModel,Datamodel,NoiseModel --- examples/functional.py | 2 +- examples/structured.py | 2 +- scripts/plot_skrample.py | 4 ++-- skrample/diffusers.py | 18 +++++++++--------- skrample/sampling/functional.py | 20 ++++++++++---------- skrample/sampling/interface.py | 2 +- skrample/sampling/models.py | 16 ++++++++-------- tests/diffusers_map.py | 4 ++-- tests/diffusers_samplers.py | 18 +++++++++--------- tests/miscellaneous.py | 18 +++++++++--------- 10 files changed, 52 insertions(+), 52 deletions(-) diff --git a/examples/functional.py b/examples/functional.py index 5a67fb9..63cc269 100755 --- a/examples/functional.py +++ b/examples/functional.py @@ -65,7 +65,7 @@ def call_model(x: torch.Tensor, t: float, s: float) -> torch.Tensor: bar = tqdm(total=steps) sample = sampler.generate_model( model=call_model, - model_transform=models.EpsilonModel(), + model_transform=models.NoiseModel(), steps=steps, rng=lambda: rng.generate().to(dtype=dtype, device=device), callback=lambda x, n, t, s: bar.update(n + 1 - bar.n), diff --git a/examples/structured.py b/examples/structured.py index d560433..dffb601 100755 --- a/examples/structured.py +++ b/examples/structured.py @@ -21,7 +21,7 @@ schedule: scheduling.SkrampleSchedule = scheduling.Karras(scheduling.Scaled()) sampler: structured.StructuredSampler = structured.DPM(order=2, add_noise=True) - transform: models.ModelTransform = models.EpsilonModel() + transform: models.DiffusionModel = models.NoiseModel() tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(url, subfolder="tokenizer") text_encoder: CLIPTextModel = CLIPTextModel.from_pretrained( diff --git a/scripts/plot_skrample.py b/scripts/plot_skrample.py index 0a73248..9159f4a 100755 --- a/scripts/plot_skrample.py +++ b/scripts/plot_skrample.py @@ -58,8 +58,8 @@ def colors(hue_steps: int) -> Generator[list[float]]: yield oklch_to_srgb(np.array([lighness_actual, chroma_actual, hue], dtype=np.float64)) -TRANSFORMS: dict[str, tuple[float, SigmaTransform, models.ModelTransform]] = { - "polar": (1.0, sigma_polar, models.EpsilonModel()), +TRANSFORMS: dict[str, tuple[float, SigmaTransform, models.DiffusionModel]] = { + "polar": (1.0, sigma_polar, models.NoiseModel()), "complement": (1.0, sigma_complement, models.FlowModel()), } SAMPLERS: dict[str, structured.StructuredSampler | functional.FunctionalSampler] = { diff --git a/skrample/diffusers.py b/skrample/diffusers.py index c6c89eb..9b58908 100644 --- a/skrample/diffusers.py +++ b/skrample/diffusers.py @@ -19,7 +19,7 @@ TensorNoiseProps, schedule_to_ramp, ) -from skrample.sampling.models import DiffusionModel, EpsilonModel, FlowModel, ModelTransform, VelocityModel +from skrample.sampling.models import DataModel, DiffusionModel, FlowModel, NoiseModel, VelocityModel from skrample.sampling.structured import SKSamples, StructuredSampler from skrample.scheduling import ScheduleCommon, ScheduleModifier, SkrampleSchedule @@ -67,9 +67,9 @@ ("algorithm_type", "sde-dpmsolver"): ("add_noise", True), ("algorithm_type", "sde-dpmsolver++"): ("add_noise", True), # Complex types - ("prediction_type", "epsilon"): ("skrample_predictor", EpsilonModel()), + ("prediction_type", "epsilon"): ("skrample_predictor", NoiseModel()), ("prediction_type", "flow"): ("skrample_predictor", FlowModel()), - ("prediction_type", "sample"): ("skrample_predictor", DiffusionModel()), + ("prediction_type", "sample"): ("skrample_predictor", DataModel()), ("prediction_type", "v_prediction"): ("skrample_predictor", VelocityModel()), ("use_beta_sigmas", True): ("skrample_modifier", scheduling.Beta), ("use_exponential_sigmas", True): ("skrample_modifier", scheduling.Exponential), @@ -100,7 +100,7 @@ class ParsedDiffusersConfig: schedule: type[SkrampleSchedule] schedule_props: dict[str, Any] schedule_modifiers: list[tuple[type[ScheduleModifier], dict[str, Any]]] - model: ModelTransform + model: DiffusionModel def parse_diffusers_config( @@ -122,11 +122,11 @@ def parse_diffusers_config( } if "skrample_predictor" in remapped: - model: ModelTransform = remapped.pop("skrample_predictor") + model: DiffusionModel = remapped.pop("skrample_predictor") elif "shift" in remapped: # should only be flow model = FlowModel() else: - model = EpsilonModel() + model = NoiseModel() if not sampler: sampler, sampler_props = DIFFUSERS_CLASS_MAP.get(diffusers_class, (sampling.DPM, {})) @@ -184,7 +184,7 @@ def attr_dict[T: Any](**kwargs: T) -> OrderedDict[str, T]: def as_diffusers_config( sampler: StructuredSampler, schedule: SkrampleSchedule, - model: ModelTransform, + model: DiffusionModel, ) -> dict[str, Any]: "Converts skrample classes back into a diffusers-readable config. Not comprehensive" skrample_config = dataclasses.asdict(sampler) @@ -216,7 +216,7 @@ class SkrampleWrapperScheduler[T: TensorNoiseProps | None]: sampler: StructuredSampler schedule: SkrampleSchedule - model: ModelTransform = EpsilonModel() # noqa: RUF009 # is immutable + model: DiffusionModel = NoiseModel() # noqa: RUF009 # is immutable noise_type: type[TensorNoiseCommon[T]] = Random # type: ignore # Unsure why? noise_props: T | None = None compute_scale: torch.dtype | None = torch.float32 @@ -242,7 +242,7 @@ def from_diffusers_config[N: TensorNoiseProps | None]( # pyright fails if you u sampler: type[StructuredSampler] | None = None, schedule: type[SkrampleSchedule] | None = None, schedule_modifiers: list[tuple[type[ScheduleModifier], dict[str, Any]]] = [], - model: ModelTransform | None = None, + model: DiffusionModel | None = None, noise_type: type[TensorNoiseCommon[N]] = Random, compute_scale: torch.dtype | None = torch.float32, sampler_props: dict[str, Any] = {}, diff --git a/skrample/sampling/functional.py b/skrample/sampling/functional.py index 1ed1761..97312db 100644 --- a/skrample/sampling/functional.py +++ b/skrample/sampling/functional.py @@ -41,11 +41,11 @@ def step_tableau[T: Sample]( tableau: tableaux.Tableau | tableaux.ExtendedTableau, sample: T, model: SampleableModel[T], - model_transform: models.ModelTransform, + model_transform: models.DiffusionModel, step: int, schedule: FloatSchedule, sigma_transform: SigmaTransform, - derivative_transform: models.ModelTransform | None = None, + derivative_transform: models.DiffusionModel | None = None, step_size: int = 1, epsilon: float = 1e-8, ) -> tuple[T, ...]: @@ -106,7 +106,7 @@ def sample_model[T: Sample]( self, sample: T, model: SampleableModel[T], - model_transform: models.ModelTransform, + model_transform: models.DiffusionModel, steps: int, include: slice = slice(None), rng: RNG[T] | None = None, @@ -119,7 +119,7 @@ def sample_model[T: Sample]( def generate_model[T: Sample]( self, model: SampleableModel[T], - model_transform: models.ModelTransform, + model_transform: models.DiffusionModel, rng: RNG[T], steps: int, include: slice = slice(None), @@ -162,7 +162,7 @@ def adjust_steps(self, steps: int) -> int: @dataclasses.dataclass(frozen=True) class FunctionalDerivative(FunctionalHigher): - derivative_transform: models.ModelTransform | None = models.DiffusionModel() # noqa: RUF009 # is immutable + derivative_transform: models.DiffusionModel | None = models.DataModel() # noqa: RUF009 # is immutable "Transform model to this space when computing higher order samples." @@ -173,7 +173,7 @@ def step[T: Sample]( self, sample: T, model: SampleableModel[T], - model_transform: models.ModelTransform, + model_transform: models.DiffusionModel, step: int, schedule: FloatSchedule, rng: RNG[T] | None = None, @@ -183,7 +183,7 @@ def sample_model[T: Sample]( self, sample: T, model: SampleableModel[T], - model_transform: models.ModelTransform, + model_transform: models.DiffusionModel, steps: int, include: slice = slice(None), rng: RNG[T] | None = None, @@ -260,7 +260,7 @@ def step[T: Sample]( self, sample: T, model: SampleableModel[T], - model_transform: models.ModelTransform, + model_transform: models.DiffusionModel, step: int, schedule: FloatSchedule, rng: RNG[T] | None = None, @@ -298,7 +298,7 @@ def step[T: Sample]( self, sample: T, model: SampleableModel[T], - model_transform: models.ModelTransform, + model_transform: models.DiffusionModel, step: int, schedule: FloatSchedule, rng: RNG[T] | None = None, @@ -371,7 +371,7 @@ def sample_model[T: Sample]( self, sample: T, model: SampleableModel[T], - model_transform: models.ModelTransform, + model_transform: models.DiffusionModel, steps: int, include: slice = slice(None), rng: RNG[T] | None = None, diff --git a/skrample/sampling/interface.py b/skrample/sampling/interface.py index d2f83ce..f91683c 100644 --- a/skrample/sampling/interface.py +++ b/skrample/sampling/interface.py @@ -18,7 +18,7 @@ def sample_model[T: Sample]( self, sample: T, model: functional.SampleableModel[T], - model_transform: models.ModelTransform, + model_transform: models.DiffusionModel, steps: int, include: slice = slice(None), rng: RNG[T] | None = None, diff --git a/skrample/sampling/models.py b/skrample/sampling/models.py index 733032f..8e116d1 100644 --- a/skrample/sampling/models.py +++ b/skrample/sampling/models.py @@ -8,7 +8,7 @@ @dataclasses.dataclass(frozen=True) -class ModelTransform(abc.ABC): +class DiffusionModel(abc.ABC): """Common framework for diffusion model sampling.""" @abc.abstractmethod @@ -45,7 +45,7 @@ def backward[T: Sample]( @dataclasses.dataclass(frozen=True) -class DiffusionModel(ModelTransform): +class DataModel(DiffusionModel): """X-Prediction Predicts the clean image. Usually for single step models.""" @@ -72,7 +72,7 @@ def delta(self, sigma_from: float, sigma_to: float, sigma_transform: SigmaTransf @dataclasses.dataclass(frozen=True) -class EpsilonModel(ModelTransform): +class NoiseModel(DiffusionModel): """Ε-Prediction Predicts the added noise. If a model does not specify, this is usually what it needs.""" # noqa: RUF002 @@ -97,7 +97,7 @@ def delta(self, sigma_from: float, sigma_to: float, sigma_transform: SigmaTransf @dataclasses.dataclass(frozen=True) -class FlowModel(ModelTransform): +class FlowModel(DiffusionModel): """U-Prediction. Flow matching models use this, notably FLUX.1 and SD3""" @@ -121,7 +121,7 @@ def delta(self, sigma_from: float, sigma_to: float, sigma_transform: SigmaTransf @dataclasses.dataclass(frozen=True) -class VelocityModel(ModelTransform): +class VelocityModel(DiffusionModel): """V-Prediction. Rare, models will usually explicitly say they require velocity/vpred/zero terminal SNR""" @@ -145,7 +145,7 @@ def delta(self, sigma_from: float, sigma_to: float, sigma_transform: SigmaTransf @dataclasses.dataclass(frozen=True) -class FakeModel(ModelTransform): +class FakeModel(DiffusionModel): "Marker for transforms that are only used for alternative sampling of other models." @@ -191,8 +191,8 @@ def delta(self, sigma_from: float, sigma_to: float, sigma_transform: SigmaTransf @dataclasses.dataclass(frozen=True) class ModelConvert: - transform_from: ModelTransform - transform_to: ModelTransform + transform_from: DiffusionModel + transform_to: DiffusionModel def output_to[T: Sample](self, sample: T, output_from: T, sigma: float, sigma_transform: SigmaTransform) -> T: if self.transform_to is self.transform_from: diff --git a/tests/diffusers_map.py b/tests/diffusers_map.py index 9ad4824..f4eb279 100644 --- a/tests/diffusers_map.py +++ b/tests/diffusers_map.py @@ -11,11 +11,11 @@ from testing_common import FLOW_CONFIG, SCALED_CONFIG from skrample.diffusers import SkrampleWrapperScheduler -from skrample.sampling.models import EpsilonModel, FlowModel, VelocityModel +from skrample.sampling.models import FlowModel, NoiseModel, VelocityModel from skrample.sampling.structured import DPM, Adams, Euler, UniPC from skrample.scheduling import Beta, Exponential, FlowShift, Karras, Linear, Scaled -EPSILON = EpsilonModel() +EPSILON = NoiseModel() FLOW = FlowModel() VELOCITY = VelocityModel() diff --git a/tests/diffusers_samplers.py b/tests/diffusers_samplers.py index 466b3a5..0621522 100644 --- a/tests/diffusers_samplers.py +++ b/tests/diffusers_samplers.py @@ -16,7 +16,7 @@ from skrample.common import FloatSchedule, SigmaTransform, sigma_complement, sigma_polar from skrample.sampling.functional import RKUltra -from skrample.sampling.models import EpsilonModel, FlowModel, ModelTransform, VelocityModel +from skrample.sampling.models import DiffusionModel, FlowModel, NoiseModel, VelocityModel from skrample.sampling.structured import DPM, Euler, SKSamples, StructuredSampler, UniPC from skrample.sampling.tableaux import RK2 from skrample.scheduling import SkrampleSchedule @@ -29,7 +29,7 @@ | UniPCMultistepScheduler ) -EPSILON = EpsilonModel() +EPSILON = NoiseModel() FLOW = FlowModel() VELOCITY = VelocityModel() @@ -55,7 +55,7 @@ def fake_model(t: torch.Tensor) -> torch.Tensor: def dual_sample( a: StructuredSampler, b: DiffusersScheduler, - model_transform: ModelTransform, + model_transform: DiffusionModel, steps: range, mu: float | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: @@ -113,7 +113,7 @@ def dual_sample( def compare_samplers( a: StructuredSampler, b: DiffusersScheduler, - t: ModelTransform = EPSILON, + t: DiffusionModel = EPSILON, mu: float | None = None, margin: float = 1e-8, message: str = "", @@ -207,14 +207,14 @@ def test_unipc() -> None: for (mt, dt, st, ds), s in itertools.product( ( ( - EpsilonModel(), - EpsilonModel(), + NoiseModel(), + NoiseModel(), sigma_polar, HeunDiscreteScheduler.from_config(SCALED_CONFIG, prediction_type="epsilon"), ), ( VelocityModel(), - EpsilonModel(), + NoiseModel(), sigma_polar, HeunDiscreteScheduler.from_config(SCALED_CONFIG, prediction_type="v_prediction"), ), @@ -230,8 +230,8 @@ def test_unipc() -> None: ), ) def test_heun( - model_transform: ModelTransform, - derivative_transform: ModelTransform, + model_transform: DiffusionModel, + derivative_transform: DiffusionModel, sigma_transform: SigmaTransform, diffusers_scheduler: HeunDiscreteScheduler | FlowMatchHeunDiscreteScheduler, steps: int, diff --git a/tests/miscellaneous.py b/tests/miscellaneous.py index dcf6f5b..7dbda1c 100644 --- a/tests/miscellaneous.py +++ b/tests/miscellaneous.py @@ -24,11 +24,11 @@ from skrample.sampling import tableaux from skrample.sampling.interface import StructuredFunctionalAdapter from skrample.sampling.models import ( + DataModel, DiffusionModel, - EpsilonModel, FlowModel, ModelConvert, - ModelTransform, + NoiseModel, ScaleX, VelocityModel, ) @@ -65,14 +65,14 @@ Karras, ] -ALL_MODELS: Sequence[type[ModelTransform]] = [ - DiffusionModel, - EpsilonModel, +ALL_MODELS: Sequence[type[DiffusionModel]] = [ + DataModel, + NoiseModel, FlowModel, VelocityModel, ] -ALL_FAKE_MODELS: Sequence[type[ModelTransform]] = [ +ALL_FAKE_MODELS: Sequence[type[DiffusionModel]] = [ ScaleX, ] @@ -93,7 +93,7 @@ def test_sigmas_to_timesteps() -> None: ("model_type", "sigma_transform"), itertools.product(ALL_MODELS, ALL_TRANSFROMS), ) -def test_model_transforms(model_type: type[ModelTransform], sigma_transform: SigmaTransform) -> None: +def test_model_transforms(model_type: type[DiffusionModel], sigma_transform: SigmaTransform) -> None: model_transform = model_type() sample = 0.8 output = 0.3 @@ -120,8 +120,8 @@ def test_model_transforms(model_type: type[ModelTransform], sigma_transform: Sig itertools.product(ALL_MODELS, ALL_MODELS + ALL_FAKE_MODELS, ALL_TRANSFROMS, (0.05, 0.0)), ) def test_model_convert( - model_from: type[ModelTransform], - model_to: type[ModelTransform], + model_from: type[DiffusionModel], + model_to: type[DiffusionModel], sigma_transform: SigmaTransform, sigma_to: float, ) -> None: From 60bc0f8f5b2bdb78b88ee21707903ca1df5428f8 Mon Sep 17 00:00:00 2001 From: Beinsezii Date: Sun, 23 Nov 2025 18:39:45 -0800 Subject: [PATCH 55/59] Don't skip derivative transform on Euler --- skrample/sampling/functional.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/skrample/sampling/functional.py b/skrample/sampling/functional.py index 97312db..2d84de0 100644 --- a/skrample/sampling/functional.py +++ b/skrample/sampling/functional.py @@ -51,7 +51,7 @@ def step_tableau[T: Sample]( ) -> tuple[T, ...]: nodes, weights = tableau[0], tableau[1:] - if len(nodes) > 1 and derivative_transform: + if derivative_transform: model = models.ModelConvert(model_transform, derivative_transform).wrap_model_call(model, sigma_transform) model_transform = derivative_transform From f5b0ae4b5a910830b176418a55be734257c4c95f Mon Sep 17 00:00:00 2001 From: Beinsezii Date: Sat, 29 Nov 2025 02:05:08 -0800 Subject: [PATCH 56/59] plot_skrample: fix two bugs with modifiers --- scripts/plot_skrample.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/scripts/plot_skrample.py b/scripts/plot_skrample.py index 9159f4a..39c572d 100755 --- a/scripts/plot_skrample.py +++ b/scripts/plot_skrample.py @@ -92,7 +92,6 @@ def colors(hue_steps: int) -> Generator[list[float]]: "exponential": (scheduling.Exponential, {}), "karras": (scheduling.Karras, {}), "flow": (scheduling.FlowShift, {}), - "flow_mu": (scheduling.FlowShift, {"mu": 1}), "hyper": (scheduling.Hyper, {}), "vyper": (scheduling.Hyper, {"scale": -2}), "hype": (scheduling.Hyper, {"tail": False}), @@ -231,7 +230,7 @@ def callback(x: float, n: int, t: float, s: float) -> None: for mod_label, (mod_type, mod_props) in [ # type: ignore # Destructure m for m in [(mod1, MODIFIERS[mod1]), (mod2, MODIFIERS[mod2])] if m[1] ]: - composed = mod_type(schedule, **mod_props) + composed = mod_type(composed, **mod_props) label += "_" + mod_label label = " ".join([s.capitalize() for s in label.split("_")]) From 5edf16c582b6c0676cdabf13072ed8002161d82a Mon Sep 17 00:00:00 2001 From: Beinsezii Date: Sat, 29 Nov 2025 17:35:56 -0800 Subject: [PATCH 57/59] Rewrite RK4.Ralston and RK4.Eighth using generic RK4 solver --- skrample/sampling/tableaux.py | 121 +++++++++++++++++++--------------- 1 file changed, 69 insertions(+), 52 deletions(-) diff --git a/skrample/sampling/tableaux.py b/skrample/sampling/tableaux.py index 9ac0211..daac724 100644 --- a/skrample/sampling/tableaux.py +++ b/skrample/sampling/tableaux.py @@ -32,39 +32,75 @@ def validate_tableau(tab: Tableau | ExtendedTableau, tolerance: float = 1e-15) - return ValueError(f"{tolerance=}, {weight_err=}, {weight=}") -def rk2_tableau(alpha: float) -> Tableau: - "Create a generic 2nd order Tableau from a given alpha value." +def rk2_tableau(c1: float) -> Tableau: + "Create a generic 2nd order Tableau from a given coefficient." return ( ( (0.0, ()), - (alpha, (alpha,)), + (c1, (c1,)), ), - (1 - 1 / (2 * alpha), 1 / (2 * alpha)), + (1 - 1 / (2 * c1), 1 / (2 * c1)), ) -def rk3_tableau(alpha: float, beta: float) -> Tableau: - "Create a generic 3rd order Tableau from a given alpha and beta values." +def rk3_tableau(c1: float, c2: float) -> Tableau: + "Create a generic 3rd order Tableau from given coefficients." return ( ( (0.0, ()), - (alpha, (alpha,)), - ( - beta, - ( - beta / alpha * ((beta - 3 * alpha * (1 - alpha)) / (3 * alpha - 2)), - -beta / alpha * ((beta - alpha) / (3 * alpha - 2)), - ), - ), + (c1, (c1,)), + (c2, (c2 / c1 * ((c2 - 3 * c1 * (1 - c1)) / (3 * c1 - 2)), -c2 / c1 * ((c2 - c1) / (3 * c1 - 2)))), ), ( - 1 - (3 * alpha + 3 * beta - 2) / (6 * alpha * beta), - (3 * beta - 2) / (6 * alpha * (beta - alpha)), - (2 - 3 * alpha) / (6 * beta * (beta - alpha)), + 1 - (3 * c1 + 3 * c2 - 2) / (6 * c1 * c2), + (3 * c2 - 2) / (6 * c1 * (c2 - c1)), + (2 - 3 * c1) / (6 * c2 * (c2 - c1)), ), ) +def rk4_tableau(c1: float, c2: float) -> Tableau: + """Create a generic 4th order Tableau from 3 coefficients. + 1/2, 1/2 (Classic) is a special case and cannot be computed using this function. + https://pages.hmc.edu/ruye/MachineLearning/lectures/ch5/node10.html""" + + ### Automatically transcribed from website using QwenVL 235B Thinking + + D = 6 * c1 * c2 - 4 * (c1 + c2) + 3 + + # Compute b coefficients + b2 = (2 * c2 - 1) / (12 * c1 * (c2 - c1) * (1 - c1)) + b3 = (2 * c1 - 1) / (12 * c2 * (c1 - c2) * (1 - c2)) + b4 = D / (12 * (1 - c1) * (1 - c2)) + b1 = 1 - b2 - b3 - b4 + + # Compute a31 and a32 + a32 = c2 * (c1 - c2) / (2 * c1 * (2 * c1 - 1)) + a31 = c2 - a32 + + # Compute a41, a42, a43 + num_a42 = (4 * c2**2 - 5 * c2 - c1 + 2) * (1 - c1) + denom_a42 = 2 * c1 * (c1 - c2) * D + a42 = num_a42 / denom_a42 + + num_a43 = (2 * c1 - 1) * (1 - c1) * (1 - c2) + denom_a43 = c2 * (c1 - c2) * D + a43 = num_a43 / denom_a43 + + a41 = 1 - a42 - a43 + + stages = ( + (0.0, ()), + (c1, (c1,)), # a21 = c1 + (c2, (a31, a32)), + (1.0, (a41, a42, a43)), + ) + + b_vector = (b1, b2, b3, b4) + + return (stages, b_vector) + + class TableauProvider[T: Tableau | ExtendedTableau](Protocol): @abc.abstractmethod def tableau(self) -> T: @@ -88,19 +124,28 @@ def tableau(self) -> T: @dataclasses.dataclass(frozen=True) class RK2Custom(TableauProvider): - alpha: float = 1.0 + c1: float = 1.0 def tableau(self) -> Tableau: - return rk2_tableau(self.alpha) + return rk2_tableau(self.c1) @dataclasses.dataclass(frozen=True) class RK3Custom(TableauProvider): - alpha: float = 1 / 2 - beta: float = 1.0 + c1: float = 1 / 2 + c2: float = 1.0 def tableau(self) -> Tableau: - return rk3_tableau(self.alpha, self.beta) + return rk3_tableau(self.c1, self.c2) + + +@dataclasses.dataclass(frozen=True) +class RK4Custom(TableauProvider): + c1: float = 1 / 3 + c2: float = 2 / 3 + + def tableau(self) -> Tableau: + return rk4_tableau(self.c1, self.c2) @enum.unique @@ -136,36 +181,8 @@ class RK4(enum.Enum): ), (1 / 6, 1 / 3, 1 / 3, 1 / 6), ) - Eighth = ( - ( - (0, ()), - (1 / 3, (1 / 3,)), - (2 / 3, (-1 / 3, 1)), - (1, (1, -1, 1)), - ), - (1 / 8, 3 / 8, 3 / 8, 1 / 8), - ) - Ralston = ( - ( - (0, ()), - (2 / 5, (2 / 5,)), - ((14 - 3 * math.sqrt(5)) / 16, ((-2889 + 1428 * math.sqrt(5)) / 1024, (3785 - 1620 * math.sqrt(5)) / 1024)), - ( - 1, - ( - (-3365 + 2094 * math.sqrt(5)) / 6040, - (-975 - 3046 * math.sqrt(5)) / 2552, - (467040 + 203968 * math.sqrt(5)) / 240845, - ), - ), - ), - ( - (263 + 24 * math.sqrt(5)) / 1812, - (125 - 1000 * math.sqrt(5)) / 3828, - (3426304 + 1661952 * math.sqrt(5)) / 5924787, - (30 - 4 * math.sqrt(5)) / 123, - ), - ) + Eighth = rk4_tableau(1 / 3, 2 / 3) + Ralston = rk4_tableau(2 / 5, (14 - 3 * math.sqrt(5)) / 16) def tableau(self) -> Tableau: return self.value From 8fe893715857ee021dbbd36e62b5cd49c8c781d7 Mon Sep 17 00:00:00 2001 From: Beinsezii Date: Sat, 29 Nov 2025 17:46:24 -0800 Subject: [PATCH 58/59] Rename RK5.Nystrom -> RKZ.Nystrom5 --- skrample/sampling/functional.py | 2 +- skrample/sampling/tableaux.py | 4 ++-- tests/miscellaneous.py | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/skrample/sampling/functional.py b/skrample/sampling/functional.py index 2d84de0..24b1206 100644 --- a/skrample/sampling/functional.py +++ b/skrample/sampling/functional.py @@ -227,7 +227,7 @@ class RKUltra(FunctionalDerivative, FunctionalSinglestep): 2: tableaux.RK2.Heun, 3: tableaux.RK3.Ralston, 4: tableaux.RK4.Ralston, - 5: tableaux.RK5.Nystrom, + 5: tableaux.RKZ.Nystrom5, } ) ) diff --git a/skrample/sampling/tableaux.py b/skrample/sampling/tableaux.py index daac724..61e136a 100644 --- a/skrample/sampling/tableaux.py +++ b/skrample/sampling/tableaux.py @@ -189,8 +189,8 @@ def tableau(self) -> Tableau: @enum.unique -class RK5(enum.Enum): - Nystrom = ( +class RKZ(enum.Enum): + Nystrom5 = ( ( (0, ()), (1 / 3, (1 / 3,)), diff --git a/tests/miscellaneous.py b/tests/miscellaneous.py index 7dbda1c..5b35fbd 100644 --- a/tests/miscellaneous.py +++ b/tests/miscellaneous.py @@ -323,7 +323,7 @@ def test_bashforth() -> None: def test_tableau_providers() -> None: - for provider in [tableaux.RK2, tableaux.RK3, tableaux.RK4, tableaux.RK5, tableaux.RKE2, tableaux.RKE5]: + for provider in [tableaux.RK2, tableaux.RK3, tableaux.RK4, tableaux.RKZ, tableaux.RKE2, tableaux.RKE5]: for variant in provider: if error := tableaux.validate_tableau(variant.tableau()): raise error From e463bbaa3a2574da38fb44ae7d6373474a9da2e3 Mon Sep 17 00:00:00 2001 From: Beinsezii Date: Sat, 29 Nov 2025 18:09:14 -0800 Subject: [PATCH 59/59] tableax: Add BogackiShampine, CashKarp, DormandPrince embedded tableaux --- skrample/sampling/functional.py | 3 +- skrample/sampling/tableaux.py | 63 +++++++++++++++++++++++++++++++-- tests/miscellaneous.py | 10 +++++- 3 files changed, 71 insertions(+), 5 deletions(-) diff --git a/skrample/sampling/functional.py b/skrample/sampling/functional.py index 24b1206..7687a2f 100644 --- a/skrample/sampling/functional.py +++ b/skrample/sampling/functional.py @@ -227,7 +227,7 @@ class RKUltra(FunctionalDerivative, FunctionalSinglestep): 2: tableaux.RK2.Heun, 3: tableaux.RK3.Ralston, 4: tableaux.RK4.Ralston, - 5: tableaux.RKZ.Nystrom5, + 5: tableaux.RKE5.CashKarp, } ) ) @@ -325,6 +325,7 @@ class RKMoire(FunctionalAdaptive, FunctionalDerivative): providers: DictOrProxy[int, tableaux.TableauProvider[tableaux.ExtendedTableau]] = MappingProxyType( { 2: tableaux.RKE2.Heun, + 3: tableaux.RKE3.BogackiShampine, 5: tableaux.RKE5.Fehlberg, } ) diff --git a/skrample/sampling/tableaux.py b/skrample/sampling/tableaux.py index 61e136a..a2ea3ca 100644 --- a/skrample/sampling/tableaux.py +++ b/skrample/sampling/tableaux.py @@ -150,6 +150,8 @@ def tableau(self) -> Tableau: @enum.unique class RK2(enum.Enum): + "2nd order, 2 calls" + Heun = rk2_tableau(1) Mid = rk2_tableau(1 / 2) Ralston = rk2_tableau(2 / 3) @@ -160,6 +162,8 @@ def tableau(self) -> Tableau: @enum.unique class RK3(enum.Enum): + "3rd order, 3 calls" + Kutta = rk3_tableau(1 / 2, 1) Heun = rk3_tableau(1 / 3, 2 / 3) Ralston = rk3_tableau(1 / 2, 3 / 4) @@ -172,6 +176,8 @@ def tableau(self) -> Tableau: @enum.unique class RK4(enum.Enum): + "4th order, 4 calls" + Classic = ( ( (0, ()), @@ -190,6 +196,9 @@ def tableau(self) -> Tableau: @enum.unique class RKZ(enum.Enum): + """Tableaux provided by this method do not have clean generic forms, and require more calls than their order. + Since these are rare, they are all categorized into one enum""" + Nystrom5 = ( ( (0, ()), @@ -216,7 +225,32 @@ class RKE2(enum.Enum): (1 / 2, 1 / 2), (1, 0), ) - # Fehlberg = enum.auto() + Fehlberg = ( + ( + (0, ()), + (1 / 2, (1 / 2,)), + (1, (1 / 256, 255 / 256)), + ), + (1 / 512, 255 / 256, 1 / 512), + (1 / 256, 255 / 256, 0), + ) + + def tableau(self) -> ExtendedTableau: + return self.value + + +@enum.unique +class RKE3(enum.Enum): + BogackiShampine = ( + ( + (0, ()), + (1 / 2, (1 / 2,)), + (3 / 4, (0, 3 / 4)), + (1, (2 / 9, 1 / 3, 4 / 9)), + ), + (2 / 9, 1 / 3, 4 / 9, 0), + (7 / 24, 1 / 4, 1 / 3, 1 / 8), + ) def tableau(self) -> ExtendedTableau: return self.value @@ -236,8 +270,31 @@ class RKE5(enum.Enum): (16 / 135, 0, 6656 / 12825, 28561 / 56430, -9 / 50, 2 / 55), (25 / 216, 0, 1408 / 2565, 2197 / 4104, -1 / 5, 0), ) - # CashKarp = enum.auto() - # DormandPrince = enum.auto() + CashKarp = ( + ( + (0, ()), + (1 / 5, (1 / 5,)), + (3 / 10, (3 / 40, 9 / 40)), + (3 / 5, (3 / 10, -9 / 10, 6 / 5)), + (1, (-11 / 54, 5 / 2, -70 / 27, 35 / 27)), + (7 / 8, (1631 / 55296, 175 / 512, 575 / 13824, 44275 / 110592, 253 / 4096)), + ), + (37 / 378, 0, 250 / 621, 125 / 594, 0, 512 / 1771), + (2825 / 27648, 0, 18575 / 48384, 13525 / 55296, 277 / 14336, 1 / 4), + ) + DormandPrince = ( + ( + (0, ()), + (1 / 5, (1 / 5,)), + (3 / 10, (3 / 40, 9 / 40)), + (4 / 5, (44 / 45, -56 / 15, 32 / 9)), + (8 / 9, (19372 / 6561, -25360 / 2187, 64448 / 6561, -212 / 729)), + (1, (9017 / 3168, -355 / 33, 46732 / 5247, 49 / 176, -5103 / 18656)), + (1, (35 / 384, 0, 500 / 1113, 125 / 192, -2187 / 6784, 11 / 84)), + ), + (35 / 384, 0, 500 / 1113, 125 / 192, -2187 / 6784, 11 / 84, 0), + (5179 / 57600, 0, 7571 / 16695, 393 / 640, -92097 / 339200, 187 / 2100, 1 / 40), + ) def tableau(self) -> ExtendedTableau: return self.value diff --git a/tests/miscellaneous.py b/tests/miscellaneous.py index 5b35fbd..4ad3254 100644 --- a/tests/miscellaneous.py +++ b/tests/miscellaneous.py @@ -323,7 +323,15 @@ def test_bashforth() -> None: def test_tableau_providers() -> None: - for provider in [tableaux.RK2, tableaux.RK3, tableaux.RK4, tableaux.RKZ, tableaux.RKE2, tableaux.RKE5]: + for provider in [ + tableaux.RK2, + tableaux.RK3, + tableaux.RK4, + tableaux.RKZ, + tableaux.RKE2, + tableaux.RKE3, + tableaux.RKE5, + ]: for variant in provider: if error := tableaux.validate_tableau(variant.tableau()): raise error