Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -95,10 +95,54 @@ async def open(self) -> None:
"""Opening an object for write , should do it's state lookup
to know what's the persisted size is.
"""
raise NotImplementedError(
"open() is not implemented yet in _AsyncWriteObjectStream"
if self._is_stream_open:
raise ValueError("Stream is already open")

# Create a new object or overwrite existing one if generation_number
# is None. This makes it consistent with GCS JSON API behavior.
# Created object type would be Appendable Object.
if self.generation_number is None:
self.first_bidi_write_req = _storage_v2.BidiWriteObjectRequest(
write_object_spec=_storage_v2.WriteObjectSpec(
resource=_storage_v2.Object(
name=self.object_name, bucket=self._full_bucket_name
),
appendable=True,
),
)
else:
self.first_bidi_write_req = _storage_v2.BidiWriteObjectRequest(
append_object_spec=_storage_v2.AppendObjectSpec(
bucket=self._full_bucket_name,
object=self.object_name,
generation=self.generation_number,
),
state_lookup=True,
)

self.socket_like_rpc = AsyncBidiRpc(
self.rpc, initial_request=self.first_bidi_write_req, metadata=self.metadata
)

await self.socket_like_rpc.open() # this is actually 1 send
response = await self.socket_like_rpc.recv()
self._is_stream_open = True

if not response.resource:
raise ValueError(
"Failed to obtain object resource after opening the stream"
)
if not response.resource.generation:
raise ValueError(
"Failed to obtain object generation after opening the stream"
)
self.generation_number = response.resource.generation

if not response.write_handle:
raise ValueError("Failed to obtain write_handle after opening the stream")

self.write_handle = response.write_handle

async def close(self) -> None:
"""Closes the bidi-gRPC connection."""
raise NotImplementedError(
Expand Down Expand Up @@ -132,3 +176,7 @@ async def recv(self) -> _storage_v2.BidiWriteObjectResponse:
raise NotImplementedError(
"recv() is not implemented yet in _AsyncWriteObjectStream"
)

@property
def is_stream_open(self) -> bool:
return self._is_stream_open
164 changes: 161 additions & 3 deletions tests/unit/asyncio/test_async_write_object_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@

BUCKET = "my-bucket"
OBJECT = "my-object"
GENERATION = 12345
WRITE_HANDLE = b"test-handle"


@pytest.fixture
Expand Down Expand Up @@ -91,13 +93,169 @@ def test_async_write_object_stream_init_raises_value_error():


@pytest.mark.asyncio
async def test_unimplemented_methods_raise_error(mock_client):
"""Test that unimplemented methods raise NotImplementedError."""
@mock.patch(
"google.cloud.storage._experimental.asyncio.async_write_object_stream.AsyncBidiRpc"
)
async def test_open_for_new_object(mock_async_bidi_rpc, mock_client):
"""Test opening a stream for a new object."""
# Arrange
socket_like_rpc = mock.AsyncMock()
mock_async_bidi_rpc.return_value = socket_like_rpc
socket_like_rpc.open = mock.AsyncMock()

mock_response = mock.MagicMock(spec=_storage_v2.BidiWriteObjectResponse)
mock_response.resource = mock.MagicMock(spec=_storage_v2.Object)
mock_response.resource.generation = GENERATION
mock_response.write_handle = WRITE_HANDLE
socket_like_rpc.recv = mock.AsyncMock(return_value=mock_response)

stream = _AsyncWriteObjectStream(mock_client, BUCKET, OBJECT)

with pytest.raises(NotImplementedError):
# Act
await stream.open()

# Assert
assert stream._is_stream_open
socket_like_rpc.open.assert_called_once()
socket_like_rpc.recv.assert_called_once()
assert stream.generation_number == GENERATION
assert stream.write_handle == WRITE_HANDLE


@pytest.mark.asyncio
@mock.patch(
"google.cloud.storage._experimental.asyncio.async_write_object_stream.AsyncBidiRpc"
)
async def test_open_for_existing_object(mock_async_bidi_rpc, mock_client):
"""Test opening a stream for an existing object."""
# Arrange
socket_like_rpc = mock.AsyncMock()
mock_async_bidi_rpc.return_value = socket_like_rpc
socket_like_rpc.open = mock.AsyncMock()

mock_response = mock.MagicMock(spec=_storage_v2.BidiWriteObjectResponse)
mock_response.resource = mock.MagicMock(spec=_storage_v2.Object)
mock_response.resource.generation = GENERATION
mock_response.write_handle = WRITE_HANDLE
socket_like_rpc.recv = mock.AsyncMock(return_value=mock_response)

stream = _AsyncWriteObjectStream(
mock_client, BUCKET, OBJECT, generation_number=GENERATION
)

# Act
await stream.open()

# Assert
assert stream._is_stream_open
socket_like_rpc.open.assert_called_once()
socket_like_rpc.recv.assert_called_once()
assert stream.generation_number == GENERATION
assert stream.write_handle == WRITE_HANDLE


@pytest.mark.asyncio
@mock.patch(
"google.cloud.storage._experimental.asyncio.async_write_object_stream.AsyncBidiRpc"
)
async def test_open_when_already_open_raises_error(mock_async_bidi_rpc, mock_client):
"""Test that opening an already open stream raises a ValueError."""
# Arrange
socket_like_rpc = mock.AsyncMock()
mock_async_bidi_rpc.return_value = socket_like_rpc
socket_like_rpc.open = mock.AsyncMock()

mock_response = mock.MagicMock(spec=_storage_v2.BidiWriteObjectResponse)
mock_response.resource = mock.MagicMock(spec=_storage_v2.Object)
mock_response.resource.generation = GENERATION
mock_response.write_handle = WRITE_HANDLE
socket_like_rpc.recv = mock.AsyncMock(return_value=mock_response)

stream = _AsyncWriteObjectStream(mock_client, BUCKET, OBJECT)
await stream.open()

# Act & Assert
with pytest.raises(ValueError, match="Stream is already open"):
await stream.open()


@pytest.mark.asyncio
@mock.patch(
"google.cloud.storage._experimental.asyncio.async_write_object_stream.AsyncBidiRpc"
)
async def test_open_raises_error_on_missing_object_resource(
mock_async_bidi_rpc, mock_client
):
"""Test that open raises ValueError if object_resource is not in the response."""
socket_like_rpc = mock.AsyncMock()
mock_async_bidi_rpc.return_value = socket_like_rpc

mock_reponse = mock.AsyncMock()
type(mock_reponse).resource = mock.PropertyMock(return_value=None)
socket_like_rpc.recv.return_value = mock_reponse

# Note: Don't use below code as unittest library automatically assigns an
# `AsyncMock` object to an attribute, if not set.
# socket_like_rpc.recv.return_value = mock.AsyncMock(
# return_value=_storage_v2.BidiWriteObjectResponse(resource=None)
# )

stream = _AsyncWriteObjectStream(mock_client, BUCKET, OBJECT)
with pytest.raises(
ValueError, match="Failed to obtain object resource after opening the stream"
):
await stream.open()


@pytest.mark.asyncio
@mock.patch(
"google.cloud.storage._experimental.asyncio.async_write_object_stream.AsyncBidiRpc"
)
async def test_open_raises_error_on_missing_generation(
mock_async_bidi_rpc, mock_client
):
"""Test that open raises ValueError if generation is not in the response."""
socket_like_rpc = mock.AsyncMock()
mock_async_bidi_rpc.return_value = socket_like_rpc

# Configure the mock response object
mock_response = mock.AsyncMock()
type(mock_response.resource).generation = mock.PropertyMock(return_value=None)
socket_like_rpc.recv.return_value = mock_response

stream = _AsyncWriteObjectStream(mock_client, BUCKET, OBJECT)
with pytest.raises(
ValueError, match="Failed to obtain object generation after opening the stream"
):
await stream.open()
# assert stream.generation_number is None


@pytest.mark.asyncio
@mock.patch(
"google.cloud.storage._experimental.asyncio.async_write_object_stream.AsyncBidiRpc"
)
async def test_open_raises_error_on_missing_write_handle(
mock_async_bidi_rpc, mock_client
):
"""Test that open raises ValueError if write_handle is not in the response."""
socket_like_rpc = mock.AsyncMock()
mock_async_bidi_rpc.return_value = socket_like_rpc
socket_like_rpc.recv = mock.AsyncMock(
return_value=_storage_v2.BidiWriteObjectResponse(
resource=_storage_v2.Object(generation=GENERATION), write_handle=None
)
)
stream = _AsyncWriteObjectStream(mock_client, BUCKET, OBJECT)
with pytest.raises(ValueError, match="Failed to obtain write_handle"):
await stream.open()


@pytest.mark.asyncio
async def test_unimplemented_methods_raise_error(mock_client):
"""Test that unimplemented methods raise NotImplementedError."""
stream = _AsyncWriteObjectStream(mock_client, BUCKET, OBJECT)

with pytest.raises(NotImplementedError):
await stream.close()

Expand Down