Skip to content

Commit

Permalink
Made the tqdm progress_bar objects of static download methods a sta…
Browse files Browse the repository at this point in the history
…tic class variable (#3297)
  • Loading branch information
FlorianEagox authored Nov 24, 2023
1 parent b47d9c6 commit 64f391b
Showing 1 changed file with 8 additions and 6 deletions.
14 changes: 8 additions & 6 deletions TTS/utils/manage.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,9 @@
}



class ModelManager(object):
tqdm_progress = None
"""Manage TTS models defined in .models.json.
It provides an interface to list and download
models defines in '.model.json'
Expand Down Expand Up @@ -525,12 +527,12 @@ def _download_zip_file(file_url, output_folder, progress_bar):
total_size_in_bytes = int(r.headers.get("content-length", 0))
block_size = 1024 # 1 Kibibyte
if progress_bar:
progress_bar = tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True)
ModelManager.tqdm_progress = tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True)
temp_zip_name = os.path.join(output_folder, file_url.split("/")[-1])
with open(temp_zip_name, "wb") as file:
for data in r.iter_content(block_size):
if progress_bar:
progress_bar.update(len(data))
ModelManager.tqdm_progress.update(len(data))
file.write(data)
with zipfile.ZipFile(temp_zip_name) as z:
z.extractall(output_folder)
Expand Down Expand Up @@ -560,12 +562,12 @@ def _download_tar_file(file_url, output_folder, progress_bar):
total_size_in_bytes = int(r.headers.get("content-length", 0))
block_size = 1024 # 1 Kibibyte
if progress_bar:
progress_bar = tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True)
ModelManager.tqdm_progress = tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True)
temp_tar_name = os.path.join(output_folder, file_url.split("/")[-1])
with open(temp_tar_name, "wb") as file:
for data in r.iter_content(block_size):
if progress_bar:
progress_bar.update(len(data))
ModelManager.tqdm_progress.update(len(data))
file.write(data)
with tarfile.open(temp_tar_name) as t:
t.extractall(output_folder)
Expand Down Expand Up @@ -596,10 +598,10 @@ def _download_model_files(file_urls, output_folder, progress_bar):
block_size = 1024 # 1 Kibibyte
with open(temp_zip_name, "wb") as file:
if progress_bar:
progress_bar = tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True)
ModelManager.tqdm_progress = tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True)
for data in r.iter_content(block_size):
if progress_bar:
progress_bar.update(len(data))
ModelManager.tqdm_progress.update(len(data))
file.write(data)

@staticmethod
Expand Down

0 comments on commit 64f391b

Please sign in to comment.