Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -114,15 +114,18 @@ def from_json(cls, path: str) -> NextStepConfig:


class NextStepModel(nn.Module):
def __init__(self, config: NextStepConfig):
def __init__(self, config: NextStepConfig, quant_config=None):
super().__init__()
self.config = config
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size

self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
self.layers = nn.ModuleList(
[LlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
[
LlamaDecoderLayer(config, layer_idx, quant_config=quant_config)
for layer_idx in range(config.num_hidden_layers)
]
)
self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.rotary_emb = LlamaRotaryEmbedding(config=config)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def forward(self, x: torch.Tensor, position_ids: torch.Tensor) -> tuple[torch.Te


class LlamaAttention(nn.Module):
def __init__(self, config, layer_idx: int):
def __init__(self, config, layer_idx: int, quant_config=None):
super().__init__()
self.config = config
self.layer_idx = layer_idx
Expand All @@ -122,12 +122,14 @@ def __init__(self, config, layer_idx: int):
total_num_heads=self.num_heads,
total_num_kv_heads=self.num_key_value_heads,
bias=config.attention_bias,
quant_config=quant_config,
)
# TP-aware: row-parallel output projection
self.o_proj = RowParallelLinear(
self.num_heads * self.head_dim,
self.hidden_size,
bias=getattr(config, "o_attention_bias", config.attention_bias),
quant_config=quant_config,
)

def forward(
Expand Down Expand Up @@ -205,7 +207,7 @@ def forward(


class LlamaMLP(nn.Module):
def __init__(self, config):
def __init__(self, config, quant_config=None):
super().__init__()
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
Expand All @@ -215,12 +217,14 @@ def __init__(self, config):
self.hidden_size,
[self.intermediate_size] * 2,
bias=config.mlp_bias,
quant_config=quant_config,
)
# TP-aware: row-parallel down projection
self.down_proj = RowParallelLinear(
self.intermediate_size,
self.hidden_size,
bias=config.mlp_bias,
quant_config=quant_config,
)
self.act_fn = nn.SiLU()

Expand All @@ -237,11 +241,11 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:


class LlamaDecoderLayer(nn.Module):
def __init__(self, config, layer_idx: int):
def __init__(self, config, layer_idx: int, quant_config=None):
super().__init__()
self.hidden_size = config.hidden_size
self.self_attn = LlamaAttention(config=config, layer_idx=layer_idx)
self.mlp = LlamaMLP(config)
self.self_attn = LlamaAttention(config=config, layer_idx=layer_idx, quant_config=quant_config)
self.mlp = LlamaMLP(config, quant_config=quant_config)
self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ def __init__(

# Load model from local TP-aware code (weights loaded later via load_weights)
config = NextStepConfig.from_json(os.path.join(model_path, "config.json"))
self.model = NextStepModel(config)
self.model = NextStepModel(config, quant_config=od_config.quantization_config)
self.model.eval()

# Load config
Expand Down
Loading