1515import  mock 
1616import  pytest 
1717
18+ from  google .cloud .firestore_v1  import  pipeline_stages  as  stages 
19+ 
1820
1921def  _make_async_pipeline (* args , client = mock .Mock ()):
2022    from  google .cloud .firestore_v1 .async_pipeline  import  AsyncPipeline 
@@ -54,11 +56,9 @@ def test_async_pipeline_repr_single_stage():
5456
5557
5658def  test_async_pipeline_repr_multiple_stage ():
57-     from  google .cloud .firestore_v1 .pipeline_stages  import  GenericStage , Collection 
58- 
59-     stage_1  =  Collection ("path" )
60-     stage_2  =  GenericStage ("second" , 2 )
61-     stage_3  =  GenericStage ("third" , 3 )
59+     stage_1  =  stages .Collection ("path" )
60+     stage_2  =  stages .GenericStage ("second" , 2 )
61+     stage_3  =  stages .GenericStage ("third" , 3 )
6262    ppl  =  _make_async_pipeline (stage_1 , stage_2 , stage_3 )
6363    repr_str  =  repr (ppl )
6464    assert  repr_str  ==  (
@@ -71,10 +71,8 @@ def test_async_pipeline_repr_multiple_stage():
7171
7272
7373def  test_async_pipeline_repr_long ():
74-     from  google .cloud .firestore_v1 .pipeline_stages  import  GenericStage 
75- 
7674    num_stages  =  100 
77-     stage_list  =  [GenericStage ("custom" , i ) for  i  in  range (num_stages )]
75+     stage_list  =  [stages . GenericStage ("custom" , i ) for  i  in  range (num_stages )]
7876    ppl  =  _make_async_pipeline (* stage_list )
7977    repr_str  =  repr (ppl )
8078    assert  repr_str .count ("GenericStage" ) ==  num_stages 
@@ -83,10 +81,9 @@ def test_async_pipeline_repr_long():
8381
8482def  test_async_pipeline__to_pb ():
8583    from  google .cloud .firestore_v1 .types .pipeline  import  StructuredPipeline 
86-     from  google .cloud .firestore_v1 .pipeline_stages  import  GenericStage 
8784
88-     stage_1  =  GenericStage ("first" )
89-     stage_2  =  GenericStage ("second" )
85+     stage_1  =  stages . GenericStage ("first" )
86+     stage_2  =  stages . GenericStage ("second" )
9087    ppl  =  _make_async_pipeline (stage_1 , stage_2 )
9188    pb  =  ppl ._to_pb ()
9289    assert  isinstance (pb , StructuredPipeline )
@@ -96,11 +93,9 @@ def test_async_pipeline__to_pb():
9693
9794def  test_async_pipeline_append ():
9895    """append should create a new pipeline with the additional stage""" 
99-     from  google .cloud .firestore_v1 .pipeline_stages  import  GenericStage 
100- 
101-     stage_1  =  GenericStage ("first" )
96+     stage_1  =  stages .GenericStage ("first" )
10297    ppl_1  =  _make_async_pipeline (stage_1 , client = object ())
103-     stage_2  =  GenericStage ("second" )
98+     stage_2  =  stages . GenericStage ("second" )
10499    ppl_2  =  ppl_1 ._append (stage_2 )
105100    assert  ppl_1  !=  ppl_2 
106101    assert  len (ppl_1 .stages ) ==  1 
@@ -118,15 +113,14 @@ async def test_async_pipeline_execute_empty():
118113    """ 
119114    from  google .cloud .firestore_v1 .types  import  ExecutePipelineResponse 
120115    from  google .cloud .firestore_v1 .types  import  ExecutePipelineRequest 
121-     from  google .cloud .firestore_v1 .pipeline_stages  import  GenericStage 
122116
123117    client  =  mock .Mock ()
124118    client .project  =  "A" 
125119    client ._database  =  "B" 
126120    mock_rpc  =  mock .AsyncMock ()
127121    client ._firestore_api .execute_pipeline  =  mock_rpc 
128122    mock_rpc .return_value  =  _async_it ([ExecutePipelineResponse ()])
129-     ppl_1  =  _make_async_pipeline (GenericStage ("s" ), client = client )
123+     ppl_1  =  _make_async_pipeline (stages . GenericStage ("s" ), client = client )
130124
131125    results  =  [r  async  for  r  in  ppl_1 .execute ()]
132126    assert  results  ==  []
@@ -145,7 +139,6 @@ async def test_async_pipeline_execute_no_doc_ref():
145139    from  google .cloud .firestore_v1 .types  import  Document 
146140    from  google .cloud .firestore_v1 .types  import  ExecutePipelineResponse 
147141    from  google .cloud .firestore_v1 .types  import  ExecutePipelineRequest 
148-     from  google .cloud .firestore_v1 .pipeline_stages  import  GenericStage 
149142    from  google .cloud .firestore_v1 .pipeline_result  import  PipelineResult 
150143
151144    client  =  mock .Mock ()
@@ -156,7 +149,7 @@ async def test_async_pipeline_execute_no_doc_ref():
156149    mock_rpc .return_value  =  _async_it (
157150        [ExecutePipelineResponse (results = [Document ()], execution_time = {"seconds" : 9 })]
158151    )
159-     ppl_1  =  _make_async_pipeline (GenericStage ("s" ), client = client )
152+     ppl_1  =  _make_async_pipeline (stages . GenericStage ("s" ), client = client )
160153
161154    results  =  [r  async  for  r  in  ppl_1 .execute ()]
162155    assert  len (results ) ==  1 
@@ -315,3 +308,20 @@ async def test_async_pipeline_execute_with_transaction():
315308    assert  request .structured_pipeline  ==  ppl_1 ._to_pb ()
316309    assert  request .database  ==  "projects/A/databases/B" 
317310    assert  request .transaction  ==  b"123" 
311+ 
312+ @pytest .mark .parametrize ("method,args,result_cls" , [ 
313+     ("select" , (), stages .Select ), 
314+     ("where" , (mock .Mock (),), stages .Where ), 
315+     ("sort" , (), stages .Sort ), 
316+     ("offset" , (1 ,), stages .Offset ), 
317+     ("limit" , (1 ,), stages .Limit ), 
318+ 
319+ ]) 
320+ def  test_async_pipeline_methods (method , args , result_cls ):
321+     start_ppl  =  _make_async_pipeline ()
322+     method_ptr  =  getattr (start_ppl , method )
323+     result_ppl  =  method_ptr (* args )
324+     assert  result_ppl  !=  start_ppl 
325+     assert  len (start_ppl .stages ) ==  0 
326+     assert  len (result_ppl .stages ) ==  1 
327+     assert  isinstance (result_ppl .stages [0 ], result_cls )
0 commit comments