diff --git a/ibis-server/app/model/data_source.py b/ibis-server/app/model/data_source.py index b4972f753..efcc33d34 100644 --- a/ibis-server/app/model/data_source.py +++ b/ibis-server/app/model/data_source.py @@ -109,6 +109,16 @@ def get_connection_info( info.settings = {} if "max_execution_time" not in info.settings: info.settings["max_execution_time"] = int(session_timeout) + case DataSource.trino: + session_timeout = headers.get(X_WREN_DB_STATEMENT_TIMEOUT, 180) + if info.kwargs is None: + info.kwargs = {} + session_properties = info.kwargs.get("session_properties", {}) + if "query_max_execution_time" not in session_properties: + session_properties["query_max_execution_time"] = ( + f"{session_timeout}s" + ) + info.kwargs["session_properties"] = session_properties return info def _build_connection_info(self, data: dict) -> ConnectionInfo: diff --git a/ibis-server/app/util.py b/ibis-server/app/util.py index ca4ca801b..c34d1938c 100644 --- a/ibis-server/app/util.py +++ b/ibis-server/app/util.py @@ -8,6 +8,7 @@ import pandas as pd import psycopg import pyarrow as pa +import trino import wren_core from fastapi import Header from ibis.expr.datatypes import Decimal @@ -271,6 +272,9 @@ async def execute_with_timeout(operation, operation_name: str): except clickhouse_connect.driver.exceptions.DatabaseError as e: if "TIMEOUT_EXCEEDED" in str(e): raise DatabaseTimeoutError(f"{operation_name} was cancelled: {e}") + except trino.exceptions.TrinoQueryError as e: + if e.error_name == "EXCEEDED_TIME_LIMIT": + raise DatabaseTimeoutError(f"{operation_name} was cancelled: {e}") except psycopg.errors.QueryCanceled as e: raise DatabaseTimeoutError(f"{operation_name} was cancelled: {e}") diff --git a/ibis-server/tests/resource/trino_etc/config.properties b/ibis-server/tests/resource/trino_etc/config.properties new file mode 100644 index 000000000..9b563576c --- /dev/null +++ b/ibis-server/tests/resource/trino_etc/config.properties @@ -0,0 +1,26 @@ +# sample nodeId to provide consistency across test runs +node.id=ffffffff-ffff-ffff-ffff-ffffffffffff +node.environment=test +node.internal-address=localhost +experimental.concurrent-startup=true +http-server.http.port=8080 + + +discovery.uri=http://localhost:8080 + +exchange.http-client.max-connections-per-server=1000 +exchange.http-client.connect-timeout=1m +exchange.http-client.idle-timeout=1m + +scheduler.http-client.max-connections-per-server=1000 +scheduler.http-client.connect-timeout=1m +scheduler.http-client.idle-timeout=1m + +query.client.timeout=5m +query.min-expire-age=30m + +node-scheduler.include-coordinator=true + +sql.default-function-catalog=memory +sql.default-function-schema=default +sql.path=memory.default diff --git a/ibis-server/tests/routers/v2/connector/test_trino.py b/ibis-server/tests/routers/v2/connector/test_trino.py index dcecbcac7..08c7cc8f5 100644 --- a/ibis-server/tests/routers/v2/connector/test_trino.py +++ b/ibis-server/tests/routers/v2/connector/test_trino.py @@ -6,6 +6,7 @@ from testcontainers.trino import TrinoContainer from trino.dbapi import connect +from app.model.data_source import X_WREN_DB_STATEMENT_TIMEOUT from app.model.validator import rules pytestmark = pytest.mark.trino @@ -84,6 +85,19 @@ def trino(request) -> TrinoContainer: "COMMENT ON COLUMN memory.default.orders.comment IS 'This is a comment'" ) + cur.execute(""" +CREATE FUNCTION memory.default.sleep(x integer) + RETURNS integer + LANGUAGE PYTHON + WITH (handler = 'sleep') + AS $$ + def sleep(a): + import time + time.sleep(a) + return 0 + $$ + """) + request.addfinalizer(db.stop) return db @@ -427,6 +441,35 @@ async def test_metadata_db_version(client, trino: TrinoContainer): assert response.text is not None +async def test_connection_timeout(client, manifest_str, trino: TrinoContainer): + connection_info = _to_connection_info(trino) + # Set a very short timeout to force a timeout error + response = await client.post( + url=f"{base_url}/query", + json={ + "connectionInfo": connection_info, + "manifestStr": manifest_str, + "sql": "SELECT memory.default.sleep(3)", # This will take longer than the default timeout + }, + headers={X_WREN_DB_STATEMENT_TIMEOUT: "1"}, # Set timeout to 1 second + ) + assert response.status_code == 504 # Gateway Timeout + assert "Query was cancelled:" in response.text + + connection_info = _to_connection_url(trino) + response = await client.post( + url=f"{base_url}/query", + json={ + "connectionInfo": {"connectionUrl": connection_info}, + "manifestStr": manifest_str, + "sql": "SELECT memory.default.sleep(3)", # This will take longer than the default timeout + }, + headers={X_WREN_DB_STATEMENT_TIMEOUT: "1"}, # Set timeout to 1 second + ) + assert response.status_code == 504 # Gateway Timeout + assert "Query was cancelled:" in response.text + + def _to_connection_info(trino: TrinoContainer): return { "host": trino.get_container_host_ip(),