From f309351250ebbb8b026ef767d7f92ec00d39ad53 Mon Sep 17 00:00:00 2001 From: Evgeny Saveliev Date: Wed, 15 Mar 2023 18:05:32 +0000 Subject: [PATCH] Fix length mismatch error when no static data --- src/synthcity/plugins/core/models/ts_model.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/synthcity/plugins/core/models/ts_model.py b/src/synthcity/plugins/core/models/ts_model.py index aa5ef333..a7026146 100644 --- a/src/synthcity/plugins/core/models/ts_model.py +++ b/src/synthcity/plugins/core/models/ts_model.py @@ -704,6 +704,7 @@ def __init__( self.device = device self.window_size = window_size + self.n_static_units_in = n_static_units_in self.model = MLP( task_type="regression", n_units_in=n_static_units_in + n_temporal_units_in * window_size, @@ -719,7 +720,7 @@ def __init__( def forward( self, static_data: torch.Tensor, temporal_data: torch.Tensor ) -> torch.Tensor: - if len(static_data) != len(temporal_data): + if self.n_static_units_in > 0 and len(static_data) != len(temporal_data): raise ValueError("Length mismatch between static and temporal data") batch_size, seq_len, n_feats = temporal_data.shape