From 3a6a3b634f6531c16df56cde4bec576d7f1b57c6 Mon Sep 17 00:00:00 2001 From: Joshna Medisetty Date: Mon, 16 Mar 2026 11:56:29 -0700 Subject: [PATCH 1/4] Load transformer config via get_hf_file_to_dict for local and HF models Signed-off-by: Joshna Medisetty --- .../diffusion/models/omnigen2/pipeline_omnigen2.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/vllm_omni/diffusion/models/omnigen2/pipeline_omnigen2.py b/vllm_omni/diffusion/models/omnigen2/pipeline_omnigen2.py index 32b874c653..1b1ad5a6ef 100644 --- a/vllm_omni/diffusion/models/omnigen2/pipeline_omnigen2.py +++ b/vllm_omni/diffusion/models/omnigen2/pipeline_omnigen2.py @@ -27,6 +27,7 @@ from diffusers.utils.torch_utils import randn_tensor from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2_5_VLProcessor from vllm.model_executor.models.utils import AutoWeightsLoader +from vllm.transformers_utils.config import get_hf_file_to_dict from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig from vllm_omni.diffusion.distributed.utils import get_local_device @@ -673,13 +674,11 @@ def __init__( self.device ) - transformer_config_path = os.path.join(model, "transformer", "config.json") + transformer_config = get_hf_file_to_dict( + "transformer/config.json", model, revision=getattr(od_config, "revision", None) + ) transformer_kwargs = {} - - if os.path.exists(transformer_config_path): - with open(transformer_config_path) as f: - transformer_config = json.load(f) - + if isinstance(transformer_config, dict): param_mapping = { "patch_size": "patch_size", "in_channels": "in_channels", @@ -697,7 +696,6 @@ def __init__( "text_feat_dim": "text_feat_dim", "timestep_scale": "timestep_scale", } - for config_key, param_name in param_mapping.items(): if config_key in transformer_config: value = transformer_config[config_key] From 398660225e92da7f6d58195849e966329af97746 Mon Sep 17 00:00:00 2001 From: Joshna Medisetty Date: Wed, 18 Mar 2026 20:43:49 +0000 Subject: [PATCH 2/4] OmniGen2: use entrypoint tf_model_config.params for transformer kwargs Signed-off-by: Joshna Medisetty --- .../models/omnigen2/pipeline_omnigen2.py | 61 +++++++++---------- 1 file changed, 29 insertions(+), 32 deletions(-) diff --git a/vllm_omni/diffusion/models/omnigen2/pipeline_omnigen2.py b/vllm_omni/diffusion/models/omnigen2/pipeline_omnigen2.py index 1b1ad5a6ef..52b108a33b 100644 --- a/vllm_omni/diffusion/models/omnigen2/pipeline_omnigen2.py +++ b/vllm_omni/diffusion/models/omnigen2/pipeline_omnigen2.py @@ -27,7 +27,6 @@ from diffusers.utils.torch_utils import randn_tensor from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2_5_VLProcessor from vllm.model_executor.models.utils import AutoWeightsLoader -from vllm.transformers_utils.config import get_hf_file_to_dict from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig from vllm_omni.diffusion.distributed.utils import get_local_device @@ -674,38 +673,36 @@ def __init__( self.device ) - transformer_config = get_hf_file_to_dict( - "transformer/config.json", model, revision=getattr(od_config, "revision", None) - ) + # Transformer config is loaded at entrypoint into od_config.tf_model_config.params. + transformer_config = od_config.tf_model_config.params transformer_kwargs = {} - if isinstance(transformer_config, dict): - param_mapping = { - "patch_size": "patch_size", - "in_channels": "in_channels", - "out_channels": "out_channels", - "hidden_size": "hidden_size", - "num_layers": "num_layers", - "num_refiner_layers": "num_refiner_layers", - "num_attention_heads": "num_attention_heads", - "num_kv_heads": "num_kv_heads", - "multiple_of": "multiple_of", - "ffn_dim_multiplier": "ffn_dim_multiplier", - "norm_eps": "norm_eps", - "axes_dim_rope": "axes_dim_rope", - "axes_lens": "axes_lens", - "text_feat_dim": "text_feat_dim", - "timestep_scale": "timestep_scale", - } - for config_key, param_name in param_mapping.items(): - if config_key in transformer_config: - value = transformer_config[config_key] - # Handle tuple parameters (axes_dim_rope, axes_lens) - if isinstance(value, list) and param_name in ( - "axes_dim_rope", - "axes_lens", - ): - value = tuple(value) - transformer_kwargs[param_name] = value + param_mapping = { + "patch_size": "patch_size", + "in_channels": "in_channels", + "out_channels": "out_channels", + "hidden_size": "hidden_size", + "num_layers": "num_layers", + "num_refiner_layers": "num_refiner_layers", + "num_attention_heads": "num_attention_heads", + "num_kv_heads": "num_kv_heads", + "multiple_of": "multiple_of", + "ffn_dim_multiplier": "ffn_dim_multiplier", + "norm_eps": "norm_eps", + "axes_dim_rope": "axes_dim_rope", + "axes_lens": "axes_lens", + "text_feat_dim": "text_feat_dim", + "timestep_scale": "timestep_scale", + } + for config_key, param_name in param_mapping.items(): + if config_key in transformer_config: + value = transformer_config[config_key] + # Handle tuple parameters (axes_dim_rope, axes_lens) + if isinstance(value, list) and param_name in ( + "axes_dim_rope", + "axes_lens", + ): + value = tuple(value) + transformer_kwargs[param_name] = value self.transformer = OmniGen2Transformer2DModel(**transformer_kwargs) self.mllm = Qwen2_5_VLForConditionalGeneration.from_pretrained( model, subfolder="mllm", local_files_only=local_files_only From 2a45a1413641e54cc51164403a31c2c758e89497 Mon Sep 17 00:00:00 2001 From: Joshna Medisetty Date: Wed, 18 Mar 2026 20:50:40 +0000 Subject: [PATCH 3/4] OmniGen2: use entrypoint tf_model_config.params for transformer kwargs Signed-off-by: Joshna Medisetty --- .../models/omnigen2/pipeline_omnigen2.py | 55 ++++++++++--------- 1 file changed, 28 insertions(+), 27 deletions(-) diff --git a/vllm_omni/diffusion/models/omnigen2/pipeline_omnigen2.py b/vllm_omni/diffusion/models/omnigen2/pipeline_omnigen2.py index 52b108a33b..915e5d82e3 100644 --- a/vllm_omni/diffusion/models/omnigen2/pipeline_omnigen2.py +++ b/vllm_omni/diffusion/models/omnigen2/pipeline_omnigen2.py @@ -676,33 +676,34 @@ def __init__( # Transformer config is loaded at entrypoint into od_config.tf_model_config.params. transformer_config = od_config.tf_model_config.params transformer_kwargs = {} - param_mapping = { - "patch_size": "patch_size", - "in_channels": "in_channels", - "out_channels": "out_channels", - "hidden_size": "hidden_size", - "num_layers": "num_layers", - "num_refiner_layers": "num_refiner_layers", - "num_attention_heads": "num_attention_heads", - "num_kv_heads": "num_kv_heads", - "multiple_of": "multiple_of", - "ffn_dim_multiplier": "ffn_dim_multiplier", - "norm_eps": "norm_eps", - "axes_dim_rope": "axes_dim_rope", - "axes_lens": "axes_lens", - "text_feat_dim": "text_feat_dim", - "timestep_scale": "timestep_scale", - } - for config_key, param_name in param_mapping.items(): - if config_key in transformer_config: - value = transformer_config[config_key] - # Handle tuple parameters (axes_dim_rope, axes_lens) - if isinstance(value, list) and param_name in ( - "axes_dim_rope", - "axes_lens", - ): - value = tuple(value) - transformer_kwargs[param_name] = value + if isinstance(transformer_config, dict): + param_mapping = { + "patch_size": "patch_size", + "in_channels": "in_channels", + "out_channels": "out_channels", + "hidden_size": "hidden_size", + "num_layers": "num_layers", + "num_refiner_layers": "num_refiner_layers", + "num_attention_heads": "num_attention_heads", + "num_kv_heads": "num_kv_heads", + "multiple_of": "multiple_of", + "ffn_dim_multiplier": "ffn_dim_multiplier", + "norm_eps": "norm_eps", + "axes_dim_rope": "axes_dim_rope", + "axes_lens": "axes_lens", + "text_feat_dim": "text_feat_dim", + "timestep_scale": "timestep_scale", + } + for config_key, param_name in param_mapping.items(): + if config_key in transformer_config: + value = transformer_config[config_key] + # Handle tuple parameters (axes_dim_rope, axes_lens) + if isinstance(value, list) and param_name in ( + "axes_dim_rope", + "axes_lens", + ): + value = tuple(value) + transformer_kwargs[param_name] = value self.transformer = OmniGen2Transformer2DModel(**transformer_kwargs) self.mllm = Qwen2_5_VLForConditionalGeneration.from_pretrained( model, subfolder="mllm", local_files_only=local_files_only From f5feb6b7adf8511695048778be740d51a2d2966e Mon Sep 17 00:00:00 2001 From: gcanlin Date: Thu, 19 Mar 2026 01:00:12 +0000 Subject: [PATCH 4/4] Use get_transformer_config_kwargs tool Signed-off-by: gcanlin --- .../models/omnigen2/pipeline_omnigen2.py | 33 ++----------------- 1 file changed, 2 insertions(+), 31 deletions(-) diff --git a/vllm_omni/diffusion/models/omnigen2/pipeline_omnigen2.py b/vllm_omni/diffusion/models/omnigen2/pipeline_omnigen2.py index 915e5d82e3..2d370aea19 100644 --- a/vllm_omni/diffusion/models/omnigen2/pipeline_omnigen2.py +++ b/vllm_omni/diffusion/models/omnigen2/pipeline_omnigen2.py @@ -36,6 +36,7 @@ OmniGen2Transformer2DModel, ) from vllm_omni.diffusion.request import OmniDiffusionRequest +from vllm_omni.diffusion.utils.tf_utils import get_transformer_config_kwargs from vllm_omni.inputs.data import OmniTextPrompt from vllm_omni.model_executor.model_loader.weight_utils import ( download_weights_from_hf_specific, @@ -673,37 +674,7 @@ def __init__( self.device ) - # Transformer config is loaded at entrypoint into od_config.tf_model_config.params. - transformer_config = od_config.tf_model_config.params - transformer_kwargs = {} - if isinstance(transformer_config, dict): - param_mapping = { - "patch_size": "patch_size", - "in_channels": "in_channels", - "out_channels": "out_channels", - "hidden_size": "hidden_size", - "num_layers": "num_layers", - "num_refiner_layers": "num_refiner_layers", - "num_attention_heads": "num_attention_heads", - "num_kv_heads": "num_kv_heads", - "multiple_of": "multiple_of", - "ffn_dim_multiplier": "ffn_dim_multiplier", - "norm_eps": "norm_eps", - "axes_dim_rope": "axes_dim_rope", - "axes_lens": "axes_lens", - "text_feat_dim": "text_feat_dim", - "timestep_scale": "timestep_scale", - } - for config_key, param_name in param_mapping.items(): - if config_key in transformer_config: - value = transformer_config[config_key] - # Handle tuple parameters (axes_dim_rope, axes_lens) - if isinstance(value, list) and param_name in ( - "axes_dim_rope", - "axes_lens", - ): - value = tuple(value) - transformer_kwargs[param_name] = value + transformer_kwargs = get_transformer_config_kwargs(od_config.tf_model_config, OmniGen2Transformer2DModel) self.transformer = OmniGen2Transformer2DModel(**transformer_kwargs) self.mllm = Qwen2_5_VLForConditionalGeneration.from_pretrained( model, subfolder="mllm", local_files_only=local_files_only