diff --git a/test/auto_parallel/high_order_grad.py b/test/auto_parallel/high_order_grad.py index bac7d22be1e920..2f3df76c906e48 100644 --- a/test/auto_parallel/high_order_grad.py +++ b/test/auto_parallel/high_order_grad.py @@ -89,6 +89,7 @@ def __init__(self, num_sample): def __getitem__(self, index): x = np.linspace(0, 0.9, 10) y = np.linspace(0, 0.9, 10) + np.random.seed(index) # Optional: Ensure reproducibility bc_value = np.random.rand(36).reshape(36, 1).astype('float32') domain_space = [] @@ -100,8 +101,9 @@ def __getitem__(self, index): bc_index.append(i + 10 * j) domain_space = np.array(domain_space, dtype='float32') bc_index = np.array(bc_index, dtype='int64') - - return domain_space, bc_index, bc_value + # Return a single input point and its related information based on the index + idx = index % len(domain_space) + return domain_space[idx], bc_index, bc_value def __len__(self): return self.num_sample