Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
60 commits
Select commit Hold shift + click to select a range
7a9f78f
First draft for FunctionalSampler
Beinsezii Sep 21, 2025
bc15d74
Split FunctionalHigher from FunctionalSinglestep
Beinsezii Sep 21, 2025
a6a0923
Plot functional samplers, upgrade Heun -> RungeKutta
Beinsezii Sep 22, 2025
d9394ea
Rewrite RungeKutta to deduplicate code
Beinsezii Sep 22, 2025
99315bf
Rewrite RungeKutta again to use tableau format for stages
Beinsezii Sep 26, 2025
9c98a3c
plot_skrample.py: improve fake model algorithm
Beinsezii Sep 26, 2025
857f967
Use real tableau values normalized by sum
Beinsezii Sep 26, 2025
0057310
RungeKutta use linear schedule interpolation for steps
Beinsezii Sep 26, 2025
3291c73
Add every RungeKutta tableau from wikipedia
Beinsezii Sep 26, 2025
7f85797
Add RKUltra tableau test
Beinsezii Sep 26, 2025
c82e35e
RKUltra do not skip non-zero intermediates
Beinsezii Sep 26, 2025
059c874
Add RKUltra.RK5.Nystrom tableau
Beinsezii Sep 28, 2025
0a09352
Tweak RKUltra.adjust_steps() to better handle different tableaus
Beinsezii Sep 28, 2025
9f7cab9
RKUltra add `custom_tableau` field
Beinsezii Sep 28, 2025
b692567
Add examples/diffusers/functional.py
Beinsezii Sep 29, 2025
805f44a
Min diffusers == 0.35 for examples/diffusers/functional.py
Beinsezii Sep 29, 2025
355021d
Add FunctionalSampler.generate_model(), some cleanup
Beinsezii Oct 5, 2025
a2d5ebb
Basic adaptive samplers
Beinsezii Oct 6, 2025
8bcc08b
Rework AdaptiveHeun -> RKMoire
Beinsezii Oct 8, 2025
16f2760
Move tableaux definitions into separate module
Beinsezii Oct 12, 2025
3c841d0
Deduplicate tableau sampling code
Beinsezii Oct 12, 2025
8ebf7be
Fix a docstring
Beinsezii Oct 12, 2025
835dd99
Make tableau validations standalone && more comprehensive
Beinsezii Oct 12, 2025
fbc6990
Add generic RK2 and RK3 tableau generators
Beinsezii Oct 12, 2025
6d43373
Make step_tableau() epsilon customizable
Beinsezii Oct 12, 2025
28d0cd5
Streamling higher order tableau logic
Beinsezii Oct 12, 2025
3130645
Fix RKMoire tableau provider fetch
Beinsezii Oct 12, 2025
eace1ad
Simplify tableaux providers
Beinsezii Oct 12, 2025
00bd28b
clear whitespace
Beinsezii Oct 13, 2025
9dca150
Fix `providers` fields not being accessible at a class level
Beinsezii Oct 13, 2025
b4dafdb
Add RKMoire.rescale_max
Beinsezii Oct 13, 2025
530e582
Deduplicate lots of scheduling / euler code
Beinsezii Oct 13, 2025
2d85992
Unify multiple types across structured, functional, scheduling
Beinsezii Oct 17, 2025
2b6c0d4
Add common.mean
Beinsezii Oct 17, 2025
00a4902
Change signature of FunctionalSampler.merge_noise to use internal sch…
Beinsezii Oct 17, 2025
895870c
StructuredFunctionalAdapter only call rng() if sampler requires
Beinsezii Oct 17, 2025
6358c1b
Add float versions of SkrampleSchedule methods, FunctionalSampler.mod…
Beinsezii Oct 17, 2025
d2514fd
Add RKMoire.discard
Beinsezii Oct 17, 2025
4752995
Further improve plot_skrample.py fake model
Beinsezii Nov 2, 2025
9f38650
Add Heun tests, rough impl of derivative-based tableau solver
Beinsezii Nov 5, 2025
e4935cb
Rework step_tableau_derive to better work with polar models
Beinsezii Nov 10, 2025
2a440bf
Deduplicate step_tableau with step_tableau_derive
Beinsezii Nov 10, 2025
3139fc2
Add low step test cases for heun
Beinsezii Nov 10, 2025
a92acf0
from_d complement is equivalent to predict_flow
Beinsezii Nov 10, 2025
53a230c
from_d polar is equivalent to predict_epsilon, add velocity heun test
Beinsezii Nov 10, 2025
3e1bedb
Add unified set of Diffusion model transforms
Beinsezii Nov 17, 2025
d2ababa
Change ModelTransfrom from Type -> Dataclass, add ScaleX transform
Beinsezii Nov 23, 2025
fe0a0ba
Replace common.safe_log() -> common.ln()
Beinsezii Nov 23, 2025
96169dc
Fix test_functional_adapter
Beinsezii Nov 23, 2025
d4ca1ea
Remove common.Predictor && functions
Beinsezii Nov 23, 2025
285c132
test_model_transforms also test backward()
Beinsezii Nov 23, 2025
db0d48a
add test_model_convert
Beinsezii Nov 23, 2025
aefe0e7
Add ScaleX to test_model_convert
Beinsezii Nov 23, 2025
da56008
Rename ModelTransform,DiffusionModel,EpsilonModel -> DiffusionModel,D…
Beinsezii Nov 23, 2025
60bc0f8
Don't skip derivative transform on Euler
Beinsezii Nov 24, 2025
6350f7d
Merge pull request #57 from Beinsezii/beinsezii/model_transform
Beinsezii Nov 24, 2025
f5b0ae4
plot_skrample: fix two bugs with modifiers
Beinsezii Nov 29, 2025
5edf16c
Rewrite RK4.Ralston and RK4.Eighth using generic RK4 solver
Beinsezii Nov 30, 2025
8fe8937
Rename RK5.Nystrom -> RKZ.Nystrom5
Beinsezii Nov 30, 2025
e463bba
tableax: Add BogackiShampine, CashKarp, DormandPrince embedded tableaux
Beinsezii Nov 30, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 102 additions & 0 deletions examples/diffusers/functional.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
#! /usr/bin/env python

