[diffusion] chore: align LTX-2 with official#24313
Conversation
There was a problem hiding this comment.
Code Review
This pull request focuses on precision alignment for the LTX2.3 model, specifically addressing consistency issues in the LoRA warmup and HQ pipelines. Key improvements include ensuring bias and reduction ordering in RowParallelLinear layers matches the base implementation when LoRA is disabled, and adopting NumPy-based double-precision RoPE frequency generation to align with official implementations. The res2s SDE logic was also refined to maintain scheduler dtypes during main-step updates. Review feedback identifies critical risks associated with in-place tensor modifications that could corrupt global sigma schedules or cause side effects for callers, and suggests optimizing Grouped Query Attention (GQA) by utilizing native SDPA support instead of manual tensor expansion.
I am having trouble creating individual review comments. Click here to see my feedback.
python/sglang/multimodal_gen/runtime/pipelines_core/stages/ltx_2_denoising.py (378-379)
In-place modification of sigma_down can lead to critical bugs. Specifically, if sigma_down was None initially, it is assigned a reference to sigma_next at line 372. Modifying it in-place at line 379 will then modify sigma_next. If sigma_next is a view into the scheduler's sigmas tensor, this will corrupt the global sigma schedule. Using torch.where provides a safe, out-of-place alternative.
sigma_down = torch.where(torch.isnan(sigma_down), sigma_next.to(sigma_down.dtype), sigma_down)
python/sglang/multimodal_gen/runtime/models/encoders/gemma_3.py (319-338)
Manual expansion of key and value tensors for Grouped Query Attention (GQA) is inefficient and increases peak memory usage. torch.nn.functional.scaled_dot_product_attention natively supports GQA when the head counts are compatible (i.e., query heads are a multiple of KV heads). Unless this manual expansion is strictly required for bit-exact alignment with the official implementation, it should be removed to leverage SDPA's optimized kernels.
attn_output = torch.nn.functional.scaled_dot_product_attention(
query,
key,
value,
attn_mask=attn_mask,
dropout_p=0.0,
is_causal=False,
scale=self.scaling,
)python/sglang/multimodal_gen/runtime/pipelines_core/stages/ltx_2_denoising.py (364)
In-place modification of the sigma_up argument using clamp_ is a dangerous side effect. If the caller intends to reuse the tensor passed as sigma_up, its values will be unexpectedly modified. It is safer to use an out-of-place operation like torch.minimum.
sigma_up = torch.minimum(sigma_up, sigma_next * 0.9999)
|
Warning You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again! |
|
/tag-and-rerun-ci |
This reverts commit c5f2381.
* main: (894 commits) [Bug Fix] Fix RunAI streamer: corrupted weights, missing quant init, and broken URIs for multimodal models (sgl-project#22715) [Kernel] Deprecate DeepGemm in sgl kernel and apply custom wheel sgl-deep-gemm (sgl-project#24268) propagate pytest exit code from test __main__ entries (sgl-project#24487) [R3] Avoid implicit CUDA sync in routed experts DP slicing (sgl-project#24550) Add ChatCompletionRequest-style support to /v1/tokenize (sgl-project#23981) Support Triton MLA FP8 KV cache (sgl-project#20479) [diffusion] chore: align LTX-2 with official (sgl-project#24313) Expand support matrix for pypi wheel release (sgl-project#24565) [codex] Optimize Z-Image packed QKV (sgl-project#24117) [Misc] Fix breaking weight checker test (sgl-project#24553) [LoRA] Fix qkv_proj LoRA buffer sizing when tp_size > num_key_value_heads (sgl-project#24420) ci: bump test_mimo_models.py est_time 330 → 610 (sgl-project#24551) [CI] Temporarily disable marco/mcdse-2b-v1 in test_embedding_models (sgl-project#24279) Improve metrics, observability, and PD deploy tooling (sgl-project#24521) Fix diffusion fallback guards and validation (sgl-project#23335) [PD] Prevent update_status to Failed from cleared entries (sgl-project#24539) [CP] Register KV cache allgather buffer with symmetric memory (sgl-project#24040) Support getting checksums in weight checker (sgl-project#24537) Refactor buffer patterns in weight checker (sgl-project#24538) Add unit and end-to-end tests for weight checker (sgl-project#24536) ... # Conflicts: # python/sglang/srt/managers/scheduler.py # python/sglang/srt/model_executor/model_runner.py
Motivation
Align native LTX text-encoder attention behavior with the official implementation while preserving high-performance attention backends outside the text encoder path. Keep CI consistency gates honest by using official GT only for cases whose request semantics are currently comparable.
Modifications
origin/main, including the merged component attention backend override support from [diffusion] cli: support component attention backend overrides #24320.enable_gqa=True.text_encodercomponent for all LTX2 native configs, including LTX-2.0 and LTX-2.3, viacomponent_attention_backendsinstead of forcing globalattention_backend=torch_sdpa.42to match refreshed official GT generation.6e7b99e16b857c98285277fe3b4ffef30559bde9.official_generatedGT for currently comparable LTX official-aligned cases:ltx_2.3_one_stage_ti2vltx_2.3_two_stage_t2v_2gpussglang_generatedinstead of hiding gaps behind very loose official thresholds:ltx_2_two_stage_t2vltx_2_3_hq_pipelineltx_2_3_two_stage_ti2v_2gpusltx_2_two_stage_t2vSSIM threshold to0.89; keep its CLIP/PSNR/MAD video defaults.raw.githubusercontent.comHEAD is flaky.Notes
official_generatedat the pinned ci-data revision is onlyclip=0.6489,ssim=0.1460,psnr=5.8351,mean_abs_diff=112.2637. It uses refreshed nativesglang_generatedGT plus a modest SSIM relaxation instead of hiding the gap behind extremely loose official thresholds.ltx_2_3_two_stage_ti2v_2gpusstill has mid/last PSNR around11.1even withtransformer=torch_sdpa, so DIT attention backend is not the root cause.num_frames=24; actual output extraction still differs from the official GT contract.