Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unrolling tensor subclasses in fwd/bwd split #1489

3 changes: 0 additions & 3 deletions thunder/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,6 @@
from thunder.core.interpreter import print_interpreter_log, print_to_log
from thunder.core.jit_ext import thunder_general_jit
from thunder.executors.torch_autograd import split_forward_backward, ThunderFunction
from thunder.transforms.tensor_subclasses import flatten_tensor_subclasses

# NOTE This import is intentionally pytorch so that it thunder.torch doesn't import this
import torch as pytorch
Expand Down Expand Up @@ -587,8 +586,6 @@ def get_computation_and_inputs(*args, **kwargs):
if len(tensor_args_consumed_by_inplace_grouped_by_numel) > 1:
vanilla_tensor_args = set(tensor_indices)

computation_trc = flatten_tensor_subclasses(computation_trc)

if epilogue_trc is not None:
epilogue_traces = [epilogue_trc]
else:
Expand Down
2 changes: 1 addition & 1 deletion thunder/core/jit_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -637,7 +637,7 @@ def _convert_pytorchfunc_to_thundertrace(
trace = TraceCtx()
trace.bound_symbols.extend(active_jit_ctx.computation_trace.pop_scope())
func_result = unwrap(wrapped_func_result)
if shallow_copy_output:
if shallow_copy_output and not trace.bound_symbols:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

copy from #1485

from thunder.core.baseutils import sequencify

out_to_shallow_copy: dict[Variable, TensorProxy] = {}
Expand Down
2 changes: 2 additions & 0 deletions thunder/core/pytree.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from functools import partial
from types import FunctionType
import dataclasses
from enum import Enum

import optree
import torch
Expand Down Expand Up @@ -64,6 +65,7 @@ def tree_flatten(args, namespace=OPTREE_NAMESPACE):
and not is_likely_from_collections_namedtuple(args)
and not dataclasses.is_dataclass(args)
and not type(args).__module__.startswith("torch.return_types")
and not issubclass(type(args), Enum)
):
raise TypeError(f"tree_flatten of type {type(args)} is not supported.")
return optree.tree_flatten(args, none_is_leaf=True, namespace=namespace)
Expand Down
4 changes: 4 additions & 0 deletions thunder/executors/torch_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ def split_forward_backward(computation_trc: TraceCtx, compile_data, compile_stat
from thunder.distributed.transforms import FSDPCommBucketing
from thunder.distributed.utils import sort_data_parallel_syncs, sort_waits, sort_communication_ops
from thunder.executors.passes import del_last_used, transform_for_execution
from thunder.transforms.tensor_subclasses import flatten_tensor_subclasses, DesugarTensorSubclass

utils.check(compile_data is not None, lambda: "`compile_data` is required")
# NOTE: This function is rather slow, so it's intended to be used
Expand All @@ -154,6 +155,7 @@ def split_forward_backward(computation_trc: TraceCtx, compile_data, compile_stat
# not any other container type. So we need to flatten the outputs of
# the forward trace and inputs of the backward trace.
fw_trace, bw_trace = forward_and_backward_from_trace(primal_trace, torch_autograd=True)
fw_trace, fw_tensor_subclass_desugar = flatten_tensor_subclasses(fw_trace)

fw_traces = [fw_trace]
bw_traces = [bw_trace]
Expand Down Expand Up @@ -245,6 +247,8 @@ def split_forward_backward(computation_trc: TraceCtx, compile_data, compile_stat
if getattr(compile_data.fn, "use_fsdp", False):
bw_trace = _fsdp_comm_bucketing.apply_bucketing_to_backward_trace(bw_trace)

bw_trace, bw_tensor_subclass_desugar = flatten_tensor_subclasses(bw_trace)

# Now we can run the optimization passes on the backward trace
# TODO Restore request for no rematerialization
bw_extrace = transform_for_execution(
Expand Down
3 changes: 0 additions & 3 deletions thunder/executors/torchex.py
Original file line number Diff line number Diff line change
Expand Up @@ -1408,9 +1408,6 @@ def _scaled_mm_transform(
if b.stride()[0] != 1 and b.stride()[1] > 1:
b = b.t().contiguous().t()

print(
f"{type(a)=}, {type(b)=}, {type(scale_a)=}, {type(scale_b)=}, {type(bias)=}, {type(scale_result)=}, {type(result_dtype)=}, {type(use_fast_accum)=}"
)
return _scaled_mm(a, b, scale_a, scale_b, bias, scale_result, result_dtype, use_fast_accum)


Expand Down
4 changes: 4 additions & 0 deletions thunder/tests/test_tensor_subclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,3 +269,7 @@ def test_torchao_float8_linear(executor, device, _):

jitted = executor.make_callable(fp8_model)
actual = jitted(x)

print(expected)
print(actual)
# torch.testing.assert_close(actual, expected)
15 changes: 14 additions & 1 deletion thunder/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1259,7 +1259,9 @@ def t(a: TensorLike, /) -> TensorLike:
lambda: f"t() expects a tensor with <= 2 dimensions, but self is {a.ndim}D",
RuntimeError,
)
return transpose(a, 0, 1) if a.ndim == 2 else a
if a.ndim == 2:
return transpose(a, 0, 1)
return a


@run_once
Expand Down Expand Up @@ -1312,6 +1314,17 @@ def transpose(a: TensorLike, /, dim0: int, dim1: int) -> TensorLike:
return clang.transpose(a, permutation)


def _transpose_grad(a: TensorLike, /, dim0: int, dim1: int) -> TensorLike:
fwd = transpose(a, dim0, dim1)
g = get_grad(fwd)
a_grad = transpose(g, dim0, dim1)
put_grad(a, a_grad)
return fwd


register_grad(transpose, _transpose_grad)
Comment on lines +1317 to +1325
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rel: #1487

needed to avoid prims.permute



@torchsymbol(torch.unbind, is_method=True)
def unbind(a: TensorLike, /, dim: int = 0) -> tuple[TensorLike, ...]:
utils.check(
Expand Down
46 changes: 32 additions & 14 deletions thunder/transforms/tensor_subclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,10 @@
from torch.fx import GraphModule
from torch._ops import OpOverload
from thunder.core.symbol import Symbol, BoundSymbol
from torch._C import _TensorMeta


__all__ = [
"DesugarTensorSubclass",
"flatten_tensor_subclasses",
]

Expand Down Expand Up @@ -249,17 +249,18 @@ def translate_fx_graph_into_bsym(
import thunder.torch as ltorch

unwrapped_bsym_args: dict[int, ProxyInterface] = {}
list_of_unflatten_bsym: list[BoundSymbol] = []
list_of_flattening_bsyms: list[BoundSymbol] = []
for a in bsym.flat_args:
if isinstance(a, SubclassTensorProxy):
if variableify(a) in self.subclass_proxy_to_flatten:
self.computation_trace.push_scope([])
with tracectx(self.computation_trace):
prims.flatten_tensor_subclass(a)
unflatten_bsym = self.computation_trace.pop_scope()[0]
list_of_unflatten_bsym.append(unflatten_bsym)
flattening_bsym = self.computation_trace.pop_scope()[0]
list_of_flattening_bsyms.append(flattening_bsym)
tensor_attr_names = self._get_tensor_attr_names(a)
tensors = a._tensors

non_tensor_attr_names = self._get_non_tensor_attr_names(a)
non_tensors = a._non_tensors
metadata = dict(zip(non_tensor_attr_names, non_tensors))
Expand Down Expand Up @@ -307,8 +308,8 @@ def translate_fx_graph_into_bsym(
ltorch_ops_for_node_of_ops.append(getattr(ltorch, node.target._opname))

bsyms: list[BoundSymbol] = []
if list_of_unflatten_bsym:
bsyms.extend(list_of_unflatten_bsym)
if list_of_flattening_bsyms:
bsyms.extend(list_of_flattening_bsyms)
fxnode_output_name_to_tensor_proxy: dict[str, OpOverload] = {}
for node, ltorch_op in zip(list_of_function_call_node, ltorch_ops_for_node_of_ops):
args: list[Node] = node.args
Expand Down Expand Up @@ -379,10 +380,22 @@ def translate_fx_graph_into_bsym(
f"{len(new_tensor_proxies)=} != {len(orig_output._tensors)=}"
),
)
if [variableify(t) for t in orig_output._tensors] != [variableify(t) for t in new_tensor_proxies]:
orig_output._tensors = new_tensor_proxies
for name, tensor in zip(orig_output._tensor_attr_names, new_tensor_proxies):
setattr(orig_output, name, tensor)
with tracectx(self.computation_trace):
new_subclass = orig_output.replace()
new_subclass._tensors = new_tensor_proxies
for name, value in zip(new_subclass._tensor_attr_names, new_tensor_proxies):
setattr(new_subclass, name, value)
bsyms.append(
prims.unflatten_tensor_subclass.bind(
new_subclass._subclass_type,
dict(zip(new_subclass._tensor_attr_names, new_tensor_proxies)),
dict(zip(new_subclass._non_tensor_attr_names, new_subclass._non_tensors)),
output=new_subclass,
)
)

self.swap_map[variableify(orig_output)] = new_subclass
self.subclass_proxy_to_flatten.add(variableify(new_subclass))

else:
non_none_args = [n for n in node_of_output.args[0] if n is not None]
Expand Down Expand Up @@ -502,7 +515,12 @@ def f_with_wrap_and_unwrap(*desugared_args) -> tuple[OutputWrapperForFxTracing,

def __call__(self, bsym: BoundSymbol) -> list[BoundSymbol]:
updated_bsym: BoundSymbol = bsym.from_bsym_swap_proxies(self.swap_map)
if updated_bsym.sym.id == prims.PrimIDs.RETURN:
if bsym.sym.id == prims.PrimIDs.RETURN:
new_swap_map = {}
for k, v in self.swap_map.items():
if isinstance(v, SubclassTensorProxy):
continue
new_swap_map[k] = v
if not self.subclass_proxy_to_flatten or True:
return [updated_bsym]

Expand Down Expand Up @@ -567,7 +585,7 @@ def __call__(self, bsym: BoundSymbol) -> list[BoundSymbol]:
return self.translate_fx_graph_into_bsym(bsym_with_modified_output, fx)


def flatten_tensor_subclasses(computation_trace: TraceCtx) -> TraceCtx:
def flatten_tensor_subclasses(computation_trace: TraceCtx) -> tuple[TraceCtx, DesugarTensorSubclass]:
"""Flatten tensor subclasses in ``computation_trace``.

Two things are happening inside of this function:
Expand Down Expand Up @@ -601,9 +619,9 @@ def flatten_tensor_subclasses(computation_trace: TraceCtx) -> TraceCtx:
updated_bsyms.extend(maybe_desugared_bsyms)

if not desugar_tensor_subclass.subclass_proxy_to_flatten:
return computation_trace
return computation_trace, None

computation_trace_with_subclass_tensor_proxy_output = from_trace(computation_trace)
computation_trace_with_subclass_tensor_proxy_output.bound_symbols.extend(updated_bsyms)
computation_trace_with_subclass_tensor_proxy_output.set_provenance(TraceProvenance("tensor subclasses desugared"))
return computation_trace_with_subclass_tensor_proxy_output
return computation_trace_with_subclass_tensor_proxy_output, desugar_tensor_subclass
Loading