-
Notifications
You must be signed in to change notification settings - Fork 12
/
utils.py
34 lines (31 loc) · 1.25 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import argparse
import requests
import json
import time
import logging
import hmac
import urllib
import base64
import hashlib
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--data', type=str, default='elec', help='dataset names')
parser.add_argument('--name', type=str, default='ddpm_d', help='method for sde')
parser.add_argument('--beta1', type=float, default=0.01, help='beta min')
parser.add_argument('--beta2', type=float, default=10, help='beta max')
parser.add_argument('--scale', type=int, default=100, help='num scales')
parser.add_argument('--epochs', type=int, default=10, help='num of epochs')
parser.add_argument('--batch', type=int, default=32, help='num of epochs')
args = parser.parse_args()
args.path = './metrics_' + args.data + '.log'
return args
def write_to_file(args, config, metrics, path):
with open(path, 'a+', encoding='utf8') as file_obj:
file_obj.write('='*20)
file_obj.write(str(args))
file_obj.write('=' * 20)
file_obj.write('\n')
file_obj.write(str(config))
file_obj.write('\n')
file_obj.write('\n'.join([str((term, metrics[term])) for term in metrics]))
file_obj.write('\n')