-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathloader.py
executable file
·147 lines (123 loc) · 5 KB
/
loader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
import os
import requests
import hashlib
import zipfile
import tarfile
import datetime
from tqdm import tqdm
import pandas as pd
import argparse
class DownloadableFile:
def __init__(self, url, filename, expected_hash, version="1.0", zipped=True):
self.url = url
self.filename = filename
self.expected_hash = expected_hash
self.zipped = zipped
self.version = version
FANTOM = DownloadableFile(
url='https://storage.googleapis.com/ai2-mosaic-public/projects/fantom/fantom.tar.gz',
filename='fantom.tar.gz',
expected_hash='1d08dfa0ea474c7f83b9bc7e3a7b466eab25194043489dd618b4c5223e1253a4',
version="1.0",
zipped=True
)
# =============================================================================================================
def unzip_file(file_path, directory='.'):
if file_path.endswith(".zip"):
target_location = os.path.join(directory, os.path.splitext(os.path.basename(file_path))[0])
with zipfile.ZipFile(file_path, 'r') as zip_ref:
zip_ref.extractall(target_location)
elif file_path.endswith(".tar.gz"):
target_location = os.path.join(directory, os.path.basename(file_path).split(".")[0])
with tarfile.open(file_path) as tar:
tar.extractall(target_location)
return target_location
def check_built(path, version_string=None):
"""
Check if '.built' flag has been set for that task.
If a version_string is provided, this has to match, or the version is regarded as not built.
"""
fname = os.path.join(path, '.built')
if not os.path.isfile(fname):
return False
else:
with open(fname, 'r') as read:
text = read.read().split('\n')
return len(text) > 1 and text[1] == version_string
def mark_built(path, version_string="1.0"):
"""
Mark this path as prebuilt.
Marks the path as done by adding a '.built' file with the current timestamp plus a version description string.
"""
with open(os.path.join(path, '.built'), 'w') as write:
write.write(str(datetime.datetime.today()))
if version_string:
write.write('\n' + version_string)
def download_and_check_hash(url, filename, expected_hash, version, directory='data', chunk_size=1024*1024*10):
# Download the file
response = requests.get(url, stream=True)
try:
total_size = int(response.headers.get('content-length', 0))
except:
print("Couldn't get content-length from response headers, using chunk_size instead")
total_size = chunk_size
progress_bar = tqdm(total=total_size, unit='iB', unit_scale=True)
data = b''
for chunk in response.iter_content(chunk_size=chunk_size):
progress_bar.update(len(chunk))
data += chunk
progress_bar.close()
# Create the directory if it doesn't exist
if not os.path.exists(directory):
os.makedirs(directory)
# Save the file to disk
file_path = os.path.join(directory, filename)
with open(file_path, 'wb') as f:
f.write(data)
# Calculate the hash of the downloaded data
sha256_hash = hashlib.sha256(data).hexdigest()
# Compare the calculated hash to the expected hash
if sha256_hash != expected_hash:
print('@@@ Downloaded file hash does not match expected hash!')
raise RuntimeError
return file_path
def build_data(resource, directory='data'):
# check whether the file already exists
if resource.filename.endswith('.tar.gz'):
resource_dir = os.path.splitext(os.path.splitext(os.path.basename(resource.filename))[0])[0]
else:
resource_dir = os.path.splitext(os.path.basename(resource.filename))[0]
file_path = os.path.join(directory, resource_dir)
built = check_built(file_path, resource.version)
if not built:
# Download the file
file_path = download_and_check_hash(resource.url, resource.filename, resource.expected_hash, resource.version, directory)
# Unzip the file
if resource.zipped:
built_location = unzip_file(file_path, directory)
# Delete the zip file
os.remove(file_path)
else:
built_location = file_path
mark_built(built_location, resource.version)
print("Successfully built dataset at {}".format(built_location))
else:
print("Already built at {}. version {}".format(file_path, resource.version))
built_location = file_path
return built_location
def load():
dpath = build_data(FANTOM)
file = os.path.join(dpath, "fantom_v1.json")
df = pd.read_json(file)
return df
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='arguments for dataset loaders')
parser.add_argument('--dataset',
type=str,
help="Specify the dataset name.")
parser.add_argument('--split',
type=str,
default="train",
help="Specify the split name. train, valid, test, etc.")
args = parser.parse_args()
load()