Skip to content

Commit 28fd42d

Browse files
committed
added tests
1 parent 09d45cb commit 28fd42d

File tree

3 files changed

+67
-38
lines changed

3 files changed

+67
-38
lines changed

tests/unit/v1/test_async_pipeline.py

Lines changed: 29 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
import mock
1616
import pytest
1717

18+
from google.cloud.firestore_v1 import pipeline_stages as stages
19+
1820

1921
def _make_async_pipeline(*args, client=mock.Mock()):
2022
from google.cloud.firestore_v1.async_pipeline import AsyncPipeline
@@ -54,11 +56,9 @@ def test_async_pipeline_repr_single_stage():
5456

5557

5658
def test_async_pipeline_repr_multiple_stage():
57-
from google.cloud.firestore_v1.pipeline_stages import GenericStage, Collection
58-
59-
stage_1 = Collection("path")
60-
stage_2 = GenericStage("second", 2)
61-
stage_3 = GenericStage("third", 3)
59+
stage_1 = stages.Collection("path")
60+
stage_2 = stages.GenericStage("second", 2)
61+
stage_3 = stages.GenericStage("third", 3)
6262
ppl = _make_async_pipeline(stage_1, stage_2, stage_3)
6363
repr_str = repr(ppl)
6464
assert repr_str == (
@@ -71,10 +71,8 @@ def test_async_pipeline_repr_multiple_stage():
7171

7272

7373
def test_async_pipeline_repr_long():
74-
from google.cloud.firestore_v1.pipeline_stages import GenericStage
75-
7674
num_stages = 100
77-
stage_list = [GenericStage("custom", i) for i in range(num_stages)]
75+
stage_list = [stages.GenericStage("custom", i) for i in range(num_stages)]
7876
ppl = _make_async_pipeline(*stage_list)
7977
repr_str = repr(ppl)
8078
assert repr_str.count("GenericStage") == num_stages
@@ -83,10 +81,9 @@ def test_async_pipeline_repr_long():
8381

8482
def test_async_pipeline__to_pb():
8583
from google.cloud.firestore_v1.types.pipeline import StructuredPipeline
86-
from google.cloud.firestore_v1.pipeline_stages import GenericStage
8784

88-
stage_1 = GenericStage("first")
89-
stage_2 = GenericStage("second")
85+
stage_1 = stages.GenericStage("first")
86+
stage_2 = stages.GenericStage("second")
9087
ppl = _make_async_pipeline(stage_1, stage_2)
9188
pb = ppl._to_pb()
9289
assert isinstance(pb, StructuredPipeline)
@@ -96,11 +93,9 @@ def test_async_pipeline__to_pb():
9693

9794
def test_async_pipeline_append():
9895
"""append should create a new pipeline with the additional stage"""
99-
from google.cloud.firestore_v1.pipeline_stages import GenericStage
100-
101-
stage_1 = GenericStage("first")
96+
stage_1 = stages.GenericStage("first")
10297
ppl_1 = _make_async_pipeline(stage_1, client=object())
103-
stage_2 = GenericStage("second")
98+
stage_2 = stages.GenericStage("second")
10499
ppl_2 = ppl_1._append(stage_2)
105100
assert ppl_1 != ppl_2
106101
assert len(ppl_1.stages) == 1
@@ -118,15 +113,14 @@ async def test_async_pipeline_execute_empty():
118113
"""
119114
from google.cloud.firestore_v1.types import ExecutePipelineResponse
120115
from google.cloud.firestore_v1.types import ExecutePipelineRequest
121-
from google.cloud.firestore_v1.pipeline_stages import GenericStage
122116

123117
client = mock.Mock()
124118
client.project = "A"
125119
client._database = "B"
126120
mock_rpc = mock.AsyncMock()
127121
client._firestore_api.execute_pipeline = mock_rpc
128122
mock_rpc.return_value = _async_it([ExecutePipelineResponse()])
129-
ppl_1 = _make_async_pipeline(GenericStage("s"), client=client)
123+
ppl_1 = _make_async_pipeline(stages.GenericStage("s"), client=client)
130124

131125
results = [r async for r in ppl_1.execute()]
132126
assert results == []
@@ -145,7 +139,6 @@ async def test_async_pipeline_execute_no_doc_ref():
145139
from google.cloud.firestore_v1.types import Document
146140
from google.cloud.firestore_v1.types import ExecutePipelineResponse
147141
from google.cloud.firestore_v1.types import ExecutePipelineRequest
148-
from google.cloud.firestore_v1.pipeline_stages import GenericStage
149142
from google.cloud.firestore_v1.pipeline_result import PipelineResult
150143

151144
client = mock.Mock()
@@ -156,7 +149,7 @@ async def test_async_pipeline_execute_no_doc_ref():
156149
mock_rpc.return_value = _async_it(
157150
[ExecutePipelineResponse(results=[Document()], execution_time={"seconds": 9})]
158151
)
159-
ppl_1 = _make_async_pipeline(GenericStage("s"), client=client)
152+
ppl_1 = _make_async_pipeline(stages.GenericStage("s"), client=client)
160153

161154
results = [r async for r in ppl_1.execute()]
162155
assert len(results) == 1
@@ -315,3 +308,20 @@ async def test_async_pipeline_execute_with_transaction():
315308
assert request.structured_pipeline == ppl_1._to_pb()
316309
assert request.database == "projects/A/databases/B"
317310
assert request.transaction == b"123"
311+
312+
@pytest.mark.parametrize("method,args,result_cls", [
313+
("select", (), stages.Select),
314+
("where", (mock.Mock(),), stages.Where),
315+
("sort", (), stages.Sort),
316+
("offset", (1,), stages.Offset),
317+
("limit", (1,), stages.Limit),
318+
319+
])
320+
def test_async_pipeline_methods(method, args, result_cls):
321+
start_ppl = _make_async_pipeline()
322+
method_ptr = getattr(start_ppl, method)
323+
result_ppl = method_ptr(*args)
324+
assert result_ppl != start_ppl
325+
assert len(start_ppl.stages) == 0
326+
assert len(result_ppl.stages) == 1
327+
assert isinstance(result_ppl.stages[0], result_cls)

tests/unit/v1/test_pipeline.py

Lines changed: 30 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@
1313
# limitations under the License
1414

1515
import mock
16+
import pytest
17+
18+
from google.cloud.firestore_v1 import pipeline_stages as stages
1619

1720

1821
def _make_pipeline(*args, client=mock.Mock()):
@@ -48,11 +51,9 @@ def test_pipeline_repr_single_stage():
4851

4952

5053
def test_pipeline_repr_multiple_stage():
51-
from google.cloud.firestore_v1.pipeline_stages import GenericStage, Collection
52-
53-
stage_1 = Collection("path")
54-
stage_2 = GenericStage("second", 2)
55-
stage_3 = GenericStage("third", 3)
54+
stage_1 = stages.Collection("path")
55+
stage_2 = stages.GenericStage("second", 2)
56+
stage_3 = stages.GenericStage("third", 3)
5657
ppl = _make_pipeline(stage_1, stage_2, stage_3)
5758
repr_str = repr(ppl)
5859
assert repr_str == (
@@ -65,10 +66,8 @@ def test_pipeline_repr_multiple_stage():
6566

6667

6768
def test_pipeline_repr_long():
68-
from google.cloud.firestore_v1.pipeline_stages import GenericStage
69-
7069
num_stages = 100
71-
stage_list = [GenericStage("custom", i) for i in range(num_stages)]
70+
stage_list = [stages.GenericStage("custom", i) for i in range(num_stages)]
7271
ppl = _make_pipeline(*stage_list)
7372
repr_str = repr(ppl)
7473
assert repr_str.count("GenericStage") == num_stages
@@ -77,10 +76,9 @@ def test_pipeline_repr_long():
7776

7877
def test_pipeline__to_pb():
7978
from google.cloud.firestore_v1.types.pipeline import StructuredPipeline
80-
from google.cloud.firestore_v1.pipeline_stages import GenericStage
8179

82-
stage_1 = GenericStage("first")
83-
stage_2 = GenericStage("second")
80+
stage_1 = stages.GenericStage("first")
81+
stage_2 = stages.GenericStage("second")
8482
ppl = _make_pipeline(stage_1, stage_2)
8583
pb = ppl._to_pb()
8684
assert isinstance(pb, StructuredPipeline)
@@ -90,11 +88,9 @@ def test_pipeline__to_pb():
9088

9189
def test_pipeline_append():
9290
"""append should create a new pipeline with the additional stage"""
93-
from google.cloud.firestore_v1.pipeline_stages import GenericStage
94-
95-
stage_1 = GenericStage("first")
91+
stage_1 = stages.GenericStage("first")
9692
ppl_1 = _make_pipeline(stage_1, client=object())
97-
stage_2 = GenericStage("second")
93+
stage_2 = stages.GenericStage("second")
9894
ppl_2 = ppl_1._append(stage_2)
9995
assert ppl_1 != ppl_2
10096
assert len(ppl_1.stages) == 1
@@ -111,14 +107,13 @@ def test_pipeline_execute_empty():
111107
"""
112108
from google.cloud.firestore_v1.types import ExecutePipelineResponse
113109
from google.cloud.firestore_v1.types import ExecutePipelineRequest
114-
from google.cloud.firestore_v1.pipeline_stages import GenericStage
115110

116111
client = mock.Mock()
117112
client.project = "A"
118113
client._database = "B"
119114
mock_rpc = client._firestore_api.execute_pipeline
120115
mock_rpc.return_value = [ExecutePipelineResponse()]
121-
ppl_1 = _make_pipeline(GenericStage("s"), client=client)
116+
ppl_1 = _make_pipeline(stages.GenericStage("s"), client=client)
122117

123118
results = list(ppl_1.execute())
124119
assert results == []
@@ -136,7 +131,6 @@ def test_pipeline_execute_no_doc_ref():
136131
from google.cloud.firestore_v1.types import Document
137132
from google.cloud.firestore_v1.types import ExecutePipelineResponse
138133
from google.cloud.firestore_v1.types import ExecutePipelineRequest
139-
from google.cloud.firestore_v1.pipeline_stages import GenericStage
140134
from google.cloud.firestore_v1.pipeline_result import PipelineResult
141135

142136
client = mock.Mock()
@@ -146,7 +140,7 @@ def test_pipeline_execute_no_doc_ref():
146140
mock_rpc.return_value = [
147141
ExecutePipelineResponse(results=[Document()], execution_time={"seconds": 9})
148142
]
149-
ppl_1 = _make_pipeline(GenericStage("s"), client=client)
143+
ppl_1 = _make_pipeline(stages.GenericStage("s"), client=client)
150144

151145
results = list(ppl_1.execute())
152146
assert len(results) == 1
@@ -295,3 +289,20 @@ def test_pipeline_execute_with_transaction():
295289
assert request.structured_pipeline == ppl_1._to_pb()
296290
assert request.database == "projects/A/databases/B"
297291
assert request.transaction == b"123"
292+
293+
@pytest.mark.parametrize("method,args,result_cls", [
294+
("select", (), stages.Select),
295+
("where", (mock.Mock(),), stages.Where),
296+
("sort", (), stages.Sort),
297+
("offset", (1,), stages.Offset),
298+
("limit", (1,), stages.Limit),
299+
300+
])
301+
def test_pipeline_methods(method, args, result_cls):
302+
start_ppl = _make_pipeline()
303+
method_ptr = getattr(start_ppl, method)
304+
result_ppl = method_ptr(*args)
305+
assert result_ppl != start_ppl
306+
assert len(start_ppl.stages) == 0
307+
assert len(result_ppl.stages) == 1
308+
assert isinstance(result_ppl.stages[0], result_cls)

tests/unit/v1/test_pipeline_source.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,14 @@ def test_collection(self):
4444
assert isinstance(first_stage, stages.Collection)
4545
assert first_stage.path == "/path"
4646

47+
def test_collection_group(self):
48+
instance = self._make_client().pipeline()
49+
ppl = instance.collection_group("id")
50+
assert isinstance(ppl, self._expected_pipeline_type)
51+
assert len(ppl.stages) == 1
52+
first_stage = ppl.stages[0]
53+
assert isinstance(first_stage, stages.CollectionGroup)
54+
assert first_stage.collection_id == "id"
4755

4856
class TestPipelineSourceWithAsyncClient(TestPipelineSource):
4957
"""

0 commit comments

Comments
 (0)