-
Notifications
You must be signed in to change notification settings - Fork 4.3k
/
base.py
1461 lines (1264 loc) · 58.4 KB
/
base.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# TODO: https://github.com/apache/beam/issues/21822
# mypy: ignore-errors
"""An extensible run inference transform.
Users of this module can extend the ModelHandler class for any machine learning
framework. A ModelHandler implementation is a required parameter of
RunInference.
The transform handles standard inference functionality, like metric
collection, sharing model between threads, and batching elements.
"""
import logging
import os
import pickle
import sys
import threading
import time
import uuid
from collections import OrderedDict
from collections import defaultdict
from copy import deepcopy
from dataclasses import dataclass
from typing import Any
from typing import Callable
from typing import Dict
from typing import Generic
from typing import Iterable
from typing import List
from typing import Mapping
from typing import NamedTuple
from typing import Optional
from typing import Sequence
from typing import Tuple
from typing import TypeVar
from typing import Union
import apache_beam as beam
from apache_beam.utils import multi_process_shared
from apache_beam.utils import shared
try:
# pylint: disable=wrong-import-order, wrong-import-position
import resource
except ImportError:
resource = None # type: ignore[assignment]
_NANOSECOND_TO_MILLISECOND = 1_000_000
_NANOSECOND_TO_MICROSECOND = 1_000
ModelT = TypeVar('ModelT')
ExampleT = TypeVar('ExampleT')
PreProcessT = TypeVar('PreProcessT')
PredictionT = TypeVar('PredictionT')
PostProcessT = TypeVar('PostProcessT')
_INPUT_TYPE = TypeVar('_INPUT_TYPE')
_OUTPUT_TYPE = TypeVar('_OUTPUT_TYPE')
KeyT = TypeVar('KeyT')
# We use NamedTuple to define the structure of the PredictionResult,
# however, as support for generic NamedTuples is not available in Python
# versions prior to 3.11, we use the __new__ method to provide default
# values for the fields while maintaining backwards compatibility.
class PredictionResult(NamedTuple('PredictionResult',
[('example', _INPUT_TYPE),
('inference', _OUTPUT_TYPE),
('model_id', Optional[str])])):
__slots__ = ()
def __new__(cls, example, inference, model_id=None):
return super().__new__(cls, example, inference, model_id)
PredictionResult.__doc__ = """A NamedTuple containing both input and output
from the inference."""
PredictionResult.example.__doc__ = """The input example."""
PredictionResult.inference.__doc__ = """Results for the inference on the model
for the given example."""
PredictionResult.model_id.__doc__ = """Model ID used to run the prediction."""
class ModelMetadata(NamedTuple):
model_id: str
model_name: str
class RunInferenceDLQ(NamedTuple):
failed_inferences: beam.PCollection
failed_preprocessing: Sequence[beam.PCollection]
failed_postprocessing: Sequence[beam.PCollection]
class _ModelLoadStats(NamedTuple):
model_tag: str
load_latency: Optional[int]
byte_size: Optional[int]
ModelMetadata.model_id.__doc__ = """Unique identifier for the model. This can be
a file path or a URL where the model can be accessed. It is used to load
the model for inference."""
ModelMetadata.model_name.__doc__ = """Human-readable name for the model. This
can be used to identify the model in the metrics generated by the
RunInference transform."""
def _to_milliseconds(time_ns: int) -> int:
return int(time_ns / _NANOSECOND_TO_MILLISECOND)
def _to_microseconds(time_ns: int) -> int:
return int(time_ns / _NANOSECOND_TO_MICROSECOND)
@dataclass(frozen=True)
class KeyModelPathMapping(Generic[KeyT]):
"""
Dataclass for mapping 1 or more keys to 1 model path. This is used in
conjunction with a KeyedModelHandler with many model handlers to update
a set of keys' model handlers with the new path. Given
`KeyModelPathMapping(keys: ['key1', 'key2'], update_path: 'updated/path',
model_id: 'id1')`, all examples with keys `key1` or `key2` will have their
corresponding model handler's update_model function called with
'updated/path' and their metrics will correspond with 'id1'. For more
information see the KeyedModelHandler documentation
https://beam.apache.org/releases/pydoc/current/apache_beam.ml.inference.base.html#apache_beam.ml.inference.base.KeyedModelHandler
documentation and the website section on model updates
https://beam.apache.org/documentation/sdks/python-machine-learning/#automatic-model-refresh
"""
keys: List[KeyT]
update_path: str
model_id: str = ''
class ModelHandler(Generic[ExampleT, PredictionT, ModelT]):
"""Has the ability to load and apply an ML model."""
def __init__(self):
"""Environment variables are set using a dict named 'env_vars' before
loading the model. Child classes can accept this dict as a kwarg."""
self._env_vars = {}
def load_model(self) -> ModelT:
"""Loads and initializes a model for processing."""
raise NotImplementedError(type(self))
def run_inference(
self,
batch: Sequence[ExampleT],
model: ModelT,
inference_args: Optional[Dict[str, Any]] = None) -> Iterable[PredictionT]:
"""Runs inferences on a batch of examples.
Args:
batch: A sequence of examples or features.
model: The model used to make inferences.
inference_args: Extra arguments for models whose inference call requires
extra parameters.
Returns:
An Iterable of Predictions.
"""
raise NotImplementedError(type(self))
def get_num_bytes(self, batch: Sequence[ExampleT]) -> int:
"""
Returns:
The number of bytes of data for a batch.
"""
return len(pickle.dumps(batch))
def get_metrics_namespace(self) -> str:
"""
Returns:
A namespace for metrics collected by the RunInference transform.
"""
return 'RunInference'
def get_resource_hints(self) -> dict:
"""
Returns:
Resource hints for the transform.
"""
return {}
def batch_elements_kwargs(self) -> Mapping[str, Any]:
"""
Returns:
kwargs suitable for beam.BatchElements.
"""
return {}
def validate_inference_args(self, inference_args: Optional[Dict[str, Any]]):
"""Validates inference_args passed in the inference call.
Because most frameworks do not need extra arguments in their predict() call,
the default behavior is to error out if inference_args are present.
"""
if inference_args:
raise ValueError(
'inference_args were provided, but should be None because this '
'framework does not expect extra arguments on inferences.')
def update_model_path(self, model_path: Optional[str] = None):
"""
Update the model path produced by side inputs. update_model_path should be
used when a ModelHandler represents a single model, not multiple models.
This will be true in most cases. For more information see the website
section on model updates
https://beam.apache.org/documentation/sdks/python-machine-learning/#automatic-model-refresh
"""
pass
def update_model_paths(
self,
model: ModelT,
model_paths: Optional[Union[str, List[KeyModelPathMapping]]] = None):
"""
Update the model paths produced by side inputs. update_model_paths should
be used when updating multiple models at once (e.g. when using a
KeyedModelHandler that holds multiple models). For more information see
the KeyedModelHandler documentation
https://beam.apache.org/releases/pydoc/current/apache_beam.ml.inference.base.html#apache_beam.ml.inference.base.KeyedModelHandler
documentation and the website section on model updates
https://beam.apache.org/documentation/sdks/python-machine-learning/#automatic-model-refresh
"""
pass
def get_preprocess_fns(self) -> Iterable[Callable[[Any], Any]]:
"""Gets all preprocessing functions to be run before batching/inference.
Functions are in order that they should be applied."""
return []
def get_postprocess_fns(self) -> Iterable[Callable[[Any], Any]]:
"""Gets all postprocessing functions to be run after inference.
Functions are in order that they should be applied."""
return []
def set_environment_vars(self):
"""Sets environment variables using a dictionary provided via kwargs.
Keys are the env variable name, and values are the env variable value.
Child ModelHandler classes should set _env_vars via kwargs in __init__,
or else call super().__init__()."""
env_vars = getattr(self, '_env_vars', {})
for env_variable, env_value in env_vars.items():
os.environ[env_variable] = env_value
def with_preprocess_fn(
self, fn: Callable[[PreProcessT], ExampleT]
) -> 'ModelHandler[PreProcessT, PredictionT, ModelT, PreProcessT]':
"""Returns a new ModelHandler with a preprocessing function
associated with it. The preprocessing function will be run
before batching/inference and should map your input PCollection
to the base ModelHandler's input type. If you apply multiple
preprocessing functions, they will be run on your original
PCollection in order from last applied to first applied."""
return _PreProcessingModelHandler(self, fn)
def with_postprocess_fn(
self, fn: Callable[[PredictionT], PostProcessT]
) -> 'ModelHandler[ExampleT, PostProcessT, ModelT, PostProcessT]':
"""Returns a new ModelHandler with a postprocessing function
associated with it. The postprocessing function will be run
after inference and should map the base ModelHandler's output
type to your desired output type. If you apply multiple
postprocessing functions, they will be run on your original
inference result in order from first applied to last applied."""
return _PostProcessingModelHandler(self, fn)
def share_model_across_processes(self) -> bool:
"""Returns a boolean representing whether or not a model should
be shared across multiple processes instead of being loaded per process.
This is primary useful for large models that can't fit multiple copies in
memory. Multi-process support may vary by runner, but this will fallback to
loading per process as necessary. See
https://beam.apache.org/releases/pydoc/current/apache_beam.utils.multi_process_shared.html"""
return False
def override_metrics(self, metrics_namespace: str = '') -> bool:
"""Returns a boolean representing whether or not a model handler will
override metrics reporting. If True, RunInference will not report any
metrics."""
return False
class _ModelManager:
"""
A class for efficiently managing copies of multiple models. Will load a
single copy of each model into a multi_process_shared object and then
return a lookup key for that object.
"""
def __init__(self, mh_map: Dict[str, ModelHandler]):
"""
Args:
mh_map: A map from keys to model handlers which can be used to load a
model.
"""
self._max_models = None
# Map keys to model handlers
self._mh_map: Dict[str, ModelHandler] = mh_map
# Map keys to the last updated model path for that key
self._key_to_last_update: Dict[str, str] = defaultdict(str)
# Map key for a model to a unique tag that will persist for the life of
# that model in memory. A new tag will be generated if a model is swapped
# out of memory and reloaded.
self._tag_map: Dict[str, str] = OrderedDict()
# Map a tag to a multiprocessshared model object for that tag. Each entry
# of this map should last as long as the corresponding entry in _tag_map.
self._proxy_map: Dict[str, multi_process_shared.MultiProcessShared] = {}
def load(self, key: str) -> _ModelLoadStats:
"""
Loads the appropriate model for the given key into memory.
Args:
key: the key associated with the model we'd like to load.
Returns:
_ModelLoadStats with tag, byte size, and latency to load the model. If
the model was already loaded, byte size/latency will be None.
"""
# Map the key for a model to a unique tag that will persist until the model
# is released. This needs to be unique between releasing/reacquiring th
# model because otherwise the ProxyManager will try to reuse the model that
# has been released and deleted.
if key in self._tag_map:
self._tag_map.move_to_end(key)
return _ModelLoadStats(self._tag_map[key], None, None)
else:
self._tag_map[key] = uuid.uuid4().hex
tag = self._tag_map[key]
mh = self._mh_map[key]
if self._max_models is not None and self._max_models < len(self._tag_map):
# If we're about to exceed our LRU size, release the last used model.
tag_to_remove = self._tag_map.popitem(last=False)[1]
shared_handle, model_to_remove = self._proxy_map[tag_to_remove]
shared_handle.release(model_to_remove)
del self._proxy_map[tag_to_remove]
# Load the new model
memory_before = _get_current_process_memory_in_bytes()
start_time = _to_milliseconds(time.time_ns())
shared_handle = multi_process_shared.MultiProcessShared(
mh.load_model, tag=tag)
model_reference = shared_handle.acquire()
self._proxy_map[tag] = (shared_handle, model_reference)
memory_after = _get_current_process_memory_in_bytes()
end_time = _to_milliseconds(time.time_ns())
return _ModelLoadStats(
tag, end_time - start_time, memory_after - memory_before)
def increment_max_models(self, increment: int):
"""
Increments the number of models that this instance of a _ModelManager is
able to hold. If it is never called, no limit is imposed.
Args:
increment: the amount by which we are incrementing the number of models.
"""
if self._max_models is None:
self._max_models = 0
self._max_models += increment
def update_model_handler(self, key: str, model_path: str, previous_key: str):
"""
Updates the model path of this model handler and removes it from memory so
that it can be reloaded with the updated path. No-ops if no model update
needs to be applied.
Args:
key: the key associated with the model we'd like to update.
model_path: the new path to the model we'd like to load.
previous_key: the key that is associated with the old version of this
model. This will often be the same as the current key, but sometimes
we will want to keep both the old and new models to serve different
cohorts. In that case, the keys should be different.
"""
if self._key_to_last_update[key] == model_path:
return
self._key_to_last_update[key] = model_path
if key not in self._mh_map:
self._mh_map[key] = deepcopy(self._mh_map[previous_key])
self._mh_map[key].update_model_path(model_path)
if key in self._tag_map:
tag_to_remove = self._tag_map[key]
shared_handle, model_to_remove = self._proxy_map[tag_to_remove]
shared_handle.release(model_to_remove)
del self._tag_map[key]
del self._proxy_map[tag_to_remove]
# Use a dataclass instead of named tuple because NamedTuples and generics don't
# mix well across the board for all versions:
# https://github.com/python/typing/issues/653
class KeyModelMapping(Generic[KeyT, ExampleT, PredictionT, ModelT]):
"""
Dataclass for mapping 1 or more keys to 1 model handler. Given
`KeyModelMapping(['key1', 'key2'], myMh)`, all examples with keys `key1`
or `key2` will be run against the model defined by the `myMh` ModelHandler.
"""
def __init__(
self, keys: List[KeyT], mh: ModelHandler[ExampleT, PredictionT, ModelT]):
self.keys = keys
self.mh = mh
class KeyedModelHandler(Generic[KeyT, ExampleT, PredictionT, ModelT],
ModelHandler[Tuple[KeyT, ExampleT],
Tuple[KeyT, PredictionT],
Union[ModelT, _ModelManager]]):
def __init__(
self,
unkeyed: Union[ModelHandler[ExampleT, PredictionT, ModelT],
List[KeyModelMapping[KeyT, ExampleT, PredictionT,
ModelT]]],
max_models_per_worker_hint: Optional[int] = None):
"""A ModelHandler that takes keyed examples and returns keyed predictions.
For example, if the original model is used with RunInference to take a
PCollection[E] to a PCollection[P], this ModelHandler would take a
PCollection[Tuple[K, E]] to a PCollection[Tuple[K, P]], making it possible
to use the key to associate the outputs with the inputs. KeyedModelHandler
is able to accept either a single unkeyed ModelHandler or many different
model handlers corresponding to the keys for which that ModelHandler should
be used. For example, the following configuration could be used to map keys
1-3 to ModelHandler1 and keys 4-5 to ModelHandler2:
k1 = ['k1', 'k2', 'k3']
k2 = ['k4', 'k5']
KeyedModelHandler([KeyModelMapping(k1, mh1), KeyModelMapping(k2, mh2)])
Note that a single copy of each of these models may all be held in memory
at the same time; be careful not to load too many large models or your
pipeline may cause Out of Memory exceptions.
KeyedModelHandlers support Automatic Model Refresh to update your model
to a newer version without stopping your streaming pipeline. For an
overview of this feature, see
https://beam.apache.org/documentation/sdks/python-machine-learning/#automatic-model-refresh
To use this feature with a KeyedModelHandler that has many models per key,
you can pass in a list of KeyModelPathMapping objects to define your new
model paths. For example, passing in the side input of
[KeyModelPathMapping(keys=['k1', 'k2'], update_path='update/path/1'),
KeyModelPathMapping(keys=['k3'], update_path='update/path/2')]
will update the model corresponding to keys 'k1' and 'k2' with path
'update/path/1' and the model corresponding to 'k3' with 'update/path/2'.
In order to do a side input update: (1) all restrictions mentioned in
https://beam.apache.org/documentation/sdks/python-machine-learning/#automatic-model-refresh
must be met, (2) all update_paths must be non-empty, even if they are not
being updated from their original values, and (3) the set of keys
originally defined cannot change. This means that if originally you have
defined model handlers for 'key1', 'key2', and 'key3', all 3 of those keys
must appear in your list of KeyModelPathMappings exactly once. No
additional keys can be added.
When using many models defined per key, metrics about inference and model
loading will be gathered on an aggregate basis for all keys. These will be
reported with no prefix. Metrics will also be gathered on a per key basis.
Since some keys can share the same model, only one set of metrics will be
reported per key 'cohort'. These will be reported in the form:
`<cohort_key>-<metric_name>`, where `<cohort_key>` can be any key selected
from the cohort. When model updates occur, the metrics will be reported in
the form `<cohort_key>-<model id>-<metric_name>`.
Args:
unkeyed: Either (a) an implementation of ModelHandler that does not
require keys or (b) a list of KeyModelMappings mapping lists of keys to
unkeyed ModelHandlers.
max_models_per_worker_hint: A hint to the runner indicating how many
models can be held in memory at one time per worker process. For
example, if your worker has 8 GB of memory provisioned and your workers
take up 1 GB each, you should set this to 7 to allow all models to sit
in memory with some buffer.
"""
self._metrics_collectors: Dict[str, _MetricsCollector] = {}
self._default_metrics_collector: _MetricsCollector = None
self._metrics_namespace = ''
self._single_model = not isinstance(unkeyed, list)
if self._single_model:
if len(unkeyed.get_preprocess_fns()) or len(
unkeyed.get_postprocess_fns()):
raise Exception(
'Cannot make make an unkeyed model handler with pre or '
'postprocessing functions defined into a keyed model handler. All '
'pre/postprocessing functions must be defined on the outer model'
'handler.')
self._env_vars = unkeyed._env_vars
self._unkeyed = unkeyed
return
self._max_models_per_worker_hint = max_models_per_worker_hint
# To maintain an efficient representation, we will map all keys in a given
# KeyModelMapping to a single id (the first key in the KeyModelMapping
# list). We will then map that key to a ModelHandler. This will allow us to
# quickly look up the appropriate ModelHandler for any given key.
self._id_to_mh_map: Dict[str, ModelHandler[ExampleT, PredictionT,
ModelT]] = {}
self._key_to_id_map: Dict[str, str] = {}
for mh_tuple in unkeyed:
mh = mh_tuple.mh
keys = mh_tuple.keys
if len(mh.get_preprocess_fns()) or len(mh.get_postprocess_fns()):
raise ValueError(
'Cannot use an unkeyed model handler with pre or '
'postprocessing functions defined in a keyed model handler. All '
'pre/postprocessing functions must be defined on the outer model'
'handler.')
hints = mh.get_resource_hints()
if len(hints) > 0:
logging.warning(
'mh %s defines the following resource hints, which will be'
'ignored: %s. Resource hints are not respected when more than one '
'model handler is used in a KeyedModelHandler. If you would like '
'to specify resource hints, you can do so by overriding the '
'KeyedModelHandler.get_resource_hints() method.',
mh,
hints)
batch_kwargs = mh.batch_elements_kwargs()
if len(batch_kwargs) > 0:
logging.warning(
'mh %s defines the following batching kwargs which will be '
'ignored %s. Batching kwargs are not respected when '
'more than one model handler is used in a KeyedModelHandler. If '
'you would like to specify resource hints, you can do so by '
'overriding the KeyedModelHandler.batch_elements_kwargs() method.',
hints,
batch_kwargs)
env_vars = mh._env_vars
if len(env_vars) > 0:
logging.warning(
'mh %s defines the following _env_vars which will be ignored %s. '
'_env_vars are not respected when more than one model handler is '
'used in a KeyedModelHandler. If you need env vars set at '
'inference time, you can do so with '
'a custom inference function.',
mh,
env_vars)
if len(keys) == 0:
raise ValueError(
f'Empty list maps to model handler {mh}. All model handlers must '
'have one or more associated keys.')
self._id_to_mh_map[keys[0]] = mh
for key in keys:
if key in self._key_to_id_map:
raise ValueError(
f'key {key} maps to multiple model handlers. All keys must map '
'to exactly one model handler.')
self._key_to_id_map[key] = keys[0]
def load_model(self) -> Union[ModelT, _ModelManager]:
if self._single_model:
return self._unkeyed.load_model()
return _ModelManager(self._id_to_mh_map)
def run_inference(
self,
batch: Sequence[Tuple[KeyT, ExampleT]],
model: Union[ModelT, _ModelManager],
inference_args: Optional[Dict[str, Any]] = None
) -> Iterable[Tuple[KeyT, PredictionT]]:
if self._single_model:
keys, unkeyed_batch = zip(*batch)
return zip(
keys,
self._unkeyed.run_inference(unkeyed_batch, model, inference_args))
# The first time a MultiProcessShared ModelManager is used for inference
# from this process, we should increment its max model count
if self._max_models_per_worker_hint is not None:
lock = threading.Lock()
if lock.acquire(blocking=False):
model.increment_max_models(self._max_models_per_worker_hint)
self._max_models_per_worker_hint = None
batch_by_key = defaultdict(list)
key_by_id = defaultdict(set)
for key, example in batch:
batch_by_key[key].append(example)
key_by_id[self._key_to_id_map[key]].add(key)
predictions = []
for id, keys in key_by_id.items():
mh = self._id_to_mh_map[id]
loaded_model = model.load(id)
keyed_model_tag = loaded_model.model_tag
if loaded_model.byte_size is not None:
self._metrics_collectors[id].update_load_model_metrics(
loaded_model.load_latency, loaded_model.byte_size)
self._default_metrics_collector.update_load_model_metrics(
loaded_model.load_latency, loaded_model.byte_size)
keyed_model_shared_handle = multi_process_shared.MultiProcessShared(
mh.load_model, tag=keyed_model_tag)
keyed_model = keyed_model_shared_handle.acquire()
start_time = _to_microseconds(time.time_ns())
num_bytes = 0
num_elements = 0
try:
for key in keys:
unkeyed_batches = batch_by_key[key]
try:
for inf in mh.run_inference(unkeyed_batches,
keyed_model,
inference_args):
predictions.append((key, inf))
except BaseException as e:
self._metrics_collectors[id].failed_batches_counter.inc()
self._default_metrics_collector.failed_batches_counter.inc()
raise e
num_bytes += mh.get_num_bytes(unkeyed_batches)
num_elements += len(unkeyed_batches)
finally:
keyed_model_shared_handle.release(keyed_model)
end_time = _to_microseconds(time.time_ns())
inference_latency = end_time - start_time
self._metrics_collectors[id].update(
num_elements, num_bytes, inference_latency)
self._default_metrics_collector.update(
num_elements, num_bytes, inference_latency)
return predictions
def get_num_bytes(self, batch: Sequence[Tuple[KeyT, ExampleT]]) -> int:
keys, unkeyed_batch = zip(*batch)
batch_bytes = len(pickle.dumps(keys))
if self._single_model:
return batch_bytes + self._unkeyed.get_num_bytes(unkeyed_batch)
batch_by_key = defaultdict(list)
for key, examples in batch:
batch_by_key[key].append(examples)
for key, examples in batch_by_key.items():
mh_id = self._key_to_id_map[key]
batch_bytes += self._id_to_mh_map[mh_id].get_num_bytes(examples)
return batch_bytes
def get_metrics_namespace(self) -> str:
if self._single_model:
return self._unkeyed.get_metrics_namespace()
return 'BeamML_KeyedModels'
def get_resource_hints(self):
if self._single_model:
return self._unkeyed.get_resource_hints()
return {}
def batch_elements_kwargs(self):
if self._single_model:
return self._unkeyed.batch_elements_kwargs()
return {}
def validate_inference_args(self, inference_args: Optional[Dict[str, Any]]):
if self._single_model:
return self._unkeyed.validate_inference_args(inference_args)
for mh in self._id_to_mh_map.values():
mh.validate_inference_args(inference_args)
def update_model_paths(
self,
model: Union[ModelT, _ModelManager],
model_paths: List[KeyModelPathMapping[KeyT]] = None):
# When there are many models, the keyed model handler is responsible for
# reorganizing the model handlers into cohorts and telling the model
# manager to update every cohort's associated model handler. The model
# manager is responsible for performing the updates and tracking which
# updates have already been applied.
if model_paths is None or len(model_paths) == 0 or model is None:
return
if self._single_model:
raise RuntimeError(
'Invalid model update: sent many model paths to '
'update, but KeyedModelHandler is wrapping a single '
'model.')
# Map cohort ids to a dictionary mapping new model paths to the keys that
# were originally in that cohort. We will use this to construct our new
# cohorts.
# cohort_path_mapping will be structured as follows:
# {
# original_cohort_id: {
# 'update/path/1': ['key1FromOriginalCohort', key2FromOriginalCohort'],
# 'update/path/2': ['key3FromOriginalCohort', key4FromOriginalCohort'],
# }
# }
cohort_path_mapping: Dict[KeyT, Dict[str, List[KeyT]]] = {}
key_modelid_mapping: Dict[KeyT, str] = {}
seen_keys = set()
for mp in model_paths:
keys = mp.keys
update_path = mp.update_path
model_id = mp.model_id
if len(update_path) == 0:
raise ValueError(f'Invalid model update, path for {keys} is empty')
for key in keys:
if key in seen_keys:
raise ValueError(
f'Invalid model update: {key} appears in multiple '
'update lists. A single model update must provide exactly one '
'updated path per key.')
seen_keys.add(key)
if key not in self._key_to_id_map:
raise ValueError(
f'Invalid model update: {key} appears in '
'update, but not in the original configuration.')
key_modelid_mapping[key] = model_id
cohort_id = self._key_to_id_map[key]
if cohort_id not in cohort_path_mapping:
cohort_path_mapping[cohort_id] = defaultdict(list)
cohort_path_mapping[cohort_id][update_path].append(key)
for key in self._key_to_id_map:
if key not in seen_keys:
raise ValueError(
f'Invalid model update: {key} appears in the '
'original configuration, but not the update.')
# We now have our new set of cohorts. For each one, update our local model
# handler configuration and send the results to the ModelManager
for old_cohort_id, path_key_mapping in cohort_path_mapping.items():
for updated_path, keys in path_key_mapping.items():
cohort_id = old_cohort_id
if old_cohort_id not in keys:
# Create new cohort
cohort_id = keys[0]
for key in keys:
self._key_to_id_map[key] = cohort_id
mh = self._id_to_mh_map[old_cohort_id]
self._id_to_mh_map[cohort_id] = deepcopy(mh)
self._id_to_mh_map[cohort_id].update_model_path(updated_path)
model.update_model_handler(cohort_id, updated_path, old_cohort_id)
model_id = key_modelid_mapping[cohort_id]
self._metrics_collectors[cohort_id] = _MetricsCollector(
self._metrics_namespace, f'{cohort_id}-{model_id}-')
def update_model_path(self, model_path: Optional[str] = None):
if self._single_model:
return self._unkeyed.update_model_path(model_path=model_path)
if model_path is not None:
raise RuntimeError(
'Model updates are currently not supported for ' +
'KeyedModelHandlers with multiple different per-key ' +
'ModelHandlers.')
def share_model_across_processes(self) -> bool:
if self._single_model:
return self._unkeyed.share_model_across_processes()
return True
def override_metrics(self, metrics_namespace: str = '') -> bool:
if self._single_model:
return self._unkeyed.override_metrics(metrics_namespace)
self._metrics_namespace = metrics_namespace
self._default_metrics_collector = _MetricsCollector(metrics_namespace)
for cohort_id in self._id_to_mh_map:
self._metrics_collectors[cohort_id] = _MetricsCollector(
metrics_namespace, f'{cohort_id}-')
return True
class MaybeKeyedModelHandler(Generic[KeyT, ExampleT, PredictionT, ModelT],
ModelHandler[Union[ExampleT, Tuple[KeyT,
ExampleT]],
Union[PredictionT,
Tuple[KeyT, PredictionT]],
ModelT]):
def __init__(self, unkeyed: ModelHandler[ExampleT, PredictionT, ModelT]):
"""A ModelHandler that takes examples that might have keys and returns
predictions that might have keys.
For example, if the original model is used with RunInference to take a
PCollection[E] to a PCollection[P], this ModelHandler would take either
PCollection[E] to a PCollection[P] or PCollection[Tuple[K, E]] to a
PCollection[Tuple[K, P]], depending on the whether the elements are
tuples. This pattern makes it possible to associate the outputs with the
inputs based on the key.
Note that you cannot use this ModelHandler if E is a tuple type.
In addition, either all examples should be keyed, or none of them.
Args:
unkeyed: An implementation of ModelHandler that does not require keys.
"""
if len(unkeyed.get_preprocess_fns()) or len(unkeyed.get_postprocess_fns()):
raise Exception(
'Cannot make make an unkeyed model handler with pre or '
'postprocessing functions defined into a keyed model handler. All '
'pre/postprocessing functions must be defined on the outer model'
'handler.')
self._unkeyed = unkeyed
self._env_vars = unkeyed._env_vars
def load_model(self) -> ModelT:
return self._unkeyed.load_model()
def run_inference(
self,
batch: Sequence[Union[ExampleT, Tuple[KeyT, ExampleT]]],
model: ModelT,
inference_args: Optional[Dict[str, Any]] = None
) -> Union[Iterable[PredictionT], Iterable[Tuple[KeyT, PredictionT]]]:
# Really the input should be
# Union[Sequence[ExampleT], Sequence[Tuple[KeyT, ExampleT]]]
# but there's not a good way to express (or check) that.
if isinstance(batch[0], tuple):
is_keyed = True
keys, unkeyed_batch = zip(*batch) # type: ignore[arg-type]
else:
is_keyed = False
unkeyed_batch = batch # type: ignore[assignment]
unkeyed_results = self._unkeyed.run_inference(
unkeyed_batch, model, inference_args)
if is_keyed:
return zip(keys, unkeyed_results)
else:
return unkeyed_results
def get_num_bytes(
self, batch: Sequence[Union[ExampleT, Tuple[KeyT, ExampleT]]]) -> int:
# MyPy can't follow the branching logic.
if isinstance(batch[0], tuple):
keys, unkeyed_batch = zip(*batch) # type: ignore[arg-type]
return len(
pickle.dumps(keys)) + self._unkeyed.get_num_bytes(unkeyed_batch)
else:
return self._unkeyed.get_num_bytes(batch) # type: ignore[arg-type]
def get_metrics_namespace(self) -> str:
return self._unkeyed.get_metrics_namespace()
def get_resource_hints(self):
return self._unkeyed.get_resource_hints()
def batch_elements_kwargs(self):
return self._unkeyed.batch_elements_kwargs()
def validate_inference_args(self, inference_args: Optional[Dict[str, Any]]):
return self._unkeyed.validate_inference_args(inference_args)
def update_model_path(self, model_path: Optional[str] = None):
return self._unkeyed.update_model_path(model_path=model_path)
def get_preprocess_fns(self) -> Iterable[Callable[[Any], Any]]:
return self._unkeyed.get_preprocess_fns()
def get_postprocess_fns(self) -> Iterable[Callable[[Any], Any]]:
return self._unkeyed.get_postprocess_fns()
def share_model_across_processes(self) -> bool:
return self._unkeyed.share_model_across_processes()
class _PreProcessingModelHandler(Generic[ExampleT,
PredictionT,
ModelT,
PreProcessT],
ModelHandler[PreProcessT, PredictionT,
ModelT]):
def __init__(
self,
base: ModelHandler[ExampleT, PredictionT, ModelT],
preprocess_fn: Callable[[PreProcessT], ExampleT]):
"""A ModelHandler that has a preprocessing function associated with it.
Args:
base: An implementation of the underlying model handler.
preprocess_fn: the preprocessing function to use.
"""
self._base = base
self._env_vars = base._env_vars
self._preprocess_fn = preprocess_fn
def load_model(self) -> ModelT:
return self._base.load_model()
def run_inference(
self,
batch: Sequence[Union[ExampleT, Tuple[KeyT, ExampleT]]],
model: ModelT,
inference_args: Optional[Dict[str, Any]] = None
) -> Union[Iterable[PredictionT], Iterable[Tuple[KeyT, PredictionT]]]:
return self._base.run_inference(batch, model, inference_args)
def get_num_bytes(
self, batch: Sequence[Union[ExampleT, Tuple[KeyT, ExampleT]]]) -> int:
return self._base.get_num_bytes(batch)
def get_metrics_namespace(self) -> str:
return self._base.get_metrics_namespace()
def get_resource_hints(self):
return self._base.get_resource_hints()
def batch_elements_kwargs(self):
return self._base.batch_elements_kwargs()
def validate_inference_args(self, inference_args: Optional[Dict[str, Any]]):
return self._base.validate_inference_args(inference_args)
def update_model_path(self, model_path: Optional[str] = None):
return self._base.update_model_path(model_path=model_path)
def get_preprocess_fns(self) -> Iterable[Callable[[Any], Any]]:
return [self._preprocess_fn] + self._base.get_preprocess_fns()
def get_postprocess_fns(self) -> Iterable[Callable[[Any], Any]]:
return self._base.get_postprocess_fns()
class _PostProcessingModelHandler(Generic[ExampleT,
PredictionT,
ModelT,
PostProcessT],
ModelHandler[ExampleT, PostProcessT, ModelT]):
def __init__(
self,
base: ModelHandler[ExampleT, PredictionT, ModelT],
postprocess_fn: Callable[[PredictionT], PostProcessT]):
"""A ModelHandler that has a preprocessing function associated with it.
Args:
base: An implementation of the underlying model handler.
postprocess_fn: the preprocessing function to use.
"""
self._base = base
self._env_vars = base._env_vars
self._postprocess_fn = postprocess_fn
def load_model(self) -> ModelT:
return self._base.load_model()
def run_inference(
self,
batch: Sequence[Union[ExampleT, Tuple[KeyT, ExampleT]]],
model: ModelT,
inference_args: Optional[Dict[str, Any]] = None
) -> Union[Iterable[PredictionT], Iterable[Tuple[KeyT, PredictionT]]]:
return self._base.run_inference(batch, model, inference_args)
def get_num_bytes(
self, batch: Sequence[Union[ExampleT, Tuple[KeyT, ExampleT]]]) -> int:
return self._base.get_num_bytes(batch)
def get_metrics_namespace(self) -> str:
return self._base.get_metrics_namespace()
def get_resource_hints(self):
return self._base.get_resource_hints()
def batch_elements_kwargs(self):
return self._base.batch_elements_kwargs()
def validate_inference_args(self, inference_args: Optional[Dict[str, Any]]):
return self._base.validate_inference_args(inference_args)
def update_model_path(self, model_path: Optional[str] = None):
return self._base.update_model_path(model_path=model_path)
def get_preprocess_fns(self) -> Iterable[Callable[[Any], Any]]:
return self._base.get_preprocess_fns()
def get_postprocess_fns(self) -> Iterable[Callable[[Any], Any]]:
return self._base.get_postprocess_fns() + [self._postprocess_fn]
class RunInference(beam.PTransform[beam.PCollection[ExampleT],
beam.PCollection[PredictionT]]):
def __init__(
self,
model_handler: ModelHandler[ExampleT, PredictionT, Any],
clock=time,
inference_args: Optional[Dict[str, Any]] = None,
metrics_namespace: Optional[str] = None,
*,
model_metadata_pcoll: beam.PCollection[ModelMetadata] = None,
watch_model_pattern: Optional[str] = None,
**kwargs):
"""
A transform that takes a PCollection of examples (or features) for use