diff --git a/python/sglang/multimodal_gen/configs/sample/flux.py b/python/sglang/multimodal_gen/configs/sample/flux.py index 7f95cd98d759..692df8332ac2 100644 --- a/python/sglang/multimodal_gen/configs/sample/flux.py +++ b/python/sglang/multimodal_gen/configs/sample/flux.py @@ -20,3 +20,9 @@ def __post_init__(self): # FIXME # self.height = default_sample_size * vae_scale_factor # self.width = default_sample_size * vae_scale_factor + + +@dataclass +class Flux2KleinSamplingParams(FluxSamplingParams): + # Klein is step-distilled, so default to 4 steps + num_inference_steps: int = 4 diff --git a/python/sglang/multimodal_gen/registry.py b/python/sglang/multimodal_gen/registry.py index 94110ec47d5d..2f09062afd15 100644 --- a/python/sglang/multimodal_gen/registry.py +++ b/python/sglang/multimodal_gen/registry.py @@ -61,7 +61,10 @@ Wan2_2_T2V_A14B_Config, Wan2_2_TI2V_5B_Config, ) -from sglang.multimodal_gen.configs.sample.flux import FluxSamplingParams +from sglang.multimodal_gen.configs.sample.flux import ( + Flux2KleinSamplingParams, + FluxSamplingParams, +) from sglang.multimodal_gen.configs.sample.glmimage import GlmImageSamplingParams from sglang.multimodal_gen.configs.sample.hunyuan import ( FastHunyuanSamplingParam, @@ -538,7 +541,7 @@ def _register_configs(): model_detectors=[lambda hf_id: "flux.1" in hf_id.lower()], ) register_configs( - sampling_param_cls=FluxSamplingParams, + sampling_param_cls=Flux2KleinSamplingParams, pipeline_config_cls=Flux2KleinPipelineConfig, hf_model_paths=[ "black-forest-labs/FLUX.2-klein-4B", diff --git a/python/sglang/multimodal_gen/test/server/perf_baselines.json b/python/sglang/multimodal_gen/test/server/perf_baselines.json index 474315b54210..3f58e3a5bbf1 100644 --- a/python/sglang/multimodal_gen/test/server/perf_baselines.json +++ b/python/sglang/multimodal_gen/test/server/perf_baselines.json @@ -304,6 +304,27 @@ "expected_avg_denoise_ms": 520.09, "expected_median_denoise_ms": 528.0 }, + "flux_2_klein_image_t2i": { + "stages_ms": { + "InputValidationStage": 0.05, + "TextEncodingStage": 530.93, + "ImageVAEEncodingStage": 0.0, + "ConditioningStage": 0.02, + "LatentPreparationStage": 12.71, + "TimestepPreparationStage": 2.91, + "DenoisingStage": 2112.24, + "DecodingStage": 489.8 + }, + "denoise_step_ms": { + "0": 511.3, + "1": 541.19, + "2": 518.93, + "3": 541.2 + }, + "expected_e2e_ms": 3148.75, + "expected_avg_denoise_ms": 528.06, + "expected_median_denoise_ms": 526.56 + }, "flux_2_image_t2i_layerwise_offload": { "stages_ms": { "InputValidationStage": 0.06, diff --git a/python/sglang/multimodal_gen/test/server/testcase_configs.py b/python/sglang/multimodal_gen/test/server/testcase_configs.py index 3281ace17f22..8f9b7d4b5100 100644 --- a/python/sglang/multimodal_gen/test/server/testcase_configs.py +++ b/python/sglang/multimodal_gen/test/server/testcase_configs.py @@ -349,6 +349,14 @@ def from_req_perf_record( ), T2I_sampling_params, ), + DiffusionTestCase( + "flux_2_klein_image_t2i", + DiffusionServerArgs( + model_path="black-forest-labs/FLUX.2-klein-4B", + modality="image", + ), + T2I_sampling_params, + ), # TODO: replace with a faster model to test the --dit-layerwise-offload # TODO: currently, we don't support sending more than one request in test, and setting `num_outputs_per_prompt` to 2 doesn't guarantee the denoising be executed twice, # so we do one warmup and send one request instead