Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 34 additions & 25 deletions tests/integration/test_dbapi_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ def test_select_query(trino_connection):
assert columns["coordinator"] == "boolean"
assert columns["state"] == "varchar"
assert cur.query_id is not None
assert cur.query == "SELECT * FROM system.runtime.nodes"
assert cur.stats is not None


Expand Down Expand Up @@ -135,37 +136,45 @@ def test_string_query_param(trino_connection):


def test_execute_many(trino_connection):
cur = trino_connection.cursor()
cur.execute("CREATE TABLE memory.default.test_execute_many (key int, value varchar)")
cur.fetchall()
operation = "INSERT INTO memory.default.test_execute_many (key, value) VALUES (?, ?)"
cur.executemany(operation, [(1, "value1")])
cur.fetchall()
cur.execute("SELECT * FROM memory.default.test_execute_many ORDER BY key")
rows = cur.fetchall()
assert len(list(rows)) == 1
assert rows[0] == [1, "value1"]
try:
cur = trino_connection.cursor()
cur.execute("CREATE TABLE memory.default.test_execute_many (key int, value varchar)")
cur.fetchall()
operation = "INSERT INTO memory.default.test_execute_many (key, value) VALUES (?, ?)"
cur.executemany(operation, [(1, "value1")])
cur.fetchall()
cur.execute("SELECT * FROM memory.default.test_execute_many ORDER BY key")
rows = cur.fetchall()
assert len(list(rows)) == 1
assert rows[0] == [1, "value1"]

operation = "INSERT INTO memory.default.test_execute_many (key, value) VALUES (?, ?)"
cur.executemany(operation, [(2, "value2"), (3, "value3")])
cur.fetchall()
operation = "INSERT INTO memory.default.test_execute_many (key, value) VALUES (?, ?)"
cur.executemany(operation, [(2, "value2"), (3, "value3")])
cur.fetchall()

cur.execute("SELECT * FROM memory.default.test_execute_many ORDER BY key")
rows = cur.fetchall()
assert len(list(rows)) == 3
assert rows[0] == [1, "value1"]
assert rows[1] == [2, "value2"]
assert rows[2] == [3, "value3"]
cur.execute("SELECT * FROM memory.default.test_execute_many ORDER BY key")
rows = cur.fetchall()
assert len(list(rows)) == 3
assert rows[0] == [1, "value1"]
assert rows[1] == [2, "value2"]
assert rows[2] == [3, "value3"]
finally:
cur = trino_connection.cursor()
cur.execute("DROP TABLE IF EXISTS memory.default.test_execute_many")


def test_execute_many_without_params(trino_connection):
cur = trino_connection.cursor()
cur.execute("CREATE TABLE memory.default.test_execute_many_without_param (value varchar)")
cur.fetchall()
with pytest.raises(TrinoUserError) as e:
cur.executemany("INSERT INTO memory.default.test_execute_many_without_param (value) VALUES (?)", [])
try:
cur = trino_connection.cursor()
cur.execute("CREATE TABLE memory.default.test_execute_many_without_param (value varchar)")
cur.fetchall()
assert "Incorrect number of parameters: expected 1 but found 0" in str(e.value)
with pytest.raises(TrinoUserError) as e:
cur.executemany("INSERT INTO memory.default.test_execute_many_without_param (value) VALUES (?)", [])
cur.fetchall()
assert "Incorrect number of parameters: expected 1 but found 0" in str(e.value)
finally:
cur = trino_connection.cursor()
cur.execute("DROP TABLE IF EXISTS memory.default.test_execute_many_without_param")


def test_execute_many_select(trino_connection):
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1006,7 +1006,7 @@ def json(self):
with mock.patch.object(req, 'post', return_value=MockResponse()) as mock_post:
query = TrinoQuery(
request=req,
sql=sql
query=sql
)
result = query.execute(additional_http_headers=additional_headers)

