|
1 | 1 | import argparse
|
2 | 2 | import os
|
3 |
| -import torch |
4 |
| -import torch.backends |
5 |
| -from exp.exp_long_term_forecasting import Exp_Long_Term_Forecast |
6 |
| -from exp.exp_imputation import Exp_Imputation |
7 |
| -from exp.exp_short_term_forecasting import Exp_Short_Term_Forecast |
8 |
| -from exp.exp_anomaly_detection import Exp_Anomaly_Detection |
9 |
| -from exp.exp_classification import Exp_Classification |
10 | 3 | from utils.print_args import print_args
|
11 | 4 | import random
|
12 | 5 | import numpy as np
|
13 | 6 |
|
14 | 7 | if __name__ == '__main__':
|
15 |
| - fix_seed = 2021 |
16 |
| - random.seed(fix_seed) |
17 |
| - torch.manual_seed(fix_seed) |
18 |
| - np.random.seed(fix_seed) |
19 | 8 |
|
20 | 9 | parser = argparse.ArgumentParser(description='TimesNet')
|
21 | 10 |
|
|
141 | 130 | parser.add_argument('--patch_len', type=int, default=16, help='patch length')
|
142 | 131 |
|
143 | 132 | args = parser.parse_args()
|
144 |
| - if torch.cuda.is_available() and args.use_gpu: |
145 |
| - args.device = torch.device('cuda:{}'.format(args.gpu)) |
| 133 | + # declare CUDA_VISIBLE_DEVICES before using torch.cuda |
| 134 | + if args.use_gpu and args.gpu_type == 'cuda': |
| 135 | + os.environ["CUDA_VISIBLE_DEVICES"] = str( |
| 136 | + args.gpu) if not args.use_multi_gpu else args.devices |
| 137 | + |
| 138 | + import torch |
| 139 | + import torch.backends |
| 140 | + from exp.exp_long_term_forecasting import Exp_Long_Term_Forecast |
| 141 | + from exp.exp_imputation import Exp_Imputation |
| 142 | + from exp.exp_short_term_forecasting import Exp_Short_Term_Forecast |
| 143 | + from exp.exp_anomaly_detection import Exp_Anomaly_Detection |
| 144 | + from exp.exp_classification import Exp_Classification |
| 145 | + |
| 146 | + fix_seed = 2021 |
| 147 | + random.seed(fix_seed) |
| 148 | + torch.manual_seed(fix_seed) |
| 149 | + np.random.seed(fix_seed) |
| 150 | + |
| 151 | + if torch.cuda.is_available() and args.use_gpu and args.gpu_type == 'cuda': |
| 152 | + if args.use_multi_gpu: # multi-gpu |
| 153 | + args.devices = args.devices.replace(' ', '') |
| 154 | + device_ids = args.devices.split(',') |
| 155 | + args.device_indices = [int(id_) for id_ in device_ids] # e.g. '1,2' -> [1, 2] |
| 156 | + args.device_ids = list(range(len(args.device_indices))) # e.g. [1, 2] -> [0, 1] because of visible devices |
| 157 | + args.gpu = args.device_indices[0] |
| 158 | + args.device = torch.device(f'cuda:0') |
| 159 | + else: # one gpu |
| 160 | + args.device = torch.device('cuda') |
146 | 161 | print('Using GPU')
|
| 162 | + elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available() \ |
| 163 | + and args.use_gpu and args.gpu_type == 'mps': |
| 164 | + args.device = torch.device("mps") |
147 | 165 | else:
|
148 |
| - if hasattr(torch.backends, "mps"): |
149 |
| - args.device = torch.device("mps") if torch.backends.mps.is_available() else torch.device("cpu") |
150 |
| - else: |
151 |
| - args.device = torch.device("cpu") |
| 166 | + args.device = torch.device("cpu") |
152 | 167 | print('Using cpu or mps')
|
153 | 168 |
|
154 |
| - if args.use_gpu and args.use_multi_gpu: |
155 |
| - args.devices = args.devices.replace(' ', '') |
156 |
| - device_ids = args.devices.split(',') |
157 |
| - args.device_ids = [int(id_) for id_ in device_ids] |
158 |
| - args.gpu = args.device_ids[0] |
159 |
| - |
160 | 169 | print('Args in experiment:')
|
161 | 170 | print_args(args)
|
162 | 171 |
|
|
0 commit comments