diff --git a/figure_2_iteration_benchmark.py b/figure_2_iteration_benchmark.py index d60c624..37d4171 100644 --- a/figure_2_iteration_benchmark.py +++ b/figure_2_iteration_benchmark.py @@ -55,7 +55,7 @@ def _iterate(dataset, h5labels, random: bool = False, need_sort: bool = False): for batch_idx in index_iter(dataset.shape[0], BATCH_SIZE, shuffle=random): if random and need_sort: batch_idx.sort() - batch_X = dataset[batch_idx, :] + batch_X = dataset[batch_idx, :].compute() batch_labels = h5labels[batch_idx]