Skip to content

Commit a77c56f

Browse files
anj-sAnjali Sridhar
and
Anjali Sridhar
authored
[fix] Revert change that removed the option to run OffloadModel with out activation checkpointing. (#608)
* revert change made * add tests and revert sync shard changes * add tests * remove file checked in by error * inine var * fix lint errors * add checkpoint activation * fix mypy * use a bigger model * modify tests for now * resolve conflicts Co-authored-by: Anjali Sridhar <[email protected]>
1 parent 5650695 commit a77c56f

File tree

5 files changed

+136
-30
lines changed

5 files changed

+136
-30
lines changed

.circleci/config.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,7 @@ run_offload_benchmark: &run_offload_benchmark
170170
- run:
171171
name: Run Offload Benchmark
172172
command: |
173-
python benchmarks/experimental/offload.py
173+
python benchmarks/experimental/offload.py --checkpoint_activation
174174
175175
run_pipe_benchmark: &run_pipe_benchmark
176176
- run:

benchmarks/experimental/offload.py

+14-14
Original file line numberDiff line numberDiff line change
@@ -233,7 +233,7 @@ def get_batch(source):
233233

234234

235235
def verify_peak_memory(golden_config, std_dev):
236-
print("Peak allocated bytes on cuda:0: {:1d}".format(torch.cuda.memory_stats(0)["allocated_bytes.all.peak"]))
236+
237237
current_device_usage = torch.cuda.memory_stats(0)["allocated_bytes.all.peak"]
238238
golden_ref = golden_config["peak_mem_usage"]
239239
if not current_device_usage < golden_ref * std_dev:
@@ -246,7 +246,6 @@ def verify_peak_memory(golden_config, std_dev):
246246
def verify_lm_throughput(wps, golden_config, args):
247247
"""Verify that words per second for a given benchmark run matches the golden data."""
248248

249-
print("Throughput(wps) is {:.2f}.".format(wps))
250249
if not wps > (golden_config["avg_wps"] - (3 * golden_config["std_dev_wps"])):
251250
raise RuntimeError(
252251
"Throughput(wps):{:.2f} is below the golden threshold of an "
@@ -272,9 +271,12 @@ def benchmark_language_model(model_config, model, benchmark_config, model_specs,
272271
raise RuntimeError(
273272
f"Golden data verification is only supported for the Transformer(lm) model and not {args.model_name}"
274273
)
275-
golden_config = get_golden_config(args.model_name, args)
276-
verify_lm_throughput(wps, golden_config, args)
277-
verify_peak_memory(golden_config, 1.1)
274+
print("Throughput(wps) is {:.2f}.".format(wps))
275+
print("Peak allocated bytes on cuda:0: {:1d}".format(torch.cuda.memory_stats(0)["allocated_bytes.all.peak"]))
276+
if not args.dry_run:
277+
golden_config = get_golden_config(args.model_name, args)
278+
verify_lm_throughput(wps, golden_config, args)
279+
verify_peak_memory(golden_config, 1.1)
278280

279281

280282
def get_synthetic_dataloaders(args, device, benchmark_config, model_specs):
@@ -343,11 +345,11 @@ def create_model_config(args, benchmark_config=None, model_specs=None):
343345
raise RuntimeError(f"Unrecognized args.model_mame {args.model_name}")
344346

345347

346-
def create_benchmark_config(model_name):
348+
def create_benchmark_config(args):
347349
"""Return a dict with configurations required for benchmarking `model_name` model."""
348350

349351
if args.model_name == "lm":
350-
return lm_wikitext2.get_benchmark_config()
352+
return lm_wikitext2.get_benchmark_config(checkpoint_activation=args.checkpoint_activation)
351353
elif args.model_name == "seq":
352354
return offload_seq.get_benchmark_config()
353355
else:
@@ -383,17 +385,15 @@ def run_benchmark(args):
383385
init_random_seed(0)
384386

385387
if args.model_name == "lm":
386-
benchmark_config = create_benchmark_config(args.model_name)
388+
benchmark_config = create_benchmark_config(args)
387389
model_specs = get_model_specs(args.model_name)
388390
model_config = create_model_config(args, benchmark_config=benchmark_config, model_specs=model_specs)
389391
model = model_config["model"]
390392

391-
if args.dry_run:
392-
train(model_config, model, benchmark_config, model_specs, args)
393-
else:
394-
benchmark_language_model(model_config, model, benchmark_config, model_specs, args)
393+
benchmark_language_model(model_config, model, benchmark_config, model_specs, args)
394+
395395
elif args.model_name == "seq":
396-
benchmark_config = create_benchmark_config(args.model_name)
396+
benchmark_config = create_benchmark_config(args)
397397
model_specs = get_model_specs(args.model_name)
398398
model_config = create_model_config(args, benchmark_config=benchmark_config, model_specs=model_specs)
399399
model = model_config["model"]
@@ -419,7 +419,7 @@ def run_benchmark(args):
419419
"--use_synthetic_data", default=True, action="store_true", help="Uses synthetic data for running benchmarks."
420420
)
421421
parser.add_argument("--use_fp16", action="store_true", default=False)
422-
parser.add_argument("--checkpoint_activation", action="store_true", default=True)
422+
parser.add_argument("--checkpoint_activation", action="store_true", default=False)
423423
parser.add_argument("--use_profiler", action="store_true", default=False)
424424

425425

benchmarks/golden_configs/lm_wikitext2.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,14 @@ def get_model_config():
2020
"seq_len": 32,
2121
}
2222

23-
def get_benchmark_config():
23+
def get_benchmark_config(checkpoint_activation=True):
2424

2525
return {
2626
"epochs": 1,
2727
"lr": 0.001, # learning rate
2828
"batch_size": 8,
2929
"criterion": nn.CrossEntropyLoss(),
30-
"checkpoint_activation": True,
30+
"checkpoint_activation": checkpoint_activation,
3131
"num_microbatches": 1,
3232
"slices": 3,
3333
}
@@ -59,7 +59,7 @@ def get_benchmark_config():
5959
"criterion": nn.CrossEntropyLoss(),
6060
"slices": 3,
6161
"checkpoint_activation": True,
62-
"num_microbatches": 4,
62+
"num_microbatches": 1,
6363
}
6464

