diff --git a/graphdatascience/arrow_client/authenticated_flight_client.py b/graphdatascience/arrow_client/authenticated_flight_client.py index 2cd7e4efe..00d93f5cf 100644 --- a/graphdatascience/arrow_client/authenticated_flight_client.py +++ b/graphdatascience/arrow_client/authenticated_flight_client.py @@ -180,7 +180,7 @@ def do_action(self, endpoint: str, payload: bytes | dict[str, Any]) -> Iterator[ return self._flight_client.do_action(Action(endpoint, payload_bytes)) # type: ignore - def do_action_with_retry(self, endpoint: str, payload: bytes | dict[str, Any]) -> Iterator[Result]: + def do_action_with_retry(self, endpoint: str, payload: bytes | dict[str, Any]) -> list[Result]: @retry( reraise=True, before=before_log("Send action", self._logger, logging.DEBUG), @@ -188,8 +188,10 @@ def do_action_with_retry(self, endpoint: str, payload: bytes | dict[str, Any]) - stop=self._retry_config.stop, wait=self._retry_config.wait, ) - def run_with_retry() -> Iterator[Result]: - return self.do_action(endpoint, payload) + def run_with_retry() -> list[Result]: + # the Flight response error code is only checked on iterator consumption + # we eagerly collect iterator here to trigger retry in case of an error + return list(self.do_action(endpoint, payload)) return run_with_retry() diff --git a/graphdatascience/arrow_client/v2/data_mapper_utils.py b/graphdatascience/arrow_client/v2/data_mapper_utils.py index 6cd3ebc1f..073557728 100644 --- a/graphdatascience/arrow_client/v2/data_mapper_utils.py +++ b/graphdatascience/arrow_client/v2/data_mapper_utils.py @@ -1,10 +1,10 @@ import json -from typing import Any, Iterator +from typing import Any from pyarrow._flight import Result -def deserialize_single(input_stream: Iterator[Result]) -> dict[str, Any]: +def deserialize_single(input_stream: list[Result]) -> dict[str, Any]: rows = deserialize(input_stream) if len(rows) != 1: raise ValueError(f"Expected exactly one result, got {len(rows)}") @@ -12,8 +12,8 @@ def deserialize_single(input_stream: Iterator[Result]) -> dict[str, Any]: return rows[0] -def deserialize(input_stream: Iterator[Result]) -> list[dict[str, Any]]: +def deserialize(input_stream: list[Result]) -> list[dict[str, Any]]: def deserialize_row(row: Result): # type:ignore return json.loads(row.body.to_pybytes().decode()) - return [deserialize_row(row) for row in list(input_stream)] + return [deserialize_row(row) for row in input_stream] diff --git a/graphdatascience/tests/integrationV2/procedure_surface/arrow/graph_creation_helper.py b/graphdatascience/tests/integrationV2/procedure_surface/arrow/graph_creation_helper.py index 1324622c8..e0833fcbd 100644 --- a/graphdatascience/tests/integrationV2/procedure_surface/arrow/graph_creation_helper.py +++ b/graphdatascience/tests/integrationV2/procedure_surface/arrow/graph_creation_helper.py @@ -15,8 +15,7 @@ def create_graph( arrow_client: AuthenticatedArrowClient, graph_name: str, gdl: str, undirected: tuple[str, str] | None = None ) -> Generator[GraphV2, Any, None]: try: - raw_res = arrow_client.do_action("v2/graph.fromGDL", {"graphName": graph_name, "gdlGraph": gdl}) - deserialize_single(raw_res) + raw_res = list(arrow_client.do_action("v2/graph.fromGDL", {"graphName": graph_name, "gdlGraph": gdl})) if undirected is not None: JobClient.run_job_and_wait( @@ -26,9 +25,11 @@ def create_graph( show_progress=False, ) - raw_res = arrow_client.do_action( - "v2/graph.relationships.drop", - {"graphName": graph_name, "relationshipType": undirected[0]}, + raw_res = list( + arrow_client.do_action( + "v2/graph.relationships.drop", + {"graphName": graph_name, "relationshipType": undirected[0]}, + ) ) deserialize_single(raw_res) diff --git a/graphdatascience/tests/unit/arrow_client/V2/test_data_mapper_utils.py b/graphdatascience/tests/unit/arrow_client/V2/test_data_mapper_utils.py index 0465dc565..91544bfe5 100644 --- a/graphdatascience/tests/unit/arrow_client/V2/test_data_mapper_utils.py +++ b/graphdatascience/tests/unit/arrow_client/V2/test_data_mapper_utils.py @@ -1,26 +1,20 @@ -from typing import Iterator - import pytest -from pyarrow._flight import Result from graphdatascience.arrow_client.v2.data_mapper_utils import deserialize_single from graphdatascience.tests.unit.arrow_client.arrow_test_utils import ArrowTestResult def test_deserialize_single_success() -> None: - input_stream = iter([ArrowTestResult({"key": "value"})]) expected = {"key": "value"} - actual = deserialize_single(input_stream) + actual = deserialize_single([ArrowTestResult({"key": "value"})]) assert expected == actual def test_deserialize_single_raises_on_empty_stream() -> None: - input_stream: Iterator[Result] = iter([]) with pytest.raises(ValueError, match="Expected exactly one result, got 0"): - deserialize_single(input_stream) + deserialize_single([]) def test_deserialize_single_raises_on_multiple_results() -> None: - input_stream = iter([ArrowTestResult({"key1": "value1"}), ArrowTestResult({"key2": "value2"})]) with pytest.raises(ValueError, match="Expected exactly one result, got 2"): - deserialize_single(input_stream) + deserialize_single([ArrowTestResult({"key1": "value1"}), ArrowTestResult({"key2": "value2"})])