Skip to content

Commit

Permalink
add smoke tests for torch models
Browse files Browse the repository at this point in the history
  • Loading branch information
Lorenzo Stella committed Dec 14, 2022
1 parent 3bf1617 commit 14b84b6
Show file tree
Hide file tree
Showing 3 changed files with 105 additions and 10 deletions.
13 changes: 10 additions & 3 deletions src/gluonts/torch/model/mqf2/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,10 @@

from gluonts.core.component import validated
from gluonts.torch.model.deepar.module import DeepARModel
from gluonts.torch.distributions import DistributionOutput
from gluonts.torch.distributions import (
DistributionOutput,
MQF2DistributionOutput,
)

from cpflows.flows import ActNorm
from cpflows.icnn import PICNN
Expand All @@ -35,7 +38,7 @@ def __init__(
num_feat_static_real: int,
num_feat_static_cat: int,
cardinality: List[int],
distr_output: DistributionOutput,
distr_output: Optional[DistributionOutput] = None,
embedding_dimension: Optional[List[int]] = None,
num_layers: int = 2,
hidden_size: int = 40,
Expand Down Expand Up @@ -74,7 +77,11 @@ def __init__(
num_layers=num_layers,
hidden_size=hidden_size,
dropout_rate=dropout_rate,
distr_output=distr_output,
distr_output=(
distr_output
if distr_output is not None
else MQF2DistributionOutput(prediction_length)
),
lags_seq=lags_seq,
scaling=scaling,
num_parallel_samples=num_parallel_samples,
Expand Down
28 changes: 21 additions & 7 deletions src/gluonts/torch/model/simple_feedforward/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.

from typing import List, Tuple
from typing import Dict, List, Tuple, Optional

import torch
from torch import nn
Expand Down Expand Up @@ -55,23 +55,25 @@ def __init__(
self,
prediction_length: int,
context_length: int,
hidden_dimensions: List[int],
hidden_dimensions: Optional[List[int]] = None,
distr_output=StudentTOutput(),
batch_norm: bool = False,
) -> None:
super().__init__()

assert prediction_length > 0
assert context_length > 0
assert len(hidden_dimensions) > 0
assert hidden_dimensions is None or len(hidden_dimensions) > 0

self.prediction_length = prediction_length
self.context_length = context_length
self.hidden_dimensions = hidden_dimensions
self.hidden_dimensions = (
hidden_dimensions if hidden_dimensions is not None else [20, 20]
)
self.distr_output = distr_output
self.batch_norm = batch_norm

dimensions = [context_length] + hidden_dimensions[:-1]
dimensions = [context_length] + self.hidden_dimensions[:-1]

modules = []
for in_size, out_size in zip(dimensions[:-1], dimensions[1:]):
Expand All @@ -80,12 +82,24 @@ def __init__(
modules.append(nn.BatchNorm1d(out_size))
modules.append(
make_linear_layer(
dimensions[-1], prediction_length * hidden_dimensions[-1]
dimensions[-1], prediction_length * self.hidden_dimensions[-1]
)
)

self.nn = nn.Sequential(*modules)
self.args_proj = self.distr_output.get_args_proj(hidden_dimensions[-1])
self.args_proj = self.distr_output.get_args_proj(
self.hidden_dimensions[-1]
)

def input_shapes(self, batch_size=1) -> Dict[str, Tuple[int, ...]]:
return {
"context": (batch_size, self.context_length),
}

def input_types(self) -> Dict[str, torch.dtype]:
return {
"context": torch.float,
}

def forward(
self,
Expand Down
74 changes: 74 additions & 0 deletions test/torch/model/test_modules.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
import pytest
import torch

from gluonts.torch.model.deepar import DeepARModel
from gluonts.torch.model.mqf2 import MQF2MultiHorizonModel
from gluonts.torch.model.simple_feedforward import SimpleFeedForwardModel


def construct_batch(module, batch_size=1):
return tuple(
[
torch.zeros(shape, dtype=module.input_types()[name])
for (name, shape) in module.input_shapes(
batch_size=batch_size
).items()
]
)


def assert_shapes_and_dtypes(tensors, shapes, dtypes):
if isinstance(tensors, torch.Tensor):
assert tensors.shape == shapes
assert tensors.dtype == dtypes
else:
for tensor, shape, dtype in zip(tensors, shapes, dtypes):
assert_shapes_and_dtypes(tensor, shape, dtype)


@pytest.mark.parametrize(
"module, batch_size, expected_shapes, expected_dtypes",
[
(
DeepARModel(
freq="1H",
context_length=24,
prediction_length=12,
num_feat_dynamic_real=1,
num_feat_static_real=1,
num_feat_static_cat=1,
cardinality=[1],
),
4,
(4, 100, 12),
torch.float,
),
(
MQF2MultiHorizonModel(
freq="1H",
context_length=24,
prediction_length=12,
num_feat_dynamic_real=1,
num_feat_static_real=1,
num_feat_static_cat=1,
cardinality=[1],
),
4,
(4, 100, 12),
torch.float,
),
(
SimpleFeedForwardModel(
context_length=24,
prediction_length=12,
),
4,
[[(4, 12), (4, 12), (4, 12)], (4, 1), (4, 1)],
[[torch.float, torch.float, torch.float], torch.float, torch.float],
),
],
)
def test_module_smoke(module, batch_size, expected_shapes, expected_dtypes):
batch = construct_batch(module, batch_size=batch_size)
outputs = module(*batch)
assert_shapes_and_dtypes(outputs, expected_shapes, expected_dtypes)

0 comments on commit 14b84b6

Please sign in to comment.