Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 33 additions & 39 deletions crates/ruff_python_ast/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,23 @@
}


@dataclass
class VisitorInfo:
name: str
accepts_sequence: bool = False


# Map of AST node types to their corresponding visitor information.
# Only visitors that are different from the default `visit_*` method are included.
# These visitors either have a different name or accept a sequence of items.
type_to_visitor_function: dict[str, VisitorInfo] = {
"TypeParams": VisitorInfo("visit_type_params", True),
"Parameters": VisitorInfo("visit_parameters", True),
"Stmt": VisitorInfo("visit_body", True),
"Arguments": VisitorInfo("visit_arguments", True),
}


def rustfmt(code: str) -> str:
return check_output(["rustfmt", "--emit=stdout"], input=code, text=True)

Expand Down Expand Up @@ -202,6 +219,7 @@ def extract_type_argument(rust_type_str: str) -> str:
if close_bracket_index == -1 or close_bracket_index <= open_bracket_index:
raise ValueError(f"Brackets are not balanced for type {rust_type_str}")
inner_type = rust_type_str[open_bracket_index + 1 : close_bracket_index].strip()
inner_type = inner_type.replace("crate::", "")
return inner_type


Expand Down Expand Up @@ -766,39 +784,6 @@ def write_node(out: list[str], ast: Ast) -> None:
# Source order visitor


@dataclass
class VisitorInfo:
name: str
accepts_sequence: bool = False


# Map of AST node types to their corresponding visitor information
type_to_visitor_function: dict[str, VisitorInfo] = {
"Decorator": VisitorInfo("visit_decorator"),
"Identifier": VisitorInfo("visit_identifier"),
"crate::TypeParams": VisitorInfo("visit_type_params", True),
"crate::Parameters": VisitorInfo("visit_parameters", True),
"Expr": VisitorInfo("visit_expr"),
"Stmt": VisitorInfo("visit_body", True),
"Arguments": VisitorInfo("visit_arguments", True),
"crate::Arguments": VisitorInfo("visit_arguments", True),
"Operator": VisitorInfo("visit_operator"),
"ElifElseClause": VisitorInfo("visit_elif_else_clause"),
"WithItem": VisitorInfo("visit_with_item"),
"MatchCase": VisitorInfo("visit_match_case"),
"ExceptHandler": VisitorInfo("visit_except_handler"),
"Alias": VisitorInfo("visit_alias"),
"UnaryOp": VisitorInfo("visit_unary_op"),
"DictItem": VisitorInfo("visit_dict_item"),
"Comprehension": VisitorInfo("visit_comprehension"),
"CmpOp": VisitorInfo("visit_cmp_op"),
"FStringValue": VisitorInfo("visit_f_string_value"),
"StringLiteralValue": VisitorInfo("visit_string_literal"),
"BytesLiteralValue": VisitorInfo("visit_bytes_literal"),
}
annotation_visitor_function = VisitorInfo("visit_annotation")


def write_source_order(out: list[str], ast: Ast) -> None:
for group in ast.groups:
for node in group.nodes:
Expand All @@ -816,24 +801,33 @@ def write_source_order(out: list[str], ast: Ast) -> None:
fields_list += "range: _,\n"

for field in node.fields_in_source_order():
visitor = type_to_visitor_function[field.parsed_ty.inner]
visitor_name = (
type_to_visitor_function.get(
field.parsed_ty.inner, VisitorInfo("")
).name
or f"visit_{to_snake_case(field.parsed_ty.inner)}"
)
visits_sequence = type_to_visitor_function.get(
field.parsed_ty.inner, VisitorInfo("")
).accepts_sequence

if field.is_annotation:
visitor = annotation_visitor_function
visitor_name = "visit_annotation"

if field.parsed_ty.optional:
body += f"""
if let Some({field.name}) = {field.name} {{
visitor.{visitor.name}({field.name});
visitor.{visitor_name}({field.name});
}}\n
"""
elif not visitor.accepts_sequence and field.parsed_ty.seq:
elif not visits_sequence and field.parsed_ty.seq:
body += f"""
for elm in {field.name} {{
visitor.{visitor.name}(elm);
visitor.{visitor_name}(elm);
}}
"""
else:
body += f"visitor.{visitor.name}({field.name});\n"
body += f"visitor.{visitor_name}({field.name});\n"

visitor_arg_name = "visitor"
if len(node.fields_in_source_order()) == 0:
Expand Down
Loading