diff --git a/test/auto_parallel/test_shard_tensor_api.py b/test/auto_parallel/test_shard_tensor_api.py index 14b0529b96bd5..5efa0c25031f1 100644 --- a/test/auto_parallel/test_shard_tensor_api.py +++ b/test/auto_parallel/test_shard_tensor_api.py @@ -93,9 +93,9 @@ def test_dynamic_mode_property_change(self): self.assertEqual(d_tensor.process_mesh, self.mesh) def test_stop_gradient(self): - x = paddle.ones([10, 10]) + x = paddle.ones([4, 1024, 512]) x.stop_gradient = False - x = dist.shard_tensor(x, self.mesh, [Shard(0)]) + x = dist.shard_tensor(x, self.mesh, [Shard(0), Replicate()]) assert not x.stop_gradient