diff --git a/image/data.py b/image/data.py index 59297bf..8c54459 100644 --- a/image/data.py +++ b/image/data.py @@ -198,7 +198,7 @@ def flatten_input(*features): if len(datasets) > 1: dataset = tf.data.Dataset.zip(tuple(datasets)) - dataset = dataset.map(flatten_input) + dataset = dataset.map(flatten_input,num_parallel_calls=tf.data.experimental.AUTOTUNE) else: dataset = datasets[0] diff --git a/text/utils/proc_data_utils.py b/text/utils/proc_data_utils.py index 31dca52..e668db0 100644 --- a/text/utils/proc_data_utils.py +++ b/text/utils/proc_data_utils.py @@ -251,7 +251,7 @@ def flatten_input(*features): if len(dataset_list) > 1: d = tf.data.Dataset.zip(tuple(dataset_list)) - d = d.map(flatten_input) + d = d.map(flatten_input,num_parallel_calls=tf.data.experimental.AUTOTUNE) else: d = dataset_list[0]