-
Notifications
You must be signed in to change notification settings - Fork 284
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[fix] Revert change that removed the option to run OffloadModel with out activation checkpointing. #608
[fix] Revert change that removed the option to run OffloadModel with out activation checkpointing. #608
Changes from all commits
36820dc
3adfe9b
d869d91
4b25a7c
ec10bf0
701099a
8784d1d
886cf99
181aef2
00fcfd5
86a3da2
52fcbe7
7205ad1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -292,6 +292,75 @@ def backward(ctx, *grad_outputs): # type: ignore | |
return (None, None) + grads | ||
|
||
|
||
class ShardSyncLayer(torch.autograd.Function): | ||
""" | ||
The shard sync layer is a synchronization point between model shards. | ||
- In the forward pass, it drops parameters in the previous shard and | ||
loads parameters for the next shard. | ||
- In the backward pass, it does the reverse. | ||
It does not change or create any outputs at all, instead it just | ||
forwards the input as the output. | ||
NOTE: see https://pytorch.org/docs/stable/autograd.html#torch.autograd.Function | ||
""" | ||
|
||
@staticmethod | ||
@_conditional_amp_fwd_decorator # type: ignore | ||
def forward(ctx: Any, inputs: Any, index: int, model_slices: Any, model_instance: Any) -> Any: | ||
drop_index = index | ||
load_index = index + 1 | ||
max_slices = len(model_slices) | ||
|
||
if drop_index >= 0: | ||
# Move shard from device to offload device. | ||
model_slices[drop_index].forward_drop() | ||
|
||
if load_index < max_slices: | ||
# Load shard from offload device to device. | ||
model_slices[load_index].forward_load() | ||
|
||
ctx.index = index | ||
ctx.model_slices = model_slices | ||
ctx.model_instance = model_instance | ||
|
||
return inputs if isinstance(inputs, tuple) else (inputs,) | ||
|
||
@staticmethod | ||
@_conditional_amp_bwd_decorator | ||
def backward(ctx, *grad_outputs): # type: ignore | ||
|
||
load_index = ctx.index | ||
drop_index = load_index + 1 | ||
model_slices = ctx.model_slices | ||
model_instance = ctx.model_instance | ||
|
||
# TODO(anj-s): Are these redundant in the backward pass? | ||
if drop_index == len(model_slices): | ||
# Drop the last activation since it is still on the CPU | ||
# after the loss.backward() call. | ||
model_instance._activations[-1] = tuple([a.cuda() for a in list(model_instance._activations[-1])]) | ||
|
||
if drop_index < len(model_slices): | ||
# Move shard from device to offload device. | ||
model_slices[drop_index].backward_drop() | ||
model_instance._activations[drop_index] = tuple( | ||
[a.cpu() for a in list(model_instance._activations[drop_index])] | ||
) | ||
|
||
if load_index >= 0: | ||
# Load shard from offload device to device. | ||
model_slices[load_index].backward_load() | ||
model_instance._activations[load_index] = tuple( | ||
[a.cuda() for a in list(model_instance._activations[load_index])] | ||
) | ||
|
||
# The returned variables need to mirror the forward inputs | ||
# TODO(anj-s): Why do we need to do this? | ||
if isinstance(grad_outputs, tuple): | ||
return grad_outputs[0], None, None, None | ||
|
||
return grad_outputs, None, None, None | ||
|
||
|
||
class OffloadModel(nn.Module): | ||
"""Wraps an arbitrary :class:`nn.Sequential <torch.nn.Sequential>` module | ||
to train by offloading majority of the model parameters to the CPU. | ||
|
@@ -405,4 +474,23 @@ def forward(self, *inputs: Any, **_: Any) -> Any: | |
|
||
# We need the second param to be a dummy input to enable the | ||
# backward pass to be triggered for integer inputs. | ||
return OffloadFunction.apply(*inputs, torch.tensor([], requires_grad=True), self) | ||
if self._checkpoint_activation: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. oh, I must have reviewed the offending PR and missed that, sorry about that There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No worries! I realized that the tests weren't really catching this so glad I realized it. |
||
return OffloadFunction.apply(*inputs, torch.tensor([], requires_grad=True), self) | ||
|
||
self._activations = [] | ||
for index in range(-1, len(self.model_slices)): | ||
if index >= 0: | ||
# TODO(anj-s): This might be a redundant call since we have the previous | ||
# activation on the device already. | ||
self._activations[index] = tuple([a.cuda() for a in list(self._activations[index])]) | ||
inputs = self._activations[index] | ||
inputs = self.model_slices[index](*inputs) | ||
# Call the custom autograd hooks (discard/load slices FW and BW) | ||
inputs = ShardSyncLayer.apply(inputs, index, self.model_slices, self) | ||
self._activations.append(inputs) | ||
if index >= 0: | ||
self._activations[index] = tuple([a.cpu() for a in list(self._activations[index])]) | ||
|
||
result = self._activations[-1] | ||
result = tuple([r.cuda() for r in result]) | ||
return result[0] if len(result) == 1 else result |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -32,20 +32,38 @@ def test_single_run(): | |
device, offload_device = _init() | ||
model = _get_model() | ||
|
||
offload_model = OffloadModel(model=model, device=device, offload_device=offload_device, num_slices=2,) | ||
offload_optimizer = torch.optim.SGD(offload_model.parameters(), lr=0.001) | ||
peak_mem = {} | ||
for checkpoint_activation in [True, False]: | ||
offload_model = OffloadModel( | ||
model=model, | ||
device=device, | ||
offload_device=offload_device, | ||
num_slices=2, | ||
checkpoint_activation=checkpoint_activation, | ||
) | ||
offload_optimizer = torch.optim.SGD(offload_model.parameters(), lr=0.001) | ||
|
||
input = torch.ones(1000, 2).to(device) | ||
labels = torch.ones(1000, 2).to(device) | ||
offload_model.train() | ||
pred = offload_model(input) | ||
loss_fn = torch.nn.MSELoss(reduction="sum") | ||
loss = loss_fn(pred, labels) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. checking elsewhere for some form of parity ? wondering just in case |
||
loss.backward() | ||
offload_optimizer.step() | ||
key = "ca_" + str(checkpoint_activation) | ||
peak_mem[key] = torch.cuda.memory_stats(0)["allocated_bytes.all.peak"] | ||
print( | ||
"Peak allocated bytes on cuda:0 for checkpoint_activation " | ||
+ str(checkpoint_activation) | ||
+ ": {:2f}".format(peak_mem[key]) | ||
) | ||
|
||
input = torch.ones(2, 2).to(device) | ||
labels = torch.ones(2, 2).to(device) | ||
offload_model.train() | ||
pred = offload_model(input) | ||
loss_fn = torch.nn.MSELoss(reduction="sum") | ||
loss = loss_fn(pred, labels) | ||
loss.backward() | ||
offload_optimizer.step() | ||
# TODO(anj-s): We need a better requirement since this fails on CircleCI right now. | ||
assert peak_mem["ca_True"] <= peak_mem["ca_False"] | ||
|
||
|
||
def _get_model(num_inputs=2, num_hidden=2, num_layers=1, num_outputs=2): | ||
def _get_model(num_inputs=2, num_hidden=20, num_layers=10, num_outputs=2): | ||
model = torch.nn.Sequential( | ||
torch.nn.Linear(num_inputs, num_hidden), | ||
*([torch.nn.Linear(num_hidden, num_hidden) for _ in range(num_layers)]), | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
at least this part I'm a bit familiar with :)