Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
fix remote check in download_data
Browse files Browse the repository at this point in the history
  • Loading branch information
Borda committed Aug 9, 2023
1 parent 269e852 commit ae6370f
Showing 1 changed file with 30 additions and 27 deletions.
57 changes: 30 additions & 27 deletions src/flash/core/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@
}


def download_data(url: str, path: str = "data/", verbose: bool = False) -> None:
def download_data(url: str, path: str = "data/", verbose: bool = False, chunk_size: int = 1024) -> None:
"""Download file with progressbar.
# Code adapted from: https://gist.github.com/ruxi/5d6803c116ec1130d484a4ab8c00c603
Expand All @@ -78,39 +78,42 @@ def download_data(url: str, path: str = "data/", verbose: bool = False) -> None:
[...]
"""
local_filename = os.path.join(path, url.split("/")[-1])
if os.path.exists(local_filename):
if verbose:
print(f"local file already exists: '{local_filename}'")
return

os.makedirs(path, exist_ok=True)
# Disable warning about making an insecure request
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)

if not os.path.exists(path):
os.makedirs(path)
local_filename = os.path.join(path, url.split("/")[-1])
r = requests.get(url, stream=True, verify=False)
file_size = int(r.headers["Content-Length"]) if "Content-Length" in r.headers else 0
chunk_size = 1024
num_bars = int(file_size / chunk_size)
if verbose:
print({"file_size": file_size})
print({"num_bars": num_bars})

if not os.path.exists(local_filename):
with open(local_filename, "wb") as fp:
for chunk in tq(
r.iter_content(chunk_size=chunk_size),
total=num_bars,
unit="KB",
desc=local_filename,
leave=True, # progressbar stays
):
fp.write(chunk) # type: ignore

def extract_tarfile(file_path: str, extract_path: str, mode: str):
if os.path.exists(file_path):
with tarfile.open(file_path, mode=mode) as tar_ref:
for member in tar_ref.getmembers():
try:
tar_ref.extract(member, path=extract_path, set_attrs=False)
except PermissionError:
raise PermissionError(f"Could not extract tar file {file_path}")
print(f"file size: {file_size}")
print(f"num bars: {num_bars}")

with open(local_filename, "wb") as fp:
for chunk in tq(
r.iter_content(chunk_size=chunk_size),
total=num_bars,
unit="KB",
desc=local_filename,
leave=True, # progressbar stays
):
fp.write(chunk) # type: ignore

def extract_tarfile(file_path: str, extract_path: str, mode: str) -> None:
if not os.path.exists(file_path):
return
with tarfile.open(file_path, mode=mode) as tar_ref:
for member in tar_ref.getmembers():
try:
tar_ref.extract(member, path=extract_path, set_attrs=False)
except PermissionError:
raise PermissionError(f"Could not extract tar file {file_path}")

if ".zip" in local_filename:
if os.path.exists(local_filename):
Expand Down

0 comments on commit ae6370f

Please sign in to comment.