Skip to content

Commit

Permalink
Take the interpreter holding references to the arguments into account (
Browse files Browse the repository at this point in the history
  • Loading branch information
shino16 authored Sep 4, 2024
1 parent 0c4f78b commit d05ebc6
Show file tree
Hide file tree
Showing 3 changed files with 126 additions and 133 deletions.
7 changes: 6 additions & 1 deletion thunder/core/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -3568,8 +3568,13 @@ def unpacking_fn(saved_for_backward, cotangents):

cotangents = backward_trace.args[1]
saved_tensors, saved_other = _split_saved_for_backward_into_tensors_and_other(saved_for_backward)

# When thunder.executors.torch_autograd.ThunderFunction.backward calls backward_fn, it copies
# collections into mutable ones, so that the tensors will be deallocated when deleted.
# See ThunderFunction.backward's notes for details
saved_tensors = list(saved_tensors)
unpacking_trace = construct_trace(rename_proxies=False, use_dce=False)(
unpacking_fn, (saved_tensors, saved_other), cotangents
unpacking_fn, [saved_tensors, saved_other], cotangents
)
assert unpacking_trace.bound_symbols[-1].sym.id == prims.PrimIDs.RETURN

Expand Down
60 changes: 51 additions & 9 deletions thunder/examine/memory_caculation.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
from collections.abc import Callable
from collections.abc import Callable, MutableSequence, MutableMapping, MutableSet
from functools import partial

from thunder.core.prims import PrimIDs
from thunder.core.proxies import FutureTensorProxy, Proxy, TensorProxy
from thunder.core.proxies import CollectionProxy, FutureTensorProxy, Proxy, TensorProxy
from thunder.core.symbol import BoundSymbol, Symbol
from thunder.core.trace import TraceCtx
from thunder.core.utils import check_type, ProxyDict
from thunder.core.pytree import tree_iter
from thunder.executors import pythonex

memory_calculate_skip_list = (PrimIDs.RETURN, "clear_collection")
# Arguments are considered independently, so we ignore all unpacking operations on them
memory_calculate_skip_list = (PrimIDs.RETURN, PrimIDs.UNPACK_TRIVIAL, PrimIDs.UNPACK_SEQUENCE)

# List of operators that considered no memory changes occurred in Thunder.
# NOTE: for the operators have different input and output shape, such as expand,
Expand All @@ -20,6 +24,7 @@
"torch_prims_reshape_impl", # torchex implementation of prims.reshape.
"permute",
"contiguous",
"split",
"torch_wait_prim_impl",
)

