Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions graphdatascience/arrow_client/authenticated_flight_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,16 +180,18 @@ def do_action(self, endpoint: str, payload: bytes | dict[str, Any]) -> Iterator[

return self._flight_client.do_action(Action(endpoint, payload_bytes)) # type: ignore

def do_action_with_retry(self, endpoint: str, payload: bytes | dict[str, Any]) -> Iterator[Result]:
def do_action_with_retry(self, endpoint: str, payload: bytes | dict[str, Any]) -> list[Result]:
@retry(
reraise=True,
before=before_log("Send action", self._logger, logging.DEBUG),
retry=self._retry_config.retry,
stop=self._retry_config.stop,
wait=self._retry_config.wait,
)
def run_with_retry() -> Iterator[Result]:
return self.do_action(endpoint, payload)
def run_with_retry() -> list[Result]:
# the Flight response error code is only checked on iterator consumption
# we eagerly collect iterator here to trigger retry in case of an error
return list(self.do_action(endpoint, payload))

return run_with_retry()

Expand Down
8 changes: 4 additions & 4 deletions graphdatascience/arrow_client/v2/data_mapper_utils.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,19 @@
import json
from typing import Any, Iterator
from typing import Any

from pyarrow._flight import Result


def deserialize_single(input_stream: Iterator[Result]) -> dict[str, Any]:
def deserialize_single(input_stream: list[Result]) -> dict[str, Any]:
rows = deserialize(input_stream)
if len(rows) != 1:
raise ValueError(f"Expected exactly one result, got {len(rows)}")

return rows[0]


def deserialize(input_stream: Iterator[Result]) -> list[dict[str, Any]]:
def deserialize(input_stream: list[Result]) -> list[dict[str, Any]]:
def deserialize_row(row: Result): # type:ignore
return json.loads(row.body.to_pybytes().decode())

return [deserialize_row(row) for row in list(input_stream)]
return [deserialize_row(row) for row in input_stream]
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,7 @@ def create_graph(
arrow_client: AuthenticatedArrowClient, graph_name: str, gdl: str, undirected: tuple[str, str] | None = None
) -> Generator[GraphV2, Any, None]:
try:
raw_res = arrow_client.do_action("v2/graph.fromGDL", {"graphName": graph_name, "gdlGraph": gdl})
deserialize_single(raw_res)
raw_res = list(arrow_client.do_action("v2/graph.fromGDL", {"graphName": graph_name, "gdlGraph": gdl}))

if undirected is not None:
JobClient.run_job_and_wait(
Expand All @@ -26,9 +25,11 @@ def create_graph(
show_progress=False,
)

raw_res = arrow_client.do_action(
"v2/graph.relationships.drop",
{"graphName": graph_name, "relationshipType": undirected[0]},
raw_res = list(
arrow_client.do_action(
"v2/graph.relationships.drop",
{"graphName": graph_name, "relationshipType": undirected[0]},
)
)
deserialize_single(raw_res)

Expand Down
Original file line number Diff line number Diff line change
@@ -1,26 +1,20 @@
from typing import Iterator

import pytest
from pyarrow._flight import Result

from graphdatascience.arrow_client.v2.data_mapper_utils import deserialize_single
from graphdatascience.tests.unit.arrow_client.arrow_test_utils import ArrowTestResult


def test_deserialize_single_success() -> None:
input_stream = iter([ArrowTestResult({"key": "value"})])
expected = {"key": "value"}
actual = deserialize_single(input_stream)
actual = deserialize_single([ArrowTestResult({"key": "value"})])
assert expected == actual


def test_deserialize_single_raises_on_empty_stream() -> None:
input_stream: Iterator[Result] = iter([])
with pytest.raises(ValueError, match="Expected exactly one result, got 0"):
deserialize_single(input_stream)
deserialize_single([])


def test_deserialize_single_raises_on_multiple_results() -> None:
input_stream = iter([ArrowTestResult({"key1": "value1"}), ArrowTestResult({"key2": "value2"})])
with pytest.raises(ValueError, match="Expected exactly one result, got 2"):
deserialize_single(input_stream)
deserialize_single([ArrowTestResult({"key1": "value1"}), ArrowTestResult({"key2": "value2"})])