From e371f37137f37bbc80e3aa5adfebdc52fcea4d54 Mon Sep 17 00:00:00 2001 From: Anand Inguva Date: Sun, 5 Feb 2023 23:12:25 -0500 Subject: [PATCH 01/13] Add support for loading torchscript models --- .../ml/inference/pytorch_inference.py | 75 +++++++++++++++---- 1 file changed, 61 insertions(+), 14 deletions(-) diff --git a/sdks/python/apache_beam/ml/inference/pytorch_inference.py b/sdks/python/apache_beam/ml/inference/pytorch_inference.py index 3366d523076f..87bf59956916 100644 --- a/sdks/python/apache_beam/ml/inference/pytorch_inference.py +++ b/sdks/python/apache_beam/ml/inference/pytorch_inference.py @@ -58,9 +58,11 @@ def _load_model( - model_class: torch.nn.Module, state_dict_path, device, **model_params): - model = model_class(**model_params) - + model_class: torch.nn.Module, + state_dict_path, + device, + model_params, + use_torch_script_format=False): if device == torch.device('cuda') and not torch.cuda.is_available(): logging.warning( "Model handler specified a 'GPU' device, but GPUs are not available. " \ @@ -71,18 +73,26 @@ def _load_model( try: logging.info( "Loading state_dict_path %s onto a %s device", state_dict_path, device) - state_dict = torch.load(file, map_location=device) + if not use_torch_script_format: + model = model_class(**model_params) + state_dict = torch.load(file, map_location=device) + model.load_state_dict(state_dict) + else: + model = torch.jit.load(file, map_location=device) except RuntimeError as e: if device == torch.device('cuda'): message = "Loading the model onto a GPU device failed due to an " \ f"exception:\n{e}\nAttempting to load onto a CPU device instead." logging.warning(message) return _load_model( - model_class, state_dict_path, torch.device('cpu'), **model_params) + model_class, + state_dict_path, + torch.device('cpu'), + model_params, + use_torch_script_format) else: raise e - model.load_state_dict(state_dict) model.to(device) model.eval() logging.info("Finished loading PyTorch model.") @@ -149,11 +159,13 @@ class PytorchModelHandlerTensor(ModelHandler[torch.Tensor, def __init__( self, state_dict_path: str, - model_class: Callable[..., torch.nn.Module], - model_params: Dict[str, Any], + model_class: Optional[Callable[..., torch.nn.Module]] = None, + model_params: Optional[Dict[str, Any]] = None, device: str = 'CPU', *, - inference_fn: TensorInferenceFn = default_tensor_inference_fn): + inference_fn: TensorInferenceFn = default_tensor_inference_fn, + use_torch_script_format=False, + ): """Implementation of the ModelHandler interface for PyTorch. Example Usage:: @@ -174,6 +186,9 @@ def __init__( Otherwise, it will be CPU. inference_fn: the inference function to use during RunInference. default=_default_tensor_inference_fn + use_torch_script_format: When `use_torch_script_format` is set to `True`, + the model will be loaded using `torch.jit.load()`. + `model_class` and `model_params` arguments will be disregarded. **Supported Versions:** RunInference APIs in Apache Beam have been tested with PyTorch 1.9 and 1.10. @@ -188,6 +203,18 @@ def __init__( self._model_class = model_class self._model_params = model_params self._inference_fn = inference_fn + self._use_torch_script_format = use_torch_script_format + + self._validate_func_args() + + def _validate_func_args(self): + if not self._use_torch_script_format and (self._model_class is None or + self._model_params is None): + raise RuntimeError( + "Please pass both `model_class` and `model_params` to the torch " + "model handler when using it with PyTorch. " + "If you opt to load the entire that was saved using TorchScript, " + "set `use_torch_script_format` to True.") def load_model(self) -> torch.nn.Module: """Loads and initializes a Pytorch model for processing.""" @@ -195,7 +222,9 @@ def load_model(self) -> torch.nn.Module: self._model_class, self._state_dict_path, self._device, - **self._model_params) + self._model_params, + self._use_torch_script_format + ) self._device = device return model @@ -323,11 +352,12 @@ class PytorchModelHandlerKeyedTensor(ModelHandler[Dict[str, torch.Tensor], def __init__( self, state_dict_path: str, - model_class: Callable[..., torch.nn.Module], - model_params: Dict[str, Any], + model_class: Optional[Callable[..., torch.nn.Module]] = None, + model_params: Optional[Dict[str, Any]] = None, device: str = 'CPU', *, - inference_fn: KeyedTensorInferenceFn = default_keyed_tensor_inference_fn): + inference_fn: KeyedTensorInferenceFn = default_keyed_tensor_inference_fn, + use_torch_script_format: bool = False): """Implementation of the ModelHandler interface for PyTorch. Example Usage:: @@ -352,6 +382,9 @@ def __init__( Otherwise, it will be CPU. inference_fn: the function to invoke on run_inference. default = default_keyed_tensor_inference_fn + use_torch_script_format: When `use_torch_script_format` is set to `True`, + the model will be loaded using `torch.jit.load()`. + `model_class` and `model_params` arguments will be disregarded. **Supported Versions:** RunInference APIs in Apache Beam have been tested on torch>=1.9.0,<1.14.0. @@ -366,6 +399,9 @@ def __init__( self._model_class = model_class self._model_params = model_params self._inference_fn = inference_fn + self._use_torch_script_format = use_torch_script_format + + self._validate_func_args() def load_model(self) -> torch.nn.Module: """Loads and initializes a Pytorch model for processing.""" @@ -373,7 +409,9 @@ def load_model(self) -> torch.nn.Module: self._model_class, self._state_dict_path, self._device, - **self._model_params) + self._model_params, + self._use_torch_script_format + ) self._device = device return model @@ -429,3 +467,12 @@ def get_metrics_namespace(self) -> str: def validate_inference_args(self, inference_args: Optional[Dict[str, Any]]): pass + + def _validate_func_args(self): + if not self._use_torch_script_format and (self._model_class is None or + self._model_params is None): + raise RuntimeError( + "Please pass both `model_class` and `model_params` to the torch " + "model handler when using it with PyTorch. " + "If you opt to load the entire that was saved using TorchScript, " + "set `use_torch_script_format` to True.") From 3bff28c1dc3dc5d3c3e45e7daa4fb6475164a9de Mon Sep 17 00:00:00 2001 From: Anand Inguva Date: Mon, 6 Feb 2023 10:47:02 -0500 Subject: [PATCH 02/13] Add tests --- .../ml/inference/pytorch_inference_test.py | 51 +++++++++++++++++++ 1 file changed, 51 insertions(+) diff --git a/sdks/python/apache_beam/ml/inference/pytorch_inference_test.py b/sdks/python/apache_beam/ml/inference/pytorch_inference_test.py index c47b21d2f25a..c6ee809f7884 100644 --- a/sdks/python/apache_beam/ml/inference/pytorch_inference_test.py +++ b/sdks/python/apache_beam/ml/inference/pytorch_inference_test.py @@ -609,6 +609,57 @@ def test_gpu_auto_convert_to_cpu(self): "are not available. Switching to CPU.", log.output) + def test_load_torch_script_model(self): + torch_model = PytorchLinearRegression(2, 1) + torch_script_model = torch.jit.script(torch_model) + + torch_script_path = os.path.join(self.tmpdir, 'torch_script_model.pt') + + torch.jit.save(torch_script_model, torch_script_path) + + model_handler = PytorchModelHandlerTensor( + state_dict_path=torch_script_path, use_torch_script_format=True) + + torch_script_model = model_handler.load_model() + + self.assertTrue(isinstance(torch_script_model, torch.jit.ScriptModule)) + + def test_inference_torch_script_model(self): + torch_model = PytorchLinearRegression(2, 1) + torch_model.load_state_dict( + OrderedDict([('linear.weight', torch.Tensor([[2.0, 3]])), + ('linear.bias', torch.Tensor([0.5]))])) + + torch_script_model = torch.jit.script(torch_model) + + torch_script_path = os.path.join(self.tmpdir, 'torch_script_model.pt') + + torch.jit.save(torch_script_model, torch_script_path) + + model_handler = PytorchModelHandlerTensor( + state_dict_path=torch_script_path, use_torch_script_format=True) + + with TestPipeline() as pipeline: + pcoll = pipeline | 'start' >> beam.Create(TWO_FEATURES_EXAMPLES) + predictions = pcoll | RunInference(model_handler) + assert_that( + predictions, + equal_to( + TWO_FEATURES_PREDICTIONS, equals_fn=_compare_prediction_result)) + + def test_torch_model_class_none(self): + torch_model = PytorchLinearRegression(2, 1) + torch_path = os.path.join(self.tmpdir, 'torch_model.pt') + + torch.save(torch_model, torch_path) + + with self.assertRaisesRegex( + RuntimeError, + "Please pass both `model_class` and `model_params` to the torch " + "model handler when using it with PyTorch. " + "If you opt to load the entire that was saved using TorchScript"): + _ = PytorchModelHandlerTensor(state_dict_path=torch_path) + if __name__ == '__main__': unittest.main() From 42576248a0d7728f6b78dc1b3c14b66b778bcf82 Mon Sep 17 00:00:00 2001 From: Anand Inguva Date: Mon, 6 Feb 2023 15:35:04 -0500 Subject: [PATCH 03/13] Add example and benchmark --- .../inference/pytorch_image_classification.py | 10 +++++--- ...pytorch_image_classification_benchmarks.py | 25 ++++++++++++++++--- 2 files changed, 28 insertions(+), 7 deletions(-) diff --git a/sdks/python/apache_beam/examples/inference/pytorch_image_classification.py b/sdks/python/apache_beam/examples/inference/pytorch_image_classification.py index 65f21ceaa318..78b1cd17c5bb 100644 --- a/sdks/python/apache_beam/examples/inference/pytorch_image_classification.py +++ b/sdks/python/apache_beam/examples/inference/pytorch_image_classification.py @@ -109,7 +109,8 @@ def run( model_params=None, save_main_session=True, device='CPU', - test_pipeline=None) -> PipelineResult: + test_pipeline=None, + use_torch_script_format=False) -> PipelineResult: """ Args: argv: Command line arguments defined for this example. @@ -120,12 +121,14 @@ def run( save_main_session: Used for internal testing. device: Device to be used on the Runner. Choices are (CPU, GPU). test_pipeline: Used for internal testing. + use_torch_script_format: Load the model which was saved using + Torchscript API. """ known_args, pipeline_args = parse_known_args(argv) pipeline_options = PipelineOptions(pipeline_args) pipeline_options.view_as(SetupOptions).save_main_session = save_main_session - if not model_class: + if not model_class and not use_torch_script_format: # default model class will be mobilenet with pretrained weights. model_class = models.mobilenet_v2 model_params = {'num_classes': 1000} @@ -141,7 +144,8 @@ def batch_elements_kwargs(self): state_dict_path=known_args.model_state_dict_path, model_class=model_class, model_params=model_params, - device=device)) + device=device, + use_torch_script_format=use_torch_script_format)) pipeline = test_pipeline if not test_pipeline: diff --git a/sdks/python/apache_beam/testing/benchmarks/inference/pytorch_image_classification_benchmarks.py b/sdks/python/apache_beam/testing/benchmarks/inference/pytorch_image_classification_benchmarks.py index 514c9d672850..c2c05e3c2f3a 100644 --- a/sdks/python/apache_beam/testing/benchmarks/inference/pytorch_image_classification_benchmarks.py +++ b/sdks/python/apache_beam/testing/benchmarks/inference/pytorch_image_classification_benchmarks.py @@ -33,7 +33,19 @@ def __init__(self): self.metrics_namespace = 'BeamML_PyTorch' super().__init__(metrics_namespace=self.metrics_namespace) - def test(self): + def run_with_torch_script_model(self): + extra_opts = {} + extra_opts['input'] = self.pipeline.get_option('input_file') + device = self.pipeline.get_option('device') + self.result = pytorch_image_classification.run( + self.pipeline.get_full_options_as_args(**extra_opts), + test_pipeline=self.pipeline, + device=device, + use_torch_script_format=True) + + def run_with_torch_model(self): + # model_params are same for all the models. But this may change if we add + # different models. pretrained_model_name = self.pipeline.get_option('pretrained_model_name') if not pretrained_model_name: raise RuntimeError( @@ -47,9 +59,6 @@ def test(self): model_class = models.resnet152 else: raise NotImplementedError - - # model_params are same for all the models. But this may change if we add - # different models. model_params = {'num_classes': 1000, 'pretrained': False} extra_opts = {} @@ -62,6 +71,14 @@ def test(self): test_pipeline=self.pipeline, device=device) + def test(self): + use_torch_script_format = self.pipeline.get_option( + 'use_torch_script_format') + if use_torch_script_format: + self.run_with_torch_script_model() + else: + self.run_with_torch_model() + if __name__ == '__main__': logging.basicConfig(level=logging.INFO) From f4d4cb41c96f4c830aa308a77e30ceda0b1ef847 Mon Sep 17 00:00:00 2001 From: Anand Inguva Date: Wed, 8 Feb 2023 14:59:17 -0500 Subject: [PATCH 04/13] Add validate_constructor_args method --- sdks/python/apache_beam/ml/inference/base.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/sdks/python/apache_beam/ml/inference/base.py b/sdks/python/apache_beam/ml/inference/base.py index 842607f36ffd..27de9c68dccf 100644 --- a/sdks/python/apache_beam/ml/inference/base.py +++ b/sdks/python/apache_beam/ml/inference/base.py @@ -174,6 +174,12 @@ def update_model_path(self, model_path: Optional[str] = None): """Update the model paths produced by side inputs.""" pass + def validate_constructor_args(self): + """ + Validate arguments passed to the ModelHandler constructor. + """ + raise NotImplementedError + class KeyedModelHandler(Generic[KeyT, ExampleT, PredictionT, ModelT], ModelHandler[Tuple[KeyT, ExampleT], From f6f579d4f17ff2f260ecf463ae118d207df38549 Mon Sep 17 00:00:00 2001 From: Anand Inguva Date: Wed, 8 Feb 2023 15:29:28 -0500 Subject: [PATCH 05/13] Addressing comments, fixing types --- .../ml/inference/pytorch_inference.py | 144 +++++++++++------- 1 file changed, 89 insertions(+), 55 deletions(-) diff --git a/sdks/python/apache_beam/ml/inference/pytorch_inference.py b/sdks/python/apache_beam/ml/inference/pytorch_inference.py index 87bf59956916..201f4df4a274 100644 --- a/sdks/python/apache_beam/ml/inference/pytorch_inference.py +++ b/sdks/python/apache_beam/ml/inference/pytorch_inference.py @@ -58,26 +58,27 @@ def _load_model( - model_class: torch.nn.Module, - state_dict_path, - device, - model_params, - use_torch_script_format=False): + model_class: Optional[Callable[..., torch.nn.Module]], + state_dict_path: Optional[str], + device: torch.device, + model_params: Optional[Dict[str, Any]], + torch_script_model_path: Optional[str]): if device == torch.device('cuda') and not torch.cuda.is_available(): logging.warning( - "Model handler specified a 'GPU' device, but GPUs are not available. " \ + "Model handler specified a 'GPU' device, but GPUs are not available. " "Switching to CPU.") device = torch.device('cpu') - file = FileSystems.open(state_dict_path, 'rb') try: logging.info( "Loading state_dict_path %s onto a %s device", state_dict_path, device) - if not use_torch_script_format: + if not torch_script_model_path: + file = FileSystems.open(state_dict_path, 'rb') model = model_class(**model_params) state_dict = torch.load(file, map_location=device) model.load_state_dict(state_dict) else: + file = FileSystems.open(torch_script_model_path, 'rb') model = torch.jit.load(file, map_location=device) except RuntimeError as e: if device == torch.device('cuda'): @@ -89,7 +90,7 @@ def _load_model( state_dict_path, torch.device('cpu'), model_params, - use_torch_script_format) + torch_script_model_path) else: raise e @@ -158,19 +159,22 @@ class PytorchModelHandlerTensor(ModelHandler[torch.Tensor, torch.nn.Module]): def __init__( self, - state_dict_path: str, + state_dict_path: Optional[str] = None, model_class: Optional[Callable[..., torch.nn.Module]] = None, model_params: Optional[Dict[str, Any]] = None, device: str = 'CPU', *, inference_fn: TensorInferenceFn = default_tensor_inference_fn, - use_torch_script_format=False, + torch_script_model_path: Optional[str] = None, ): """Implementation of the ModelHandler interface for PyTorch. - Example Usage:: - - pcoll | RunInference(PytorchModelHandlerTensor(state_dict_path="my_uri")) + Example Usage for torch model:: + pcoll | RunInference(PytorchModelHandlerTensor(state_dict_path="my_uri", + model_class="my_class")) + Example Usage for torchscript model:: + pcoll | RunInference(PytorchModelHandlerTensor( + torch_script_model_path="my_uri")) See https://pytorch.org/tutorials/beginner/saving_loading_models.html for details @@ -186,9 +190,10 @@ def __init__( Otherwise, it will be CPU. inference_fn: the inference function to use during RunInference. default=_default_tensor_inference_fn - use_torch_script_format: When `use_torch_script_format` is set to `True`, - the model will be loaded using `torch.jit.load()`. - `model_class` and `model_params` arguments will be disregarded. + torch_script_model_path: Path to the torch script model. + the model will be loaded using `torch.jit.load()`. + `state_dict_path`, `model_class` and `model_params` + arguments will be disregarded. **Supported Versions:** RunInference APIs in Apache Beam have been tested with PyTorch 1.9 and 1.10. @@ -201,20 +206,25 @@ def __init__( logging.info("Device is set to CPU") self._device = torch.device('cpu') self._model_class = model_class - self._model_params = model_params + self._model_params = model_params if model_params else {} self._inference_fn = inference_fn - self._use_torch_script_format = use_torch_script_format + self._torch_script_model_path = torch_script_model_path - self._validate_func_args() + self.validate_constructor_args() - def _validate_func_args(self): - if not self._use_torch_script_format and (self._model_class is None or - self._model_params is None): + def validate_constructor_args(self): + if self._state_dict_path and not self._model_class: raise RuntimeError( - "Please pass both `model_class` and `model_params` to the torch " - "model handler when using it with PyTorch. " - "If you opt to load the entire that was saved using TorchScript, " - "set `use_torch_script_format` to True.") + "A state_dict_path has been supplied to the model " + "handler, but the required model_class is missing. " + "Please provide the model_class in order to " + "successfully load the state_dict_path.") + + if self._torch_script_model_path: + if self._state_dict_path and self._model_class: + raise RuntimeError( + "Please specify either torch_script_model_path or " + "(state_dict_path, model_class) to successfully load the model.") def load_model(self) -> torch.nn.Module: """Loads and initializes a Pytorch model for processing.""" @@ -223,13 +233,18 @@ def load_model(self) -> torch.nn.Module: self._state_dict_path, self._device, self._model_params, - self._use_torch_script_format + self._torch_script_model_path ) self._device = device return model def update_model_path(self, model_path: Optional[str] = None): - self._state_dict_path = model_path if model_path else self._state_dict_path + if self._torch_script_model_path: + self._torch_script_model_path = ( + model_path if model_path else self._torch_script_model_path) + else: + self._state_dict_path = ( + model_path if model_path else self._state_dict_path) def run_inference( self, @@ -258,9 +273,11 @@ def run_inference( An Iterable of type PredictionResult. """ inference_args = {} if not inference_args else inference_args - + model_id = ( + self._state_dict_path + if not self._torch_script_model_path else self._torch_script_model_path) return self._inference_fn( - batch, model, self._device, inference_args, self._state_dict_path) + batch, model, self._device, inference_args, model_id) def get_num_bytes(self, batch: Sequence[torch.Tensor]) -> int: """ @@ -351,19 +368,23 @@ class PytorchModelHandlerKeyedTensor(ModelHandler[Dict[str, torch.Tensor], torch.nn.Module]): def __init__( self, - state_dict_path: str, + state_dict_path: Optional[str] = None, model_class: Optional[Callable[..., torch.nn.Module]] = None, model_params: Optional[Dict[str, Any]] = None, device: str = 'CPU', *, inference_fn: KeyedTensorInferenceFn = default_keyed_tensor_inference_fn, - use_torch_script_format: bool = False): + torch_script_model_path: Optional[str] = None): """Implementation of the ModelHandler interface for PyTorch. - Example Usage:: + Example Usage for torch model:: + pcoll | RunInference(PytorchModelHandlerKeyedTensor( + state_dict_path="my_uri", + model_class="my_class")) - pcoll | RunInference( - PytorchModelHandlerKeyedTensor(state_dict_path="my_uri")) + Example Usage for torchscript model:: + pcoll | RunInference(PytorchModelHandlerKeyedTensor( + torch_script_model_path="my_uri")) **NOTE:** This API and its implementation are under development and do not provide backward compatibility guarantees. @@ -382,9 +403,10 @@ def __init__( Otherwise, it will be CPU. inference_fn: the function to invoke on run_inference. default = default_keyed_tensor_inference_fn - use_torch_script_format: When `use_torch_script_format` is set to `True`, - the model will be loaded using `torch.jit.load()`. - `model_class` and `model_params` arguments will be disregarded. + torch_script_model_path: Path to the torch script model. + the model will be loaded using `torch.jit.load()`. + `state_dict_path`, `model_class` and `model_params` + arguments will be disregarded.. **Supported Versions:** RunInference APIs in Apache Beam have been tested on torch>=1.9.0,<1.14.0. @@ -397,11 +419,25 @@ def __init__( logging.info("Device is set to CPU") self._device = torch.device('cpu') self._model_class = model_class - self._model_params = model_params + self._model_params = model_params if model_params else {} self._inference_fn = inference_fn - self._use_torch_script_format = use_torch_script_format + self._torch_script_model_path = torch_script_model_path + + self.validate_constructor_args() - self._validate_func_args() + def validate_constructor_args(self): + if self._state_dict_path and not self._model_class: + raise RuntimeError( + "A state_dict_path has been supplied to the model " + "handler, but the required model_class is missing. " + "Please provide the model_class in order to " + "successfully load the state_dict_path.") + + if self._torch_script_model_path: + if self._state_dict_path and self._model_class: + raise RuntimeError( + "Please specify either torch_script_model_path or " + "(state_dict_path, model_class) to successfully load the model.") def load_model(self) -> torch.nn.Module: """Loads and initializes a Pytorch model for processing.""" @@ -410,13 +446,18 @@ def load_model(self) -> torch.nn.Module: self._state_dict_path, self._device, self._model_params, - self._use_torch_script_format + self._torch_script_model_path ) self._device = device return model def update_model_path(self, model_path: Optional[str] = None): - self._state_dict_path = model_path if model_path else self._state_dict_path + if self._torch_script_model_path: + self._torch_script_model_path = ( + model_path if model_path else self._torch_script_model_path) + else: + self._state_dict_path = ( + model_path if model_path else self._state_dict_path) def run_inference( self, @@ -445,9 +486,11 @@ def run_inference( An Iterable of type PredictionResult. """ inference_args = {} if not inference_args else inference_args - + model_id = ( + self._state_dict_path + if not self._torch_script_model_path else self._torch_script_model_path) return self._inference_fn( - batch, model, self._device, inference_args, self._state_dict_path) + batch, model, self._device, inference_args, model_id) def get_num_bytes(self, batch: Sequence[torch.Tensor]) -> int: """ @@ -467,12 +510,3 @@ def get_metrics_namespace(self) -> str: def validate_inference_args(self, inference_args: Optional[Dict[str, Any]]): pass - - def _validate_func_args(self): - if not self._use_torch_script_format and (self._model_class is None or - self._model_params is None): - raise RuntimeError( - "Please pass both `model_class` and `model_params` to the torch " - "model handler when using it with PyTorch. " - "If you opt to load the entire that was saved using TorchScript, " - "set `use_torch_script_format` to True.") From d24ea60810d7f297b71a193c1140f52075e53561 Mon Sep 17 00:00:00 2001 From: Anand Inguva Date: Wed, 8 Feb 2023 15:32:49 -0500 Subject: [PATCH 06/13] Fix/add tests --- .../ml/inference/pytorch_inference_test.py | 68 +++++++++++++++++-- 1 file changed, 63 insertions(+), 5 deletions(-) diff --git a/sdks/python/apache_beam/ml/inference/pytorch_inference_test.py b/sdks/python/apache_beam/ml/inference/pytorch_inference_test.py index c6ee809f7884..ed58de28feaf 100644 --- a/sdks/python/apache_beam/ml/inference/pytorch_inference_test.py +++ b/sdks/python/apache_beam/ml/inference/pytorch_inference_test.py @@ -123,6 +123,7 @@ def __init__(self, device, *, inference_fn=default_tensor_inference_fn): self._device = device self._inference_fn = inference_fn self._state_dict_path = None + self._torch_script_model_path = None class TestPytorchModelHandlerKeyedTensorForInferenceOnly( @@ -131,6 +132,7 @@ def __init__(self, device, *, inference_fn=default_keyed_tensor_inference_fn): self._device = device self._inference_fn = inference_fn self._state_dict_path = None + self._torch_script_model_path = None def _compare_prediction_result(x, y): @@ -618,7 +620,7 @@ def test_load_torch_script_model(self): torch.jit.save(torch_script_model, torch_script_path) model_handler = PytorchModelHandlerTensor( - state_dict_path=torch_script_path, use_torch_script_format=True) + torch_script_model_path=torch_script_path) torch_script_model = model_handler.load_model() @@ -637,7 +639,7 @@ def test_inference_torch_script_model(self): torch.jit.save(torch_script_model, torch_script_path) model_handler = PytorchModelHandlerTensor( - state_dict_path=torch_script_path, use_torch_script_format=True) + torch_script_model_path=torch_script_path) with TestPipeline() as pipeline: pcoll = pipeline | 'start' >> beam.Create(TWO_FEATURES_EXAMPLES) @@ -655,11 +657,67 @@ def test_torch_model_class_none(self): with self.assertRaisesRegex( RuntimeError, - "Please pass both `model_class` and `model_params` to the torch " - "model handler when using it with PyTorch. " - "If you opt to load the entire that was saved using TorchScript"): + "A state_dict_path has been supplied to the model " + "handler, but the required model_class is missing. " + "Please provide the model_class in order to"): _ = PytorchModelHandlerTensor(state_dict_path=torch_path) + def test_specify_torch_script_path_and_state_dict_path(self): + torch_model = PytorchLinearRegression(2, 1) + torch_path = os.path.join(self.tmpdir, 'torch_model.pt') + + torch.save(torch_model, torch_path) + torch_script_model = torch.jit.script(torch_model) + + torch_script_path = os.path.join(self.tmpdir, 'torch_script_model.pt') + + torch.jit.save(torch_script_model, torch_script_path) + with self.assertRaisesRegex( + RuntimeError, "Please specify either torch_script_model_path or "): + _ = PytorchModelHandlerTensor( + state_dict_path=torch_path, + model_class=PytorchLinearRegression, + torch_script_model_path=torch_script_path) + + def test_prediction_result_model_id_with_torch_script_model(self): + torch_model = PytorchLinearRegression(2, 1) + torch_script_model = torch.jit.script(torch_model) + torch_script_path = os.path.join(self.tmpdir, 'torch_script_model.pt') + torch.jit.save(torch_script_model, torch_script_path) + + model_handler = PytorchModelHandlerTensor( + torch_script_model_path=torch_script_path) + + def check_torch_script_model_id(element): + assert ('torch_script_model.pt' in element.model_id) is True + + with TestPipeline() as pipeline: + pcoll = pipeline | 'start' >> beam.Create(TWO_FEATURES_EXAMPLES) + predictions = pcoll | RunInference(model_handler) + _ = predictions | beam.Map(check_torch_script_model_id) + + def test_prediction_result_model_id_with_torch_model(self): + # weights associated with PytorchLinearRegression class + state_dict = OrderedDict([('linear.weight', torch.Tensor([[2.0, 3]])), + ('linear.bias', torch.Tensor([0.5]))]) + torch_path = os.path.join(self.tmpdir, 'torch_model.pt') + torch.save(state_dict, torch_path) + + model_handler = PytorchModelHandlerTensor( + state_dict_path=torch_path, + model_class=PytorchLinearRegression, + model_params={ + 'input_dim': 2, 'output_dim': 1 + }) + + def check_torch_script_model_id(element): + assert ('torch_model.pt' in element.model_id) is True + + with TestPipeline() as pipeline: + pcoll = pipeline | 'start' >> beam.Create(TWO_FEATURES_EXAMPLES) + predictions = pcoll | RunInference(model_handler) + _ = predictions | beam.Map(check_torch_script_model_id) + if __name__ == '__main__': unittest.main() From fc7dae046cd26f9e8ebc456660240711a07ce5b8 Mon Sep 17 00:00:00 2001 From: Anand Inguva Date: Wed, 8 Feb 2023 15:38:02 -0500 Subject: [PATCH 07/13] Add change log --- CHANGES.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/CHANGES.md b/CHANGES.md index b26b5f886fdf..11d15c7dab2d 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -65,6 +65,8 @@ * Add UDF metrics support for Samza portable mode. * Option for SparkRunner to avoid the need of SDF output to fit in memory ([#23852](https://github.com/apache/beam/issues/23852)). This helps e.g. with ParquetIO reads. Turn the feature on by adding experiment `use_bounded_concurrent_output_for_sdf`. +* Add support for loading TorchScript models with `PytorchModelHandler`. The TorchScript model path can be + passed to PytorchModelHandler using `torch_script_model_path=`. ([#25321](https://github.com/apache/beam/pull/25321)) ## Breaking Changes From 01a67954c2a6ff90c9dc3ce711e23c50498ebca2 Mon Sep 17 00:00:00 2001 From: Anand Inguva Date: Wed, 8 Feb 2023 15:40:15 -0500 Subject: [PATCH 08/13] revert changes --- .../inference/pytorch_image_classification.py | 10 +++----- ...pytorch_image_classification_benchmarks.py | 25 +++---------------- 2 files changed, 7 insertions(+), 28 deletions(-) diff --git a/sdks/python/apache_beam/examples/inference/pytorch_image_classification.py b/sdks/python/apache_beam/examples/inference/pytorch_image_classification.py index 78b1cd17c5bb..65f21ceaa318 100644 --- a/sdks/python/apache_beam/examples/inference/pytorch_image_classification.py +++ b/sdks/python/apache_beam/examples/inference/pytorch_image_classification.py @@ -109,8 +109,7 @@ def run( model_params=None, save_main_session=True, device='CPU', - test_pipeline=None, - use_torch_script_format=False) -> PipelineResult: + test_pipeline=None) -> PipelineResult: """ Args: argv: Command line arguments defined for this example. @@ -121,14 +120,12 @@ def run( save_main_session: Used for internal testing. device: Device to be used on the Runner. Choices are (CPU, GPU). test_pipeline: Used for internal testing. - use_torch_script_format: Load the model which was saved using - Torchscript API. """ known_args, pipeline_args = parse_known_args(argv) pipeline_options = PipelineOptions(pipeline_args) pipeline_options.view_as(SetupOptions).save_main_session = save_main_session - if not model_class and not use_torch_script_format: + if not model_class: # default model class will be mobilenet with pretrained weights. model_class = models.mobilenet_v2 model_params = {'num_classes': 1000} @@ -144,8 +141,7 @@ def batch_elements_kwargs(self): state_dict_path=known_args.model_state_dict_path, model_class=model_class, model_params=model_params, - device=device, - use_torch_script_format=use_torch_script_format)) + device=device)) pipeline = test_pipeline if not test_pipeline: diff --git a/sdks/python/apache_beam/testing/benchmarks/inference/pytorch_image_classification_benchmarks.py b/sdks/python/apache_beam/testing/benchmarks/inference/pytorch_image_classification_benchmarks.py index c2c05e3c2f3a..514c9d672850 100644 --- a/sdks/python/apache_beam/testing/benchmarks/inference/pytorch_image_classification_benchmarks.py +++ b/sdks/python/apache_beam/testing/benchmarks/inference/pytorch_image_classification_benchmarks.py @@ -33,19 +33,7 @@ def __init__(self): self.metrics_namespace = 'BeamML_PyTorch' super().__init__(metrics_namespace=self.metrics_namespace) - def run_with_torch_script_model(self): - extra_opts = {} - extra_opts['input'] = self.pipeline.get_option('input_file') - device = self.pipeline.get_option('device') - self.result = pytorch_image_classification.run( - self.pipeline.get_full_options_as_args(**extra_opts), - test_pipeline=self.pipeline, - device=device, - use_torch_script_format=True) - - def run_with_torch_model(self): - # model_params are same for all the models. But this may change if we add - # different models. + def test(self): pretrained_model_name = self.pipeline.get_option('pretrained_model_name') if not pretrained_model_name: raise RuntimeError( @@ -59,6 +47,9 @@ def run_with_torch_model(self): model_class = models.resnet152 else: raise NotImplementedError + + # model_params are same for all the models. But this may change if we add + # different models. model_params = {'num_classes': 1000, 'pretrained': False} extra_opts = {} @@ -71,14 +62,6 @@ def run_with_torch_model(self): test_pipeline=self.pipeline, device=device) - def test(self): - use_torch_script_format = self.pipeline.get_option( - 'use_torch_script_format') - if use_torch_script_format: - self.run_with_torch_script_model() - else: - self.run_with_torch_model() - if __name__ == '__main__': logging.basicConfig(level=logging.INFO) From dd348ec7f32b47e107e1aa9a1afc729e3f26eff5 Mon Sep 17 00:00:00 2001 From: Anand Inguva Date: Thu, 9 Feb 2023 15:00:10 -0500 Subject: [PATCH 09/13] Add few more checks and refactor --- .../ml/inference/pytorch_inference.py | 55 +++++++++++-------- .../ml/inference/pytorch_inference_test.py | 22 ++++++++ 2 files changed, 53 insertions(+), 24 deletions(-) diff --git a/sdks/python/apache_beam/ml/inference/pytorch_inference.py b/sdks/python/apache_beam/ml/inference/pytorch_inference.py index 93c41dccfa06..4180a5aad13f 100644 --- a/sdks/python/apache_beam/ml/inference/pytorch_inference.py +++ b/sdks/python/apache_beam/ml/inference/pytorch_inference.py @@ -57,6 +57,29 @@ Iterable[PredictionResult]] +def _validate_constructor_args( + state_dict_path, model_class, torch_script_model_path): + message = ( + "A {param1} has been supplied to the model " + "handler, but the required {param2} is missing. " + "Please provide the {param2} in order to " + "successfully load the {param1}.") + # state_dict_path and model_class are coupled with each other + # raise RuntimeError if user forgets to pass any one of them. + if state_dict_path and not model_class: + raise RuntimeError( + message.format(param1="state_dict_path", param2="model_class")) + + if not state_dict_path and model_class: + raise RuntimeError( + message.format(param1="model_class", param2="state_dict_path")) + + if torch_script_model_path and state_dict_path: + raise RuntimeError( + "Please specify either torch_script_model_path or " + "(state_dict_path, model_class) to successfully load the model.") + + def _load_model( model_class: Optional[Callable[..., torch.nn.Module]], state_dict_path: Optional[str], @@ -219,18 +242,10 @@ def __init__( self.validate_constructor_args() def validate_constructor_args(self): - if self._state_dict_path and not self._model_class: - raise RuntimeError( - "A state_dict_path has been supplied to the model " - "handler, but the required model_class is missing. " - "Please provide the model_class in order to " - "successfully load the state_dict_path.") - - if self._torch_script_model_path: - if self._state_dict_path and self._model_class: - raise RuntimeError( - "Please specify either torch_script_model_path or " - "(state_dict_path, model_class) to successfully load the model.") + _validate_constructor_args( + state_dict_path=self._state_dict_path, + model_class=self._model_class, + torch_script_model_path=self._torch_script_model_path) def load_model(self) -> torch.nn.Module: """Loads and initializes a Pytorch model for processing.""" @@ -442,18 +457,10 @@ def __init__( self.validate_constructor_args() def validate_constructor_args(self): - if self._state_dict_path and not self._model_class: - raise RuntimeError( - "A state_dict_path has been supplied to the model " - "handler, but the required model_class is missing. " - "Please provide the model_class in order to " - "successfully load the state_dict_path.") - - if self._torch_script_model_path: - if self._state_dict_path and self._model_class: - raise RuntimeError( - "Please specify either torch_script_model_path or " - "(state_dict_path, model_class) to successfully load the model.") + _validate_constructor_args( + state_dict_path=self._state_dict_path, + model_class=self._model_class, + torch_script_model_path=self._torch_script_model_path) def load_model(self) -> torch.nn.Module: """Loads and initializes a Pytorch model for processing.""" diff --git a/sdks/python/apache_beam/ml/inference/pytorch_inference_test.py b/sdks/python/apache_beam/ml/inference/pytorch_inference_test.py index 91b9ef8e0301..947e06e20473 100644 --- a/sdks/python/apache_beam/ml/inference/pytorch_inference_test.py +++ b/sdks/python/apache_beam/ml/inference/pytorch_inference_test.py @@ -754,6 +754,28 @@ def test_torch_model_class_none(self): "Please provide the model_class in order to"): _ = PytorchModelHandlerTensor(state_dict_path=torch_path) + with self.assertRaisesRegex( + RuntimeError, + "A state_dict_path has been supplied to the model " + "handler, but the required model_class is missing. " + "Please provide the model_class in order to"): + _ = (PytorchModelHandlerKeyedTensor(state_dict_path=torch_path)) + + def test_torch_model_state_dict_none(self): + with self.assertRaisesRegex( + RuntimeError, + "A model_class has been supplied to the model " + "handler, but the required state_dict_path is missing. " + "Please provide the state_dict_path in order to"): + _ = PytorchModelHandlerTensor(model_class=PytorchLinearRegression) + + with self.assertRaisesRegex( + RuntimeError, + "A model_class has been supplied to the model " + "handler, but the required state_dict_path is missing. " + "Please provide the state_dict_path in order to"): + _ = PytorchModelHandlerKeyedTensor(model_class=PytorchLinearRegression) + def test_specify_torch_script_path_and_state_dict_path(self): torch_model = PytorchLinearRegression(2, 1) torch_path = os.path.join(self.tmpdir, 'torch_model.pt') From b18e10f5b6add921c2527dcec0b2c64ea385b082 Mon Sep 17 00:00:00 2001 From: Anand Inguva Date: Thu, 9 Feb 2023 16:14:25 -0500 Subject: [PATCH 10/13] Fixup lint --- sdks/python/apache_beam/ml/inference/pytorch_inference.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/sdks/python/apache_beam/ml/inference/pytorch_inference.py b/sdks/python/apache_beam/ml/inference/pytorch_inference.py index 4180a5aad13f..c555282af62b 100644 --- a/sdks/python/apache_beam/ml/inference/pytorch_inference.py +++ b/sdks/python/apache_beam/ml/inference/pytorch_inference.py @@ -81,11 +81,8 @@ def _validate_constructor_args( def _load_model( - model_class: Optional[Callable[..., torch.nn.Module]], - state_dict_path: Optional[str], - device: torch.device, - model_params: Optional[Dict[str, Any]], - torch_script_model_path: Optional[str]): + model_class, state_dict_path, device, model_params, + torch_script_model_path): if device == torch.device('cuda') and not torch.cuda.is_available(): logging.warning( "Model handler specified a 'GPU' device, but GPUs are not available. " From 64fbc4a4486b0ad3942a605d77a3e5d0fa5f8887 Mon Sep 17 00:00:00 2001 From: Anand Inguva Date: Thu, 9 Feb 2023 16:14:25 -0500 Subject: [PATCH 11/13] Revert "Fixup lint" This reverts commit b18e10f5b6add921c2527dcec0b2c64ea385b082. --- sdks/python/apache_beam/ml/inference/pytorch_inference.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/sdks/python/apache_beam/ml/inference/pytorch_inference.py b/sdks/python/apache_beam/ml/inference/pytorch_inference.py index c555282af62b..4180a5aad13f 100644 --- a/sdks/python/apache_beam/ml/inference/pytorch_inference.py +++ b/sdks/python/apache_beam/ml/inference/pytorch_inference.py @@ -81,8 +81,11 @@ def _validate_constructor_args( def _load_model( - model_class, state_dict_path, device, model_params, - torch_script_model_path): + model_class: Optional[Callable[..., torch.nn.Module]], + state_dict_path: Optional[str], + device: torch.device, + model_params: Optional[Dict[str, Any]], + torch_script_model_path: Optional[str]): if device == torch.device('cuda') and not torch.cuda.is_available(): logging.warning( "Model handler specified a 'GPU' device, but GPUs are not available. " From 5637f3966e607d33da7a01048d8f5a21abcf555c Mon Sep 17 00:00:00 2001 From: Anand Inguva Date: Fri, 10 Feb 2023 12:09:51 -0500 Subject: [PATCH 12/13] Add ignore for mypy --- sdks/python/apache_beam/ml/inference/pytorch_inference.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sdks/python/apache_beam/ml/inference/pytorch_inference.py b/sdks/python/apache_beam/ml/inference/pytorch_inference.py index 4180a5aad13f..e81436bea065 100644 --- a/sdks/python/apache_beam/ml/inference/pytorch_inference.py +++ b/sdks/python/apache_beam/ml/inference/pytorch_inference.py @@ -97,7 +97,7 @@ def _load_model( "Loading state_dict_path %s onto a %s device", state_dict_path, device) if not torch_script_model_path: file = FileSystems.open(state_dict_path, 'rb') - model = model_class(**model_params) + model = model_class(**model_params) # type: ignore[misc] state_dict = torch.load(file, map_location=device) model.load_state_dict(state_dict) else: From e5c9062aab57e010e2cfc5aa08fa8522edecf324 Mon Sep 17 00:00:00 2001 From: Anand Inguva Date: Fri, 10 Feb 2023 15:13:32 -0500 Subject: [PATCH 13/13] Make validate_constructor_args local to pytorch handler --- sdks/python/apache_beam/ml/inference/base.py | 6 ------ sdks/python/apache_beam/ml/inference/pytorch_inference.py | 7 ------- 2 files changed, 13 deletions(-) diff --git a/sdks/python/apache_beam/ml/inference/base.py b/sdks/python/apache_beam/ml/inference/base.py index 60b6b5ba89e5..50056107702e 100644 --- a/sdks/python/apache_beam/ml/inference/base.py +++ b/sdks/python/apache_beam/ml/inference/base.py @@ -174,12 +174,6 @@ def update_model_path(self, model_path: Optional[str] = None): """Update the model paths produced by side inputs.""" pass - def validate_constructor_args(self): - """ - Validate arguments passed to the ModelHandler constructor. - """ - raise NotImplementedError - class KeyedModelHandler(Generic[KeyT, ExampleT, PredictionT, ModelT], ModelHandler[Tuple[KeyT, ExampleT], diff --git a/sdks/python/apache_beam/ml/inference/pytorch_inference.py b/sdks/python/apache_beam/ml/inference/pytorch_inference.py index 425820c58abb..71a4ccc63a27 100644 --- a/sdks/python/apache_beam/ml/inference/pytorch_inference.py +++ b/sdks/python/apache_beam/ml/inference/pytorch_inference.py @@ -243,9 +243,6 @@ def __init__( self._batching_kwargs['max_batch_size'] = max_batch_size self._torch_script_model_path = torch_script_model_path - self.validate_constructor_args() - - def validate_constructor_args(self): _validate_constructor_args( state_dict_path=self._state_dict_path, model_class=self._model_class, @@ -462,10 +459,6 @@ def __init__( if max_batch_size is not None: self._batching_kwargs['max_batch_size'] = max_batch_size self._torch_script_model_path = torch_script_model_path - - self.validate_constructor_args() - - def validate_constructor_args(self): _validate_constructor_args( state_dict_path=self._state_dict_path, model_class=self._model_class,