From 14b84b60fa3fcedc71ad4cf61ea9b851e3ee4556 Mon Sep 17 00:00:00 2001 From: Lorenzo Stella Date: Wed, 14 Dec 2022 10:35:16 +0100 Subject: [PATCH] add smoke tests for torch models --- src/gluonts/torch/model/mqf2/module.py | 13 +++- .../torch/model/simple_feedforward/module.py | 28 +++++-- test/torch/model/test_modules.py | 74 +++++++++++++++++++ 3 files changed, 105 insertions(+), 10 deletions(-) create mode 100644 test/torch/model/test_modules.py diff --git a/src/gluonts/torch/model/mqf2/module.py b/src/gluonts/torch/model/mqf2/module.py index 69515e59e5..8adee9b739 100644 --- a/src/gluonts/torch/model/mqf2/module.py +++ b/src/gluonts/torch/model/mqf2/module.py @@ -17,7 +17,10 @@ from gluonts.core.component import validated from gluonts.torch.model.deepar.module import DeepARModel -from gluonts.torch.distributions import DistributionOutput +from gluonts.torch.distributions import ( + DistributionOutput, + MQF2DistributionOutput, +) from cpflows.flows import ActNorm from cpflows.icnn import PICNN @@ -35,7 +38,7 @@ def __init__( num_feat_static_real: int, num_feat_static_cat: int, cardinality: List[int], - distr_output: DistributionOutput, + distr_output: Optional[DistributionOutput] = None, embedding_dimension: Optional[List[int]] = None, num_layers: int = 2, hidden_size: int = 40, @@ -74,7 +77,11 @@ def __init__( num_layers=num_layers, hidden_size=hidden_size, dropout_rate=dropout_rate, - distr_output=distr_output, + distr_output=( + distr_output + if distr_output is not None + else MQF2DistributionOutput(prediction_length) + ), lags_seq=lags_seq, scaling=scaling, num_parallel_samples=num_parallel_samples, diff --git a/src/gluonts/torch/model/simple_feedforward/module.py b/src/gluonts/torch/model/simple_feedforward/module.py index 2c3c8e1072..dfd9e51fc8 100644 --- a/src/gluonts/torch/model/simple_feedforward/module.py +++ b/src/gluonts/torch/model/simple_feedforward/module.py @@ -11,7 +11,7 @@ # express or implied. See the License for the specific language governing # permissions and limitations under the License. -from typing import List, Tuple +from typing import Dict, List, Tuple, Optional import torch from torch import nn @@ -55,7 +55,7 @@ def __init__( self, prediction_length: int, context_length: int, - hidden_dimensions: List[int], + hidden_dimensions: Optional[List[int]] = None, distr_output=StudentTOutput(), batch_norm: bool = False, ) -> None: @@ -63,15 +63,17 @@ def __init__( assert prediction_length > 0 assert context_length > 0 - assert len(hidden_dimensions) > 0 + assert hidden_dimensions is None or len(hidden_dimensions) > 0 self.prediction_length = prediction_length self.context_length = context_length - self.hidden_dimensions = hidden_dimensions + self.hidden_dimensions = ( + hidden_dimensions if hidden_dimensions is not None else [20, 20] + ) self.distr_output = distr_output self.batch_norm = batch_norm - dimensions = [context_length] + hidden_dimensions[:-1] + dimensions = [context_length] + self.hidden_dimensions[:-1] modules = [] for in_size, out_size in zip(dimensions[:-1], dimensions[1:]): @@ -80,12 +82,24 @@ def __init__( modules.append(nn.BatchNorm1d(out_size)) modules.append( make_linear_layer( - dimensions[-1], prediction_length * hidden_dimensions[-1] + dimensions[-1], prediction_length * self.hidden_dimensions[-1] ) ) self.nn = nn.Sequential(*modules) - self.args_proj = self.distr_output.get_args_proj(hidden_dimensions[-1]) + self.args_proj = self.distr_output.get_args_proj( + self.hidden_dimensions[-1] + ) + + def input_shapes(self, batch_size=1) -> Dict[str, Tuple[int, ...]]: + return { + "context": (batch_size, self.context_length), + } + + def input_types(self) -> Dict[str, torch.dtype]: + return { + "context": torch.float, + } def forward( self, diff --git a/test/torch/model/test_modules.py b/test/torch/model/test_modules.py new file mode 100644 index 0000000000..9661ce7bf8 --- /dev/null +++ b/test/torch/model/test_modules.py @@ -0,0 +1,74 @@ +import pytest +import torch + +from gluonts.torch.model.deepar import DeepARModel +from gluonts.torch.model.mqf2 import MQF2MultiHorizonModel +from gluonts.torch.model.simple_feedforward import SimpleFeedForwardModel + + +def construct_batch(module, batch_size=1): + return tuple( + [ + torch.zeros(shape, dtype=module.input_types()[name]) + for (name, shape) in module.input_shapes( + batch_size=batch_size + ).items() + ] + ) + + +def assert_shapes_and_dtypes(tensors, shapes, dtypes): + if isinstance(tensors, torch.Tensor): + assert tensors.shape == shapes + assert tensors.dtype == dtypes + else: + for tensor, shape, dtype in zip(tensors, shapes, dtypes): + assert_shapes_and_dtypes(tensor, shape, dtype) + + +@pytest.mark.parametrize( + "module, batch_size, expected_shapes, expected_dtypes", + [ + ( + DeepARModel( + freq="1H", + context_length=24, + prediction_length=12, + num_feat_dynamic_real=1, + num_feat_static_real=1, + num_feat_static_cat=1, + cardinality=[1], + ), + 4, + (4, 100, 12), + torch.float, + ), + ( + MQF2MultiHorizonModel( + freq="1H", + context_length=24, + prediction_length=12, + num_feat_dynamic_real=1, + num_feat_static_real=1, + num_feat_static_cat=1, + cardinality=[1], + ), + 4, + (4, 100, 12), + torch.float, + ), + ( + SimpleFeedForwardModel( + context_length=24, + prediction_length=12, + ), + 4, + [[(4, 12), (4, 12), (4, 12)], (4, 1), (4, 1)], + [[torch.float, torch.float, torch.float], torch.float, torch.float], + ), + ], +) +def test_module_smoke(module, batch_size, expected_shapes, expected_dtypes): + batch = construct_batch(module, batch_size=batch_size) + outputs = module(*batch) + assert_shapes_and_dtypes(outputs, expected_shapes, expected_dtypes)