Expand Down
19 changes: 13 additions & 6 deletions trino/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -732,11 +732,10 @@ class TrinoQuery(object):
def __init__(
self,
request: TrinoRequest,
sql: str,
query: str,
legacy_primitive_types: bool = False,
) -> None:
self.query_id: Optional[str] = None

self._query_id: Optional[str] = None
self._stats: Dict[Any, Any] = {}
self._info_uri: Optional[str] = None
self._warnings: List[Dict[Any, Any]] = []
Expand All @@ -747,11 +746,19 @@ def __init__(
self._update_type = None
self._update_count = None
self._next_uri = None
self._sql = sql
self._query = query
self._result: Optional[TrinoResult] = None
self._legacy_primitive_types = legacy_primitive_types
self._row_mapper: Optional[RowMapper] = None

@property
def query_id(self) -> Optional[str]:
return self._query_id

@property
def query(self) -> Optional[str]:
return self._query

@property
def columns(self):
if self.query_id:
Expand Down Expand Up @@ -796,10 +803,10 @@ def execute(self, additional_http_headers=None) -> TrinoResult:
if self.cancelled:
raise exceptions.TrinoUserError("Query has been cancelled", self.query_id)

response = self._request.post(self._sql, additional_http_headers)
response = self._request.post(self._query, additional_http_headers)
status = self._request.process(response)
self._info_uri = status.info_uri
self.query_id = status.id
self._query_id = status.id
self._stats.update({"queryId": self.query_id})
self._update_state(status)
self._warnings = getattr(status, "warnings", [])
Expand Down
16 changes: 11 additions & 5 deletions trino/dbapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,12 @@ def query_id(self) -> Optional[str]:
return self._query.query_id
return None

@property
def query(self) -> Optional[str]:
if self._query is not None:
return self._query.query
return None

@property
def warnings(self):
if self._query is not None:
Expand All @@ -364,7 +370,7 @@ def _prepare_statement(self, statement: str, name: str) -> None:
:param name: name that will be assigned to the prepared statement.
"""
sql = f"PREPARE {name} FROM {statement}"
query = trino.client.TrinoQuery(self.connection._create_request(), sql=sql,
query = trino.client.TrinoQuery(self.connection._create_request(), query=sql,
legacy_primitive_types=self._legacy_primitive_types)
query.execute()

Expand All @@ -374,7 +380,7 @@ def _execute_prepared_statement(
params
):
sql = 'EXECUTE ' + statement_name + ' USING ' + ','.join(map(self._format_prepared_param, params))
return trino.client.TrinoQuery(self._request, sql=sql, legacy_primitive_types=self._legacy_primitive_types)
return trino.client.TrinoQuery(self._request, query=sql, legacy_primitive_types=self._legacy_primitive_types)

def _format_prepared_param(self, param):
"""
Expand Down Expand Up @@ -454,7 +460,7 @@ def _format_prepared_param(self, param):

def _deallocate_prepared_statement(self, statement_name: str) -> None:
sql = 'DEALLOCATE PREPARE ' + statement_name
query = trino.client.TrinoQuery(self.connection._create_request(), sql=sql,
query = trino.client.TrinoQuery(self.connection._create_request(), query=sql,
legacy_primitive_types=self._legacy_primitive_types)
query.execute()

Expand Down Expand Up @@ -486,7 +492,7 @@ def execute(self, operation, params=None):
self._deallocate_prepared_statement(statement_name)

else:
self._query = trino.client.TrinoQuery(self._request, sql=operation,
self._query = trino.client.TrinoQuery(self._request, query=operation,
legacy_primitive_types=self._legacy_primitive_types)
self._iterator = iter(self._query.execute())
return self
Expand Down Expand Up @@ -582,7 +588,7 @@ def describe(self, sql: str) -> List[DescribeOutput]:
sql = f"DESCRIBE OUTPUT {statement_name}"
self._query = trino.client.TrinoQuery(
self._request,
sql=sql,
query=sql,
legacy_primitive_types=self._legacy_primitive_types,
)
result = self._query.execute()
Expand Down