diff --git a/vllm_omni/diffusion/models/omnigen2/pipeline_omnigen2.py b/vllm_omni/diffusion/models/omnigen2/pipeline_omnigen2.py index 32b874c6535..2d370aea19c 100644 --- a/vllm_omni/diffusion/models/omnigen2/pipeline_omnigen2.py +++ b/vllm_omni/diffusion/models/omnigen2/pipeline_omnigen2.py @@ -36,6 +36,7 @@ OmniGen2Transformer2DModel, ) from vllm_omni.diffusion.request import OmniDiffusionRequest +from vllm_omni.diffusion.utils.tf_utils import get_transformer_config_kwargs from vllm_omni.inputs.data import OmniTextPrompt from vllm_omni.model_executor.model_loader.weight_utils import ( download_weights_from_hf_specific, @@ -673,41 +674,7 @@ def __init__( self.device ) - transformer_config_path = os.path.join(model, "transformer", "config.json") - transformer_kwargs = {} - - if os.path.exists(transformer_config_path): - with open(transformer_config_path) as f: - transformer_config = json.load(f) - - param_mapping = { - "patch_size": "patch_size", - "in_channels": "in_channels", - "out_channels": "out_channels", - "hidden_size": "hidden_size", - "num_layers": "num_layers", - "num_refiner_layers": "num_refiner_layers", - "num_attention_heads": "num_attention_heads", - "num_kv_heads": "num_kv_heads", - "multiple_of": "multiple_of", - "ffn_dim_multiplier": "ffn_dim_multiplier", - "norm_eps": "norm_eps", - "axes_dim_rope": "axes_dim_rope", - "axes_lens": "axes_lens", - "text_feat_dim": "text_feat_dim", - "timestep_scale": "timestep_scale", - } - - for config_key, param_name in param_mapping.items(): - if config_key in transformer_config: - value = transformer_config[config_key] - # Handle tuple parameters (axes_dim_rope, axes_lens) - if isinstance(value, list) and param_name in ( - "axes_dim_rope", - "axes_lens", - ): - value = tuple(value) - transformer_kwargs[param_name] = value + transformer_kwargs = get_transformer_config_kwargs(od_config.tf_model_config, OmniGen2Transformer2DModel) self.transformer = OmniGen2Transformer2DModel(**transformer_kwargs) self.mllm = Qwen2_5_VLForConditionalGeneration.from_pretrained( model, subfolder="mllm", local_files_only=local_files_only