From 51b8cedfdc048ecfeb3b202d0f340d68dfd18ee9 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Tue, 19 Nov 2024 12:30:14 +0100 Subject: [PATCH 1/4] Modular fix --- .../configuration_my_new_model.py | 10 +++ .../configuration_my_new_model2.py | 10 +++ .../modular-transformers/modeling_dummy.py | 69 +++++-------------- .../modeling_my_new_model2.py | 3 + .../modular-transformers/modeling_super.py | 69 +++++-------------- utils/modular_model_converter.py | 12 ++-- 6 files changed, 63 insertions(+), 110 deletions(-) diff --git a/examples/modular-transformers/configuration_my_new_model.py b/examples/modular-transformers/configuration_my_new_model.py index aa0aac55ba91..7042c586cbb6 100644 --- a/examples/modular-transformers/configuration_my_new_model.py +++ b/examples/modular-transformers/configuration_my_new_model.py @@ -130,6 +130,16 @@ class MyNewModelConfig(PretrainedConfig): model_type = "my_new_model" keys_to_ignore_at_inference = ["past_key_values"] + # Default tensor parallel plan for base model `MyNewModelModel` + base_model_tp_plan = { + "layers.*.self_attn.q_proj": "colwise", + "layers.*.self_attn.k_proj": "colwise", + "layers.*.self_attn.v_proj": "colwise", + "layers.*.self_attn.o_proj": "rowwise", + "layers.*.mlp.gate_proj": "colwise", + "layers.*.mlp.up_proj": "colwise", + "layers.*.mlp.down_proj": "rowwise", + } def __init__( self, diff --git a/examples/modular-transformers/configuration_my_new_model2.py b/examples/modular-transformers/configuration_my_new_model2.py index f05ace94b622..eddd7fe47973 100644 --- a/examples/modular-transformers/configuration_my_new_model2.py +++ b/examples/modular-transformers/configuration_my_new_model2.py @@ -33,6 +33,16 @@ class MyNewModel2Config(PretrainedConfig): model_type = "my_new_model2" keys_to_ignore_at_inference = ["past_key_values"] + # Default tensor parallel plan for base model `MyNewModel2Model` + base_model_tp_plan = { + "layers.*.self_attn.q_proj": "colwise", + "layers.*.self_attn.k_proj": "colwise", + "layers.*.self_attn.v_proj": "colwise", + "layers.*.self_attn.o_proj": "rowwise", + "layers.*.mlp.gate_proj": "colwise", + "layers.*.mlp.up_proj": "colwise", + "layers.*.mlp.down_proj": "rowwise", + } def __init__( self, diff --git a/examples/modular-transformers/modeling_dummy.py b/examples/modular-transformers/modeling_dummy.py index ed7e3c64d7a8..6fec1793e883 100644 --- a/examples/modular-transformers/modeling_dummy.py +++ b/examples/modular-transformers/modeling_dummy.py @@ -8,7 +8,6 @@ from typing import List, Optional, Tuple, Union import torch -import torch.nn.functional as F from torch import nn from ...activations import ACT2FN @@ -150,25 +149,7 @@ def __init__(self, config): self.act_fn = ACT2FN[config.hidden_act] def forward(self, x): - if self.config.pretraining_tp > 1: - slice = self.intermediate_size // self.config.pretraining_tp - gate_proj_slices = self.gate_proj.weight.split(slice, dim=0) - up_proj_slices = self.up_proj.weight.split(slice, dim=0) - down_proj_slices = self.down_proj.weight.split(slice, dim=1) - - gate_proj = torch.cat( - [F.linear(x, gate_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1 - ) - up_proj = torch.cat([F.linear(x, up_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1) - - intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2) - down_proj = [ - F.linear(intermediate_states[i], down_proj_slices[i]) for i in range(self.config.pretraining_tp) - ] - down_proj = sum(down_proj) - else: - down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) - + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) return down_proj @@ -264,31 +245,14 @@ def forward( ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() - if self.config.pretraining_tp > 1: - key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp - query_slices = self.q_proj.weight.split( - (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0 - ) - key_slices = self.k_proj.weight.split(key_value_slicing, dim=0) - value_slices = self.v_proj.weight.split(key_value_slicing, dim=0) - - query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)] - query_states = torch.cat(query_states, dim=-1) - - key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)] - key_states = torch.cat(key_states, dim=-1) - - value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)] - value_states = torch.cat(value_states, dim=-1) - - else: - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + # use -1 to infer num_heads and num_key_value_heads as they may vary if tensor parallel is used + query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) if position_embeddings is None: logger.warning_once( @@ -330,12 +294,7 @@ def forward( attn_output = attn_output.reshape(bsz, q_len, -1) - if self.config.pretraining_tp > 1: - attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2) - o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1) - attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)]) - else: - attn_output = self.o_proj(attn_output) + attn_output = self.o_proj(attn_output) if not output_attentions: attn_weights = None @@ -508,9 +467,10 @@ def forward( key_states = self.k_proj(hidden_states) value_states = self.v_proj(hidden_states) - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + # use -1 to infer num_heads and num_key_value_heads as they may vary if tensor parallel is used + query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) if position_embeddings is None: logger.warning_once( @@ -794,7 +754,10 @@ def __init__(self, config: DummyConfig): ) self.norm = DummyRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.rotary_emb = DummyRotaryEmbedding(config=config) + self.gradient_checkpointing = False + if getattr(config, "pretraining_tp", 1) != 1: + logger.warn("`pretraining_tp` is deprecated, please use `model.tensor_parallel` instead.") # Initialize weights and apply final processing self.post_init() diff --git a/examples/modular-transformers/modeling_my_new_model2.py b/examples/modular-transformers/modeling_my_new_model2.py index 16f9e525a05e..77294d89abfd 100644 --- a/examples/modular-transformers/modeling_my_new_model2.py +++ b/examples/modular-transformers/modeling_my_new_model2.py @@ -667,7 +667,10 @@ def __init__(self, config: MyNewModel2Config): [MyNewModel2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] ) self.norm = MyNewModel2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.gradient_checkpointing = False + if getattr(config, "pretraining_tp", 1) != 1: + logger.warn("`pretraining_tp` is deprecated, please use `model.tensor_parallel` instead.") # Initialize weights and apply final processing self.post_init() diff --git a/examples/modular-transformers/modeling_super.py b/examples/modular-transformers/modeling_super.py index 7df04bcc2a99..7ad606280dcc 100644 --- a/examples/modular-transformers/modeling_super.py +++ b/examples/modular-transformers/modeling_super.py @@ -8,7 +8,6 @@ from typing import List, Optional, Tuple, Union import torch -import torch.nn.functional as F from torch import nn from ...activations import ACT2FN @@ -150,25 +149,7 @@ def __init__(self, config): self.act_fn = ACT2FN[config.hidden_act] def forward(self, x): - if self.config.pretraining_tp > 1: - slice = self.intermediate_size // self.config.pretraining_tp - gate_proj_slices = self.gate_proj.weight.split(slice, dim=0) - up_proj_slices = self.up_proj.weight.split(slice, dim=0) - down_proj_slices = self.down_proj.weight.split(slice, dim=1) - - gate_proj = torch.cat( - [F.linear(x, gate_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1 - ) - up_proj = torch.cat([F.linear(x, up_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1) - - intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2) - down_proj = [ - F.linear(intermediate_states[i], down_proj_slices[i]) for i in range(self.config.pretraining_tp) - ] - down_proj = sum(down_proj) - else: - down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) - + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) return down_proj @@ -264,31 +245,14 @@ def forward( ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() - if self.config.pretraining_tp > 1: - key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp - query_slices = self.q_proj.weight.split( - (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0 - ) - key_slices = self.k_proj.weight.split(key_value_slicing, dim=0) - value_slices = self.v_proj.weight.split(key_value_slicing, dim=0) - - query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)] - query_states = torch.cat(query_states, dim=-1) - - key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)] - key_states = torch.cat(key_states, dim=-1) - - value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)] - value_states = torch.cat(value_states, dim=-1) - - else: - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + # use -1 to infer num_heads and num_key_value_heads as they may vary if tensor parallel is used + query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) if position_embeddings is None: logger.warning_once( @@ -330,12 +294,7 @@ def forward( attn_output = attn_output.reshape(bsz, q_len, -1) - if self.config.pretraining_tp > 1: - attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2) - o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1) - attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)]) - else: - attn_output = self.o_proj(attn_output) + attn_output = self.o_proj(attn_output) if not output_attentions: attn_weights = None @@ -508,9 +467,10 @@ def forward( key_states = self.k_proj(hidden_states) value_states = self.v_proj(hidden_states) - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + # use -1 to infer num_heads and num_key_value_heads as they may vary if tensor parallel is used + query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) if position_embeddings is None: logger.warning_once( @@ -794,7 +754,10 @@ def __init__(self, config: SuperConfig): ) self.norm = SuperRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.rotary_emb = SuperRotaryEmbedding(config=config) + self.gradient_checkpointing = False + if getattr(config, "pretraining_tp", 1) != 1: + logger.warn("`pretraining_tp` is deprecated, please use `model.tensor_parallel` instead.") # Initialize weights and apply final processing self.post_init() diff --git a/utils/modular_model_converter.py b/utils/modular_model_converter.py index e5f6e34ece0e..1ca1b9b4c36b 100644 --- a/utils/modular_model_converter.py +++ b/utils/modular_model_converter.py @@ -773,6 +773,8 @@ def _merge_functions(self, functions: dict[str, cst.CSTNode], object_mapping: di self.object_dependency_mapping.update( {obj: dep for obj, dep in object_mapping.items() if obj in functions.keys()} ) + # Add them to global nodes + self.global_nodes.update(self.functions) def _merge_assignments(self, assignments: dict[str, cst.CSTNode], object_mapping: dict[str, set]): """Update the global nodes with the assignment from the modular file. @@ -786,6 +788,8 @@ def _merge_assignments(self, assignments: dict[str, cst.CSTNode], object_mapping self.assignments[assignment] = node if assignment in object_mapping: self.object_dependency_mapping[assignment] = object_mapping[assignment] + # Add them to global nodes + self.global_nodes.update(self.assignments) def _merge_classes(self, classes: dict[str, cst.CSTNode]): """Update the global nodes with the new classes from the modular (i.e. classes which do not exist in current file, and @@ -813,10 +817,7 @@ def merge_modular_dependencies(self, classes, functions, assignments, object_map self._merge_classes(classes) self.modular_file_start_lines = start_lines - # Correctly re-set the global nodes at this point - self.global_nodes.update(self.functions) - self.global_nodes.update(self.assignments) - # Restrict the dependency mappings to the know entities to avoid Python's built-ins + # Restrict the dependency mappings to the known entities to avoid Python's built-ins and imports self._restrict_dependencies_to_known_entities() # Create the global mapping of recursive dependencies for functions and assignments self.object_recursive_dependency_mapping = self._compute_recursive_object_dependencies() @@ -1024,14 +1025,17 @@ def get_needed_imports(body: dict[str, dict], all_imports: list[cst.CSTNode]) -> import_ref_count[name] = ref_count imports_to_keep = [] + existing_protected_statements = set() # str representation of the import node statements - does not work with the nodes directly for node in all_imports: if m.matches(node, m.If()): # handle safe imports new_statements = [] for stmt_node in node.body.body: append_new_import_node(stmt_node, unused_imports, new_statements) + new_statements = [stmt for stmt in new_statements if str(stmt) not in existing_protected_statements] if len(new_statements) > 0: new_node = node.with_changes(body=node.body.with_changes(body=new_statements)) imports_to_keep.append(new_node) + existing_protected_statements.update({str(stmt) for stmt in new_statements}) else: append_new_import_node(node, unused_imports, imports_to_keep) From 540192d074ae61e8fe33061022b48f90ee83b7e9 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Tue, 19 Nov 2024 12:35:37 +0100 Subject: [PATCH 2/4] style --- utils/modular_model_converter.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/utils/modular_model_converter.py b/utils/modular_model_converter.py index 1ca1b9b4c36b..9efcab8978ac 100644 --- a/utils/modular_model_converter.py +++ b/utils/modular_model_converter.py @@ -1025,7 +1025,7 @@ def get_needed_imports(body: dict[str, dict], all_imports: list[cst.CSTNode]) -> import_ref_count[name] = ref_count imports_to_keep = [] - existing_protected_statements = set() # str representation of the import node statements - does not work with the nodes directly + existing_protected_statements = set() # str repr of the import node - does not work with the nodes directly for node in all_imports: if m.matches(node, m.If()): # handle safe imports new_statements = [] From 2f0eda9730dde9cd4accab615e424ad74094a2b7 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Tue, 19 Nov 2024 15:55:11 +0100 Subject: [PATCH 3/4] remove logger warning --- examples/modular-transformers/modeling_dummy.py | 2 +- examples/modular-transformers/modeling_my_new_model2.py | 2 +- utils/modular_model_converter.py | 1 - 3 files changed, 2 insertions(+), 3 deletions(-) diff --git a/examples/modular-transformers/modeling_dummy.py b/examples/modular-transformers/modeling_dummy.py index 6fec1793e883..0b373d4e6eab 100644 --- a/examples/modular-transformers/modeling_dummy.py +++ b/examples/modular-transformers/modeling_dummy.py @@ -837,7 +837,7 @@ def forward( all_self_attns = () if output_attentions else None next_decoder_cache = None - for decoder_layer in self.layers: + for decoder_layer in self.layers[: self.config.num_hidden_layers]: if output_hidden_states: all_hidden_states += (hidden_states,) diff --git a/examples/modular-transformers/modeling_my_new_model2.py b/examples/modular-transformers/modeling_my_new_model2.py index 77294d89abfd..189e090094c7 100644 --- a/examples/modular-transformers/modeling_my_new_model2.py +++ b/examples/modular-transformers/modeling_my_new_model2.py @@ -755,7 +755,7 @@ def forward( all_self_attns = () if output_attentions else None next_decoder_cache = None - for decoder_layer in self.layers: + for decoder_layer in self.layers[: self.config.num_hidden_layers]: if output_hidden_states: all_hidden_states += (hidden_states,) diff --git a/utils/modular_model_converter.py b/utils/modular_model_converter.py index 9efcab8978ac..754f8b106ca0 100644 --- a/utils/modular_model_converter.py +++ b/utils/modular_model_converter.py @@ -266,7 +266,6 @@ def update_body(self, existing_body, new_statements): if m.matches(stmt, m.SimpleStatementLine(body=[m.Assign()])): target = self.python_module.code_for_node(stmt.body[0].targets[0].target) if target in self.deleted_targets: - logger.warning(f"Deleted the assign for {target}") continue if target in self.all_assign_target: stmt = self.all_assign_target[target] From 5d04d9e0427291741b9f1034debc2bea737ce1cd Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Tue, 19 Nov 2024 15:59:26 +0100 Subject: [PATCH 4/4] Update modular_model_converter.py --- utils/modular_model_converter.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/utils/modular_model_converter.py b/utils/modular_model_converter.py index 754f8b106ca0..ccf15363de92 100644 --- a/utils/modular_model_converter.py +++ b/utils/modular_model_converter.py @@ -1024,7 +1024,7 @@ def get_needed_imports(body: dict[str, dict], all_imports: list[cst.CSTNode]) -> import_ref_count[name] = ref_count imports_to_keep = [] - existing_protected_statements = set() # str repr of the import node - does not work with the nodes directly + existing_protected_statements = set() # str repr of the import nodes - does not work with the nodes directly for node in all_imports: if m.matches(node, m.If()): # handle safe imports new_statements = []