Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[fix] Revert change that removed the option to run OffloadModel with out activation checkpointing. #608

Merged
merged 13 commits into from
Apr 15, 2021
2 changes: 1 addition & 1 deletion .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ run_offload_benchmark: &run_offload_benchmark
- run:
name: Run Offload Benchmark
command: |
python benchmarks/experimental/offload.py
python benchmarks/experimental/offload.py --checkpoint_activation

run_pipe_benchmark: &run_pipe_benchmark
- run:
Expand Down
28 changes: 14 additions & 14 deletions benchmarks/experimental/offload.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ def get_batch(source):


def verify_peak_memory(golden_config, std_dev):
print("Peak allocated bytes on cuda:0: {:1d}".format(torch.cuda.memory_stats(0)["allocated_bytes.all.peak"]))

current_device_usage = torch.cuda.memory_stats(0)["allocated_bytes.all.peak"]
golden_ref = golden_config["peak_mem_usage"]
if not current_device_usage < golden_ref * std_dev:
Expand All @@ -246,7 +246,6 @@ def verify_peak_memory(golden_config, std_dev):
def verify_lm_throughput(wps, golden_config, args):
"""Verify that words per second for a given benchmark run matches the golden data."""

print("Throughput(wps) is {:.2f}.".format(wps))
if not wps > (golden_config["avg_wps"] - (3 * golden_config["std_dev_wps"])):
raise RuntimeError(
"Throughput(wps):{:.2f} is below the golden threshold of an "
Expand All @@ -272,9 +271,12 @@ def benchmark_language_model(model_config, model, benchmark_config, model_specs,
raise RuntimeError(
f"Golden data verification is only supported for the Transformer(lm) model and not {args.model_name}"
)
golden_config = get_golden_config(args.model_name, args)
verify_lm_throughput(wps, golden_config, args)
verify_peak_memory(golden_config, 1.1)
print("Throughput(wps) is {:.2f}.".format(wps))
print("Peak allocated bytes on cuda:0: {:1d}".format(torch.cuda.memory_stats(0)["allocated_bytes.all.peak"]))
if not args.dry_run:
golden_config = get_golden_config(args.model_name, args)
verify_lm_throughput(wps, golden_config, args)
verify_peak_memory(golden_config, 1.1)


def get_synthetic_dataloaders(args, device, benchmark_config, model_specs):
Expand Down Expand Up @@ -343,11 +345,11 @@ def create_model_config(args, benchmark_config=None, model_specs=None):
raise RuntimeError(f"Unrecognized args.model_mame {args.model_name}")


def create_benchmark_config(model_name):
def create_benchmark_config(args):
"""Return a dict with configurations required for benchmarking `model_name` model."""

if args.model_name == "lm":
return lm_wikitext2.get_benchmark_config()
return lm_wikitext2.get_benchmark_config(checkpoint_activation=args.checkpoint_activation)
elif args.model_name == "seq":
return offload_seq.get_benchmark_config()
else:
Expand Down Expand Up @@ -383,17 +385,15 @@ def run_benchmark(args):
init_random_seed(0)

if args.model_name == "lm":
benchmark_config = create_benchmark_config(args.model_name)
benchmark_config = create_benchmark_config(args)
model_specs = get_model_specs(args.model_name)
model_config = create_model_config(args, benchmark_config=benchmark_config, model_specs=model_specs)
model = model_config["model"]

if args.dry_run:
train(model_config, model, benchmark_config, model_specs, args)
else:
benchmark_language_model(model_config, model, benchmark_config, model_specs, args)
benchmark_language_model(model_config, model, benchmark_config, model_specs, args)

elif args.model_name == "seq":
benchmark_config = create_benchmark_config(args.model_name)
benchmark_config = create_benchmark_config(args)
model_specs = get_model_specs(args.model_name)
model_config = create_model_config(args, benchmark_config=benchmark_config, model_specs=model_specs)
model = model_config["model"]
Expand All @@ -419,7 +419,7 @@ def run_benchmark(args):
"--use_synthetic_data", default=True, action="store_true", help="Uses synthetic data for running benchmarks."
)
parser.add_argument("--use_fp16", action="store_true", default=False)
parser.add_argument("--checkpoint_activation", action="store_true", default=True)
parser.add_argument("--checkpoint_activation", action="store_true", default=False)
parser.add_argument("--use_profiler", action="store_true", default=False)


Expand Down
6 changes: 3 additions & 3 deletions benchmarks/golden_configs/lm_wikitext2.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,14 @@ def get_model_config():
"seq_len": 32,
}

def get_benchmark_config():
def get_benchmark_config(checkpoint_activation=True):

return {
"epochs": 1,
"lr": 0.001, # learning rate
"batch_size": 8,
"criterion": nn.CrossEntropyLoss(),
"checkpoint_activation": True,
"checkpoint_activation": checkpoint_activation,
"num_microbatches": 1,
"slices": 3,
}
Expand Down Expand Up @@ -59,7 +59,7 @@ def get_benchmark_config():
"criterion": nn.CrossEntropyLoss(),
"slices": 3,
"checkpoint_activation": True,
"num_microbatches": 4,
"num_microbatches": 1,
}


Expand Down
90 changes: 89 additions & 1 deletion fairscale/experimental/nn/offload.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,75 @@ def backward(ctx, *grad_outputs): # type: ignore
return (None, None) + grads


class ShardSyncLayer(torch.autograd.Function):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

at least this part I'm a bit familiar with :)