Expand Down Expand Up @@ -77,8 +82,6 @@ def default_alloc_memory(
Returns:
int: The size of memory change caused by the input bsym
"""
# Skip CollectionProxy(output of input unpacking operators, such as unpack_sequence)
# and other negligible scalar types
tensor_outs = [x for x in bsym.flat_proxy_outs if isinstance(x, (TensorProxy, FutureTensorProxy))]
result = sum(t.numel * t.dtype.bytes for t in tensor_outs)
for x in tensor_outs:
Expand All @@ -96,10 +99,10 @@ def track_alias_op_memory(
bsym: BoundSymbol, tensor_to_memory_data: ProxyDict, name_to_alloc_memory: dict[str, int]
) -> int:
inp = bsym.flat_proxy_args[0]
out = bsym.flat_proxy_outs[0]
assert inp in tensor_to_memory_data
tensor_to_memory_data[inp].incr_ref()
tensor_to_memory_data[out] = tensor_to_memory_data[inp]
for out in bsym.flat_proxy_outs:
tensor_to_memory_data[inp].incr_ref()
tensor_to_memory_data[out] = tensor_to_memory_data[inp]
return 0


Expand All @@ -118,6 +121,32 @@ def del_op_memory(bsym: BoundSymbol, tensor_to_memory_data: ProxyDict, name_to_a
return memory_size


@register_memory_calculate_function(pythonex.clear_mutable_collection.id)
def clear_mutable_collection_argument_memory(
bsym: BoundSymbol, tensor_to_memory_data: ProxyDict, name_to_alloc_memory: dict[str, int], is_argument: bool
) -> int:
# Clearing the collection forces the interpreter to release references to its elements,
# even if the collection was an argument.
# So we cancel the n += 1 (see get_alloc_memory) for tensors contained in such a collection
if not is_argument:
return 0

collection_proxy = bsym.flat_proxy_args[0]
if not isinstance(collection_proxy.collection(), (MutableSequence, MutableMapping, MutableSet)):
return 0

memory_size = 0
for a in tree_iter(collection_proxy.collection()):
if not isinstance(a, (TensorProxy, FutureTensorProxy)):
continue
cnt_a = tensor_to_memory_data[a].decr_ref()
if cnt_a == 0:
size_a = tensor_to_memory_data[a].get_memory_size()
memory_size -= size_a
name_to_alloc_memory[f"clear_mutable_collection {a.name}"] = -size_a
return memory_size


def get_alloc_memory(trc: TraceCtx) -> tuple[int, dict[str, int]]:
"""
Calculate the memory usage based on the executable trace.
Expand All @@ -140,13 +169,26 @@ def get_alloc_memory(trc: TraceCtx) -> tuple[int, dict[str, int]]:
max_allocated = 0
allocated = 0

arg_names = {arginfo[0] for arginfo in trc.siginfo().args} | trc.siginfo().kwargs.keys()
tensor_to_memory_data = ProxyDict()
for arg in tree_iter((trc.args, trc.kwargs)):
# In addition to the arguments themselves (n=1), the interpreter holds references to the arguments,
# accounting for n += 1
tensor_to_memory_data[arg] = MemoryData(n=2, proxy=arg)
mem_size = arg.numel * arg.dtype.bytes
allocated += mem_size
name_to_alloc_memory[f"argument {arg.name}"] = mem_size

for bsym in trc.bound_symbols:
if bsym.sym.id in memory_calculate_skip_list:
continue

impl = memory_calculate_impls.get(bsym.sym.id, default_alloc_memory)
allocated += impl(bsym, tensor_to_memory_data, name_to_alloc_memory)
if impl is clear_mutable_collection_argument_memory:
is_argument = bsym.flat_proxy_args[0].name in arg_names
impl = partial(impl, is_argument=is_argument)

allocated += impl(bsym, tensor_to_memory_data, name_to_alloc_memory)
max_allocated = max(max_allocated, allocated)

return max_allocated, name_to_alloc_memory
192 changes: 69 additions & 123 deletions thunder/tests/test_examine_memory.py
Original file line number Diff line number Diff line change
@@ -1,49 +1,66 @@
from contextlib import contextmanager
from functools import partial

import pytest

import torch

import thunder
import thunder.core.devices as devices
import thunder.core.dtypes as dtypes
from thunder.core.pytree import tree_map
import thunder.torch as ltorch
from thunder.core.proxies import TensorProxy
from thunder.examine.memory_caculation import get_alloc_memory

from thunder.tests.framework import instantiate, nvFuserTestExecutor, TorchTestExecutor
from thunder.tests.framework import requiresCUDA, TorchExecutor
from thunder.tests.make_tensor import make_tensor


@contextmanager
def runtime_allocated_memory(dev):
torch.cuda.reset_peak_memory_stats(dev)
try:
yield
finally:
memory_states = torch.cuda.memory_stats(dev)
alloc = memory_states["allocated_bytes.all.peak"]
req = memory_states["requested_bytes.all.peak"]
print(f"**peak allocated/required memory: {alloc}, {req}")
def measure_memory_usage(trace):
torch.cuda.reset_peak_memory_stats()
before = torch.cuda.memory_stats().get("requested_bytes.all.current", 0)

def make_tensor_like_torch_dtype(p):
return make_tensor(p.shape, dtype=ltorch.to_torch_dtype(p.dtype), device=p.device)

args, kwargs = tree_map(make_tensor_like_torch_dtype, (trace.args, trace.kwargs))
output = trace.python_callable()(*args, **kwargs)

after = torch.cuda.memory_stats()["requested_bytes.all.current"]
peak = torch.cuda.memory_stats()["requested_bytes.all.peak"]

return {"peak": peak - before, "current": after - before, "output": output}


def measure_fw_and_bw_memory_usage(fw_trace, bw_trace):
fw_results = measure_memory_usage(fw_trace)
bw_results = measure_memory_usage(bw_trace)

return {f"fw_{k}": v for k, v in fw_results.items()} | {f"bw_{k}": v for k, v in bw_results.items()}

def get_return_memory(bsym):
assert bsym.sym is thunder.core.prims.python_return
return_tensors_name = set()
res = 0
for x in bsym.flat_proxy_args:
if isinstance(x, TensorProxy) and x.name not in return_tensors_name:
res += x.numel * x.dtype.bytes
return_tensors_name.add(x.name)
return res

# TODO: Test for nvFuserExecutor
# nvFuserExecutor is skipped for now, because nvFuser and eager execution treat allocation and broadcast differently.
# In the future, we need to update get_alloc_memory to support nvFuser and update tests accordingly.
@requiresCUDA
def test_view_ops():
def test(func, *shapes):
inputs = [make_tensor(shape, dtype=torch.float32, device="cuda", requires_grad=True) for shape in shapes]
cfunc = TorchExecutor.make_callable(func, disable_preprocessing=False)
cfunc(*inputs)

@instantiate(dtypes=(thunder.float32,), devicetypes=(devices.DeviceType.CUDA,))
def test_view_ops(executor, device: str, dtype: dtypes.dtype):
torch_dtype = ltorch.to_torch_dtype(dtype)
a = make_tensor((4,), device=device, dtype=torch_dtype, requires_grad=True)
b = make_tensor((2, 2), device=device, dtype=torch_dtype, requires_grad=True)
fw_trace = thunder.last_traces(cfunc)[-1]
bw_trace = thunder.last_backward_traces(cfunc)[-1]
max_mem_fw = get_alloc_memory(fw_trace)
max_mem_bw = get_alloc_memory(bw_trace)

result = measure_fw_and_bw_memory_usage(fw_trace, bw_trace)
assert max_mem_fw[0] == result["fw_peak"]
assert sum(max_mem_fw[1].values()) == result["fw_current"]
assert max_mem_bw[0] == result["bw_peak"]
assert sum(max_mem_bw[1].values()) == result["bw_current"]

def foo(a, b): # [4] [4]
a_1 = torch.unsqueeze(a, 0) # [1,4]
b_2 = torch.unsqueeze(b, 0) # [1,4]
return (a_1 + b_2,)

test(foo, (4,), (4,))

def bar(a, b): # [4] [2,2]
a_1 = torch.unsqueeze(a, 0) # [1,4]
Expand All @@ -59,29 +76,7 @@ def bar(a, b): # [4] [2,2]
result2 = b_4 + a_3
return result1, result2

# Bookending changes memory footprint and whether bookending is enabled
# depends on version. For several tests in this file that are sensitive to
# bookending, I forced bookending to be off, the default for the latest
# veresion of nvFuser. I could also test nv_enable_bookend=True, the old
# default, but I would have to use different "golden" values, making tests
# complicated.
cbar = executor.make_callable(bar, disable_preprocessing=False, nv_enable_bookend=False)
with runtime_allocated_memory(device):
cbar(a, b)

fw_traces = thunder.last_traces(cbar)
fwd_extrace = fw_traces[-1]
max_mem_fwd = get_alloc_memory(fwd_extrace)
assert max_mem_fwd[0] == 144
assert sum(max_mem_fwd[1].values()) == get_return_memory(fwd_extrace.bound_symbols[-1]) # 144
bw_traces = thunder.last_backward_traces(cbar)
bw_extrace = bw_traces[-1]
max_mem_bw = get_alloc_memory(bw_extrace)
# nvFuser should be able to avoid the allocation of result2 and produce it as alias to result1, reducing the allocated memory to 128. However,
# due to a limitation of `get_alloc_memory` (https://github.com/Lightning-AI/lightning-thunder/blob/6dfe7e939a19d1ef5ab259de8709a79f0104fa42/thunder/examine/memory_caculation.py#L123-L125), this saved memory is not taken into consideration.
assert max_mem_bw[0] == 144

assert sum(max_mem_bw[1].values()) == get_return_memory(bw_extrace.bound_symbols[-1]) # 32
test(bar, (4,), (2, 2))

def bar1(a, b, c): # [4], [1,4,4], [4,1,4]
a_1 = torch.unsqueeze(a, 0) # [1,4]
Expand All @@ -90,83 +85,34 @@ def bar1(a, b, c): # [4], [1,4,4], [4,1,4]
a_4 = a_2.expand(4, 1, 4)
return b + a_3, c + a_4

a = make_tensor((4,), device=device, dtype=torch_dtype)
b = make_tensor((1, 4, 4), device=device, dtype=torch_dtype)
c = make_tensor((4, 1, 4), device=device, dtype=torch_dtype)
cbar = executor.make_callable(bar1, disable_preprocessing=False, nv_enable_bookend=False)
with runtime_allocated_memory(device):
cbar(a, b, c)

traces = thunder.last_traces(cbar)
extrace = traces[-1]
alloc_mem = get_alloc_memory(extrace)
if isinstance(executor, nvFuserTestExecutor):
assert alloc_mem[0] == 272
assert sum(alloc_mem[1].values()) == get_return_memory(extrace.bound_symbols[-1]) # 128
if isinstance(executor, TorchTestExecutor):
assert alloc_mem[0] == 208
assert sum(alloc_mem[1].values()) == get_return_memory(extrace.bound_symbols[-1]) # 128
test(bar1, (4,), (1, 4, 4), (4, 1, 4))

def bar2(a, b): # [5,2], [2,2]
a_1, a_2, a_3 = torch.split(a, 2)
c = a_1 + b
d = a + a
return c, a_2, d
return c, d, a_2, a_3 # We have to use all the outputs of torch.split due to #1043

a = make_tensor((5, 2), device=device, dtype=torch_dtype)
b = make_tensor((2, 2), device=device, dtype=torch_dtype)
cbar = executor.make_callable(bar2, disable_preprocessing=False, nv_enable_bookend=False)
test(bar2, (5, 2), (2, 2))

with runtime_allocated_memory(device):
cbar(a, b)

traces = thunder.last_traces(cbar)
extrace = traces[-1]
alloc_mem = get_alloc_memory(extrace)
if isinstance(executor, nvFuserTestExecutor):
assert alloc_mem[0] == 128
assert sum(alloc_mem[1].values()) == get_return_memory(extrace.bound_symbols[-1]) # 72
if isinstance(executor, TorchTestExecutor):
assert alloc_mem[0] == 112
assert sum(alloc_mem[1].values()) == get_return_memory(extrace.bound_symbols[-1]) # 72


@instantiate(dtypes=(thunder.float32,), devicetypes=(devices.DeviceType.CUDA,))
def test_nanogpt_block(executor, device, dtype):
@requiresCUDA
def test_nanogpt_block():
import thunder.tests.nanogpt_model as nanogpt_model

tdtype = ltorch.to_torch_dtype(dtype)
make = partial(make_tensor, dtype=tdtype, device=device)

config = nanogpt_model.GPTConfig(dropout=0)
block = nanogpt_model.Block(config).to(device=device, dtype=tdtype)
cblock = executor.make_callable(block, nv_enable_bookend=False)

with runtime_allocated_memory(device):
inp = make((2, config.block_size, config.n_embd))
result = cblock(inp)
with runtime_allocated_memory(device):
result.backward(torch.ones_like(result))
fw_extrace = thunder.last_traces(cblock)[-1]
bw_extrace = thunder.last_backward_traces(cblock)[-1]
fw_alloc_mem = get_alloc_memory(fw_extrace)
bw_alloc_mem = get_alloc_memory(bw_extrace)

if isinstance(executor, nvFuserTestExecutor):
assert fw_alloc_mem[0] == 267426816
expected_return_calculated_mem = get_return_memory(fw_extrace.bound_symbols[-1])
assert expected_return_calculated_mem == sum(fw_alloc_mem[1].values())

assert bw_alloc_mem[0] == 412112896
assert sum(bw_alloc_mem[1].values()) == get_return_memory(bw_extrace.bound_symbols[-1])
if isinstance(executor, TorchTestExecutor):
assert fw_alloc_mem[0] == 362863616
# Expect the memory to -t38+t37-t65-t67.
# t67 is the expand result of ln_2_weight, and they are both return values in trace
# but for calculation we assume they share memory, so expect to subtract the size of t67.
expected_return_calculated_mem = (
get_return_memory(fw_extrace.bound_symbols[-1]) - 23 * 1024 * 1024 - 4 * 2 * 1024 * 768 * 2
)
assert expected_return_calculated_mem == sum(fw_alloc_mem[1].values())
assert bw_alloc_mem[0] == 412109824
assert sum(bw_alloc_mem[1].values()) == get_return_memory(bw_extrace.bound_symbols[-1])
block = nanogpt_model.Block(config).to(dtype=torch.float32, device="cuda")
cblock = TorchExecutor.make_callable(block)
inp = make_tensor((2, config.block_size, config.n_embd), dtype=torch.float32, device="cuda", requires_grad=True)
cblock(inp)

fw_trace = thunder.last_traces(cblock)[-1]
bw_trace = thunder.last_backward_traces(cblock)[-1]
max_mem_fw = get_alloc_memory(fw_trace)
max_mem_bw = get_alloc_memory(bw_trace)

result = measure_fw_and_bw_memory_usage(fw_trace, bw_trace)
assert max_mem_fw[0] == result["fw_peak"]
assert sum(max_mem_fw[1].values()) == result["fw_current"]
assert max_mem_bw[0] == result["bw_peak"]
assert sum(max_mem_bw[1].values()) == result["bw_current"]

0 comments on commit d05ebc6

Please sign in to comment.