diff --git a/test/auto_parallel/hybrid_strategy/semi_auto_parallel_global_input.py b/test/auto_parallel/hybrid_strategy/semi_auto_parallel_global_input.py index 0525edf01b198b..093a07f187986d 100644 --- a/test/auto_parallel/hybrid_strategy/semi_auto_parallel_global_input.py +++ b/test/auto_parallel/hybrid_strategy/semi_auto_parallel_global_input.py @@ -195,7 +195,7 @@ def test_basic(self): dist_model = dist.to_static(model, dist_dataloader, loss_fn, opt) dist_model.train() - for step, (input, label) in enumerate(dist_dataloader): + for input, label in dist_dataloader: loss = dist_model(input, label) if cur_rank in [5, 7]: @@ -204,7 +204,7 @@ def test_basic(self): dist.all_reduce(loss, group=group) else: dist_opt = dist.shard_optimizer(opt) - for step, (input, label) in enumerate(dist_dataloader()): + for input, label in dist_dataloader: logits = model(input) loss = loss_fn(logits, label) loss.backward()