diff --git a/superset/db_engine_specs/trino.py b/superset/db_engine_specs/trino.py index 425137e302e6..125a96ab8230 100644 --- a/superset/db_engine_specs/trino.py +++ b/superset/db_engine_specs/trino.py @@ -184,7 +184,7 @@ def handle_cursor(cls, cursor: Cursor, query: Query, session: Session) -> None: @classmethod def execute_with_cursor( - cls, cursor: Any, sql: str, query: Query, session: Session + cls, cursor: Cursor, sql: str, query: Query, session: Session ) -> None: """ Trigger execution of a query and handle the resulting cursor. @@ -193,34 +193,40 @@ def execute_with_cursor( in another thread and invoke `handle_cursor` to poll for the query ID to appear on the cursor in parallel. """ + # Fetch the query ID beforehand, since it might fail inside the thread due to + # how the SQLAlchemy session is handled. + query_id = query.id + execute_result: dict[str, Any] = {} + execute_event = threading.Event() - def _execute(results: dict[str, Any]) -> None: - logger.debug("Query %d: Running query: %s", query.id, sql) + def _execute(results: dict[str, Any], event: threading.Event) -> None: + logger.debug("Query %d: Running query: %s", query_id, sql) - # Pass result / exception information back to the parent thread try: cls.execute(cursor, sql) - results["complete"] = True except Exception as ex: # pylint: disable=broad-except - results["complete"] = True results["error"] = ex + finally: + event.set() - execute_thread = threading.Thread(target=_execute, args=(execute_result,)) + execute_thread = threading.Thread( + target=_execute, + args=(execute_result, execute_event), + ) execute_thread.start() # Wait for a query ID to be available before handling the cursor, as # it's required by that method; it may never become available on error. - while not cursor.query_id and not execute_result.get("complete"): + while not cursor.query_id and not execute_event.is_set(): time.sleep(0.1) - logger.debug("Query %d: Handling cursor", query.id) + logger.debug("Query %d: Handling cursor", query_id) cls.handle_cursor(cursor, query, session) # Block until the query completes; same behaviour as the client itself - logger.debug("Query %d: Waiting for query to complete", query.id) - while not execute_result.get("complete"): - time.sleep(0.5) + logger.debug("Query %d: Waiting for query to complete", query_id) + execute_event.wait() # Unfortunately we'll mangle the stack trace due to the thread, but # throwing the original exception allows mapping database errors as normal @@ -234,7 +240,7 @@ def prepare_cancel_query(cls, query: Query, session: Session) -> None: session.commit() @classmethod - def cancel_query(cls, cursor: Any, query: Query, cancel_query_id: str) -> bool: + def cancel_query(cls, cursor: Cursor, query: Query, cancel_query_id: str) -> bool: """ Cancel query in the underlying database.