From bbd8f51d57da96bc07229995192a45514237eb4d Mon Sep 17 00:00:00 2001 From: Daniel Harding Date: Thu, 21 Nov 2024 16:59:34 +0300 Subject: [PATCH] Generalize RecursionError handling. There are many functions that are implemented using recursion besides TokenList.flatten() that could raise RecursionError. Move the try/except block handling RecursionError from TokenList.flatten() to FilterStack.run() to avoid any of them resulting in an unwanted RecursionError. --- sqlparse/engine/filter_stack.py | 32 ++++++++++++++++++-------------- sqlparse/sql.py | 14 +++++--------- 2 files changed, 23 insertions(+), 23 deletions(-) diff --git a/sqlparse/engine/filter_stack.py b/sqlparse/engine/filter_stack.py index 3feba377..415d3fc9 100644 --- a/sqlparse/engine/filter_stack.py +++ b/sqlparse/engine/filter_stack.py @@ -10,6 +10,7 @@ from sqlparse import lexer from sqlparse.engine import grouping from sqlparse.engine.statement_splitter import StatementSplitter +from sqlparse.exceptions import SQLParseError from sqlparse.filters import StripTrailingSemicolonFilter @@ -26,22 +27,25 @@ def enable_grouping(self): self._grouping = True def run(self, sql, encoding=None): - stream = lexer.tokenize(sql, encoding) - # Process token stream - for filter_ in self.preprocess: - stream = filter_.process(stream) + try: + stream = lexer.tokenize(sql, encoding) + # Process token stream + for filter_ in self.preprocess: + stream = filter_.process(stream) - stream = StatementSplitter().process(stream) + stream = StatementSplitter().process(stream) - # Output: Stream processed Statements - for stmt in stream: - if self._grouping: - stmt = grouping.group(stmt) + # Output: Stream processed Statements + for stmt in stream: + if self._grouping: + stmt = grouping.group(stmt) - for filter_ in self.stmtprocess: - filter_.process(stmt) + for filter_ in self.stmtprocess: + filter_.process(stmt) - for filter_ in self.postprocess: - stmt = filter_.process(stmt) + for filter_ in self.postprocess: + stmt = filter_.process(stmt) - yield stmt + yield stmt + except RecursionError as err: + raise SQLParseError('Maximum recursion depth exceeded') from err diff --git a/sqlparse/sql.py b/sqlparse/sql.py index 10373751..be74694c 100644 --- a/sqlparse/sql.py +++ b/sqlparse/sql.py @@ -10,7 +10,6 @@ import re from sqlparse import tokens as T -from sqlparse.exceptions import SQLParseError from sqlparse.utils import imt, remove_quotes @@ -211,14 +210,11 @@ def flatten(self): This method is recursively called for all child tokens. """ - try: - for token in self.tokens: - if token.is_group: - yield from token.flatten() - else: - yield token - except RecursionError as err: - raise SQLParseError('Maximum recursion depth exceeded') from err + for token in self.tokens: + if token.is_group: + yield from token.flatten() + else: + yield token def get_sublists(self): for token in self.tokens: