From dfe52479506e4c6adddf72998627e0b1c756c120 Mon Sep 17 00:00:00 2001 From: Parth Bansal Date: Sat, 4 Apr 2026 10:27:03 +0000 Subject: [PATCH 1/3] [Feature] Add support for quant_config in NextStep Signed-off-by: Parth Bansal --- .../models/nextstep_1_1/modeling_nextstep.py | 5 +++-- .../models/nextstep_1_1/modeling_nextstep_llama.py | 14 +++++++++----- .../models/nextstep_1_1/pipeline_nextstep_1_1.py | 4 +++- 3 files changed, 15 insertions(+), 8 deletions(-) diff --git a/vllm_omni/diffusion/models/nextstep_1_1/modeling_nextstep.py b/vllm_omni/diffusion/models/nextstep_1_1/modeling_nextstep.py index ded3079265e..b9d35f683a5 100644 --- a/vllm_omni/diffusion/models/nextstep_1_1/modeling_nextstep.py +++ b/vllm_omni/diffusion/models/nextstep_1_1/modeling_nextstep.py @@ -114,7 +114,7 @@ def from_json(cls, path: str) -> NextStepConfig: class NextStepModel(nn.Module): - def __init__(self, config: NextStepConfig): + def __init__(self, config: NextStepConfig, quant_config=None): super().__init__() self.config = config self.padding_idx = config.pad_token_id @@ -122,7 +122,8 @@ def __init__(self, config: NextStepConfig): self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) self.layers = nn.ModuleList( - [LlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + [LlamaDecoderLayer(config, layer_idx, quant_config=quant_config) + for layer_idx in range(config.num_hidden_layers)] ) self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.rotary_emb = LlamaRotaryEmbedding(config=config) diff --git a/vllm_omni/diffusion/models/nextstep_1_1/modeling_nextstep_llama.py b/vllm_omni/diffusion/models/nextstep_1_1/modeling_nextstep_llama.py index 7b367b6ff49..fce34359a06 100644 --- a/vllm_omni/diffusion/models/nextstep_1_1/modeling_nextstep_llama.py +++ b/vllm_omni/diffusion/models/nextstep_1_1/modeling_nextstep_llama.py @@ -103,7 +103,7 @@ def forward(self, x: torch.Tensor, position_ids: torch.Tensor) -> tuple[torch.Te class LlamaAttention(nn.Module): - def __init__(self, config, layer_idx: int): + def __init__(self, config, layer_idx: int, quant_config=None): super().__init__() self.config = config self.layer_idx = layer_idx @@ -122,12 +122,14 @@ def __init__(self, config, layer_idx: int): total_num_heads=self.num_heads, total_num_kv_heads=self.num_key_value_heads, bias=config.attention_bias, + quant_config=quant_config, ) # TP-aware: row-parallel output projection self.o_proj = RowParallelLinear( self.num_heads * self.head_dim, self.hidden_size, bias=getattr(config, "o_attention_bias", config.attention_bias), + quant_config=quant_config, ) def forward( @@ -205,7 +207,7 @@ def forward( class LlamaMLP(nn.Module): - def __init__(self, config): + def __init__(self, config, quant_config=None): super().__init__() self.hidden_size = config.hidden_size self.intermediate_size = config.intermediate_size @@ -215,12 +217,14 @@ def __init__(self, config): self.hidden_size, [self.intermediate_size] * 2, bias=config.mlp_bias, + quant_config=quant_config, ) # TP-aware: row-parallel down projection self.down_proj = RowParallelLinear( self.intermediate_size, self.hidden_size, bias=config.mlp_bias, + quant_config=quant_config, ) self.act_fn = nn.SiLU() @@ -237,11 +241,11 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class LlamaDecoderLayer(nn.Module): - def __init__(self, config, layer_idx: int): + def __init__(self, config, layer_idx: int, quant_config=None): super().__init__() self.hidden_size = config.hidden_size - self.self_attn = LlamaAttention(config=config, layer_idx=layer_idx) - self.mlp = LlamaMLP(config) + self.self_attn = LlamaAttention(config=config, layer_idx=layer_idx, quant_config=quant_config) + self.mlp = LlamaMLP(config, quant_config=quant_config) self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) diff --git a/vllm_omni/diffusion/models/nextstep_1_1/pipeline_nextstep_1_1.py b/vllm_omni/diffusion/models/nextstep_1_1/pipeline_nextstep_1_1.py index 4fa56ea9313..bd11b65f4b0 100644 --- a/vllm_omni/diffusion/models/nextstep_1_1/pipeline_nextstep_1_1.py +++ b/vllm_omni/diffusion/models/nextstep_1_1/pipeline_nextstep_1_1.py @@ -33,6 +33,7 @@ NextStepConfig, NextStepModel, ) +from vllm_omni.diffusion.quantization import get_vllm_quant_config_for_layers from vllm_omni.diffusion.profiler.diffusion_pipeline_profiler import DiffusionPipelineProfilerMixin from vllm_omni.diffusion.request import OmniDiffusionRequest from vllm_omni.model_executor.model_loader.weight_utils import ( @@ -169,7 +170,8 @@ def __init__( # Load model from local TP-aware code (weights loaded later via load_weights) config = NextStepConfig.from_json(os.path.join(model_path, "config.json")) - self.model = NextStepModel(config) + quant_config = get_vllm_quant_config_for_layers(od_config.quantization_config) + self.model = NextStepModel(config, quant_config=quant_config) self.model.eval() # Load config From 5d4326a226da01d165a84bcdbd512a222fc2bbbf Mon Sep 17 00:00:00 2001 From: Parth Bansal Date: Sat, 4 Apr 2026 11:53:06 +0000 Subject: [PATCH 2/3] update Signed-off-by: Parth Bansal --- .../diffusion/models/nextstep_1_1/pipeline_nextstep_1_1.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/vllm_omni/diffusion/models/nextstep_1_1/pipeline_nextstep_1_1.py b/vllm_omni/diffusion/models/nextstep_1_1/pipeline_nextstep_1_1.py index bd11b65f4b0..f71ec4cf253 100644 --- a/vllm_omni/diffusion/models/nextstep_1_1/pipeline_nextstep_1_1.py +++ b/vllm_omni/diffusion/models/nextstep_1_1/pipeline_nextstep_1_1.py @@ -33,7 +33,6 @@ NextStepConfig, NextStepModel, ) -from vllm_omni.diffusion.quantization import get_vllm_quant_config_for_layers from vllm_omni.diffusion.profiler.diffusion_pipeline_profiler import DiffusionPipelineProfilerMixin from vllm_omni.diffusion.request import OmniDiffusionRequest from vllm_omni.model_executor.model_loader.weight_utils import ( @@ -170,8 +169,7 @@ def __init__( # Load model from local TP-aware code (weights loaded later via load_weights) config = NextStepConfig.from_json(os.path.join(model_path, "config.json")) - quant_config = get_vllm_quant_config_for_layers(od_config.quantization_config) - self.model = NextStepModel(config, quant_config=quant_config) + self.model = NextStepModel(config, quant_config=od_config.quantization_config) self.model.eval() # Load config From 46940682b29b1cde59ca1eab89b9ba6b7aa7f8b0 Mon Sep 17 00:00:00 2001 From: Parth Bansal Date: Sat, 4 Apr 2026 12:13:38 +0000 Subject: [PATCH 3/3] update Signed-off-by: Parth Bansal --- .../diffusion/models/nextstep_1_1/modeling_nextstep.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/vllm_omni/diffusion/models/nextstep_1_1/modeling_nextstep.py b/vllm_omni/diffusion/models/nextstep_1_1/modeling_nextstep.py index b9d35f683a5..95eeb14de2b 100644 --- a/vllm_omni/diffusion/models/nextstep_1_1/modeling_nextstep.py +++ b/vllm_omni/diffusion/models/nextstep_1_1/modeling_nextstep.py @@ -122,8 +122,10 @@ def __init__(self, config: NextStepConfig, quant_config=None): self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) self.layers = nn.ModuleList( - [LlamaDecoderLayer(config, layer_idx, quant_config=quant_config) - for layer_idx in range(config.num_hidden_layers)] + [ + LlamaDecoderLayer(config, layer_idx, quant_config=quant_config) + for layer_idx in range(config.num_hidden_layers) + ] ) self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.rotary_emb = LlamaRotaryEmbedding(config=config)