diff --git a/crates/ruff_python_ast/generate.py b/crates/ruff_python_ast/generate.py index 24e344908a9e0..f6afab68cc174 100644 --- a/crates/ruff_python_ast/generate.py +++ b/crates/ruff_python_ast/generate.py @@ -40,6 +40,23 @@ } +@dataclass +class VisitorInfo: + name: str + accepts_sequence: bool = False + + +# Map of AST node types to their corresponding visitor information. +# Only visitors that are different from the default `visit_*` method are included. +# These visitors either have a different name or accept a sequence of items. +type_to_visitor_function: dict[str, VisitorInfo] = { + "TypeParams": VisitorInfo("visit_type_params", True), + "Parameters": VisitorInfo("visit_parameters", True), + "Stmt": VisitorInfo("visit_body", True), + "Arguments": VisitorInfo("visit_arguments", True), +} + + def rustfmt(code: str) -> str: return check_output(["rustfmt", "--emit=stdout"], input=code, text=True) @@ -202,6 +219,7 @@ def extract_type_argument(rust_type_str: str) -> str: if close_bracket_index == -1 or close_bracket_index <= open_bracket_index: raise ValueError(f"Brackets are not balanced for type {rust_type_str}") inner_type = rust_type_str[open_bracket_index + 1 : close_bracket_index].strip() + inner_type = inner_type.replace("crate::", "") return inner_type @@ -766,39 +784,6 @@ def write_node(out: list[str], ast: Ast) -> None: # Source order visitor -@dataclass -class VisitorInfo: - name: str - accepts_sequence: bool = False - - -# Map of AST node types to their corresponding visitor information -type_to_visitor_function: dict[str, VisitorInfo] = { - "Decorator": VisitorInfo("visit_decorator"), - "Identifier": VisitorInfo("visit_identifier"), - "crate::TypeParams": VisitorInfo("visit_type_params", True), - "crate::Parameters": VisitorInfo("visit_parameters", True), - "Expr": VisitorInfo("visit_expr"), - "Stmt": VisitorInfo("visit_body", True), - "Arguments": VisitorInfo("visit_arguments", True), - "crate::Arguments": VisitorInfo("visit_arguments", True), - "Operator": VisitorInfo("visit_operator"), - "ElifElseClause": VisitorInfo("visit_elif_else_clause"), - "WithItem": VisitorInfo("visit_with_item"), - "MatchCase": VisitorInfo("visit_match_case"), - "ExceptHandler": VisitorInfo("visit_except_handler"), - "Alias": VisitorInfo("visit_alias"), - "UnaryOp": VisitorInfo("visit_unary_op"), - "DictItem": VisitorInfo("visit_dict_item"), - "Comprehension": VisitorInfo("visit_comprehension"), - "CmpOp": VisitorInfo("visit_cmp_op"), - "FStringValue": VisitorInfo("visit_f_string_value"), - "StringLiteralValue": VisitorInfo("visit_string_literal"), - "BytesLiteralValue": VisitorInfo("visit_bytes_literal"), -} -annotation_visitor_function = VisitorInfo("visit_annotation") - - def write_source_order(out: list[str], ast: Ast) -> None: for group in ast.groups: for node in group.nodes: @@ -816,24 +801,33 @@ def write_source_order(out: list[str], ast: Ast) -> None: fields_list += "range: _,\n" for field in node.fields_in_source_order(): - visitor = type_to_visitor_function[field.parsed_ty.inner] + visitor_name = ( + type_to_visitor_function.get( + field.parsed_ty.inner, VisitorInfo("") + ).name + or f"visit_{to_snake_case(field.parsed_ty.inner)}" + ) + visits_sequence = type_to_visitor_function.get( + field.parsed_ty.inner, VisitorInfo("") + ).accepts_sequence + if field.is_annotation: - visitor = annotation_visitor_function + visitor_name = "visit_annotation" if field.parsed_ty.optional: body += f""" if let Some({field.name}) = {field.name} {{ - visitor.{visitor.name}({field.name}); + visitor.{visitor_name}({field.name}); }}\n """ - elif not visitor.accepts_sequence and field.parsed_ty.seq: + elif not visits_sequence and field.parsed_ty.seq: body += f""" for elm in {field.name} {{ - visitor.{visitor.name}(elm); + visitor.{visitor_name}(elm); }} """ else: - body += f"visitor.{visitor.name}({field.name});\n" + body += f"visitor.{visitor_name}({field.name});\n" visitor_arg_name = "visitor" if len(node.fields_in_source_order()) == 0: