-
Notifications
You must be signed in to change notification settings - Fork 12
/
Copy pathddp.py
36 lines (28 loc) · 833 Bytes
/
ddp.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
# Copyright (c) 2023 42dot. All rights reserved.
import os
import random
import numpy as np
import torch
import torch.distributed as dist
def setup_ddp(rank, world_size, manual_seed=True):
"""
This function sets distributed data parallel(ddp) module for mutli-gpu training.
"""
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12355'
dist.init_process_group(
backend='nccl',
init_method='env://',
world_size=world_size,
rank=rank)
if manual_seed:
random_seed = 42 + rank
torch.manual_seed(random_seed)
np.random.seed(random_seed)
random.seed(random_seed)
torch.cuda.set_device(rank)
def clear_ddp():
"""
This function clears ddp training.
"""
dist.destroy_process_group()