from typing import ClassVar

import torch
from diffusers.modular_pipelines.components_manager import ComponentsManager
from diffusers.modular_pipelines.flux.denoise import FluxDenoiseStep, FluxLoopDenoiser
from diffusers.modular_pipelines.flux.modular_blocks import TEXT2IMAGE_BLOCKS
from diffusers.modular_pipelines.flux.modular_pipeline import FluxModularPipeline
from diffusers.modular_pipelines.modular_pipeline import ModularPipelineBlocks, PipelineState, SequentialPipelineBlocks
from tqdm import tqdm

import skrample.scheduling as scheduling
from skrample.diffusers import SkrampleWrapperScheduler
from skrample.sampling import functional, models, structured
from skrample.sampling.interface import StructuredFunctionalAdapter

model_id = "black-forest-labs/FLUX.1-dev"

blocks = SequentialPipelineBlocks.from_blocks_dict(TEXT2IMAGE_BLOCKS)

schedule = scheduling.FlowShift(scheduling.Linear(), shift=2)
wrapper = SkrampleWrapperScheduler(
sampler=structured.Euler(), schedule=schedule, model=models.FlowModel(), allow_dynamic=False
)

# Equivalent to structured example
sampler = StructuredFunctionalAdapter(schedule, structured.DPM(order=2, add_noise=True))
# Native functional example
sampler = functional.RKUltra(schedule, 4)
# Dynamic model calls
sampler = functional.FastHeun(schedule)
# Dynamic step sizes
sampler = functional.RKMoire(schedule)


class FunctionalDenoise(FluxDenoiseStep):
# Exclude the after_denoise block
block_classes: ClassVar[list[type[ModularPipelineBlocks]]] = [FluxLoopDenoiser]
block_names: ClassVar[list[str]] = ["denoiser"]

@torch.no_grad()
def __call__(self, components: FluxModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)

if isinstance(sampler, functional.FunctionalHigher):
block_state["num_inference_steps"] = sampler.adjust_steps(block_state["num_inference_steps"])
progress = tqdm(total=block_state["num_inference_steps"])

i = 0

def call_model(sample: torch.Tensor, timestep: float, sigma: float) -> torch.Tensor:
nonlocal i, components, block_state, progress
block_state["latents"] = sample
components, block_state = self.loop_step(
components,
block_state, # type: ignore
i=i,
t=sample.new_tensor([timestep] * len(sample)),
)
return block_state["noise_pred"] # type: ignore

def sample_callback(x: torch.Tensor, n: int, t: float, s: float) -> None:
nonlocal i
progress.update(n + 1 - progress.n)
i = n + 1

block_state["latents"] = sampler.sample_model(
sample=block_state["latents"],
model=call_model,
model_transform=models.FlowModel(),
steps=block_state["num_inference_steps"],
callback=sample_callback,
)

self.set_block_state(state, block_state) # type: ignore
return components, state # type: ignore


blocks.sub_blocks["denoise"] = FunctionalDenoise()

