From 508646ea51e629e8b31080abc2aff213618ed60b Mon Sep 17 00:00:00 2001 From: "Mateusz \"Serafin\" Gajewski" Date: Fri, 31 Jan 2025 20:25:03 +0100 Subject: [PATCH 1/3] Add metadata to SpooledData This is part of the contract which is not used in the JSON encoding scheme but will be in the future for other formats. --- trino/client.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/trino/client.py b/trino/client.py index a122702f..ecbb34ab 100644 --- a/trino/client.py +++ b/trino/client.py @@ -931,6 +931,7 @@ def fetch(self) -> List[Union[List[Any]], Any]: def _to_segments(self, rows: _SpooledProtocolResponseTO) -> SpooledData: encoding = rows["encoding"] + metadata = rows["metadata"] if "metadata" in rows else None segments = [] for segment in rows["segments"]: segment_type = segment["type"] @@ -943,7 +944,7 @@ def _to_segments(self, rows: _SpooledProtocolResponseTO) -> SpooledData: else: raise ValueError(f"Unsupported segment type: {segment_type}") - return SpooledData(encoding, segments) + return SpooledData(encoding, metadata, segments) def cancel(self) -> None: """Cancel the current query""" @@ -1024,6 +1025,7 @@ def _parse_retry_after_header(retry_after): # Trino Spooled protocol transfer objects class _SpooledProtocolResponseTO(TypedDict): encoding: Literal["json", "json+std", "json+lz4"] + metadata: _SegmentMetadataTO segments: List[_SegmentTO] @@ -1168,10 +1170,12 @@ class SpooledData: Attributes: encoding (str): The encoding format of the spooled data. + metadata (_SegmentMetadataTO): Metadata for all segments segments (List[Segment]): The list of segments in the spooled data. """ - def __init__(self, encoding: str, segments: List[Segment]) -> None: + def __init__(self, encoding: str, metadata: _SegmentMetadataTO, segments: List[Segment]) -> None: self._encoding = encoding + self._metadata = metadata self._segments = segments self._segments_iterator = iter(segments) @@ -1190,7 +1194,7 @@ def __next__(self) -> Tuple["SpooledData", "Segment"]: return self, next(self._segments_iterator) def __repr__(self): - return (f"SpooledData(encoding={self._encoding}, segments={list(self._segments)})") + return (f"SpooledData(encoding={self._encoding}, metadata={self._metadata}, segments={list(self._segments)})") class SegmentIterator: From ac3d3f5e9c5a7538e68c81c6877acb164034d8dc Mon Sep 17 00:00:00 2001 From: "Mateusz \"Serafin\" Gajewski" Date: Fri, 31 Jan 2025 20:55:24 +0100 Subject: [PATCH 2/3] Rename SpooledData to DecodableSegment and remove Tuple iterator This makes the API much easier to consume in the `segments` cursor style: ``` cur = conn.cursor('segment') cur.execute(sql) segments = cur.fetchall() total_row_count = 0 row_mapper = RowMapperFactory().create(columns=cur._query.columns, legacy_primitive_types=False) for segment in segments: rows = list(SegmentIterator(segment, row_mapper)) print ("rows length is " + str(len(rows)) + " " + segment.encoding) total_row_count += len(rows) print(total_row_count) ``` This will work as well: ``` cur = conn.cursor('segment') cur.execute(sql) segments = cur.fetchall() total_row_count = 0 row_mapper = RowMapperFactory().create(columns=cur._query.columns, legacy_primitive_types=False) rows = list(SegmentIterator(segments, row_mapper)) print ("rows length is " + str(len(rows))) total_row_count += len(rows) print(total_row_count) ``` --- tests/integration/test_dbapi_integration.py | 18 ++++--- trino/client.py | 60 +++++++++++---------- 2 files changed, 44 insertions(+), 34 deletions(-) diff --git a/tests/integration/test_dbapi_integration.py b/tests/integration/test_dbapi_integration.py index d94f97f0..276d6971 100644 --- a/tests/integration/test_dbapi_integration.py +++ b/tests/integration/test_dbapi_integration.py @@ -29,12 +29,14 @@ import trino from tests.integration.conftest import trino_version from trino import constants +from trino.client import SegmentIterator from trino.dbapi import Cursor from trino.dbapi import DescribeOutput from trino.dbapi import TimeBoundLRUCache from trino.exceptions import NotSupportedError from trino.exceptions import TrinoQueryError from trino.exceptions import TrinoUserError +from trino.mapper import RowMapperFactory from trino.transaction import IsolationLevel @@ -1876,12 +1878,16 @@ def test_segments_cursor(trino_connection): start => 1, stop => 5, step => 1)) n""") - rows = cur.fetchall() - assert len(rows) > 0 - for spooled_data, spooled_segment in rows: - assert spooled_data.encoding == trino_connection._client_session.encoding - assert isinstance(spooled_segment.uri, str), f"Expected string for uri, got {spooled_segment.uri}" - assert isinstance(spooled_segment.ack_uri, str), f"Expected string for ack_uri, got {spooled_segment.ack_uri}" + segments = cur.fetchall() + assert len(segments) > 0 + row_mapper = RowMapperFactory().create(columns=cur._query.columns, legacy_primitive_types=False) + total = 0 + for segment in segments: + assert segment.encoding == trino_connection._client_session.encoding + assert isinstance(segment.segment.uri, str), f"Expected string for uri, got {segment.segment.uri}" + assert isinstance(segment.segment.ack_uri, str), f"Expected string for ack_uri, got {segment.segment.ack_uri}" + total += len(list(SegmentIterator(segment, row_mapper))) + assert total == 300875, f"Expected total rows 300875, got {total}" def get_cursor(legacy_prepared_statements, run_trino): diff --git a/trino/client.py b/trino/client.py index ecbb34ab..ddf92064 100644 --- a/trino/client.py +++ b/trino/client.py @@ -89,7 +89,7 @@ "TrinoQuery", "TrinoRequest", "PROXIES", - "SpooledData", + "DecodableSegment", "SpooledSegment", "InlineSegment", "Segment" @@ -920,16 +920,16 @@ def fetch(self) -> List[Union[List[Any]], Any]: if isinstance(status.rows, dict): # spooling protocol rows = cast(_SpooledProtocolResponseTO, rows) - segments = self._to_segments(rows) + spooled = self._to_segments(rows) if self._fetch_mode == "segments": - return segments - return list(SegmentIterator(segments, self._row_mapper)) + return spooled + return list(SegmentIterator(spooled, self._row_mapper)) elif isinstance(status.rows, list): return self._row_mapper.map(rows) else: raise ValueError(f"Unexpected type: {type(status.rows)}") - def _to_segments(self, rows: _SpooledProtocolResponseTO) -> SpooledData: + def _to_segments(self, rows: _SpooledProtocolResponseTO) -> List[DecodableSegment]: encoding = rows["encoding"] metadata = rows["metadata"] if "metadata" in rows else None segments = [] @@ -944,7 +944,7 @@ def _to_segments(self, rows: _SpooledProtocolResponseTO) -> SpooledData: else: raise ValueError(f"Unsupported segment type: {segment_type}") - return SpooledData(encoding, metadata, segments) + return list(map(lambda segment: DecodableSegment(encoding, metadata, segment), segments)) def cancel(self) -> None: """Cancel the current query""" @@ -1164,46 +1164,44 @@ def __repr__(self): ) -class SpooledData: +class DecodableSegment: """ Represents a collection of spooled segments of data, with an encoding format. Attributes: encoding (str): The encoding format of the spooled data. - metadata (_SegmentMetadataTO): Metadata for all segments - segments (List[Segment]): The list of segments in the spooled data. + metadata (_SegmentMetadataTO): Metadata for all segments in the query + segment (Segment): The spooled segment data """ - def __init__(self, encoding: str, metadata: _SegmentMetadataTO, segments: List[Segment]) -> None: + def __init__(self, encoding: str, metadata: _SegmentMetadataTO, segment: Segment) -> None: self._encoding = encoding self._metadata = metadata - self._segments = segments - self._segments_iterator = iter(segments) + self._segment = segment @property def encoding(self): return self._encoding @property - def segments(self): - return self._segments - - def __iter__(self) -> Iterator[Tuple["SpooledData", "Segment"]]: - return self + def segment(self): + return self._segment - def __next__(self) -> Tuple["SpooledData", "Segment"]: - return self, next(self._segments_iterator) + @property + def metadata(self): + return self._metadata def __repr__(self): - return (f"SpooledData(encoding={self._encoding}, metadata={self._metadata}, segments={list(self._segments)})") + return (f"DecodableSegment(encoding={self._encoding}, metadata={self._metadata}, segment={self._segment})") class SegmentIterator: - def __init__(self, spooled_data: SpooledData, mapper: RowMapper) -> None: - self._segments = iter(spooled_data._segments) - self._decoder = SegmentDecoder(CompressedQueryDataDecoderFactory(mapper).create(spooled_data.encoding)) + def __init__(self, segments: Union[DecodableSegment, List[DecodableSegment]], mapper: RowMapper) -> None: + self._segments = iter(segments if isinstance(segments, List) else [segments]) + self._mapper = mapper + self._decoder = None self._rows: Iterator[List[List[Any]]] = iter([]) self._finished = False - self._current_segment: Optional[Segment] = None + self._current_segment: Optional[DecodableSegment] = None def __iter__(self) -> Iterator[List[Any]]: return self @@ -1214,16 +1212,22 @@ def __next__(self) -> List[Any]: try: return next(self._rows) except StopIteration: - if self._current_segment and isinstance(self._current_segment, SpooledSegment): - self._current_segment.acknowledge() if self._finished: raise StopIteration self._load_next_segment() def _load_next_segment(self): try: - self._current_segment = segment = next(self._segments) - self._rows = iter(self._decoder.decode(segment)) + if self._current_segment: + segment = self._current_segment.segment + if isinstance(segment, SpooledSegment): + segment.acknowledge() + + self._current_segment = next(self._segments) + if self._decoder is None: + self._decoder = SegmentDecoder(CompressedQueryDataDecoderFactory(self._mapper) + .create(self._current_segment.encoding)) + self._rows = iter(self._decoder.decode(self._current_segment.segment)) except StopIteration: self._finished = True From af35c029a5d226145045ecc2aa902aeefbbf869e Mon Sep 17 00:00:00 2001 From: "Mateusz \"Serafin\" Gajewski" Date: Fri, 31 Jan 2025 22:39:38 +0100 Subject: [PATCH 3/3] Use unathenticated client for segment retrieval Sending authorization headers will make S3 presigned uris to fail due to additional headers that cannot be present. --- trino/client.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/trino/client.py b/trino/client.py index ddf92064..637fe82b 100644 --- a/trino/client.py +++ b/trino/client.py @@ -573,6 +573,15 @@ def http_headers(self) -> CaseInsensitiveDict[str]: return headers + def unauthenticated(self): + return TrinoRequest( + host=self._host, + port=self._port, + max_attempts=self.max_attempts, + request_timeout=self._request_timeout, + handle_retry=self._handle_retry, + client_session=ClientSession(user=self._client_session.user)) + @property def max_attempts(self) -> int: return self._max_attempts @@ -940,7 +949,7 @@ def _to_segments(self, rows: _SpooledProtocolResponseTO) -> List[DecodableSegmen segments.append(InlineSegment(inline_segment)) elif segment_type == SegmentType.SPOOLED: spooled_segment = cast(_SpooledSegmentTO, segment) - segments.append(SpooledSegment(spooled_segment, self._request)) + segments.append(SpooledSegment(spooled_segment, self._request.unauthenticated())) else: raise ValueError(f"Unsupported segment type: {segment_type}")