6565

fairscale/experimental/nn/offload.py

+89-1
Original file line numberDiff line numberDiff line change
@@ -292,6 +292,75 @@ def backward(ctx, *grad_outputs): # type: ignore
292292
return (None, None) + grads
293293

294294

295+
class ShardSyncLayer(torch.autograd.Function):
296+
"""
297+
The shard sync layer is a synchronization point between model shards.
298+
- In the forward pass, it drops parameters in the previous shard and
299+
loads parameters for the next shard.
300+
- In the backward pass, it does the reverse.
301+
It does not change or create any outputs at all, instead it just
302+
forwards the input as the output.
303+
NOTE: see https://pytorch.org/docs/stable/autograd.html#torch.autograd.Function
304+
"""
305+
306+
@staticmethod
307+
@_conditional_amp_fwd_decorator # type: ignore
308+
def forward(ctx: Any, inputs: Any, index: int, model_slices: Any, model_instance: Any) -> Any:
309+
drop_index = index
310+
load_index = index + 1
311+
max_slices = len(model_slices)
312+
313+
if drop_index >= 0:
314+
# Move shard from device to offload device.
315+
model_slices[drop_index].forward_drop()
316+
317+
if load_index < max_slices:
318+
# Load shard from offload device to device.
319+
model_slices[load_index].forward_load()
320+
321+
ctx.index = index
322+
ctx.model_slices = model_slices
323+
ctx.model_instance = model_instance
324+
325+
return inputs if isinstance(inputs, tuple) else (inputs,)
326+
327+
@staticmethod
328+
@_conditional_amp_bwd_decorator
329+
def backward(ctx, *grad_outputs): # type: ignore
330+
331+
load_index = ctx.index
332+
drop_index = load_index + 1
333+
model_slices = ctx.model_slices
334+
model_instance = ctx.model_instance
335+
336+
# TODO(anj-s): Are these redundant in the backward pass?
337+
if drop_index == len(model_slices):
338+
# Drop the last activation since it is still on the CPU
339+
# after the loss.backward() call.
340+
model_instance._activations[-1] = tuple([a.cuda() for a in list(model_instance._activations[-1])])
341+
342+
if drop_index < len(model_slices):
343+
# Move shard from device to offload device.
344+
model_slices[drop_index].backward_drop()
345+
model_instance._activations[drop_index] = tuple(
346+
[a.cpu() for a in list(model_instance._activations[drop_index])]
347+
)
348+
349+
if load_index >= 0:
350+
# Load shard from offload device to device.
351+
model_slices[load_index].backward_load()
352+
model_instance._activations[load_index] = tuple(
353+
[a.cuda() for a in list(model_instance._activations[load_index])]
354+
)
355+
356+
# The returned variables need to mirror the forward inputs
357+
# TODO(anj-s): Why do we need to do this?
358+
if isinstance(grad_outputs, tuple):
359+
return grad_outputs[0], None, None, None
360+
361+
return grad_outputs, None, None, None
362+
363+
295364
class OffloadModel(nn.Module):
296365
"""Wraps an arbitrary :class:`nn.Sequential <torch.nn.Sequential>` module
297366
to train by offloading majority of the model parameters to the CPU.
@@ -405,4 +474,23 @@ def forward(self, *inputs: Any, **_: Any) -> Any:
405474

406475
# We need the second param to be a dummy input to enable the
407476
# backward pass to be triggered for integer inputs.
408-
return OffloadFunction.apply(*inputs, torch.tensor([], requires_grad=True), self)
477+
if self._checkpoint_activation:
478+
return OffloadFunction.apply(*inputs, torch.tensor([], requires_grad=True), self)
479+
480+
self._activations = []
481+
for index in range(-1, len(self.model_slices)):
482+
if index >= 0:
483+
# TODO(anj-s): This might be a redundant call since we have the previous
484+
# activation on the device already.
485+
self._activations[index] = tuple([a.cuda() for a in list(self._activations[index])])
486+
inputs = self._activations[index]
487+
inputs = self.model_slices[index](*inputs)
488+
# Call the custom autograd hooks (discard/load slices FW and BW)
489+
inputs = ShardSyncLayer.apply(inputs, index, self.model_slices, self)
490+
self._activations.append(inputs)
491+
if index >= 0:
492+
self._activations[index] = tuple([a.cpu() for a in list(self._activations[index])])
493+
494+
result = self._activations[-1]
495+
result = tuple([r.cuda() for r in result])
496+
return result[0] if len(result) == 1 else result

tests/experimental/nn/test_offload.py

+29-11
Original file line numberDiff line numberDiff line change
@@ -32,20 +32,38 @@ def test_single_run():
3232
device, offload_device = _init()
3333
model = _get_model()
3434

35-
offload_model = OffloadModel(model=model, device=device, offload_device=offload_device, num_slices=2,)
36-
offload_optimizer = torch.optim.SGD(offload_model.parameters(), lr=0.001)
35+
peak_mem = {}
36+
for checkpoint_activation in [True, False]:
37+
offload_model = OffloadModel(
38+
model=model,
39+
device=device,
40+
offload_device=offload_device,
41+
num_slices=2,
42+
checkpoint_activation=checkpoint_activation,
43+
)
44+
offload_optimizer = torch.optim.SGD(offload_model.parameters(), lr=0.001)
45+
46+
input = torch.ones(1000, 2).to(device)
47+
labels = torch.ones(1000, 2).to(device)
48+
offload_model.train()
49+
pred = offload_model(input)
50+
loss_fn = torch.nn.MSELoss(reduction="sum")
51+
loss = loss_fn(pred, labels)
52+
loss.backward()
53+
offload_optimizer.step()
54+
key = "ca_" + str(checkpoint_activation)
55+
peak_mem[key] = torch.cuda.memory_stats(0)["allocated_bytes.all.peak"]
56+
print(
57+
"Peak allocated bytes on cuda:0 for checkpoint_activation "
58+
+ str(checkpoint_activation)
59+
+ ": {:2f}".format(peak_mem[key])
60+
)
3761

38-
input = torch.ones(2, 2).to(device)
39-
labels = torch.ones(2, 2).to(device)
40-
offload_model.train()
41-
pred = offload_model(input)
42-
loss_fn = torch.nn.MSELoss(reduction="sum")
43-
loss = loss_fn(pred, labels)
44-
loss.backward()
45-
offload_optimizer.step()
62+
# TODO(anj-s): We need a better requirement since this fails on CircleCI right now.
63+
assert peak_mem["ca_True"] <= peak_mem["ca_False"]
4664

4765

48-
def _get_model(num_inputs=2, num_hidden=2, num_layers=1, num_outputs=2):
66+
def _get_model(num_inputs=2, num_hidden=20, num_layers=10, num_outputs=2):
4967
model = torch.nn.Sequential(
5068
torch.nn.Linear(num_inputs, num_hidden),
5169
*([torch.nn.Linear(num_hidden, num_hidden) for _ in range(num_layers)]),

0 commit comments

Comments
 (0)