Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

fix remote check in download_data #1666

Merged
merged 1 commit into from
Aug 9, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 30 additions & 27 deletions src/flash/core/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@
}


def download_data(url: str, path: str = "data/", verbose: bool = False) -> None:
def download_data(url: str, path: str = "data/", verbose: bool = False, chunk_size: int = 1024) -> None:
"""Download file with progressbar.

# Code adapted from: https://gist.github.com/ruxi/5d6803c116ec1130d484a4ab8c00c603
Expand All @@ -78,39 +78,42 @@ def download_data(url: str, path: str = "data/", verbose: bool = False) -> None:
[...]

"""
local_filename = os.path.join(path, url.split("/")[-1])
if os.path.exists(local_filename):
if verbose:
print(f"local file already exists: '{local_filename}'")
return

os.makedirs(path, exist_ok=True)
# Disable warning about making an insecure request
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)

if not os.path.exists(path):
os.makedirs(path)
local_filename = os.path.join(path, url.split("/")[-1])
r = requests.get(url, stream=True, verify=False)
file_size = int(r.headers["Content-Length"]) if "Content-Length" in r.headers else 0
chunk_size = 1024
num_bars = int(file_size / chunk_size)
if verbose:
print({"file_size": file_size})
print({"num_bars": num_bars})

if not os.path.exists(local_filename):
with open(local_filename, "wb") as fp:
for chunk in tq(
r.iter_content(chunk_size=chunk_size),
total=num_bars,
unit="KB",
desc=local_filename,
leave=True, # progressbar stays
):
fp.write(chunk) # type: ignore

def extract_tarfile(file_path: str, extract_path: str, mode: str):
if os.path.exists(file_path):
with tarfile.open(file_path, mode=mode) as tar_ref:
for member in tar_ref.getmembers():
try:
tar_ref.extract(member, path=extract_path, set_attrs=False)
except PermissionError:
raise PermissionError(f"Could not extract tar file {file_path}")
print(f"file size: {file_size}")
print(f"num bars: {num_bars}")

with open(local_filename, "wb") as fp:
for chunk in tq(
r.iter_content(chunk_size=chunk_size),
total=num_bars,
unit="KB",
desc=local_filename,
leave=True, # progressbar stays
):
fp.write(chunk) # type: ignore

def extract_tarfile(file_path: str, extract_path: str, mode: str) -> None:
if not os.path.exists(file_path):
return
with tarfile.open(file_path, mode=mode) as tar_ref:
for member in tar_ref.getmembers():
try:
tar_ref.extract(member, path=extract_path, set_attrs=False)
except PermissionError:
raise PermissionError(f"Could not extract tar file {file_path}")

if ".zip" in local_filename:
if os.path.exists(local_filename):
Expand Down