Skip to content

Commit

Permalink
Fix length mismatch error when no static data
Browse files Browse the repository at this point in the history
  • Loading branch information
DrShushen committed Mar 15, 2023
1 parent 1e3d6b4 commit f309351
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion src/synthcity/plugins/core/models/ts_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -704,6 +704,7 @@ def __init__(

self.device = device
self.window_size = window_size
self.n_static_units_in = n_static_units_in
self.model = MLP(
task_type="regression",
n_units_in=n_static_units_in + n_temporal_units_in * window_size,
Expand All @@ -719,7 +720,7 @@ def __init__(
def forward(
self, static_data: torch.Tensor, temporal_data: torch.Tensor
) -> torch.Tensor:
if len(static_data) != len(temporal_data):
if self.n_static_units_in > 0 and len(static_data) != len(temporal_data):
raise ValueError("Length mismatch between static and temporal data")

batch_size, seq_len, n_feats = temporal_data.shape
Expand Down

0 comments on commit f309351

Please sign in to comment.