Skip to content

Commit

Permalink
Doc update Quickstart Example (#497)
Browse files Browse the repository at this point in the history
* update example snippet
  • Loading branch information
eroell authored Aug 22, 2024
1 parent a6b515e commit ec508a6
Showing 1 changed file with 7 additions and 9 deletions.
16 changes: 7 additions & 9 deletions docs/examples.rst
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,13 @@ You can also find a simple and quick-start tutorial notebook on Google Colab
# Data preprocessing. Tedious, but PyPOTS can help. 🤓
data = load_specific_dataset('physionet_2012') # PyPOTS will automatically download and extract it.
X = data['X']
num_samples = len(X['RecordID'].unique())
X = X.drop(['RecordID', 'Time'], axis = 1)
X = StandardScaler().fit_transform(X.to_numpy())
X = X.reshape(num_samples, 48, -1)
X = data['train_X']
num_samples = len(X)
X = StandardScaler().fit_transform(X.reshape(-1, X.shape[-1])).reshape(X.shape)
X_ori = X # keep X_ori for validation
X = mcar(X, 0.1) # randomly hold out 10% observed values as ground truth
dataset = {"X": X} # X for model input
print(X.shape) # (11988, 48, 37), 11988 samples, 48 time steps, 37 features
print(X.shape) # (7671, 48, 37), 7671 samples, 48 time steps, 37 features
# initialize the model
saits = SAITS(
Expand All @@ -55,7 +53,7 @@ You can also find a simple and quick-start tutorial notebook on Google Colab
model_saving_strategy="best", # only save the model with the best validation performance
)
# train the model. Here I use the whole dataset as the training set, because ground truth is not visible to the model.
# train the model. Here I consider the train dataset only, and evaluate on it, because ground truth is not visible to the model.
saits.fit(dataset)
# impute the originally-missing values and artificially-missing values
imputation = saits.impute(dataset)
Expand All @@ -64,6 +62,6 @@ You can also find a simple and quick-start tutorial notebook on Google Colab
mae = calc_mae(imputation, np.nan_to_num(X_ori), indicating_mask) # calculate mean absolute error on the ground truth (artificially-missing values)
# the best model has been already saved, but you can still manually save it with function save_model() as below
saits.save_model(saving_dir="examples/saits",file_name="manually_saved_saits_model")
saits.save(saving_path="examples/saits/manually_saved_saits_model")
# you can load the saved model into a new initialized model
saits.load_model("examples/saits/manually_saved_saits_model")
saits.load("examples/saits/manually_saved_saits_model.pypots")

0 comments on commit ec508a6

Please sign in to comment.