cm = ComponentsManager()
cm.enable_auto_cpu_offload()
pipe = blocks.init_pipeline(components_manager=cm)
pipe.load_components(["text_encoder"], repo=model_id, subfolder="text_encoder", torch_dtype=torch.bfloat16)
pipe.load_components(["tokenizer"], repo=model_id, subfolder="tokenizer")
pipe.load_components(["text_encoder_2"], repo=model_id, subfolder="text_encoder_2", torch_dtype=torch.bfloat16)
pipe.load_components(["tokenizer_2"], repo=model_id, subfolder="tokenizer_2")
pipe.load_components(["transformer"], repo=model_id, subfolder="transformer", torch_dtype=torch.bfloat16)
pipe.load_components(["vae"], repo=model_id, subfolder="vae", torch_dtype=torch.bfloat16)

pipe.register_components(scheduler=wrapper)


pipe( # type: ignore
prompt="sharp, high dynamic range photograph of a kitten on a beach of rainbow pebbles",
generator=torch.Generator("cpu").manual_seed(42),
width=1024,
height=1024,
num_inference_steps=25,
guidance_scale=2.5,
).get("images")[0].save("diffusers_functional.png")
6 changes: 3 additions & 3 deletions examples/diffusers/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@
from diffusers.pipelines.flux.pipeline_flux import FluxPipeline

import skrample.pytorch.noise as sknoise
import skrample.sampling as sampling
import skrample.sampling.structured as sampling
import skrample.scheduling as scheduling
from skrample.common import predict_flow
from skrample.diffusers import SkrampleWrapperScheduler
from skrample.sampling.models import FlowModel

