Skip to content

Commit

Permalink
fix: the error that dataset cannot be lazy loaded for CSAI;
Browse files Browse the repository at this point in the history
  • Loading branch information
WenjieDu committed Oct 27, 2024
1 parent 87a4168 commit b6a3280
Showing 1 changed file with 35 additions and 18 deletions.
53 changes: 35 additions & 18 deletions pypots/imputation/csai/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
# License: BSD-3-Clause

from typing import Union, Optional
from venv import logger

import numpy as np
import torch
Expand All @@ -14,6 +15,8 @@
from .core import _BCSAI
from .data import DatasetForCSAI
from ..base import BaseNNImputer
from ...data.checking import key_in_data_set
from ...data.saving.h5 import load_dict_from_h5
from ...optim.adam import Adam
from ...optim.base import Optimizer

Expand Down Expand Up @@ -164,6 +167,7 @@ def __init__(

# set up the optimizer
self.optimizer = optimizer
self.optimizer.init_optimizer(self.model.parameters())

def _assemble_input_for_training(self, data: list, training=True) -> dict:
# extract data
Expand Down Expand Up @@ -239,8 +243,21 @@ def fit(
file_type: str = "hdf5",
) -> None:

if isinstance(train_set, str):
logger.warning(
"CSAI does not support lazy loading because normalise mean and std need to be calculated ahead. "
"Hence the whole train set will be loaded into memory."
)
train_set = load_dict_from_h5(train_set)

training_set = DatasetForCSAI(
train_set, False, False, file_type, self.removal_percent, self.increase_factor, self.compute_intervals
train_set,
False,
False,
file_type,
self.removal_percent,
self.increase_factor,
self.compute_intervals,
)
self.intervals = training_set.intervals
self.replacement_probabilities = training_set.replacement_probabilities
Expand All @@ -254,7 +271,17 @@ def fit(
num_workers=self.num_workers,
# collate_fn=collate_fn_bidirectional
)
val_loader = None
if val_set is not None:
if isinstance(val_set, str):
logger.warning(
"CSAI does not support lazy loading because normalise mean and std need to be calculated ahead. "
"Hence the whole val set will be loaded into memory."
)
val_set = load_dict_from_h5(val_set)

if not key_in_data_set("X_ori", val_set):
raise ValueError("val_set must contain 'X_ori' for model validation.")
val_set = DatasetForCSAI(
val_set,
True,
Expand All @@ -276,23 +303,6 @@ def fit(
# collate_fn=collate_fn_bidirectional
)

# Reset the model
self.model = _BCSAI(
self.n_steps,
self.n_features,
self.rnn_hidden_size,
self.step_channels,
self.consistency_weight,
self.imputation_weight,
self.intervals,
)

self._send_model_to_given_device()
self._print_model_size()

# set up the optimizer
self.optimizer.init_optimizer(self.model.parameters())

# train the model
self._train_model(training_loader, val_loader)
self.model.load_state_dict(self.best_model_dict)
Expand All @@ -308,6 +318,13 @@ def predict(
) -> dict:

self.model.eval()

if isinstance(test_set, str):
logger.warning(
"CSAI does not support lazy loading because normalise mean and std need to be calculated ahead. "
"Hence the whole test set will be loaded into memory."
)
test_set = load_dict_from_h5(test_set)
test_set = DatasetForCSAI(
test_set,
True,
Expand Down

0 comments on commit b6a3280

Please sign in to comment.