-
Notifications
You must be signed in to change notification settings - Fork 7
/
main.py
36 lines (27 loc) · 1.27 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import torch
import warnings
warnings.filterwarnings('ignore')
from args import args, Test_data, Train_data_all, Train_data
from dataset import Dataset
from model.TimeMAE import TimeMAE
from process import Trainer
import torch.utils.data as Data
def main():
torch.set_num_threads(12)
torch.cuda.manual_seed(3407)
train_dataset = Dataset(device=args.device, mode='pretrain', data=Train_data_all, wave_len=args.wave_length)
train_loader = Data.DataLoader(train_dataset, batch_size=args.train_batch_size, shuffle=True)
args.data_shape = train_dataset.shape()
train_linear_dataset = Dataset(device=args.device, mode='supervise_train', data=Train_data, wave_len=args.wave_length)
train_linear_loader = Data.DataLoader(train_linear_dataset, batch_size=args.train_batch_size, shuffle=True)
test_dataset = Dataset(device=args.device, mode='test', data=Test_data, wave_len=args.wave_length)
test_loader = Data.DataLoader(test_dataset, batch_size=args.test_batch_size)
print(args.data_shape)
print('dataset initial ends')
model = TimeMAE(args)
print('model initial ends')
trainer = Trainer(args, model, train_loader, train_linear_loader, test_loader, verbose=True)
trainer.pretrain()
trainer.finetune()
if __name__ == '__main__':
main()