diff --git a/tests/integration/test_dbapi_integration.py b/tests/integration/test_dbapi_integration.py index dd37a14b..87d29284 100644 --- a/tests/integration/test_dbapi_integration.py +++ b/tests/integration/test_dbapi_integration.py @@ -82,6 +82,7 @@ def test_select_query(trino_connection): assert columns["coordinator"] == "boolean" assert columns["state"] == "varchar" assert cur.query_id is not None + assert cur.query == "SELECT * FROM system.runtime.nodes" assert cur.stats is not None @@ -135,37 +136,45 @@ def test_string_query_param(trino_connection): def test_execute_many(trino_connection): - cur = trino_connection.cursor() - cur.execute("CREATE TABLE memory.default.test_execute_many (key int, value varchar)") - cur.fetchall() - operation = "INSERT INTO memory.default.test_execute_many (key, value) VALUES (?, ?)" - cur.executemany(operation, [(1, "value1")]) - cur.fetchall() - cur.execute("SELECT * FROM memory.default.test_execute_many ORDER BY key") - rows = cur.fetchall() - assert len(list(rows)) == 1 - assert rows[0] == [1, "value1"] + try: + cur = trino_connection.cursor() + cur.execute("CREATE TABLE memory.default.test_execute_many (key int, value varchar)") + cur.fetchall() + operation = "INSERT INTO memory.default.test_execute_many (key, value) VALUES (?, ?)" + cur.executemany(operation, [(1, "value1")]) + cur.fetchall() + cur.execute("SELECT * FROM memory.default.test_execute_many ORDER BY key") + rows = cur.fetchall() + assert len(list(rows)) == 1 + assert rows[0] == [1, "value1"] - operation = "INSERT INTO memory.default.test_execute_many (key, value) VALUES (?, ?)" - cur.executemany(operation, [(2, "value2"), (3, "value3")]) - cur.fetchall() + operation = "INSERT INTO memory.default.test_execute_many (key, value) VALUES (?, ?)" + cur.executemany(operation, [(2, "value2"), (3, "value3")]) + cur.fetchall() - cur.execute("SELECT * FROM memory.default.test_execute_many ORDER BY key") - rows = cur.fetchall() - assert len(list(rows)) == 3 - assert rows[0] == [1, "value1"] - assert rows[1] == [2, "value2"] - assert rows[2] == [3, "value3"] + cur.execute("SELECT * FROM memory.default.test_execute_many ORDER BY key") + rows = cur.fetchall() + assert len(list(rows)) == 3 + assert rows[0] == [1, "value1"] + assert rows[1] == [2, "value2"] + assert rows[2] == [3, "value3"] + finally: + cur = trino_connection.cursor() + cur.execute("DROP TABLE IF EXISTS memory.default.test_execute_many") def test_execute_many_without_params(trino_connection): - cur = trino_connection.cursor() - cur.execute("CREATE TABLE memory.default.test_execute_many_without_param (value varchar)") - cur.fetchall() - with pytest.raises(TrinoUserError) as e: - cur.executemany("INSERT INTO memory.default.test_execute_many_without_param (value) VALUES (?)", []) + try: + cur = trino_connection.cursor() + cur.execute("CREATE TABLE memory.default.test_execute_many_without_param (value varchar)") cur.fetchall() - assert "Incorrect number of parameters: expected 1 but found 0" in str(e.value) + with pytest.raises(TrinoUserError) as e: + cur.executemany("INSERT INTO memory.default.test_execute_many_without_param (value) VALUES (?)", []) + cur.fetchall() + assert "Incorrect number of parameters: expected 1 but found 0" in str(e.value) + finally: + cur = trino_connection.cursor() + cur.execute("DROP TABLE IF EXISTS memory.default.test_execute_many_without_param") def test_execute_many_select(trino_connection): diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index fa5838b6..123ddf86 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -1006,7 +1006,7 @@ def json(self): with mock.patch.object(req, 'post', return_value=MockResponse()) as mock_post: query = TrinoQuery( request=req, - sql=sql + query=sql ) result = query.execute(additional_http_headers=additional_headers) diff --git a/trino/client.py b/trino/client.py index 7d6b1752..593d579b 100644 --- a/trino/client.py +++ b/trino/client.py @@ -732,11 +732,10 @@ class TrinoQuery(object): def __init__( self, request: TrinoRequest, - sql: str, + query: str, legacy_primitive_types: bool = False, ) -> None: - self.query_id: Optional[str] = None - + self._query_id: Optional[str] = None self._stats: Dict[Any, Any] = {} self._info_uri: Optional[str] = None self._warnings: List[Dict[Any, Any]] = [] @@ -747,11 +746,19 @@ def __init__( self._update_type = None self._update_count = None self._next_uri = None - self._sql = sql + self._query = query self._result: Optional[TrinoResult] = None self._legacy_primitive_types = legacy_primitive_types self._row_mapper: Optional[RowMapper] = None + @property + def query_id(self) -> Optional[str]: + return self._query_id + + @property + def query(self) -> Optional[str]: + return self._query + @property def columns(self): if self.query_id: @@ -796,10 +803,10 @@ def execute(self, additional_http_headers=None) -> TrinoResult: if self.cancelled: raise exceptions.TrinoUserError("Query has been cancelled", self.query_id) - response = self._request.post(self._sql, additional_http_headers) + response = self._request.post(self._query, additional_http_headers) status = self._request.process(response) self._info_uri = status.info_uri - self.query_id = status.id + self._query_id = status.id self._stats.update({"queryId": self.query_id}) self._update_state(status) self._warnings = getattr(status, "warnings", []) diff --git a/trino/dbapi.py b/trino/dbapi.py index d8ae5f72..ae28d088 100644 --- a/trino/dbapi.py +++ b/trino/dbapi.py @@ -343,6 +343,12 @@ def query_id(self) -> Optional[str]: return self._query.query_id return None + @property + def query(self) -> Optional[str]: + if self._query is not None: + return self._query.query + return None + @property def warnings(self): if self._query is not None: @@ -364,7 +370,7 @@ def _prepare_statement(self, statement: str, name: str) -> None: :param name: name that will be assigned to the prepared statement. """ sql = f"PREPARE {name} FROM {statement}" - query = trino.client.TrinoQuery(self.connection._create_request(), sql=sql, + query = trino.client.TrinoQuery(self.connection._create_request(), query=sql, legacy_primitive_types=self._legacy_primitive_types) query.execute() @@ -374,7 +380,7 @@ def _execute_prepared_statement( params ): sql = 'EXECUTE ' + statement_name + ' USING ' + ','.join(map(self._format_prepared_param, params)) - return trino.client.TrinoQuery(self._request, sql=sql, legacy_primitive_types=self._legacy_primitive_types) + return trino.client.TrinoQuery(self._request, query=sql, legacy_primitive_types=self._legacy_primitive_types) def _format_prepared_param(self, param): """ @@ -454,7 +460,7 @@ def _format_prepared_param(self, param): def _deallocate_prepared_statement(self, statement_name: str) -> None: sql = 'DEALLOCATE PREPARE ' + statement_name - query = trino.client.TrinoQuery(self.connection._create_request(), sql=sql, + query = trino.client.TrinoQuery(self.connection._create_request(), query=sql, legacy_primitive_types=self._legacy_primitive_types) query.execute() @@ -486,7 +492,7 @@ def execute(self, operation, params=None): self._deallocate_prepared_statement(statement_name) else: - self._query = trino.client.TrinoQuery(self._request, sql=operation, + self._query = trino.client.TrinoQuery(self._request, query=operation, legacy_primitive_types=self._legacy_primitive_types) self._iterator = iter(self._query.execute()) return self @@ -582,7 +588,7 @@ def describe(self, sql: str) -> List[DescribeOutput]: sql = f"DESCRIBE OUTPUT {statement_name}" self._query = trino.client.TrinoQuery( self._request, - sql=sql, + query=sql, legacy_primitive_types=self._legacy_primitive_types, ) result = self._query.execute()