diff --git a/deepspeed/runtime/zero/linear.py b/deepspeed/runtime/zero/linear.py index f29fcda2bb19..23f97d5a542a 100644 --- a/deepspeed/runtime/zero/linear.py +++ b/deepspeed/runtime/zero/linear.py @@ -77,10 +77,10 @@ def backward(ctx, grad_output): #print("Computing grad weight") dim = grad_output.dim() if dim > 2: - grad_weight = grad_output.view(-1, - grad_output.shape[-1]).t().matmul( - input.view(-1, - input.shape[-1])) + grad_weight = grad_output.reshape(-1, + grad_output.shape[-1]).t().matmul( + input.reshape(-1, + input.shape[-1])) else: grad_weight = grad_output.t().matmul(input) #print(f"Computed grad weight grad_weight {grad_weight.shape}") diff --git a/deepspeed/runtime/zero/partition_parameters.py b/deepspeed/runtime/zero/partition_parameters.py index 4465adfd7c16..5a1a40460e16 100755 --- a/deepspeed/runtime/zero/partition_parameters.py +++ b/deepspeed/runtime/zero/partition_parameters.py @@ -190,6 +190,9 @@ def _init_subclass(cls, **kwargs): torch.empty = empty_cuda_tensor if self.mem_efficient_linear: + print_rank_0( + f"Your linear layers are being patched with more memory efficient version. This will persit unless manually reset.", + force=True) self.linear_bk = torch.nn.functional.linear torch.nn.functional.linear = LinearFunctionForZeroStage3.apply @@ -210,8 +213,9 @@ def _disable_class(cls): torch.Tensor.__new__ = torch.Tensor.__old_new__ torch.empty = _orig_torch_empty - if self.mem_efficient_linear: - torch.nn.functional.linear = self.linear_bk + #un doing it here will undo it during training + #if self.mem_efficient_linear: + # torch.nn.functional.linear = self.linear_bk # Now that we cleaned up the metaclass injection, raise the exception. if exc_type is not None: @@ -357,6 +361,13 @@ def get_model(): self._convert_to_deepspeed_param(param) param.partition() + if mem_efficient_linear: + print_rank_0( + f"Your linear layers are being patched with more memory efficient version. This will persit unless manually turned reset.", + force=True) + self.linear_bk = torch.nn.functional.linear + torch.nn.functional.linear = LinearFunctionForZeroStage3.apply + def _post_init_method(self, module): #see_memory_usage(f"Before converting parmas in {module.__class__.__name__}", force=False) print_rank_0(f'Converting Params in {module.__class__.__name__}', force=False)