Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
043a941
[diffusion][CI]: Add individual component accuracy CI for diffusion m…
Ratish1 Feb 4, 2026
5f93ead
remove fallback
Ratish1 Feb 5, 2026
d18275e
Merge remote-tracking branch 'upstream/main' into feat/accuracy-test
Ratish1 Feb 9, 2026
3df80fe
upd
Ratish1 Feb 12, 2026
668e52e
Merge remote-tracking branch 'upstream/main' into feat/accuracy-test
Ratish1 Feb 23, 2026
8d4061f
test: migrate diffusion accuracy to native hook architecture
Ratish1 Feb 23, 2026
6958597
upd
Ratish1 Feb 26, 2026
61fce33
upd
Ratish1 Feb 26, 2026
d594a1c
upd
Ratish1 Mar 9, 2026
3754993
fix conflict
Ratish1 Mar 19, 2026
0dbabd7
Merge remote-tracking branch 'upstream/main' into feat/accuracy-test
Ratish1 Mar 25, 2026
61916ea
test(multimodal-gen): stabilize 2-GPU component accuracy harness
Ratish1 Mar 27, 2026
d804385
upd
Ratish1 Mar 27, 2026
5f623f1
move shard context helper into accuracy utils
Ratish1 Mar 27, 2026
d48a1ee
move accuracy runtime helpers into accuracy utils
Ratish1 Mar 27, 2026
ea39b1a
deduplicate text encoder module resolution
Ratish1 Mar 27, 2026
265f6e8
inline native forward output capture
Ratish1 Mar 27, 2026
fa816f4
deduplicate native output normalization
Ratish1 Mar 28, 2026
34c17e6
remove generic native profile registry
Ratish1 Mar 28, 2026
b7e6a68
rename accuracy hook helpers for clarity
Ratish1 Mar 28, 2026
ed16d98
make native profile and text encoder helpers more explicit
Ratish1 Mar 28, 2026
0300ad9
reduce unnecessary memory cleanup work between accuracy stages
Ratish1 Mar 28, 2026
0e5f665
remove unused accuracy helpers
Ratish1 Mar 28, 2026
9edb668
[diffusion] test: trim accuracy harness comments
Ratish1 Mar 28, 2026
4d8ccb0
Merge remote-tracking branch 'upstream/main' into feat/accuracy-test
Ratish1 Mar 28, 2026
280fd24
[diffusion] test: reduce duplicate accuracy coverage and stabilize 2-…
Ratish1 Mar 28, 2026
eb82f80
Merge remote-tracking branch 'upstream/main' into feat/accuracy-test
Ratish1 Mar 28, 2026
6534ea3
Merge branch 'main' into feat/accuracy-test
BBuf Mar 30, 2026
93307ad
fix ci names
Ratish1 Mar 30, 2026
c5f54e9
fix ci path
Ratish1 Mar 30, 2026
fe41a77
upd
Ratish1 Mar 30, 2026
9d121fc
fix OOM on 2 gpu cases
Ratish1 Mar 30, 2026
6d4245b
add CI tests to diffusion workflow
Ratish1 Mar 30, 2026
c1edd23
fix
Ratish1 Mar 30, 2026
96c5fdb
fix
Ratish1 Mar 30, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
577 changes: 577 additions & 0 deletions python/sglang/multimodal_gen/test/server/accuracy_adapters.py

Large diffs are not rendered by default.

143 changes: 143 additions & 0 deletions python/sglang/multimodal_gen/test/server/accuracy_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
from __future__ import annotations

from dataclasses import dataclass
from enum import Enum
from typing import Dict, Optional

from sglang.multimodal_gen.test.server.testcase_configs import DiffusionTestCase


class ComponentType(str, Enum):
VAE = "vae"
TRANSFORMER = "transformer"
TEXT_ENCODER = "text_encoder"


@dataclass(frozen=True)
class ComponentSkip:
reason: str


# Default thresholds by component. Override per component/case if needed.
DEFAULT_THRESHOLDS = {
ComponentType.VAE: 0.999,
ComponentType.TRANSFORMER: 0.995,
ComponentType.TEXT_ENCODER: 0.98,
}

# Optional per-case overrides: {case_id: {ComponentType: threshold}}
CASE_THRESHOLDS: Dict[str, Dict[ComponentType, float]] = {
# Add overrides here when a specific model/component needs a different threshold.
"flux_2_image_t2i": {ComponentType.TRANSFORMER: 0.99},
"flux_2_image_t2i_layerwise_offload": {ComponentType.TRANSFORMER: 0.99},
"flux_2_image_t2i_2_gpus": {ComponentType.TRANSFORMER: 0.99},
"flux_2_ti2i": {ComponentType.TRANSFORMER: 0.99},
"fast_hunyuan_video": {ComponentType.TRANSFORMER: 0.99},
}

