Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for loading torchscript models #25321

Merged
merged 15 commits into from
Feb 11, 2023
2 changes: 2 additions & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@
* Add UDF metrics support for Samza portable mode.
* Option for SparkRunner to avoid the need of SDF output to fit in memory ([#23852](https://github.com/apache/beam/issues/23852)).
This helps e.g. with ParquetIO reads. Turn the feature on by adding experiment `use_bounded_concurrent_output_for_sdf`.
* Add support for loading TorchScript models with `PytorchModelHandler`. The TorchScript model path can be
passed to PytorchModelHandler using `torch_script_model_path=<path_to_model>`. ([#25321](https://github.com/apache/beam/pull/25321))

## Breaking Changes

Expand Down
6 changes: 6 additions & 0 deletions sdks/python/apache_beam/ml/inference/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,12 @@ def update_model_path(self, model_path: Optional[str] = None):
"""Update the model paths produced by side inputs."""
pass

def validate_constructor_args(self):
damccorm marked this conversation as resolved.
Show resolved Hide resolved
"""
Validate arguments passed to the ModelHandler constructor.
"""
raise NotImplementedError


class KeyedModelHandler(Generic[KeyT, ExampleT, PredictionT, ModelT],
ModelHandler[Tuple[KeyT, ExampleT],
Expand Down
147 changes: 117 additions & 30 deletions sdks/python/apache_beam/ml/inference/pytorch_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,32 +57,66 @@
Iterable[PredictionResult]]


def _load_model(
model_class: torch.nn.Module, state_dict_path, device, **model_params):
model = model_class(**model_params)
def _validate_constructor_args(
state_dict_path, model_class, torch_script_model_path):
message = (
"A {param1} has been supplied to the model "
"handler, but the required {param2} is missing. "
"Please provide the {param2} in order to "
"successfully load the {param1}.")
# state_dict_path and model_class are coupled with each other
# raise RuntimeError if user forgets to pass any one of them.
if state_dict_path and not model_class:
raise RuntimeError(
message.format(param1="state_dict_path", param2="model_class"))

if not state_dict_path and model_class:
raise RuntimeError(
message.format(param1="model_class", param2="state_dict_path"))

if torch_script_model_path and state_dict_path:
raise RuntimeError(
"Please specify either torch_script_model_path or "
"(state_dict_path, model_class) to successfully load the model.")


def _load_model(
model_class: Optional[Callable[..., torch.nn.Module]],
state_dict_path: Optional[str],
device: torch.device,
model_params: Optional[Dict[str, Any]],
torch_script_model_path: Optional[str]):
if device == torch.device('cuda') and not torch.cuda.is_available():
logging.warning(
"Model handler specified a 'GPU' device, but GPUs are not available. " \
"Model handler specified a 'GPU' device, but GPUs are not available. "
"Switching to CPU.")
device = torch.device('cpu')

file = FileSystems.open(state_dict_path, 'rb')
try:
logging.info(
"Loading state_dict_path %s onto a %s device", state_dict_path, device)
state_dict = torch.load(file, map_location=device)
if not torch_script_model_path:
file = FileSystems.open(state_dict_path, 'rb')
model = model_class(**model_params)
state_dict = torch.load(file, map_location=device)
model.load_state_dict(state_dict)
else:
file = FileSystems.open(torch_script_model_path, 'rb')
model = torch.jit.load(file, map_location=device)
except RuntimeError as e:
if device == torch.device('cuda'):
message = "Loading the model onto a GPU device failed due to an " \
f"exception:\n{e}\nAttempting to load onto a CPU device instead."
logging.warning(message)
return _load_model(
model_class, state_dict_path, torch.device('cpu'), **model_params)
model_class,
state_dict_path,
torch.device('cpu'),
model_params,
torch_script_model_path)
else:
raise e

model.load_state_dict(state_dict)
model.to(device)
model.eval()
logging.info("Finished loading PyTorch model.")
Expand Down Expand Up @@ -148,19 +182,23 @@ class PytorchModelHandlerTensor(ModelHandler[torch.Tensor,
torch.nn.Module]):
def __init__(
self,
state_dict_path: str,
model_class: Callable[..., torch.nn.Module],
model_params: Dict[str, Any],
state_dict_path: Optional[str] = None,
model_class: Optional[Callable[..., torch.nn.Module]] = None,
model_params: Optional[Dict[str, Any]] = None,
device: str = 'CPU',
*,
inference_fn: TensorInferenceFn = default_tensor_inference_fn,
torch_script_model_path: Optional[str] = None,
min_batch_size: Optional[int] = None,
max_batch_size: Optional[int] = None):
"""Implementation of the ModelHandler interface for PyTorch.

Example Usage::

pcoll | RunInference(PytorchModelHandlerTensor(state_dict_path="my_uri"))
Example Usage for torch model::
pcoll | RunInference(PytorchModelHandlerTensor(state_dict_path="my_uri",
model_class="my_class"))
Example Usage for torchscript model::
pcoll | RunInference(PytorchModelHandlerTensor(
torch_script_model_path="my_uri"))

See https://pytorch.org/tutorials/beginner/saving_loading_models.html
for details
Expand All @@ -176,6 +214,10 @@ def __init__(
Otherwise, it will be CPU.
inference_fn: the inference function to use during RunInference.
default=_default_tensor_inference_fn
torch_script_model_path: Path to the torch script model.
the model will be loaded using `torch.jit.load()`.
`state_dict_path`, `model_class` and `model_params`
arguments will be disregarded.

**Supported Versions:** RunInference APIs in Apache Beam have been tested
with PyTorch 1.9 and 1.10.
Expand All @@ -188,26 +230,42 @@ def __init__(
logging.info("Device is set to CPU")
self._device = torch.device('cpu')
self._model_class = model_class
self._model_params = model_params
self._model_params = model_params if model_params else {}
self._inference_fn = inference_fn
self._batching_kwargs = {}
if min_batch_size is not None:
self._batching_kwargs['min_batch_size'] = min_batch_size
if max_batch_size is not None:
self._batching_kwargs['max_batch_size'] = max_batch_size
self._torch_script_model_path = torch_script_model_path

self.validate_constructor_args()

def validate_constructor_args(self):
_validate_constructor_args(
state_dict_path=self._state_dict_path,
model_class=self._model_class,
torch_script_model_path=self._torch_script_model_path)

def load_model(self) -> torch.nn.Module:
"""Loads and initializes a Pytorch model for processing."""
model, device = _load_model(
self._model_class,
self._state_dict_path,
self._device,
**self._model_params)
self._model_params,
self._torch_script_model_path
)
self._device = device
return model

def update_model_path(self, model_path: Optional[str] = None):
self._state_dict_path = model_path if model_path else self._state_dict_path
if self._torch_script_model_path:
self._torch_script_model_path = (
model_path if model_path else self._torch_script_model_path)
else:
self._state_dict_path = (
model_path if model_path else self._state_dict_path)

def run_inference(
self,
Expand Down Expand Up @@ -236,9 +294,11 @@ def run_inference(
An Iterable of type PredictionResult.
"""
inference_args = {} if not inference_args else inference_args

model_id = (
self._state_dict_path
if not self._torch_script_model_path else self._torch_script_model_path)
return self._inference_fn(
batch, model, self._device, inference_args, self._state_dict_path)
batch, model, self._device, inference_args, model_id)

def get_num_bytes(self, batch: Sequence[torch.Tensor]) -> int:
"""
Expand Down Expand Up @@ -332,20 +392,25 @@ class PytorchModelHandlerKeyedTensor(ModelHandler[Dict[str, torch.Tensor],
torch.nn.Module]):
def __init__(
self,
state_dict_path: str,
model_class: Callable[..., torch.nn.Module],
model_params: Dict[str, Any],
state_dict_path: Optional[str] = None,
model_class: Optional[Callable[..., torch.nn.Module]] = None,
model_params: Optional[Dict[str, Any]] = None,
device: str = 'CPU',
*,
inference_fn: KeyedTensorInferenceFn = default_keyed_tensor_inference_fn,
torch_script_model_path: Optional[str] = None,
min_batch_size: Optional[int] = None,
max_batch_size: Optional[int] = None):
"""Implementation of the ModelHandler interface for PyTorch.

Example Usage::
Example Usage for torch model::
pcoll | RunInference(PytorchModelHandlerKeyedTensor(
state_dict_path="my_uri",
model_class="my_class"))

pcoll | RunInference(
PytorchModelHandlerKeyedTensor(state_dict_path="my_uri"))
Example Usage for torchscript model::
pcoll | RunInference(PytorchModelHandlerKeyedTensor(
torch_script_model_path="my_uri"))

**NOTE:** This API and its implementation are under development and
do not provide backward compatibility guarantees.
Expand All @@ -364,6 +429,10 @@ def __init__(
Otherwise, it will be CPU.
inference_fn: the function to invoke on run_inference.
default = default_keyed_tensor_inference_fn
torch_script_model_path: Path to the torch script model.
the model will be loaded using `torch.jit.load()`.
`state_dict_path`, `model_class` and `model_params`
arguments will be disregarded..

**Supported Versions:** RunInference APIs in Apache Beam have been tested
on torch>=1.9.0,<1.14.0.
Expand All @@ -376,26 +445,42 @@ def __init__(
logging.info("Device is set to CPU")
self._device = torch.device('cpu')
self._model_class = model_class
self._model_params = model_params
self._model_params = model_params if model_params else {}
self._inference_fn = inference_fn
self._batching_kwargs = {}
if min_batch_size is not None:
self._batching_kwargs['min_batch_size'] = min_batch_size
if max_batch_size is not None:
self._batching_kwargs['max_batch_size'] = max_batch_size
self._torch_script_model_path = torch_script_model_path

self.validate_constructor_args()

def validate_constructor_args(self):
damccorm marked this conversation as resolved.
Show resolved Hide resolved
_validate_constructor_args(
state_dict_path=self._state_dict_path,
model_class=self._model_class,
torch_script_model_path=self._torch_script_model_path)

def load_model(self) -> torch.nn.Module:
"""Loads and initializes a Pytorch model for processing."""
model, device = _load_model(
self._model_class,
self._state_dict_path,
self._device,
**self._model_params)
self._model_params,
self._torch_script_model_path
)
self._device = device
return model

def update_model_path(self, model_path: Optional[str] = None):
self._state_dict_path = model_path if model_path else self._state_dict_path
if self._torch_script_model_path:
self._torch_script_model_path = (
model_path if model_path else self._torch_script_model_path)
else:
self._state_dict_path = (
model_path if model_path else self._state_dict_path)

def run_inference(
self,
Expand Down Expand Up @@ -424,9 +509,11 @@ def run_inference(
An Iterable of type PredictionResult.
"""
inference_args = {} if not inference_args else inference_args

model_id = (
self._state_dict_path
if not self._torch_script_model_path else self._torch_script_model_path)
return self._inference_fn(
batch, model, self._device, inference_args, self._state_dict_path)
batch, model, self._device, inference_args, model_id)

def get_num_bytes(self, batch: Sequence[torch.Tensor]) -> int:
"""
Expand Down
Loading