Skip to content

Commit

Permalink
refactor: raise not implemented error for CSAI _fetch_data_from_file;
Browse files Browse the repository at this point in the history
  • Loading branch information
WenjieDu committed Oct 26, 2024
1 parent 0979d44 commit 87a4168
Showing 1 changed file with 14 additions and 71 deletions.
85 changes: 14 additions & 71 deletions pypots/imputation/csai/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,14 +369,19 @@ def __init__(
increase_factor: float = 0.1,
compute_intervals: bool = False,
replacement_probabilities=None,
normalise_mean: list = [],
normalise_std: list = [],
normalise_mean=None,
normalise_std=None,
training: bool = True,
):
super().__init__(
data=data, return_X_ori=return_X_ori, return_X_pred=False, return_y=return_y, file_type=file_type
)

if normalise_std is None:
normalise_std = []
if normalise_mean is None:
normalise_mean = []

self.removal_percent = removal_percent
self.increase_factor = increase_factor
self.compute_intervals = compute_intervals
Expand All @@ -385,6 +390,11 @@ def __init__(
self.normalise_std = normalise_std
self.training = training

self.normalized_data = None
self.mean_set = None
self.std_set = None
self.intervals = None

if not isinstance(self.data, str):
self.normalized_data, self.mean_set, self.std_set, self.intervals = normalize_csai(
self.data["X"],
Expand Down Expand Up @@ -465,73 +475,6 @@ def _fetch_data_from_array(self, idx: int) -> Iterable:
}

def _fetch_data_from_file(self, idx: int) -> Iterable:
"""Fetch data with the lazy-loading strategy, i.e. only loading data from the file while requesting for samples.
Here the opened file handle doesn't load the entire dataset into RAM but only load the currently accessed slice.
Parameters
----------
idx :
The index of the sample to be return.
Returns
-------
sample :
The collated data sample, a list including all necessary sample info.
"""

if self.file_handle is None:
self.file_handle = self._open_file_handle()

X = torch.from_numpy(self.file_handle["X"][idx])
normalized_data, mean_set, std_set, intervals = normalize_csai(
X,
self.normalise_mean,
self.normalise_std,
self.compute_intervals,
)

processed_data, replacement_probabilities = non_uniform_sample(
normalized_data,
self.removal_percent,
self.replacement_probabilities,
self.increase_factor,
raise NotImplementedError(
"CSAI does not support lazy loading because normalise mean and std need to be calculated ahead."
)
forward_X = processed_data["values"]
forward_missing_mask = processed_data["masks"]
backward_X = torch.flip(forward_X, dims=[1])
backward_missing_mask = torch.flip(forward_missing_mask, dims=[1])

X_ori = self.processed_data["evals"]
indicating_mask = self.processed_data["eval_masks"]

if self.return_y:
y = self.processed_data["labels"]

sample = [
torch.tensor(idx),
# for forward
forward_X,
forward_missing_mask,
processed_data["deltas_f"],
processed_data["last_obs_f"],
# for backward
backward_X,
backward_missing_mask,
processed_data["deltas_b"],
processed_data["last_obs_b"],
]

if self.return_X_ori:
sample.extend([X_ori, indicating_mask])

# if the dataset has labels and is for training, then fetch it from the file
if self.return_y:
sample.append(y)

return {
"sample": sample,
"replacement_probabilities": replacement_probabilities,
"mean_set": mean_set,
"std_set": std_set,
"intervals": intervals,
}

0 comments on commit 87a4168

Please sign in to comment.