Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 51 additions & 2 deletions utils/modular_model_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,25 @@ def leave_Call(self, original_node: cst.Call, updated_node: cst.Call) -> cst.Cal
return updated_node


class RenameAttributeTransformer(cst.CSTTransformer):
"""
This Transformer is used to rename class attributes throughout a class definition.
For example, renaming `self.dinov2` to `self.dinov3` in all assignments and references.
"""

def __init__(self, rename_mapping: dict[str, str]):
self.rename_mapping = rename_mapping

def leave_Attribute(self, original_node: cst.Attribute, updated_node: cst.Attribute) -> cst.Attribute:
"""Rename attribute accesses like `self.old_name` to `self.new_name`"""
if m.matches(updated_node.value, m.Name("self")) and m.matches(updated_node.attr, m.Name()):
old_name = updated_node.attr.value
if old_name in self.rename_mapping:
new_name = self.rename_mapping[old_name]
return updated_node.with_changes(attr=cst.Name(new_name))
return updated_node


class ReplaceSuperCallTransformer(cst.CSTTransformer):
"""
This Transformer is used to unravel all calls to `super().func(...)` in class methods by the explicit parent's
Expand Down Expand Up @@ -860,7 +879,11 @@ def common_partial_suffix(str1: str, str2: str) -> str:


def replace_class_node(
mapper: ModelFileMapper, modular_class_node: cst.ClassDef, renamed_super_class: str, original_super_class: str
mapper: ModelFileMapper,
modular_class_node: cst.ClassDef,
renamed_super_class: str,
original_super_class: str,
renamer: Optional[ReplaceNameTransformer] = None,
) -> cst.ClassDef:
"""
Replace a class node which inherits from another modeling class. This function works in the following way:
Expand Down Expand Up @@ -946,6 +969,26 @@ def replace_class_node(
elif m.matches(node, m.SimpleStatementLine(body=[m.AnnAssign()])):
modular_class_attributes[node.body[0].target.value] = node

# rename attributes
attribute_renames = {}
if "_attribute_renames" in modular_class_attributes:
rename_node = modular_class_attributes["_attribute_renames"]
if m.matches(rename_node, m.SimpleStatementLine(body=[m.Assign()])):
value_node = rename_node.body[0].value
if m.matches(value_node, m.Dict()):
for element in value_node.elements:
if m.matches(element, m.DictElement()):
old_key = element.key.evaluated_value if hasattr(element.key, "evaluated_value") else None
new_val = element.value.evaluated_value if hasattr(element.value, "evaluated_value") else None
if old_key and new_val:
transformed_key = (
preserve_case_replace(old_key, renamer.patterns, renamer.cased_new_name)
if renamer
else old_key
)
attribute_renames[transformed_key] = new_val
modular_class_attributes.pop("_attribute_renames")

# Use all original modeling attributes, and potentially override some with values in the modular
new_class_attributes = list({**original_modeling_class_attributes, **modular_class_attributes}.values())

Expand Down Expand Up @@ -1047,6 +1090,12 @@ def replace_class_node(
)
new_class_body = new_replacement_class.body[0].body # get the indented block

if attribute_renames:
result_node = original_modeling_node.with_changes(body=new_class_body)
temp_module = cst.Module(body=[result_node])
new_replacement_class = temp_module.visit(RenameAttributeTransformer(attribute_renames))
new_class_body = new_replacement_class.body[0].body

return original_modeling_node.with_changes(
body=new_class_body, decorators=new_class_decorators, bases=new_class_bases, name=new_class_name
)
Expand Down Expand Up @@ -1564,7 +1613,7 @@ class node based on the inherited classes if needed. Also returns any new import
renamed_super_class = preserve_case_replace(super_class, renamer.patterns, renamer.cased_new_name)

# Create the new class node
updated_node = replace_class_node(mapper, node, renamed_super_class, super_class)
updated_node = replace_class_node(mapper, node, renamed_super_class, super_class, renamer)

# Grab all immediate dependencies of the new node
new_node_dependencies = augmented_dependencies_for_class_node(updated_node, mapper, imported_objects)
Expand Down