diff --git a/deepspeed/runtime/zero/partition_parameters.py b/deepspeed/runtime/zero/partition_parameters.py index 9ff5a7232a73..e831911efd62 100755 --- a/deepspeed/runtime/zero/partition_parameters.py +++ b/deepspeed/runtime/zero/partition_parameters.py @@ -23,6 +23,7 @@ from ..config import DeepSpeedConfig param_count = 0 +partitioned_param_data_shape = [1] def print_rank_0(message, debug=False, force=False): @@ -634,7 +635,8 @@ def _partition_param(self, param, buffer=None, has_been_updated=False): f'Before partitioning param {param.ds_id} {param.shape}', force=False) #param.data does not store anything meaningful in partitioned state - param.data = torch.ones(1).half().to(param.device) + param.data = torch.ones(partitioned_param_data_shape).half().to( + param.device) see_memory_usage(f'After partitioning param {param.ds_id} {param.shape}', force=False) @@ -715,7 +717,7 @@ def _partition_param(self, param, buffer=None, has_been_updated=False): see_memory_usage(f'Before partitioning param {param.ds_id} {param.shape}', force=False) - param.data = torch.ones(1).half().to(param.device) + param.data = torch.ones(partitioned_param_data_shape).half().to(param.device) see_memory_usage(f'After partitioning param {param.ds_id} {param.shape}', force=False) diff --git a/tests/unit/test_zero_context.py b/tests/unit/test_zero_context.py index 9c45b58abf66..5ccccb5c18a0 100644 --- a/tests/unit/test_zero_context.py +++ b/tests/unit/test_zero_context.py @@ -6,7 +6,7 @@ import pytest import deepspeed -from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus +from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus, partitioned_param_data_shape from common import distributed_test @@ -32,7 +32,7 @@ def test_scatter_gather(): with deepspeed.zero.Init(): l = torch.nn.Linear(6, 3) assert l.weight.ds_status == ZeroParamStatus.NOT_AVAILABLE - assert l.weight.numel() == 1 + assert l.weight.shape == torch.Size(partitioned_param_data_shape) # Ensure there is no impact outside the context l2 = torch.nn.Linear(6, 3)