2
2
3
3
import argparse
4
4
import os
5
+ import pathlib
5
6
6
7
from tensorboardX import SummaryWriter
7
8
from torchvision import datasets , transforms
@@ -93,7 +94,7 @@ def main():
93
94
help = 'random seed (default: 1)' )
94
95
parser .add_argument ('--log-interval' , type = int , default = 10 , metavar = 'N' ,
95
96
help = 'how many batches to wait before logging training status' )
96
- parser .add_argument ('--save-model' , action = 'store_true' , default = False ,
97
+ parser .add_argument ('--save-model' , action = 'store_true' , default = True ,
97
98
help = 'For Saving the current Model' )
98
99
parser .add_argument ('--dir' , default = 'logs' , metavar = 'L' ,
99
100
help = 'directory where summary logs are stored' )
@@ -106,7 +107,8 @@ def main():
106
107
if use_cuda :
107
108
print ('Using CUDA' )
108
109
109
- writer = SummaryWriter (args .dir )
110
+ pathlib .Path ("./tensorboard/logs/" ).mkdir (parents = True , exist_ok = True )
111
+ writer = SummaryWriter ("./tensorboard/logs/" )
110
112
111
113
torch .manual_seed (args .seed )
112
114
@@ -147,8 +149,9 @@ def main():
147
149
train (args , model , device , train_loader , optimizer , epoch , writer )
148
150
test (args , model , device , test_loader , writer , epoch )
149
151
152
+ pathlib .Path ("./tensorboard/models/" ).mkdir (parents = True , exist_ok = True )
150
153
if (args .save_model ):
151
- torch .save (model .state_dict (),"mnist_cnn.pt" )
154
+ torch .save (model .state_dict (),"./tensorboard/models/ mnist_cnn.pt" )
152
155
153
156
if __name__ == '__main__' :
154
157
main ()
0 commit comments