Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unrolling tensor subclasses in fwd/bwd split #1489

Conversation

crcrpar
Copy link
Collaborator

@crcrpar crcrpar commented Nov 28, 2024

What does this PR do?

In #1415 and #1394, tensor subclasses and their __torch_dispatch__ are unrolled before forward-backward split.
It turned out that we want to postpone it in the split as the unrolling seems to be harmful to backward generation.

TODO

  • support no autograd cases

note: pytorch/ao#1339 is used

@@ -637,7 +637,7 @@ def _convert_pytorchfunc_to_thundertrace(
trace = TraceCtx()
trace.bound_symbols.extend(active_jit_ctx.computation_trace.pop_scope())
func_result = unwrap(wrapped_func_result)
if shallow_copy_output:
if shallow_copy_output and not trace.bound_symbols:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

copy from #1485

Comment on lines +1317 to +1325
def _transpose_grad(a: TensorLike, /, dim0: int, dim1: int) -> TensorLike:
fwd = transpose(a, dim0, dim1)
g = get_grad(fwd)
a_grad = transpose(g, dim0, dim1)
put_grad(a, a_grad)
return fwd


register_grad(transpose, _transpose_grad)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rel: #1487

needed to avoid prims.permute

@@ -269,3 +266,5 @@ def test_torchao_float8_linear(executor, device, _):

jitted = executor.make_callable(fp8_model)
actual = jitted(x)

torch.testing.assert_close(actual, expected)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@crcrpar crcrpar marked this pull request as ready for review November 28, 2024 11:33
@crcrpar crcrpar merged commit 06ee30e into crpa/subclass-torchao_float8tensor Nov 28, 2024
29 of 36 checks passed
@crcrpar crcrpar deleted the crpa/torchao-fp8tensor-flattening-in-fwdbwd-split branch November 28, 2024 12:12
crcrpar added a commit that referenced this pull request Nov 28, 2024
Signed-off-by: Masaki Kozuki <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant