From 250d0a286160504dae3c62185c63f207896c749b Mon Sep 17 00:00:00 2001 From: Cheng Li Date: Tue, 20 Apr 2021 03:53:51 +0000 Subject: [PATCH 1/2] use wierd shaped tensor to avoid silent failures when not registering externel params --- deepspeed/runtime/zero/partition_parameters.py | 6 ++++-- tests/unit/test_zero_context.py | 4 ++-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/deepspeed/runtime/zero/partition_parameters.py b/deepspeed/runtime/zero/partition_parameters.py index c8bde6390b3c..3332668f1469 100755 --- a/deepspeed/runtime/zero/partition_parameters.py +++ b/deepspeed/runtime/zero/partition_parameters.py @@ -23,6 +23,7 @@ from ..config import DeepSpeedConfig param_count = 0 +partitioned_param_data_shape = [1] def print_rank_0(message, debug=False, force=False): @@ -631,7 +632,8 @@ def _partition_param(self, param, buffer=None, has_been_updated=False): f'Before partitioning param {param.ds_id} {param.shape}', force=False) #param.data does not store anything meaningful in partitioned state - param.data = torch.ones(1).half().to(param.device) + param.data = torch.ones(partitioned_param_data_shape).half().to( + param.device) see_memory_usage(f'After partitioning param {param.ds_id} {param.shape}', force=False) @@ -712,7 +714,7 @@ def _partition_param(self, param, buffer=None, has_been_updated=False): see_memory_usage(f'Before partitioning param {param.ds_id} {param.shape}', force=False) - param.data = torch.ones(1).half().to(param.device) + param.data = torch.ones(partitioned_param_data_shape).half().to(param.device) see_memory_usage(f'After partitioning param {param.ds_id} {param.shape}', force=False) diff --git a/tests/unit/test_zero_context.py b/tests/unit/test_zero_context.py index 9c45b58abf66..8833d7374e97 100644 --- a/tests/unit/test_zero_context.py +++ b/tests/unit/test_zero_context.py @@ -6,7 +6,7 @@ import pytest import deepspeed -from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus +from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus, partitioned_param_data_shape from common import distributed_test @@ -32,7 +32,7 @@ def test_scatter_gather(): with deepspeed.zero.Init(): l = torch.nn.Linear(6, 3) assert l.weight.ds_status == ZeroParamStatus.NOT_AVAILABLE - assert l.weight.numel() == 1 + assert l.weight.numel() == torch.size(partitioned_param_data_shape) # Ensure there is no impact outside the context l2 = torch.nn.Linear(6, 3) From 060285bd4e9b810865462dbfbf73631ae6219d7b Mon Sep 17 00:00:00 2001 From: Cheng Li Date: Tue, 20 Apr 2021 04:38:30 +0000 Subject: [PATCH 2/2] fix typo --- tests/unit/test_zero_context.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit/test_zero_context.py b/tests/unit/test_zero_context.py index 8833d7374e97..5ccccb5c18a0 100644 --- a/tests/unit/test_zero_context.py +++ b/tests/unit/test_zero_context.py @@ -32,7 +32,7 @@ def test_scatter_gather(): with deepspeed.zero.Init(): l = torch.nn.Linear(6, 3) assert l.weight.ds_status == ZeroParamStatus.NOT_AVAILABLE - assert l.weight.numel() == torch.size(partitioned_param_data_shape) + assert l.weight.shape == torch.Size(partitioned_param_data_shape) # Ensure there is no impact outside the context l2 = torch.nn.Linear(6, 3)