diff --git a/utils/modular_model_converter.py b/utils/modular_model_converter.py index d32e1d2dbe7c..2ec6e183a706 100644 --- a/utils/modular_model_converter.py +++ b/utils/modular_model_converter.py @@ -230,6 +230,25 @@ def leave_Call(self, original_node: cst.Call, updated_node: cst.Call) -> cst.Cal return updated_node +class RenameAttributeTransformer(cst.CSTTransformer): + """ + This Transformer is used to rename class attributes throughout a class definition. + For example, renaming `self.dinov2` to `self.dinov3` in all assignments and references. + """ + + def __init__(self, rename_mapping: dict[str, str]): + self.rename_mapping = rename_mapping + + def leave_Attribute(self, original_node: cst.Attribute, updated_node: cst.Attribute) -> cst.Attribute: + """Rename attribute accesses like `self.old_name` to `self.new_name`""" + if m.matches(updated_node.value, m.Name("self")) and m.matches(updated_node.attr, m.Name()): + old_name = updated_node.attr.value + if old_name in self.rename_mapping: + new_name = self.rename_mapping[old_name] + return updated_node.with_changes(attr=cst.Name(new_name)) + return updated_node + + class ReplaceSuperCallTransformer(cst.CSTTransformer): """ This Transformer is used to unravel all calls to `super().func(...)` in class methods by the explicit parent's @@ -860,7 +879,11 @@ def common_partial_suffix(str1: str, str2: str) -> str: def replace_class_node( - mapper: ModelFileMapper, modular_class_node: cst.ClassDef, renamed_super_class: str, original_super_class: str + mapper: ModelFileMapper, + modular_class_node: cst.ClassDef, + renamed_super_class: str, + original_super_class: str, + renamer: Optional[ReplaceNameTransformer] = None, ) -> cst.ClassDef: """ Replace a class node which inherits from another modeling class. This function works in the following way: @@ -946,6 +969,26 @@ def replace_class_node( elif m.matches(node, m.SimpleStatementLine(body=[m.AnnAssign()])): modular_class_attributes[node.body[0].target.value] = node + # rename attributes + attribute_renames = {} + if "_attribute_renames" in modular_class_attributes: + rename_node = modular_class_attributes["_attribute_renames"] + if m.matches(rename_node, m.SimpleStatementLine(body=[m.Assign()])): + value_node = rename_node.body[0].value + if m.matches(value_node, m.Dict()): + for element in value_node.elements: + if m.matches(element, m.DictElement()): + old_key = element.key.evaluated_value if hasattr(element.key, "evaluated_value") else None + new_val = element.value.evaluated_value if hasattr(element.value, "evaluated_value") else None + if old_key and new_val: + transformed_key = ( + preserve_case_replace(old_key, renamer.patterns, renamer.cased_new_name) + if renamer + else old_key + ) + attribute_renames[transformed_key] = new_val + modular_class_attributes.pop("_attribute_renames") + # Use all original modeling attributes, and potentially override some with values in the modular new_class_attributes = list({**original_modeling_class_attributes, **modular_class_attributes}.values()) @@ -1047,6 +1090,12 @@ def replace_class_node( ) new_class_body = new_replacement_class.body[0].body # get the indented block + if attribute_renames: + result_node = original_modeling_node.with_changes(body=new_class_body) + temp_module = cst.Module(body=[result_node]) + new_replacement_class = temp_module.visit(RenameAttributeTransformer(attribute_renames)) + new_class_body = new_replacement_class.body[0].body + return original_modeling_node.with_changes( body=new_class_body, decorators=new_class_decorators, bases=new_class_bases, name=new_class_name ) @@ -1564,7 +1613,7 @@ class node based on the inherited classes if needed. Also returns any new import renamed_super_class = preserve_case_replace(super_class, renamer.patterns, renamer.cased_new_name) # Create the new class node - updated_node = replace_class_node(mapper, node, renamed_super_class, super_class) + updated_node = replace_class_node(mapper, node, renamed_super_class, super_class, renamer) # Grab all immediate dependencies of the new node new_node_dependencies = augmented_dependencies_for_class_node(updated_node, mapper, imported_objects)