diff --git a/thunder/core/transforms.py b/thunder/core/transforms.py index 042468aaa7..3f9356aa89 100644 --- a/thunder/core/transforms.py +++ b/thunder/core/transforms.py @@ -3568,8 +3568,13 @@ def unpacking_fn(saved_for_backward, cotangents): cotangents = backward_trace.args[1] saved_tensors, saved_other = _split_saved_for_backward_into_tensors_and_other(saved_for_backward) + + # When thunder.executors.torch_autograd.ThunderFunction.backward calls backward_fn, it copies + # collections into mutable ones, so that the tensors will be deallocated when deleted. + # See ThunderFunction.backward's notes for details + saved_tensors = list(saved_tensors) unpacking_trace = construct_trace(rename_proxies=False, use_dce=False)( - unpacking_fn, (saved_tensors, saved_other), cotangents + unpacking_fn, [saved_tensors, saved_other], cotangents ) assert unpacking_trace.bound_symbols[-1].sym.id == prims.PrimIDs.RETURN diff --git a/thunder/examine/memory_caculation.py b/thunder/examine/memory_caculation.py index a5b64dbe6d..14c2123ed2 100644 --- a/thunder/examine/memory_caculation.py +++ b/thunder/examine/memory_caculation.py @@ -1,12 +1,16 @@ -from collections.abc import Callable +from collections.abc import Callable, MutableSequence, MutableMapping, MutableSet +from functools import partial from thunder.core.prims import PrimIDs -from thunder.core.proxies import FutureTensorProxy, Proxy, TensorProxy +from thunder.core.proxies import CollectionProxy, FutureTensorProxy, Proxy, TensorProxy from thunder.core.symbol import BoundSymbol, Symbol from thunder.core.trace import TraceCtx from thunder.core.utils import check_type, ProxyDict +from thunder.core.pytree import tree_iter +from thunder.executors import pythonex -memory_calculate_skip_list = (PrimIDs.RETURN, "clear_collection") +# Arguments are considered independently, so we ignore all unpacking operations on them +memory_calculate_skip_list = (PrimIDs.RETURN, PrimIDs.UNPACK_TRIVIAL, PrimIDs.UNPACK_SEQUENCE) # List of operators that considered no memory changes occurred in Thunder. # NOTE: for the operators have different input and output shape, such as expand, @@ -20,6 +24,7 @@ "torch_prims_reshape_impl", # torchex implementation of prims.reshape. "permute", "contiguous", + "split", "torch_wait_prim_impl", ) @@ -77,8 +82,6 @@ def default_alloc_memory( Returns: int: The size of memory change caused by the input bsym """ - # Skip CollectionProxy(output of input unpacking operators, such as unpack_sequence) - # and other negligible scalar types tensor_outs = [x for x in bsym.flat_proxy_outs if isinstance(x, (TensorProxy, FutureTensorProxy))] result = sum(t.numel * t.dtype.bytes for t in tensor_outs) for x in tensor_outs: @@ -96,10 +99,10 @@ def track_alias_op_memory( bsym: BoundSymbol, tensor_to_memory_data: ProxyDict, name_to_alloc_memory: dict[str, int] ) -> int: inp = bsym.flat_proxy_args[0] - out = bsym.flat_proxy_outs[0] assert inp in tensor_to_memory_data - tensor_to_memory_data[inp].incr_ref() - tensor_to_memory_data[out] = tensor_to_memory_data[inp] + for out in bsym.flat_proxy_outs: + tensor_to_memory_data[inp].incr_ref() + tensor_to_memory_data[out] = tensor_to_memory_data[inp] return 0 @@ -118,6 +121,32 @@ def del_op_memory(bsym: BoundSymbol, tensor_to_memory_data: ProxyDict, name_to_a return memory_size +@register_memory_calculate_function(pythonex.clear_mutable_collection.id) +def clear_mutable_collection_argument_memory( + bsym: BoundSymbol, tensor_to_memory_data: ProxyDict, name_to_alloc_memory: dict[str, int], is_argument: bool +) -> int: + # Clearing the collection forces the interpreter to release references to its elements, + # even if the collection was an argument. + # So we cancel the n += 1 (see get_alloc_memory) for tensors contained in such a collection + if not is_argument: + return 0 + + collection_proxy = bsym.flat_proxy_args[0] + if not isinstance(collection_proxy.collection(), (MutableSequence, MutableMapping, MutableSet)): + return 0 + + memory_size = 0 + for a in tree_iter(collection_proxy.collection()): + if not isinstance(a, (TensorProxy, FutureTensorProxy)): + continue + cnt_a = tensor_to_memory_data[a].decr_ref() + if cnt_a == 0: + size_a = tensor_to_memory_data[a].get_memory_size() + memory_size -= size_a + name_to_alloc_memory[f"clear_mutable_collection {a.name}"] = -size_a + return memory_size + + def get_alloc_memory(trc: TraceCtx) -> tuple[int, dict[str, int]]: """ Calculate the memory usage based on the executable trace. @@ -140,13 +169,26 @@ def get_alloc_memory(trc: TraceCtx) -> tuple[int, dict[str, int]]: max_allocated = 0 allocated = 0 + arg_names = {arginfo[0] for arginfo in trc.siginfo().args} | trc.siginfo().kwargs.keys() tensor_to_memory_data = ProxyDict() + for arg in tree_iter((trc.args, trc.kwargs)): + # In addition to the arguments themselves (n=1), the interpreter holds references to the arguments, + # accounting for n += 1 + tensor_to_memory_data[arg] = MemoryData(n=2, proxy=arg) + mem_size = arg.numel * arg.dtype.bytes + allocated += mem_size + name_to_alloc_memory[f"argument {arg.name}"] = mem_size + for bsym in trc.bound_symbols: if bsym.sym.id in memory_calculate_skip_list: continue + impl = memory_calculate_impls.get(bsym.sym.id, default_alloc_memory) - allocated += impl(bsym, tensor_to_memory_data, name_to_alloc_memory) + if impl is clear_mutable_collection_argument_memory: + is_argument = bsym.flat_proxy_args[0].name in arg_names + impl = partial(impl, is_argument=is_argument) + allocated += impl(bsym, tensor_to_memory_data, name_to_alloc_memory) max_allocated = max(max_allocated, allocated) return max_allocated, name_to_alloc_memory diff --git a/thunder/tests/test_examine_memory.py b/thunder/tests/test_examine_memory.py index 9cb82f2a36..cd02b8ab5a 100644 --- a/thunder/tests/test_examine_memory.py +++ b/thunder/tests/test_examine_memory.py @@ -1,49 +1,66 @@ -from contextlib import contextmanager -from functools import partial - import pytest import torch import thunder -import thunder.core.devices as devices -import thunder.core.dtypes as dtypes +from thunder.core.pytree import tree_map import thunder.torch as ltorch -from thunder.core.proxies import TensorProxy from thunder.examine.memory_caculation import get_alloc_memory -from thunder.tests.framework import instantiate, nvFuserTestExecutor, TorchTestExecutor +from thunder.tests.framework import requiresCUDA, TorchExecutor from thunder.tests.make_tensor import make_tensor -@contextmanager -def runtime_allocated_memory(dev): - torch.cuda.reset_peak_memory_stats(dev) - try: - yield - finally: - memory_states = torch.cuda.memory_stats(dev) - alloc = memory_states["allocated_bytes.all.peak"] - req = memory_states["requested_bytes.all.peak"] - print(f"**peak allocated/required memory: {alloc}, {req}") +def measure_memory_usage(trace): + torch.cuda.reset_peak_memory_stats() + before = torch.cuda.memory_stats().get("requested_bytes.all.current", 0) + + def make_tensor_like_torch_dtype(p): + return make_tensor(p.shape, dtype=ltorch.to_torch_dtype(p.dtype), device=p.device) + + args, kwargs = tree_map(make_tensor_like_torch_dtype, (trace.args, trace.kwargs)) + output = trace.python_callable()(*args, **kwargs) + + after = torch.cuda.memory_stats()["requested_bytes.all.current"] + peak = torch.cuda.memory_stats()["requested_bytes.all.peak"] + + return {"peak": peak - before, "current": after - before, "output": output} + + +def measure_fw_and_bw_memory_usage(fw_trace, bw_trace): + fw_results = measure_memory_usage(fw_trace) + bw_results = measure_memory_usage(bw_trace) + return {f"fw_{k}": v for k, v in fw_results.items()} | {f"bw_{k}": v for k, v in bw_results.items()} -def get_return_memory(bsym): - assert bsym.sym is thunder.core.prims.python_return - return_tensors_name = set() - res = 0 - for x in bsym.flat_proxy_args: - if isinstance(x, TensorProxy) and x.name not in return_tensors_name: - res += x.numel * x.dtype.bytes - return_tensors_name.add(x.name) - return res +# TODO: Test for nvFuserExecutor +# nvFuserExecutor is skipped for now, because nvFuser and eager execution treat allocation and broadcast differently. +# In the future, we need to update get_alloc_memory to support nvFuser and update tests accordingly. +@requiresCUDA +def test_view_ops(): + def test(func, *shapes): + inputs = [make_tensor(shape, dtype=torch.float32, device="cuda", requires_grad=True) for shape in shapes] + cfunc = TorchExecutor.make_callable(func, disable_preprocessing=False) + cfunc(*inputs) -@instantiate(dtypes=(thunder.float32,), devicetypes=(devices.DeviceType.CUDA,)) -def test_view_ops(executor, device: str, dtype: dtypes.dtype): - torch_dtype = ltorch.to_torch_dtype(dtype) - a = make_tensor((4,), device=device, dtype=torch_dtype, requires_grad=True) - b = make_tensor((2, 2), device=device, dtype=torch_dtype, requires_grad=True) + fw_trace = thunder.last_traces(cfunc)[-1] + bw_trace = thunder.last_backward_traces(cfunc)[-1] + max_mem_fw = get_alloc_memory(fw_trace) + max_mem_bw = get_alloc_memory(bw_trace) + + result = measure_fw_and_bw_memory_usage(fw_trace, bw_trace) + assert max_mem_fw[0] == result["fw_peak"] + assert sum(max_mem_fw[1].values()) == result["fw_current"] + assert max_mem_bw[0] == result["bw_peak"] + assert sum(max_mem_bw[1].values()) == result["bw_current"] + + def foo(a, b): # [4] [4] + a_1 = torch.unsqueeze(a, 0) # [1,4] + b_2 = torch.unsqueeze(b, 0) # [1,4] + return (a_1 + b_2,) + + test(foo, (4,), (4,)) def bar(a, b): # [4] [2,2] a_1 = torch.unsqueeze(a, 0) # [1,4] @@ -59,29 +76,7 @@ def bar(a, b): # [4] [2,2] result2 = b_4 + a_3 return result1, result2 - # Bookending changes memory footprint and whether bookending is enabled - # depends on version. For several tests in this file that are sensitive to - # bookending, I forced bookending to be off, the default for the latest - # veresion of nvFuser. I could also test nv_enable_bookend=True, the old - # default, but I would have to use different "golden" values, making tests - # complicated. - cbar = executor.make_callable(bar, disable_preprocessing=False, nv_enable_bookend=False) - with runtime_allocated_memory(device): - cbar(a, b) - - fw_traces = thunder.last_traces(cbar) - fwd_extrace = fw_traces[-1] - max_mem_fwd = get_alloc_memory(fwd_extrace) - assert max_mem_fwd[0] == 144 - assert sum(max_mem_fwd[1].values()) == get_return_memory(fwd_extrace.bound_symbols[-1]) # 144 - bw_traces = thunder.last_backward_traces(cbar) - bw_extrace = bw_traces[-1] - max_mem_bw = get_alloc_memory(bw_extrace) - # nvFuser should be able to avoid the allocation of result2 and produce it as alias to result1, reducing the allocated memory to 128. However, - # due to a limitation of `get_alloc_memory` (https://github.com/Lightning-AI/lightning-thunder/blob/6dfe7e939a19d1ef5ab259de8709a79f0104fa42/thunder/examine/memory_caculation.py#L123-L125), this saved memory is not taken into consideration. - assert max_mem_bw[0] == 144 - - assert sum(max_mem_bw[1].values()) == get_return_memory(bw_extrace.bound_symbols[-1]) # 32 + test(bar, (4,), (2, 2)) def bar1(a, b, c): # [4], [1,4,4], [4,1,4] a_1 = torch.unsqueeze(a, 0) # [1,4] @@ -90,83 +85,34 @@ def bar1(a, b, c): # [4], [1,4,4], [4,1,4] a_4 = a_2.expand(4, 1, 4) return b + a_3, c + a_4 - a = make_tensor((4,), device=device, dtype=torch_dtype) - b = make_tensor((1, 4, 4), device=device, dtype=torch_dtype) - c = make_tensor((4, 1, 4), device=device, dtype=torch_dtype) - cbar = executor.make_callable(bar1, disable_preprocessing=False, nv_enable_bookend=False) - with runtime_allocated_memory(device): - cbar(a, b, c) - - traces = thunder.last_traces(cbar) - extrace = traces[-1] - alloc_mem = get_alloc_memory(extrace) - if isinstance(executor, nvFuserTestExecutor): - assert alloc_mem[0] == 272 - assert sum(alloc_mem[1].values()) == get_return_memory(extrace.bound_symbols[-1]) # 128 - if isinstance(executor, TorchTestExecutor): - assert alloc_mem[0] == 208 - assert sum(alloc_mem[1].values()) == get_return_memory(extrace.bound_symbols[-1]) # 128 + test(bar1, (4,), (1, 4, 4), (4, 1, 4)) def bar2(a, b): # [5,2], [2,2] a_1, a_2, a_3 = torch.split(a, 2) c = a_1 + b d = a + a - return c, a_2, d + return c, d, a_2, a_3 # We have to use all the outputs of torch.split due to #1043 - a = make_tensor((5, 2), device=device, dtype=torch_dtype) - b = make_tensor((2, 2), device=device, dtype=torch_dtype) - cbar = executor.make_callable(bar2, disable_preprocessing=False, nv_enable_bookend=False) + test(bar2, (5, 2), (2, 2)) - with runtime_allocated_memory(device): - cbar(a, b) - traces = thunder.last_traces(cbar) - extrace = traces[-1] - alloc_mem = get_alloc_memory(extrace) - if isinstance(executor, nvFuserTestExecutor): - assert alloc_mem[0] == 128 - assert sum(alloc_mem[1].values()) == get_return_memory(extrace.bound_symbols[-1]) # 72 - if isinstance(executor, TorchTestExecutor): - assert alloc_mem[0] == 112 - assert sum(alloc_mem[1].values()) == get_return_memory(extrace.bound_symbols[-1]) # 72 - - -@instantiate(dtypes=(thunder.float32,), devicetypes=(devices.DeviceType.CUDA,)) -def test_nanogpt_block(executor, device, dtype): +@requiresCUDA +def test_nanogpt_block(): import thunder.tests.nanogpt_model as nanogpt_model - tdtype = ltorch.to_torch_dtype(dtype) - make = partial(make_tensor, dtype=tdtype, device=device) - config = nanogpt_model.GPTConfig(dropout=0) - block = nanogpt_model.Block(config).to(device=device, dtype=tdtype) - cblock = executor.make_callable(block, nv_enable_bookend=False) - - with runtime_allocated_memory(device): - inp = make((2, config.block_size, config.n_embd)) - result = cblock(inp) - with runtime_allocated_memory(device): - result.backward(torch.ones_like(result)) - fw_extrace = thunder.last_traces(cblock)[-1] - bw_extrace = thunder.last_backward_traces(cblock)[-1] - fw_alloc_mem = get_alloc_memory(fw_extrace) - bw_alloc_mem = get_alloc_memory(bw_extrace) - - if isinstance(executor, nvFuserTestExecutor): - assert fw_alloc_mem[0] == 267426816 - expected_return_calculated_mem = get_return_memory(fw_extrace.bound_symbols[-1]) - assert expected_return_calculated_mem == sum(fw_alloc_mem[1].values()) - - assert bw_alloc_mem[0] == 412112896 - assert sum(bw_alloc_mem[1].values()) == get_return_memory(bw_extrace.bound_symbols[-1]) - if isinstance(executor, TorchTestExecutor): - assert fw_alloc_mem[0] == 362863616 - # Expect the memory to -t38+t37-t65-t67. - # t67 is the expand result of ln_2_weight, and they are both return values in trace - # but for calculation we assume they share memory, so expect to subtract the size of t67. - expected_return_calculated_mem = ( - get_return_memory(fw_extrace.bound_symbols[-1]) - 23 * 1024 * 1024 - 4 * 2 * 1024 * 768 * 2 - ) - assert expected_return_calculated_mem == sum(fw_alloc_mem[1].values()) - assert bw_alloc_mem[0] == 412109824 - assert sum(bw_alloc_mem[1].values()) == get_return_memory(bw_extrace.bound_symbols[-1]) + block = nanogpt_model.Block(config).to(dtype=torch.float32, device="cuda") + cblock = TorchExecutor.make_callable(block) + inp = make_tensor((2, config.block_size, config.n_embd), dtype=torch.float32, device="cuda", requires_grad=True) + cblock(inp) + + fw_trace = thunder.last_traces(cblock)[-1] + bw_trace = thunder.last_backward_traces(cblock)[-1] + max_mem_fw = get_alloc_memory(fw_trace) + max_mem_bw = get_alloc_memory(bw_trace) + + result = measure_fw_and_bw_memory_usage(fw_trace, bw_trace) + assert max_mem_fw[0] == result["fw_peak"] + assert sum(max_mem_fw[1].values()) == result["fw_current"] + assert max_mem_bw[0] == result["bw_peak"] + assert sum(max_mem_bw[1].values()) == result["bw_current"]