Skip to content

Commit e001420

Browse files
authored
Update train.py
Add the random seed
1 parent cb62c3a commit e001420

File tree

1 file changed

+8
-1
lines changed

1 file changed

+8
-1
lines changed

train.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import os
2424
import shutil
2525
import time
26+
import random
2627

2728
import dgl
2829
import numpy as np
@@ -274,6 +275,7 @@ def main():
274275
parser.add_argument('--log_root', type=str, default='log/', help='Root directory for all logging.')
275276

276277
# Hyperparameters
278+
parser.add_argument('--seed', type=int, default=666, help='set the random seed [default: 666]')
277279
parser.add_argument('--gpu', type=str, default='0', help='GPU ID to use. [default: 0]')
278280
parser.add_argument('--cuda', action='store_true', default=False, help='GPU or CPU [default: False]')
279281
parser.add_argument('--vocab_size', type=int, default=50000,help='Size of vocabulary. [default: 50000]')
@@ -309,7 +311,12 @@ def main():
309311
parser.add_argument('-m', type=int, default=3, help='decode summary length')
310312

311313
args = parser.parse_args()
312-
314+
315+
# set the seed
316+
random.seed(args.seed)
317+
np.random.seed(args.seed)
318+
torch.manual_seed(args.seed)
319+
313320
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
314321
torch.set_printoptions(threshold=50000)
315322

0 commit comments

Comments
 (0)