Skip to content

Commit

Permalink
fix: slice while fully loaded into memory
Browse files Browse the repository at this point in the history
  • Loading branch information
MistEO committed Apr 4, 2023
1 parent da42bad commit 02de07a
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 4 deletions.
3 changes: 2 additions & 1 deletion configs_template/config_template.json
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@
"use_sr": true,
"max_speclen": 512,
"port": "8001",
"keep_ckpts": 3
"keep_ckpts": 3,
"all_in_mem": false
},
"data": {
"training_files": "filelists/train.txt",
Expand Down
8 changes: 6 additions & 2 deletions data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,10 @@ def get_audio(self, filename):
assert abs(audio_norm.shape[1]-lmin * self.hop_length) < 3 * self.hop_length
spec, c, f0, uv = spec[:, :lmin], c[:, :lmin], f0[:lmin], uv[:lmin]
audio_norm = audio_norm[:, :lmin * self.hop_length]

return c, f0, spec, audio_norm, spk, uv

def random_slice(self, c, f0, spec, audio_norm, spk, uv):
# if spec.shape[1] < 30:
# print("skip too short audio:", filename)
# return None
Expand All @@ -92,9 +96,9 @@ def get_audio(self, filename):

def __getitem__(self, index):
if self.all_in_mem:
return self.cache[index]
return self.random_slice(*self.cache[index])
else:
return self.get_audio(self.audiopaths[index][0])
return self.random_slice(self.get_audio(self.audiopaths[index][0]))

def __len__(self):
return len(self.audiopaths)
Expand Down
4 changes: 3 additions & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,11 @@ def run(rank, n_gpus, hps):
torch.manual_seed(hps.train.seed)
torch.cuda.set_device(rank)
collate_fn = TextAudioCollate()
all_in_mem = False # If you have enough memory, turn on this option to avoid disk IO and speed up training.
all_in_mem = hps.train.all_in_mem # If you have enough memory, turn on this option to avoid disk IO and speed up training.
train_dataset = TextAudioSpeakerLoader(hps.data.training_files, hps, all_in_mem=all_in_mem)
num_workers = 5 if multiprocessing.cpu_count() > 4 else multiprocessing.cpu_count()
if all_in_mem:
num_workers = 0
train_loader = DataLoader(train_dataset, num_workers=num_workers, shuffle=False, pin_memory=True,
batch_size=hps.train.batch_size, collate_fn=collate_fn)
if rank == 0:
Expand Down

0 comments on commit 02de07a

Please sign in to comment.