# Optional per-case component skips: {case_id: {ComponentType: ComponentSkip}}
SKIP_COMPONENTS: Dict[str, Dict[ComponentType, ComponentSkip]] = {
# Example:
# "some_case_id": {ComponentType.TEXT_ENCODER: ComponentSkip("Diffusers baseline differs")},
"flux_2_klein_image_t2i": {
ComponentType.TRANSFORMER: ComponentSkip(
"Diffusers transformer differs from SGLang baseline"
),
ComponentType.TEXT_ENCODER: ComponentSkip(
"Flux-2 klein text encoder weights do not align with SGLang baseline"
),
},
"qwen_image_layered_i2i": {
ComponentType.VAE: ComponentSkip(
"Diffusers VAE config mismatches checkpoint (conv_out shape mismatch)"
)
},
"zimage_image_t2i": {
ComponentType.TRANSFORMER: ComponentSkip(
"SGLang ZImage transformer diverges from Diffusers baseline (CosSim ~0.61) despite matched weights and freqs"
),
ComponentType.TEXT_ENCODER: ComponentSkip(
"SGLang text encoder weights do not fully load from HF checkpoint"
),
},
"zimage_image_t2i_warmup": {
ComponentType.TRANSFORMER: ComponentSkip(
"SGLang ZImage transformer diverges from Diffusers baseline (CosSim ~0.61) despite matched weights and freqs"
),
ComponentType.TEXT_ENCODER: ComponentSkip(
"SGLang text encoder weights do not fully load from HF checkpoint"
),
},
"zimage_image_t2i_multi_lora": {
ComponentType.TRANSFORMER: ComponentSkip(
"SGLang ZImage transformer diverges from Diffusers baseline (CosSim ~0.61) despite matched weights and freqs"
),
ComponentType.TEXT_ENCODER: ComponentSkip(
"SGLang text encoder weights do not fully load from HF checkpoint"
),
},
"zimage_image_t2i_2_gpus": {
ComponentType.TRANSFORMER: ComponentSkip(
"SGLang ZImage transformer diverges from Diffusers baseline (CosSim ~0.61) despite matched weights and freqs"
)
},
"wan2_2_ti2v_5b": {
ComponentType.TRANSFORMER: ComponentSkip(
"SGLang transformer loader rejects new parameters in HF checkpoint"
)
},
"fastwan2_2_ti2v_5b": {
ComponentType.TRANSFORMER: ComponentSkip(
"SGLang transformer loader rejects new parameters in HF checkpoint"
)
},
"wan2_2_i2v_a14b_2gpu": {
ComponentType.TRANSFORMER: ComponentSkip(
"SGLang transformer loader rejects new parameters in HF checkpoint"
)
},
"turbo_wan2_2_i2v_a14b_2gpu": {
ComponentType.TRANSFORMER: ComponentSkip(
"SGLang transformer loader rejects new parameters in HF checkpoint"
)
},
"turbo_wan2_1_t2v_1.3b": {
ComponentType.TRANSFORMER: ComponentSkip(
"Weight transfer match ratio too low for reliable comparison"
)
},
"wan2_1_i2v_14b_480P_2gpu": {
ComponentType.TRANSFORMER: ComponentSkip(
"SGLang WAN transformer diverges from Diffusers baseline (CosSim ~0.68-0.71) despite matched weights; optional norm_added_q params missing in SGLang model"
)
},
"wan2_1_i2v_14b_lora_2gpu": {
ComponentType.TRANSFORMER: ComponentSkip(
"SGLang WAN transformer diverges from Diffusers baseline (CosSim ~0.68-0.71) despite matched weights; optional norm_added_q params missing in SGLang model"
)
},
"wan2_1_i2v_14b_720P_2gpu": {
ComponentType.TRANSFORMER: ComponentSkip(
"SGLang WAN transformer diverges from Diffusers baseline (CosSim ~0.68-0.71) despite matched weights; optional norm_added_q params missing in SGLang model"
)
},
}

# TODO: If a model needs extra compatibility logic, prefer adding a skip or an
# explicit override here instead of adding more ad-hoc hacks in the engine.


def get_threshold(case_id: str, component: ComponentType) -> float:
overrides = CASE_THRESHOLDS.get(case_id, {})
return overrides.get(component, DEFAULT_THRESHOLDS[component])


def get_skip_reason(case: DiffusionTestCase, component: ComponentType) -> Optional[str]:
skip_entry = SKIP_COMPONENTS.get(case.id, {}).get(component)
if skip_entry is None:
return None
return skip_entry.reason


def should_skip_component(case: DiffusionTestCase, component: ComponentType) -> bool:
return get_skip_reason(case, component) is not None
60 changes: 60 additions & 0 deletions python/sglang/multimodal_gen/test/server/accuracy_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
from __future__ import annotations

import os
from typing import Any, Dict

import torch

from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger

logger = init_logger(__name__)


def seed_and_broadcast(seed: int, tensor: torch.Tensor) -> torch.Tensor:
"""Seed and broadcast tensor across ranks for determinism."""
torch.manual_seed(seed)
if torch.distributed.is_initialized() and torch.distributed.get_world_size() > 1:
torch.distributed.broadcast(tensor, src=0)
return tensor


def log_tensor_stats(name: str, tensor: torch.Tensor) -> None:
if tensor is None:
return
if os.environ.get("SGLANG_DIFFUSION_ACC_DEBUG", "0") != "1":
return
t = tensor.detach().float().cpu()
logger.info(
"[%s] stats: shape=%s mean=%.6f std=%.6f min=%.6f max=%.6f",
name,
list(t.shape),
t.mean().item(),
t.std().item(),
t.min().item(),
t.max().item(),
)


def log_inputs(prefix: str, inputs: Dict[str, Any]) -> None:
if os.environ.get("SGLANG_DIFFUSION_ACC_DEBUG", "0") != "1":
return
for k, v in inputs.items():
if isinstance(v, torch.Tensor):
log_tensor_stats(f"{prefix}.{k}", v)


def extract_output_tensor(output: Any) -> torch.Tensor:
"""Best-effort extraction of a tensor from model outputs."""
if isinstance(output, torch.Tensor):
return output
if getattr(output, "last_hidden_state", None) is not None:
return output.last_hidden_state
if getattr(output, "hidden_states", None):
return output.hidden_states[-1]
if getattr(output, "pooler_output", None) is not None:
return output.pooler_output
if getattr(output, "logits", None) is not None:
return output.logits
if isinstance(output, (list, tuple)) and output:
return output[0]
return output
Comment thread
Ratish1 marked this conversation as resolved.
Outdated
Loading
Loading