Skip to content

Commit f605ccc

Browse files
committed
FIXED : implementing multi-gpu + remove duplicates
1 parent e414e5e commit f605ccc

File tree

2 files changed

+33
-38
lines changed

2 files changed

+33
-38
lines changed

exp/exp_basic.py

Lines changed: 1 addition & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -44,27 +44,13 @@ def __init__(self, args):
4444
from models import Mamba
4545
self.model_dict['Mamba'] = Mamba
4646

47-
self.device = self._acquire_device()
47+
self.device = self.args.device
4848
self.model = self._build_model().to(self.device)
4949

5050
def _build_model(self):
5151
raise NotImplementedError
5252
return None
5353

54-
def _acquire_device(self):
55-
if self.args.use_gpu and self.args.gpu_type == 'cuda':
56-
os.environ["CUDA_VISIBLE_DEVICES"] = str(
57-
self.args.gpu) if not self.args.use_multi_gpu else self.args.devices
58-
device = torch.device('cuda:{}'.format(self.args.gpu))
59-
print('Use GPU: cuda:{}'.format(self.args.gpu))
60-
elif self.args.use_gpu and self.args.gpu_type == 'mps':
61-
device = torch.device('mps')
62-
print('Use GPU: mps')
63-
else:
64-
device = torch.device('cpu')
65-
print('Use CPU')
66-
return device
67-
6854
def _get_data(self):
6955
pass
7056

run.py

Lines changed: 32 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,10 @@
11
import argparse
22
import os
3-
import torch
4-
import torch.backends
5-
from exp.exp_long_term_forecasting import Exp_Long_Term_Forecast
6-
from exp.exp_imputation import Exp_Imputation
7-
from exp.exp_short_term_forecasting import Exp_Short_Term_Forecast
8-
from exp.exp_anomaly_detection import Exp_Anomaly_Detection
9-
from exp.exp_classification import Exp_Classification
103
from utils.print_args import print_args
114
import random
125
import numpy as np
136

147
if __name__ == '__main__':
15-
fix_seed = 2021
16-
random.seed(fix_seed)
17-
torch.manual_seed(fix_seed)
18-
np.random.seed(fix_seed)
198

209
parser = argparse.ArgumentParser(description='TimesNet')
2110

@@ -141,22 +130,42 @@
141130
parser.add_argument('--patch_len', type=int, default=16, help='patch length')
142131

143132
args = parser.parse_args()
144-
if torch.cuda.is_available() and args.use_gpu:
145-
args.device = torch.device('cuda:{}'.format(args.gpu))
133+
# declare CUDA_VISIBLE_DEVICES before using torch.cuda
134+
if args.use_gpu and args.gpu_type == 'cuda':
135+
os.environ["CUDA_VISIBLE_DEVICES"] = str(
136+
args.gpu) if not args.use_multi_gpu else args.devices
137+
138+
import torch
139+
import torch.backends
140+
from exp.exp_long_term_forecasting import Exp_Long_Term_Forecast
141+
from exp.exp_imputation import Exp_Imputation
142+
from exp.exp_short_term_forecasting import Exp_Short_Term_Forecast
143+
from exp.exp_anomaly_detection import Exp_Anomaly_Detection
144+
from exp.exp_classification import Exp_Classification
145+
146+
fix_seed = 2021
147+
random.seed(fix_seed)
148+
torch.manual_seed(fix_seed)
149+
np.random.seed(fix_seed)
150+
151+
if torch.cuda.is_available() and args.use_gpu and args.gpu_type == 'cuda':
152+
if args.use_multi_gpu: # multi-gpu
153+
args.devices = args.devices.replace(' ', '')
154+
device_ids = args.devices.split(',')
155+
args.device_indices = [int(id_) for id_ in device_ids] # e.g. '1,2' -> [1, 2]
156+
args.device_ids = list(range(len(args.device_indices))) # e.g. [1, 2] -> [0, 1] because of visible devices
157+
args.gpu = args.device_indices[0]
158+
args.device = torch.device(f'cuda:0')
159+
else: # one gpu
160+
args.device = torch.device('cuda')
146161
print('Using GPU')
162+
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available() \
163+
and args.use_gpu and args.gpu_type == 'mps':
164+
args.device = torch.device("mps")
147165
else:
148-
if hasattr(torch.backends, "mps"):
149-
args.device = torch.device("mps") if torch.backends.mps.is_available() else torch.device("cpu")
150-
else:
151-
args.device = torch.device("cpu")
166+
args.device = torch.device("cpu")
152167
print('Using cpu or mps')
153168

154-
if args.use_gpu and args.use_multi_gpu:
155-
args.devices = args.devices.replace(' ', '')
156-
device_ids = args.devices.split(',')
157-
args.device_ids = [int(id_) for id_ in device_ids]
158-
args.gpu = args.device_ids[0]
159-
160169
print('Args in experiment:')
161170
print_args(args)
162171

0 commit comments

Comments
 (0)