Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions examples/modular-transformers/configuration_my_new_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,16 @@ class MyNewModelConfig(PretrainedConfig):

model_type = "my_new_model"
keys_to_ignore_at_inference = ["past_key_values"]
# Default tensor parallel plan for base model `MyNewModelModel`
base_model_tp_plan = {
"layers.*.self_attn.q_proj": "colwise",
"layers.*.self_attn.k_proj": "colwise",
"layers.*.self_attn.v_proj": "colwise",
"layers.*.self_attn.o_proj": "rowwise",
"layers.*.mlp.gate_proj": "colwise",
"layers.*.mlp.up_proj": "colwise",
"layers.*.mlp.down_proj": "rowwise",
}

def __init__(
self,
Expand Down
10 changes: 10 additions & 0 deletions examples/modular-transformers/configuration_my_new_model2.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,16 @@ class MyNewModel2Config(PretrainedConfig):

model_type = "my_new_model2"
keys_to_ignore_at_inference = ["past_key_values"]
# Default tensor parallel plan for base model `MyNewModel2Model`
base_model_tp_plan = {
"layers.*.self_attn.q_proj": "colwise",
"layers.*.self_attn.k_proj": "colwise",
"layers.*.self_attn.v_proj": "colwise",
"layers.*.self_attn.o_proj": "rowwise",
"layers.*.mlp.gate_proj": "colwise",
"layers.*.mlp.up_proj": "colwise",
"layers.*.mlp.down_proj": "rowwise",
}

def __init__(
self,
Expand Down
71 changes: 17 additions & 54 deletions examples/modular-transformers/modeling_dummy.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from typing import List, Optional, Tuple, Union

import torch
import torch.nn.functional as F
from torch import nn

from ...activations import ACT2FN
Expand Down Expand Up @@ -150,25 +149,7 @@ def __init__(self, config):
self.act_fn = ACT2FN[config.hidden_act]

def forward(self, x):
if self.config.pretraining_tp > 1:
slice = self.intermediate_size // self.config.pretraining_tp
gate_proj_slices = self.gate_proj.weight.split(slice, dim=0)
up_proj_slices = self.up_proj.weight.split(slice, dim=0)
down_proj_slices = self.down_proj.weight.split(slice, dim=1)

gate_proj = torch.cat(
[F.linear(x, gate_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1
)
up_proj = torch.cat([F.linear(x, up_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1)

intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2)
down_proj = [
F.linear(intermediate_states[i], down_proj_slices[i]) for i in range(self.config.pretraining_tp)
]
down_proj = sum(down_proj)
else:
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))

down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
return down_proj


Expand Down Expand Up @@ -264,31 +245,14 @@ def forward(
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()

if self.config.pretraining_tp > 1:
key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp
query_slices = self.q_proj.weight.split(
(self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0
)
key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)

query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)]
query_states = torch.cat(query_states, dim=-1)

key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)]
key_states = torch.cat(key_states, dim=-1)

value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)]
value_states = torch.cat(value_states, dim=-1)

else:
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)

query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
# use -1 to infer num_heads and num_key_value_heads as they may vary if tensor parallel is used
query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)

if position_embeddings is None:
logger.warning_once(
Expand Down Expand Up @@ -330,12 +294,7 @@ def forward(

attn_output = attn_output.reshape(bsz, q_len, -1)

if self.config.pretraining_tp > 1:
attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)
o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1)
attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)])
else:
attn_output = self.o_proj(attn_output)
attn_output = self.o_proj(attn_output)

if not output_attentions:
attn_weights = None
Expand Down Expand Up @@ -508,9 +467,10 @@ def forward(
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)

query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
# use -1 to infer num_heads and num_key_value_heads as they may vary if tensor parallel is used
query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)

if position_embeddings is None:
logger.warning_once(
Expand Down Expand Up @@ -794,7 +754,10 @@ def __init__(self, config: DummyConfig):
)
self.norm = DummyRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.rotary_emb = DummyRotaryEmbedding(config=config)

self.gradient_checkpointing = False
if getattr(config, "pretraining_tp", 1) != 1:
logger.warn("`pretraining_tp` is deprecated, please use `model.tensor_parallel` instead.")

# Initialize weights and apply final processing
self.post_init()
Expand Down Expand Up @@ -874,7 +837,7 @@ def forward(
all_self_attns = () if output_attentions else None
next_decoder_cache = None

for decoder_layer in self.layers:
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
if output_hidden_states:
all_hidden_states += (hidden_states,)

Expand Down
5 changes: 4 additions & 1 deletion examples/modular-transformers/modeling_my_new_model2.py
Original file line number Diff line number Diff line change
Expand Up @@ -667,7 +667,10 @@ def __init__(self, config: MyNewModel2Config):
[MyNewModel2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
)
self.norm = MyNewModel2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)

self.gradient_checkpointing = False
if getattr(config, "pretraining_tp", 1) != 1:
logger.warn("`pretraining_tp` is deprecated, please use `model.tensor_parallel` instead.")

