diff --git a/vllm_omni/diffusion/models/nextstep_1_1/modeling_nextstep.py b/vllm_omni/diffusion/models/nextstep_1_1/modeling_nextstep.py index ded3079265e..95eeb14de2b 100644 --- a/vllm_omni/diffusion/models/nextstep_1_1/modeling_nextstep.py +++ b/vllm_omni/diffusion/models/nextstep_1_1/modeling_nextstep.py @@ -114,7 +114,7 @@ def from_json(cls, path: str) -> NextStepConfig: class NextStepModel(nn.Module): - def __init__(self, config: NextStepConfig): + def __init__(self, config: NextStepConfig, quant_config=None): super().__init__() self.config = config self.padding_idx = config.pad_token_id @@ -122,7 +122,10 @@ def __init__(self, config: NextStepConfig): self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) self.layers = nn.ModuleList( - [LlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + [ + LlamaDecoderLayer(config, layer_idx, quant_config=quant_config) + for layer_idx in range(config.num_hidden_layers) + ] ) self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.rotary_emb = LlamaRotaryEmbedding(config=config) diff --git a/vllm_omni/diffusion/models/nextstep_1_1/modeling_nextstep_llama.py b/vllm_omni/diffusion/models/nextstep_1_1/modeling_nextstep_llama.py index 7b367b6ff49..fce34359a06 100644 --- a/vllm_omni/diffusion/models/nextstep_1_1/modeling_nextstep_llama.py +++ b/vllm_omni/diffusion/models/nextstep_1_1/modeling_nextstep_llama.py @@ -103,7 +103,7 @@ def forward(self, x: torch.Tensor, position_ids: torch.Tensor) -> tuple[torch.Te class LlamaAttention(nn.Module): - def __init__(self, config, layer_idx: int): + def __init__(self, config, layer_idx: int, quant_config=None): super().__init__() self.config = config self.layer_idx = layer_idx @@ -122,12 +122,14 @@ def __init__(self, config, layer_idx: int): total_num_heads=self.num_heads, total_num_kv_heads=self.num_key_value_heads, bias=config.attention_bias, + quant_config=quant_config, ) # TP-aware: row-parallel output projection self.o_proj = RowParallelLinear( self.num_heads * self.head_dim, self.hidden_size, bias=getattr(config, "o_attention_bias", config.attention_bias), + quant_config=quant_config, ) def forward( @@ -205,7 +207,7 @@ def forward( class LlamaMLP(nn.Module): - def __init__(self, config): + def __init__(self, config, quant_config=None): super().__init__() self.hidden_size = config.hidden_size self.intermediate_size = config.intermediate_size @@ -215,12 +217,14 @@ def __init__(self, config): self.hidden_size, [self.intermediate_size] * 2, bias=config.mlp_bias, + quant_config=quant_config, ) # TP-aware: row-parallel down projection self.down_proj = RowParallelLinear( self.intermediate_size, self.hidden_size, bias=config.mlp_bias, + quant_config=quant_config, ) self.act_fn = nn.SiLU() @@ -237,11 +241,11 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class LlamaDecoderLayer(nn.Module): - def __init__(self, config, layer_idx: int): + def __init__(self, config, layer_idx: int, quant_config=None): super().__init__() self.hidden_size = config.hidden_size - self.self_attn = LlamaAttention(config=config, layer_idx=layer_idx) - self.mlp = LlamaMLP(config) + self.self_attn = LlamaAttention(config=config, layer_idx=layer_idx, quant_config=quant_config) + self.mlp = LlamaMLP(config, quant_config=quant_config) self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) diff --git a/vllm_omni/diffusion/models/nextstep_1_1/pipeline_nextstep_1_1.py b/vllm_omni/diffusion/models/nextstep_1_1/pipeline_nextstep_1_1.py index 4fa56ea9313..f71ec4cf253 100644 --- a/vllm_omni/diffusion/models/nextstep_1_1/pipeline_nextstep_1_1.py +++ b/vllm_omni/diffusion/models/nextstep_1_1/pipeline_nextstep_1_1.py @@ -169,7 +169,7 @@ def __init__( # Load model from local TP-aware code (weights loaded later via load_weights) config = NextStepConfig.from_json(os.path.join(model_path, "config.json")) - self.model = NextStepModel(config) + self.model = NextStepModel(config, quant_config=od_config.quantization_config) self.model.eval() # Load config