"""
The shard sync layer is a synchronization point between model shards.
- In the forward pass, it drops parameters in the previous shard and
loads parameters for the next shard.
- In the backward pass, it does the reverse.
It does not change or create any outputs at all, instead it just
forwards the input as the output.
NOTE: see https://pytorch.org/docs/stable/autograd.html#torch.autograd.Function
"""

@staticmethod
@_conditional_amp_fwd_decorator # type: ignore
def forward(ctx: Any, inputs: Any, index: int, model_slices: Any, model_instance: Any) -> Any:
drop_index = index
load_index = index + 1
max_slices = len(model_slices)

if drop_index >= 0:
# Move shard from device to offload device.
model_slices[drop_index].forward_drop()

if load_index < max_slices:
# Load shard from offload device to device.
model_slices[load_index].forward_load()

ctx.index = index
ctx.model_slices = model_slices
ctx.model_instance = model_instance

return inputs if isinstance(inputs, tuple) else (inputs,)

@staticmethod
@_conditional_amp_bwd_decorator
def backward(ctx, *grad_outputs): # type: ignore

load_index = ctx.index
drop_index = load_index + 1
model_slices = ctx.model_slices
model_instance = ctx.model_instance

# TODO(anj-s): Are these redundant in the backward pass?
if drop_index == len(model_slices):
# Drop the last activation since it is still on the CPU
# after the loss.backward() call.
model_instance._activations[-1] = tuple([a.cuda() for a in list(model_instance._activations[-1])])

if drop_index < len(model_slices):
# Move shard from device to offload device.
model_slices[drop_index].backward_drop()
model_instance._activations[drop_index] = tuple(
[a.cpu() for a in list(model_instance._activations[drop_index])]
)

if load_index >= 0:
# Load shard from offload device to device.
model_slices[load_index].backward_load()
model_instance._activations[load_index] = tuple(
[a.cuda() for a in list(model_instance._activations[load_index])]
)

# The returned variables need to mirror the forward inputs
# TODO(anj-s): Why do we need to do this?
if isinstance(grad_outputs, tuple):
return grad_outputs[0], None, None, None

return grad_outputs, None, None, None


class OffloadModel(nn.Module):
"""Wraps an arbitrary :class:`nn.Sequential <torch.nn.Sequential>` module
to train by offloading majority of the model parameters to the CPU.
Expand Down Expand Up @@ -405,4 +474,23 @@ def forward(self, *inputs: Any, **_: Any) -> Any:

# We need the second param to be a dummy input to enable the
# backward pass to be triggered for integer inputs.
return OffloadFunction.apply(*inputs, torch.tensor([], requires_grad=True), self)
if self._checkpoint_activation:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh, I must have reviewed the offending PR and missed that, sorry about that

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No worries! I realized that the tests weren't really catching this so glad I realized it.

return OffloadFunction.apply(*inputs, torch.tensor([], requires_grad=True), self)

self._activations = []
for index in range(-1, len(self.model_slices)):
if index >= 0:
# TODO(anj-s): This might be a redundant call since we have the previous
# activation on the device already.
self._activations[index] = tuple([a.cuda() for a in list(self._activations[index])])
inputs = self._activations[index]
inputs = self.model_slices[index](*inputs)
# Call the custom autograd hooks (discard/load slices FW and BW)
inputs = ShardSyncLayer.apply(inputs, index, self.model_slices, self)
self._activations.append(inputs)
if index >= 0:
self._activations[index] = tuple([a.cpu() for a in list(self._activations[index])])

result = self._activations[-1]
result = tuple([r.cuda() for r in result])
return result[0] if len(result) == 1 else result
40 changes: 29 additions & 11 deletions tests/experimental/nn/test_offload.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,20 +32,38 @@ def test_single_run():
device, offload_device = _init()
model = _get_model()

offload_model = OffloadModel(model=model, device=device, offload_device=offload_device, num_slices=2,)
offload_optimizer = torch.optim.SGD(offload_model.parameters(), lr=0.001)
peak_mem = {}
for checkpoint_activation in [True, False]:
offload_model = OffloadModel(
model=model,
device=device,
offload_device=offload_device,
num_slices=2,
checkpoint_activation=checkpoint_activation,
)
offload_optimizer = torch.optim.SGD(offload_model.parameters(), lr=0.001)

input = torch.ones(1000, 2).to(device)
labels = torch.ones(1000, 2).to(device)
offload_model.train()
pred = offload_model(input)
loss_fn = torch.nn.MSELoss(reduction="sum")
loss = loss_fn(pred, labels)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

checking elsewhere for some form of parity ? wondering just in case

loss.backward()
offload_optimizer.step()
key = "ca_" + str(checkpoint_activation)
peak_mem[key] = torch.cuda.memory_stats(0)["allocated_bytes.all.peak"]
print(
"Peak allocated bytes on cuda:0 for checkpoint_activation "
+ str(checkpoint_activation)
+ ": {:2f}".format(peak_mem[key])
)

input = torch.ones(2, 2).to(device)
labels = torch.ones(2, 2).to(device)
offload_model.train()
pred = offload_model(input)
loss_fn = torch.nn.MSELoss(reduction="sum")
loss = loss_fn(pred, labels)
loss.backward()
offload_optimizer.step()
# TODO(anj-s): We need a better requirement since this fails on CircleCI right now.
assert peak_mem["ca_True"] <= peak_mem["ca_False"]


def _get_model(num_inputs=2, num_hidden=2, num_layers=1, num_outputs=2):
def _get_model(num_inputs=2, num_hidden=20, num_layers=10, num_outputs=2):
model = torch.nn.Sequential(
torch.nn.Linear(num_inputs, num_hidden),
*([torch.nn.Linear(num_hidden, num_hidden) for _ in range(num_layers)]),
Expand Down