-
Notifications
You must be signed in to change notification settings - Fork 49
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
many dsp operator accumulate gradient update sequence length after trimming filtering in waveglow metrics for Bayesian model other loss functions for contrastive loss ...
- Loading branch information
Showing
32 changed files
with
2,310 additions
and
130 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -22,6 +22,7 @@ | |
import core_scripts.data_io.customize_collate_fn as nii_collate_fn | ||
import core_scripts.data_io.customize_sampler as nii_sampler_fn | ||
import core_scripts.data_io.conf as nii_dconf | ||
import core_scripts.data_io.seq_info as nii_seqinfo | ||
|
||
__author__ = "Xin Wang" | ||
__email__ = "[email protected]" | ||
|
@@ -35,7 +36,7 @@ | |
class merge_loader(): | ||
""" merge_loader | ||
Data loader for customized data with m_concate_set = None | ||
By defauly, draw equal number of samples from each subset | ||
By default, draw equal number of samples from each subset | ||
__iter__(): | ||
__next__(): load data and merge into minibatch | ||
|
@@ -162,6 +163,31 @@ def f_get_seq_len_list(self): | |
tmp += sub_dataset.f_get_seq_len_list() | ||
return tmp | ||
|
||
def f_get_updated_seq_len_for_sampler_list(self): | ||
""" Similar to f_get_seq_len_list | ||
but it returns the updated data sequence length only for | ||
length-based shuffling in sampler | ||
""" | ||
tmp = [] | ||
for sub_dataset in self.datasets: | ||
tmp += sub_dataset.f_get_updated_seq_len_for_sampler_list() | ||
return tmp | ||
|
||
def f_update_seq_len_for_sampler_list(self, data_info): | ||
""" | ||
""" | ||
for one_info in data_info: | ||
data_idx = nii_seqinfo.parse_idx(one_info) | ||
data_len = nii_seqinfo.parse_length(one_info) | ||
for idx_u, idx_d, subset in \ | ||
zip(self.len_top, self.len_bot, self.datasets): | ||
if data_idx < idx_u: | ||
subset.f_update_seq_len_for_sampler_list(data_idx, data_len) | ||
break | ||
else: | ||
pass | ||
return | ||
|
||
def f_manage_data(self, lst_data_idx, opt): | ||
""" f_manage_data(self, lst_data_idx, opt) | ||
""" | ||
|
@@ -420,13 +446,14 @@ def print_info(self): | |
dset.print_info() | ||
return | ||
|
||
def putitem(self, output_data, save_dir, data_infor_str): | ||
def putitem(self, output_data, save_dir, filename_prefix, data_infor_str): | ||
""" Decompose the output_data from network into | ||
separate files | ||
""" | ||
# Since all datasets have similar configuration on feat dim, | ||
# use anyone is OK | ||
self.m_datasets[0].putitem(output_data, save_dir, data_infor_str) | ||
self.m_datasets[0].putitem(output_data, save_dir, filename_prefix, | ||
data_infor_str) | ||
|
||
def get_in_dim(self): | ||
""" Return the dimension of input features | ||
|
@@ -457,6 +484,39 @@ def get_seq_list(self): | |
tmp += dataset.get_seq_list() | ||
return tmp | ||
|
||
def update_seq_len_in_sampler_sub(self, data_info): | ||
""" | ||
""" | ||
# assume data_info logs the new data length that can be used for | ||
# sampler shuffle_by_length | ||
if self.way_to_merge == 'concatenate': | ||
self.m_concate_set.f_update_seq_len_for_sampler_list(data_info) | ||
else: | ||
print("Not implemented") | ||
sys.exit(1) | ||
return | ||
|
||
def update_seq_len_in_sampler(self): | ||
""" update_seq_len() | ||
Update sequence length if sequence length has been changed | ||
(for example, during silence trim process) | ||
This is necessary when using shuffle_by_seq_length sampler | ||
and the sequences were trimmed in data augmentation function. | ||
""" | ||
# call each subdataset and update the sequence length | ||
for idx, _ in enumerate(self.m_datasets): | ||
self.m_datasets[idx].update_seq_len_in_sampler() | ||
|
||
# update loader of this database | ||
if self.way_to_merge == 'concatenate': | ||
if self.m_params['sampler'] == nii_sampler_fn.g_str_sampler_bsbl \ | ||
and hasattr(self.m_loader.sampler, 'update_seq_length'): | ||
self.m_loader.sampler.update_seq_length( | ||
self.m_concate_set.f_get_updated_seq_len_for_sampler_list()) | ||
return | ||
|
||
def manage_data(self, lst_data_idx, opt): | ||
""" manage_data(data_index_list, opt) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.