Skip to content

Commit

Permalink
Add ability to load multiple copies of a model across processes (#31052)
Browse files Browse the repository at this point in the history
* Add ability to load multiple copies of a model across processes

* push changes I had locally not remotely

* Lint

* naming + lint

* Changes from feedback
  • Loading branch information
damccorm committed Apr 25, 2024
1 parent 6384a4b commit 567f7a0
Show file tree
Hide file tree
Showing 8 changed files with 283 additions and 30 deletions.
110 changes: 102 additions & 8 deletions sdks/python/apache_beam/ml/inference/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,13 @@ def share_model_across_processes(self) -> bool:
https://beam.apache.org/releases/pydoc/current/apache_beam.utils.multi_process_shared.html"""
return False

def model_copies(self) -> int:
"""Returns the maximum number of model copies that should be loaded at one
time. This only impacts model handlers that are using
share_model_across_processes to share their model across processes instead
of being loaded per process."""
return 1

def override_metrics(self, metrics_namespace: str = '') -> bool:
"""Returns a boolean representing whether or not a model handler will
override metrics reporting. If True, RunInference will not report any
Expand Down Expand Up @@ -795,6 +802,21 @@ def share_model_across_processes(self) -> bool:
return self._unkeyed.share_model_across_processes()
return True

def model_copies(self) -> int:
if self._single_model:
return self._unkeyed.model_copies()
for mh in self._id_to_mh_map.values():
if mh.model_copies() != 1:
raise ValueError(
'KeyedModelHandler cannot map records to multiple '
'models if one or more of its ModelHandlers '
'require multiple model copies (set via '
'model_copies). To fix, verify that each '
'ModelHandler is not set to load multiple copies of '
'its model.')

return 1

def override_metrics(self, metrics_namespace: str = '') -> bool:
if self._single_model:
return self._unkeyed.override_metrics(metrics_namespace)
Expand Down Expand Up @@ -902,6 +924,9 @@ def should_skip_batching(self) -> bool:
def share_model_across_processes(self) -> bool:
return self._unkeyed.share_model_across_processes()

def model_copies(self) -> int:
return self._unkeyed.model_copies()


class _PrebatchedModelHandler(Generic[ExampleT, PredictionT, ModelT],
ModelHandler[Sequence[ExampleT],
Expand Down Expand Up @@ -952,6 +977,12 @@ def get_preprocess_fns(self) -> Iterable[Callable[[Any], Any]]:
def should_skip_batching(self) -> bool:
return True

def share_model_across_processes(self) -> bool:
return self._base.share_model_across_processes()

def model_copies(self) -> int:
return self._base.model_copies()

def get_postprocess_fns(self) -> Iterable[Callable[[Any], Any]]:
return self._base.get_postprocess_fns()

Expand Down Expand Up @@ -1012,6 +1043,12 @@ def get_preprocess_fns(self) -> Iterable[Callable[[Any], Any]]:
def should_skip_batching(self) -> bool:
return self._base.should_skip_batching()

def share_model_across_processes(self) -> bool:
return self._base.share_model_across_processes()

def model_copies(self) -> int:
return self._base.model_copies()

def get_postprocess_fns(self) -> Iterable[Callable[[Any], Any]]:
return self._base.get_postprocess_fns()

Expand Down Expand Up @@ -1071,6 +1108,12 @@ def get_preprocess_fns(self) -> Iterable[Callable[[Any], Any]]:
def should_skip_batching(self) -> bool:
return self._base.should_skip_batching()

def share_model_across_processes(self) -> bool:
return self._base.share_model_across_processes()

def model_copies(self) -> int:
return self._base.model_copies()

def get_postprocess_fns(self) -> Iterable[Callable[[Any], Any]]:
return self._base.get_postprocess_fns() + [self._postprocess_fn]

Expand Down Expand Up @@ -1378,6 +1421,45 @@ def update(
self._inference_request_batch_byte_size.update(examples_byte_size)


class _ModelRoutingStrategy():
"""A class meant to sit in a shared location for mapping incoming batches to
different models. Currently only supports round-robin, but can be extended
to support other protocols if needed.
"""
def __init__(self):
self._cur_index = 0

def next_model_index(self, num_models):
self._cur_index = (self._cur_index + 1) % num_models
return self._cur_index


class _SharedModelWrapper():
"""A router class to map incoming calls to the correct model.
This allows us to round robin calls to models sitting in different
processes so that we can more efficiently use resources (e.g. GPUs).
"""
def __init__(self, models: List[Any], model_tag: str):
self.models = models
if len(models) > 1:
self.model_router = multi_process_shared.MultiProcessShared(
lambda: _ModelRoutingStrategy(),
tag=f'{model_tag}_counter',
always_proxy=True).acquire()

def next_model(self):
if len(self.models) == 1:
# Short circuit if there's no routing strategy needed in order to
# avoid the cross-process call
return self.models[0]

return self.models[self.model_router.next_model_index(len(self.models))]

def all_models(self):
return self.models


class _RunInferenceDoFn(beam.DoFn, Generic[ExampleT, PredictionT]):
def __init__(
self,
Expand Down Expand Up @@ -1408,16 +1490,19 @@ def __init__(
def _load_model(
self,
side_input_model_path: Optional[Union[str,
List[KeyModelPathMapping]]] = None):
List[KeyModelPathMapping]]] = None
) -> _SharedModelWrapper:
def load():
"""Function for constructing shared LoadedModel."""
memory_before = _get_current_process_memory_in_bytes()
start_time = _to_milliseconds(self._clock.time_ns())
if isinstance(side_input_model_path, str):
self._model_handler.update_model_path(side_input_model_path)
else:
self._model_handler.update_model_paths(
self._model, side_input_model_path)
if self._model is not None:
models = self._model.all_models()
for m in models:
self._model_handler.update_model_paths(m, side_input_model_path)
model = self._model_handler.load_model()
end_time = _to_milliseconds(self._clock.time_ns())
memory_after = _get_current_process_memory_in_bytes()
Expand All @@ -1434,19 +1519,27 @@ def load():
if isinstance(side_input_model_path, str) and side_input_model_path != '':
model_tag = side_input_model_path
if self._model_handler.share_model_across_processes():
model = multi_process_shared.MultiProcessShared(
load, tag=model_tag, always_proxy=True).acquire()
models = []
for i in range(self._model_handler.model_copies()):
models.append(
multi_process_shared.MultiProcessShared(
load, tag=f'{model_tag}{i}', always_proxy=True).acquire())
model_wrapper = _SharedModelWrapper(models, model_tag)
else:
model = self._shared_model_handle.acquire(load, tag=model_tag)
model_wrapper = _SharedModelWrapper([model], model_tag)
# since shared_model_handle is shared across threads, the model path
# might not get updated in the model handler
# because we directly get cached weak ref model from shared cache, instead
# of calling load(). For sanity check, call update_model_path again.
if isinstance(side_input_model_path, str):
self._model_handler.update_model_path(side_input_model_path)
else:
self._model_handler.update_model_paths(self._model, side_input_model_path)
return model
if self._model is not None:
models = self._model.all_models()
for m in models:
self._model_handler.update_model_paths(m, side_input_model_path)
return model_wrapper

def get_metrics_collector(self, prefix: str = ''):
"""
Expand Down Expand Up @@ -1476,8 +1569,9 @@ def update_model(
def _run_inference(self, batch, inference_args):
start_time = _to_microseconds(self._clock.time_ns())
try:
model = self._model.next_model()
result_generator = self._model_handler.run_inference(
batch, self._model, inference_args)
batch, model, inference_args)
except BaseException as e:
if self._metrics_collector:
self._metrics_collector.failed_batches_counter.inc()
Expand Down
71 changes: 71 additions & 0 deletions sdks/python/apache_beam/ml/inference/base_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,15 @@ def increment_state(self, amount: int):
self._state += amount


class FakeIncrementingModel:
def __init__(self):
self._state = 0

def predict(self, example: int) -> int:
self._state += 1
return self._state


class FakeModelHandler(base.ModelHandler[int, int, FakeModel]):
def __init__(
self,
Expand All @@ -71,6 +80,8 @@ def __init__(
max_batch_size=9999,
multi_process_shared=False,
state=None,
incrementing=False,
max_copies=1,
num_bytes_per_element=None,
**kwargs):
self._fake_clock = clock
Expand All @@ -79,11 +90,16 @@ def __init__(
self._env_vars = kwargs.get('env_vars', {})
self._multi_process_shared = multi_process_shared
self._state = state
self._incrementing = incrementing
self._max_copies = max_copies
self._num_bytes_per_element = num_bytes_per_element

def load_model(self):
assert (not self._incrementing or self._state is None)
if self._fake_clock:
self._fake_clock.current_time_ns += 500_000_000 # 500ms
if self._incrementing:
return FakeIncrementingModel()
if self._state is not None:
return FakeStatefulModel(self._state)
return FakeModel()
Expand Down Expand Up @@ -116,6 +132,9 @@ def batch_elements_kwargs(self):
def share_model_across_processes(self):
return self._multi_process_shared

def model_copies(self):
return self._max_copies

def get_num_bytes(self, batch: Sequence[int]) -> int:
if self._num_bytes_per_element:
return self._num_bytes_per_element * len(batch)
Expand Down Expand Up @@ -258,6 +277,58 @@ def test_run_inference_impl_simple_examples_multi_process_shared(self):
FakeModelHandler(multi_process_shared=True))
assert_that(actual, equal_to(expected), label='assert:inferences')

def test_run_inference_impl_simple_examples_multi_process_shared_multi_copy(
self):
with TestPipeline() as pipeline:
examples = [1, 5, 3, 10]
expected = [example + 1 for example in examples]
pcoll = pipeline | 'start' >> beam.Create(examples)
actual = pcoll | base.RunInference(
FakeModelHandler(multi_process_shared=True, max_copies=4))
assert_that(actual, equal_to(expected), label='assert:inferences')

def test_run_inference_impl_multi_process_shared_incrementing_multi_copy(
self):
with TestPipeline() as pipeline:
examples = [1, 5, 3, 10, 1, 5, 3, 10, 1, 5, 3, 10, 1, 5, 3, 10]
expected = [1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4]
pcoll = pipeline | 'start' >> beam.Create(examples)
actual = pcoll | base.RunInference(
FakeModelHandler(
multi_process_shared=True,
max_copies=4,
incrementing=True,
max_batch_size=1))
assert_that(actual, equal_to(expected), label='assert:inferences')

def test_run_inference_impl_mps_nobatch_incrementing_multi_copy(self):
with TestPipeline() as pipeline:
examples = [1, 5, 3, 10, 1, 5, 3, 10, 1, 5, 3, 10, 1, 5, 3, 10]
expected = [1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4]
batched_examples = [[example] for example in examples]
pcoll = pipeline | 'start' >> beam.Create(batched_examples)
actual = pcoll | base.RunInference(
FakeModelHandler(
multi_process_shared=True, max_copies=4,
incrementing=True).with_no_batching())
assert_that(actual, equal_to(expected), label='assert:inferences')

def test_run_inference_impl_keyed_mps_incrementing_multi_copy(self):
with TestPipeline() as pipeline:
examples = [1, 5, 3, 10, 1, 5, 3, 10, 1, 5, 3, 10, 1, 5, 3, 10]
keyed_examples = [('abc', example) for example in examples]
expected = [1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4]
keyed_expected = [('abc', val) for val in expected]
pcoll = pipeline | 'start' >> beam.Create(keyed_examples)
actual = pcoll | base.RunInference(
base.KeyedModelHandler(
FakeModelHandler(
multi_process_shared=True,
max_copies=4,
incrementing=True,
max_batch_size=1)))
assert_that(actual, equal_to(keyed_expected), label='assert:inferences')

def test_run_inference_impl_with_keyed_examples(self):
with TestPipeline() as pipeline:
examples = [1, 5, 3, 10]
Expand Down
Loading

0 comments on commit 567f7a0

Please sign in to comment.