2
2
from functools import partial
3
3
from itertools import groupby
4
4
from pathlib import Path
5
- from typing import Iterable , List , Optional
5
+ from typing import Iterable , List , Optional , Union
6
6
7
7
from splitgill .manager import SplitgillClient , SplitgillDatabase
8
8
from splitgill .model import Record
@@ -141,16 +141,20 @@ def get_view(self, name: str) -> Optional[View]:
141
141
return view
142
142
return None
143
143
144
- def get_splitgill_database (self , view : View ) -> SplitgillDatabase :
144
+ def get_database (self , view : Union [ str , View ] ) -> Optional [ SplitgillDatabase ] :
145
145
"""
146
146
Returns a new SplitgillDatabase instance for the given view. If the view doesn't
147
- have an associated SplitgillDatabase name, then a ValueError is raised .
147
+ have an associated SplitgillDatabase name, then None is returned .
148
148
149
- :param view: a view
150
- :return: a SplitgillDatabase instance
149
+ :param view: a View instance or a view's name
150
+ :return: a SplitgillDatabase instance or None
151
151
"""
152
+ if isinstance (view , str ):
153
+ view = self .get_view (view )
154
+ if view is None :
155
+ return None
152
156
if not view .has_database :
153
- raise ValueError ( "View does not have a sg_name" )
157
+ return None
154
158
return SplitgillDatabase (view .sg_name , self .client )
155
159
156
160
def queue_changes (self , records : Iterable [SourceRecord ], store_name : str ):
@@ -266,7 +270,7 @@ def add_to_mongo(self, view_name: str, everything: bool = False) -> Optional[int
266
270
self .release_records (now ())
267
271
268
272
view = self .get_view (view_name )
269
- database = self .get_splitgill_database (view )
273
+ database = self .get_database (view )
270
274
271
275
if everything :
272
276
changed_records = view .iter_all ()
@@ -303,7 +307,7 @@ def sync_to_elasticsearch(self, view_name: str, resync: bool = False):
303
307
haven't changed
304
308
"""
305
309
view = self .get_view (view_name )
306
- database = self .get_splitgill_database (view )
310
+ database = self .get_database (view )
307
311
database .sync (resync = resync )
308
312
309
313
def force_merge (self , view_name : str ) -> dict :
@@ -315,7 +319,7 @@ def force_merge(self, view_name: str) -> dict:
315
319
:return:
316
320
"""
317
321
view = self .get_view (view_name )
318
- database = self .get_splitgill_database (view )
322
+ database = self .get_database (view )
319
323
client = self .client .elasticsearch
320
324
return client .options (request_timeout = None ).indices .forcemerge (
321
325
index = database .indices .wildcard ,
0 commit comments