pipe: FluxPipeline = FluxPipeline.from_pretrained( # type: ignore
"black-forest-labs/FLUX.1-dev",
Expand All @@ -17,7 +17,7 @@
pipe.scheduler = scheduler = SkrampleWrapperScheduler(
sampler=sampling.DPM(order=2, add_noise=True),
schedule=scheduling.FlowShift(scheduling.Linear(), shift=2.0),
predictor=predict_flow,
model=FlowModel(),
noise_type=sknoise.Brownian,
allow_dynamic=False,
)
Expand Down
2 changes: 1 addition & 1 deletion examples/diffusers/wrapper_from.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from diffusers.pipelines.flux.pipeline_flux import FluxPipeline

import skrample.pytorch.noise as sknoise
import skrample.sampling as sampling
import skrample.sampling.structured as sampling
from skrample.diffusers import SkrampleWrapperScheduler

pipe: FluxPipeline = FluxPipeline.from_pretrained( # type: ignore
Expand Down
77 changes: 77 additions & 0 deletions examples/functional.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
#! /usr/bin/env python

import torch
from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
from PIL import Image
from tqdm import tqdm
from transformers.models.clip import CLIPTextModel, CLIPTokenizer

import skrample.pytorch.noise as noise
import skrample.scheduling as scheduling
from skrample.sampling import functional, models, structured
from skrample.sampling.interface import StructuredFunctionalAdapter

with torch.inference_mode():
device: torch.device = torch.device("cuda")
dtype: torch.dtype = torch.float16
url: str = "Lykon/dreamshaper-8"
seed = torch.Generator("cpu").manual_seed(0)
steps: int = 25
cfg: float = 3

schedule = scheduling.Karras(scheduling.Scaled())

# Equivalent to structured example
sampler = StructuredFunctionalAdapter(schedule, structured.DPM(order=2, add_noise=True))
# Native functional example
sampler = functional.RKUltra(schedule, 4)
# Dynamic model calls
sampler = functional.FastHeun(schedule)
# Dynamic step sizes
sampler = functional.RKMoire(schedule)

tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(url, subfolder="tokenizer")
text_encoder: CLIPTextModel = CLIPTextModel.from_pretrained(
url, subfolder="text_encoder", device_map=device, torch_dtype=dtype
)
model: UNet2DConditionModel = UNet2DConditionModel.from_pretrained( # type: ignore
url, subfolder="unet", device_map=device, torch_dtype=dtype
)
image_encoder: AutoencoderKL = AutoencoderKL.from_pretrained( # type: ignore
url, subfolder="vae", device_map=device, torch_dtype=dtype
)

text_embeds: torch.Tensor = text_encoder(
tokenizer(
"bright colorful fantasy art of a kitten in a field of rainbow flowers",
padding="max_length",
return_tensors="pt",
).input_ids.to(device=device)
).last_hidden_state

def call_model(x: torch.Tensor, t: float, s: float) -> torch.Tensor:
conditioned, unconditioned = model(
x.expand([x.shape[0] * 2, *x.shape[1:]]),
t,
torch.cat([text_embeds, torch.zeros_like(text_embeds)]),
).sample.chunk(2)
return conditioned + (cfg - 1) * (conditioned - unconditioned)

if isinstance(sampler, functional.FunctionalHigher):
steps = sampler.adjust_steps(steps)

rng = noise.Random.from_inputs((1, 4, 80, 80), seed)
bar = tqdm(total=steps)
sample = sampler.generate_model(
model=call_model,
model_transform=models.NoiseModel(),
steps=steps,
rng=lambda: rng.generate().to(dtype=dtype, device=device),
callback=lambda x, n, t, s: bar.update(n + 1 - bar.n),
)

image: torch.Tensor = image_encoder.decode(sample / image_encoder.config.scaling_factor).sample[0] # type: ignore
Image.fromarray(
((image + 1) * (255 / 2)).clamp(0, 255).permute(1, 2, 0).to(device="cpu", dtype=torch.uint8).numpy()
).save("functional.png")
19 changes: 10 additions & 9 deletions examples/raw.py → examples/structured.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@
from tqdm import tqdm
from transformers.models.clip import CLIPTextModel, CLIPTokenizer

import skrample.common
import skrample.sampling as sampling
import skrample.sampling.structured as structured
import skrample.scheduling as scheduling
from skrample.sampling import models

with torch.inference_mode():
device: torch.device = torch.device("cuda")
Expand All @@ -20,8 +20,8 @@
cfg: float = 3

schedule: scheduling.SkrampleSchedule = scheduling.Karras(scheduling.Scaled())
sampler: sampling.SkrampleSampler = sampling.DPM(order=2, add_noise=True)
predictor: skrample.common.Predictor = skrample.common.predict_epsilon
sampler: structured.StructuredSampler = structured.DPM(order=2, add_noise=True)
transform: models.DiffusionModel = models.NoiseModel()

tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(url, subfolder="tokenizer")
text_encoder: CLIPTextModel = CLIPTextModel.from_pretrained(
Expand All @@ -43,23 +43,24 @@
).last_hidden_state

sample: torch.Tensor = torch.randn([1, 4, 80, 80], generator=seed).to(dtype=dtype, device=device)
previous: list[sampling.SKSamples[torch.Tensor]] = []
previous: list[structured.SKSamples[torch.Tensor]] = []
float_schedule = schedule.schedule(steps)

for n, (timestep, sigma) in enumerate(tqdm(schedule.schedule(steps))):
for n, (timestep, sigma) in enumerate(tqdm(float_schedule)):
conditioned, unconditioned = model(
sample.expand([sample.shape[0] * 2, *sample.shape[1:]]),
timestep,
torch.cat([text_embeds, torch.zeros_like(text_embeds)]),
).sample.chunk(2)
model_output: torch.Tensor = conditioned + (cfg - 1) * (conditioned - unconditioned)

prediction = predictor(sample, model_output, sigma, schedule.sigma_transform)
prediction = transform.to_x(sample, model_output, sigma, schedule.sigma_transform)

sampler_output = sampler.sample(
sample=sample,
prediction=prediction,
step=n,
sigma_schedule=schedule.sigmas(steps),
schedule=float_schedule,
sigma_transform=schedule.sigma_transform,
noise=torch.randn(sample.shape, generator=seed).to(dtype=sample.dtype, device=sample.device),
previous=tuple(previous),
Expand All @@ -71,4 +72,4 @@
image: torch.Tensor = image_encoder.decode(sample / image_encoder.config.scaling_factor).sample[0] # type: ignore
Image.fromarray(
((image + 1) * (255 / 2)).clamp(0, 255).permute(1, 2, 0).to(device="cpu", dtype=torch.uint8).numpy()
).save("raw.png")
).save("structured.png")
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ scripts = ["skrample[all]", "matplotlib>=3.10.1"]
test = [
"skrample[scripts]",
"accelerate>=1.3",
"diffusers>=0.32",
"diffusers>=0.35",
"protobuf>=5.29",
"pyright>=1.1.400",
"pytest-xdist>=3.6.1",
Expand Down
4 changes: 2 additions & 2 deletions scripts/overhead.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@
import torch

from skrample.diffusers import SkrampleWrapperScheduler
from skrample.sampling import DPM
from skrample.sampling.structured import Euler
from skrample.scheduling import Beta, FlowShift, SigmoidCDF


def bench_wrapper() -> int:
wrapper = SkrampleWrapperScheduler(DPM(), Beta(FlowShift(SigmoidCDF())))
wrapper = SkrampleWrapperScheduler(Euler(), Beta(FlowShift(SigmoidCDF())))
wrapper.set_timesteps(1000)

clock = perf_counter_ns()
Expand Down
Loading
Loading