# Initialize weights and apply final processing
self.post_init()
Expand Down Expand Up @@ -752,7 +755,7 @@ def forward(
all_self_attns = () if output_attentions else None
next_decoder_cache = None

for decoder_layer in self.layers:
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
if output_hidden_states:
all_hidden_states += (hidden_states,)

Expand Down
69 changes: 16 additions & 53 deletions examples/modular-transformers/modeling_super.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from typing import List, Optional, Tuple, Union

import torch
import torch.nn.functional as F
from torch import nn

from ...activations import ACT2FN
Expand Down Expand Up @@ -150,25 +149,7 @@ def __init__(self, config):
self.act_fn = ACT2FN[config.hidden_act]

def forward(self, x):
if self.config.pretraining_tp > 1:
slice = self.intermediate_size // self.config.pretraining_tp
gate_proj_slices = self.gate_proj.weight.split(slice, dim=0)
up_proj_slices = self.up_proj.weight.split(slice, dim=0)
down_proj_slices = self.down_proj.weight.split(slice, dim=1)

gate_proj = torch.cat(
[F.linear(x, gate_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1
)
up_proj = torch.cat([F.linear(x, up_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1)

intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2)
down_proj = [
F.linear(intermediate_states[i], down_proj_slices[i]) for i in range(self.config.pretraining_tp)
]
down_proj = sum(down_proj)
else:
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))

down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
return down_proj


Expand Down Expand Up @@ -264,31 +245,14 @@ def forward(
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()

if self.config.pretraining_tp > 1:
key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp
query_slices = self.q_proj.weight.split(
(self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0
)
key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)

query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)]
query_states = torch.cat(query_states, dim=-1)

key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)]
key_states = torch.cat(key_states, dim=-1)

value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)]
value_states = torch.cat(value_states, dim=-1)

else:
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)

query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
# use -1 to infer num_heads and num_key_value_heads as they may vary if tensor parallel is used
query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)

if position_embeddings is None:
logger.warning_once(
Expand Down Expand Up @@ -330,12 +294,7 @@ def forward(

attn_output = attn_output.reshape(bsz, q_len, -1)

if self.config.pretraining_tp > 1:
attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)
o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1)
attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)])
else:
attn_output = self.o_proj(attn_output)
attn_output = self.o_proj(attn_output)

if not output_attentions:
attn_weights = None
Expand Down Expand Up @@ -508,9 +467,10 @@ def forward(
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)

query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
# use -1 to infer num_heads and num_key_value_heads as they may vary if tensor parallel is used
query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)

if position_embeddings is None:
logger.warning_once(
Expand Down Expand Up @@ -794,7 +754,10 @@ def __init__(self, config: SuperConfig):
)
self.norm = SuperRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.rotary_emb = SuperRotaryEmbedding(config=config)

self.gradient_checkpointing = False
if getattr(config, "pretraining_tp", 1) != 1:
logger.warn("`pretraining_tp` is deprecated, please use `model.tensor_parallel` instead.")

# Initialize weights and apply final processing
self.post_init()
Expand Down
13 changes: 8 additions & 5 deletions utils/modular_model_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,6 @@ def update_body(self, existing_body, new_statements):
if m.matches(stmt, m.SimpleStatementLine(body=[m.Assign()])):
target = self.python_module.code_for_node(stmt.body[0].targets[0].target)
if target in self.deleted_targets:
logger.warning(f"Deleted the assign for {target}")
continue
if target in self.all_assign_target:
stmt = self.all_assign_target[target]
Expand Down Expand Up @@ -773,6 +772,8 @@ def _merge_functions(self, functions: dict[str, cst.CSTNode], object_mapping: di
self.object_dependency_mapping.update(
{obj: dep for obj, dep in object_mapping.items() if obj in functions.keys()}
)
# Add them to global nodes
self.global_nodes.update(self.functions)

def _merge_assignments(self, assignments: dict[str, cst.CSTNode], object_mapping: dict[str, set]):
"""Update the global nodes with the assignment from the modular file.
Expand All @@ -786,6 +787,8 @@ def _merge_assignments(self, assignments: dict[str, cst.CSTNode], object_mapping
self.assignments[assignment] = node
if assignment in object_mapping:
self.object_dependency_mapping[assignment] = object_mapping[assignment]
# Add them to global nodes
self.global_nodes.update(self.assignments)

def _merge_classes(self, classes: dict[str, cst.CSTNode]):
"""Update the global nodes with the new classes from the modular (i.e. classes which do not exist in current file, and
Expand Down Expand Up @@ -813,10 +816,7 @@ def merge_modular_dependencies(self, classes, functions, assignments, object_map
self._merge_classes(classes)
self.modular_file_start_lines = start_lines

# Correctly re-set the global nodes at this point
self.global_nodes.update(self.functions)
self.global_nodes.update(self.assignments)
# Restrict the dependency mappings to the know entities to avoid Python's built-ins
# Restrict the dependency mappings to the known entities to avoid Python's built-ins and imports
self._restrict_dependencies_to_known_entities()
# Create the global mapping of recursive dependencies for functions and assignments
self.object_recursive_dependency_mapping = self._compute_recursive_object_dependencies()
Expand Down Expand Up @@ -1024,14 +1024,17 @@ def get_needed_imports(body: dict[str, dict], all_imports: list[cst.CSTNode]) ->
import_ref_count[name] = ref_count

imports_to_keep = []
existing_protected_statements = set() # str repr of the import nodes - does not work with the nodes directly
for node in all_imports:
if m.matches(node, m.If()): # handle safe imports
new_statements = []
for stmt_node in node.body.body:
append_new_import_node(stmt_node, unused_imports, new_statements)
new_statements = [stmt for stmt in new_statements if str(stmt) not in existing_protected_statements]
if len(new_statements) > 0:
new_node = node.with_changes(body=node.body.with_changes(body=new_statements))
imports_to_keep.append(new_node)
existing_protected_statements.update({str(stmt) for stmt in new_statements})
else:
append_new_import_node(node, unused_imports, imports_to_keep)

Expand Down