Skip to content

Commit

Permalink
Add smoke tests for torch models (#2495)
Browse files Browse the repository at this point in the history
  • Loading branch information
lostella authored Dec 14, 2022
1 parent 3bf1617 commit 7327eab
Show file tree
Hide file tree
Showing 3 changed files with 122 additions and 10 deletions.
13 changes: 10 additions & 3 deletions src/gluonts/torch/model/mqf2/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,10 @@

from gluonts.core.component import validated
from gluonts.torch.model.deepar.module import DeepARModel
from gluonts.torch.distributions import DistributionOutput
from gluonts.torch.distributions import (
DistributionOutput,
MQF2DistributionOutput,
)

from cpflows.flows import ActNorm
from cpflows.icnn import PICNN
Expand All @@ -35,7 +38,7 @@ def __init__(
num_feat_static_real: int,
num_feat_static_cat: int,
cardinality: List[int],
distr_output: DistributionOutput,
distr_output: Optional[DistributionOutput] = None,
embedding_dimension: Optional[List[int]] = None,
num_layers: int = 2,
hidden_size: int = 40,
Expand Down Expand Up @@ -74,7 +77,11 @@ def __init__(
num_layers=num_layers,
hidden_size=hidden_size,
dropout_rate=dropout_rate,
distr_output=distr_output,
distr_output=(
distr_output
if distr_output is not None
else MQF2DistributionOutput(prediction_length)
),
lags_seq=lags_seq,
scaling=scaling,
num_parallel_samples=num_parallel_samples,
Expand Down
28 changes: 21 additions & 7 deletions src/gluonts/torch/model/simple_feedforward/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.

from typing import List, Tuple
from typing import Dict, List, Tuple, Optional

import torch
from torch import nn
Expand Down Expand Up @@ -55,23 +55,25 @@ def __init__(
self,
prediction_length: int,
context_length: int,
hidden_dimensions: List[int],
hidden_dimensions: Optional[List[int]] = None,
distr_output=StudentTOutput(),
batch_norm: bool = False,
) -> None:
super().__init__()

assert prediction_length > 0
assert context_length > 0
assert len(hidden_dimensions) > 0
assert hidden_dimensions is None or len(hidden_dimensions) > 0

self.prediction_length = prediction_length
self.context_length = context_length
self.hidden_dimensions = hidden_dimensions
self.hidden_dimensions = (
hidden_dimensions if hidden_dimensions is not None else [20, 20]
)
self.distr_output = distr_output
self.batch_norm = batch_norm

dimensions = [context_length] + hidden_dimensions[:-1]
dimensions = [context_length] + self.hidden_dimensions[:-1]

modules = []
for in_size, out_size in zip(dimensions[:-1], dimensions[1:]):
Expand All @@ -80,12 +82,24 @@ def __init__(
modules.append(nn.BatchNorm1d(out_size))
modules.append(
make_linear_layer(
dimensions[-1], prediction_length * hidden_dimensions[-1]
dimensions[-1], prediction_length * self.hidden_dimensions[-1]
)
)

self.nn = nn.Sequential(*modules)
self.args_proj = self.distr_output.get_args_proj(hidden_dimensions[-1])
self.args_proj = self.distr_output.get_args_proj(
self.hidden_dimensions[-1]
)

def input_shapes(self, batch_size=1) -> Dict[str, Tuple[int, ...]]:
return {
"context": (batch_size, self.context_length),
}

def input_types(self) -> Dict[str, torch.dtype]:
return {
"context": torch.float,
}

def forward(
self,
Expand Down
91 changes: 91 additions & 0 deletions test/torch/model/test_modules.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# or in the "license" file accompanying this file. This file is distributed
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.

import pytest
import torch

from gluonts.torch.model.deepar import DeepARModel
from gluonts.torch.model.mqf2 import MQF2MultiHorizonModel
from gluonts.torch.model.simple_feedforward import SimpleFeedForwardModel


def construct_batch(module, batch_size=1):
return tuple(
[
torch.zeros(shape, dtype=module.input_types()[name])
for (name, shape) in module.input_shapes(
batch_size=batch_size
).items()
]
)


def assert_shapes_and_dtypes(tensors, shapes, dtypes):
if isinstance(tensors, torch.Tensor):
assert tensors.shape == shapes
assert tensors.dtype == dtypes
else:
for tensor, shape, dtype in zip(tensors, shapes, dtypes):
assert_shapes_and_dtypes(tensor, shape, dtype)


@pytest.mark.parametrize(
"module, batch_size, expected_shapes, expected_dtypes",
[
(
DeepARModel(
freq="1H",
context_length=24,
prediction_length=12,
num_feat_dynamic_real=1,
num_feat_static_real=1,
num_feat_static_cat=1,
cardinality=[1],
),
4,
(4, 100, 12),
torch.float,
),
(
MQF2MultiHorizonModel(
freq="1H",
context_length=24,
prediction_length=12,
num_feat_dynamic_real=1,
num_feat_static_real=1,
num_feat_static_cat=1,
cardinality=[1],
),
4,
(4, 100, 12),
torch.float,
),
(
SimpleFeedForwardModel(
context_length=24,
prediction_length=12,
),
4,
[[(4, 12), (4, 12), (4, 12)], (4, 1), (4, 1)],
[
[torch.float, torch.float, torch.float],
torch.float,
torch.float,
],
),
],
)
def test_module_smoke(module, batch_size, expected_shapes, expected_dtypes):
batch = construct_batch(module, batch_size=batch_size)
outputs = module(*batch)
assert_shapes_and_dtypes(outputs, expected_shapes, expected_dtypes)

0 comments on commit 7327eab

Please sign in to comment.