Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[torchao float8tensor] #1415

Draft
wants to merge 43 commits into
base: crpa/subclass-tensor-ops
Choose a base branch
from

Conversation

crcrpar
Copy link
Collaborator

@crcrpar crcrpar commented Nov 8, 2024

What does this PR do?

Improve the tensor subclass support of #1394 for TorchAo float8.

note: pytorch/ao#1339 is needed

my environment

  • torch: 2.6.0a0+git62eea62
  • nvfuser: 0.2.23+gitbb05859
  • torchao: 0.7.0+gitb2e42ff6
  • CUDA device: RTX 6000 Ada Generation
  • Driver Version: 560.35.03
  • CUDA Version: 12.6

@crcrpar

This comment was marked as outdated.

@crcrpar

This comment was marked as outdated.

@crcrpar

This comment was marked as outdated.

@crcrpar crcrpar force-pushed the crpa/subclass-torchao_float8tensor branch 2 times, most recently from 896b631 to 316327f Compare November 24, 2024 16:13
@t-vi
Copy link
Collaborator

t-vi commented Nov 25, 2024

@crcrpar if you merge main, the pt nightly distributed ci tests should be fixed.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change should be in #1394

@@ -637,7 +637,7 @@ def _convert_pytorchfunc_to_thundertrace(
trace = TraceCtx()
trace.bound_symbols.extend(active_jit_ctx.computation_trace.pop_scope())
func_result = unwrap(wrapped_func_result)
if shallow_copy_output:
if shallow_copy_output and not trace.bound_symbols:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment on lines +774 to +791

added_bsym: BoundSymbol = get_jit_ctx().computation_trace.scopes[-1][-1]
import_ctx, call_ctx, object_ctx = {}, {}, {}
for bsym in trace_of_fwd.bound_symbols:
cur_import_ctx, cur_call_ctx, cur_object_ctx = bsym.gather_ctxs()
import_ctx.update(cur_import_ctx)
call_ctx.update(cur_call_ctx)
object_ctx.update(cur_object_ctx)

if import_ctx:
added_bsym._import_ctx.update(import_ctx)
if call_ctx:
if added_bsym._call_ctx is not None:
added_bsym._call_ctx.update(call_ctx)
else:
added_bsym._call_ctx = call_ctx
if object_ctx:
added_bsym._object_ctx.update(object_ctx)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should be in #1394

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change should also be in #1394

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should be in #1394

Signed-off-by: Masaki Kozuki <[email protected]>
next, function with tensor creation in it

Signed-off-by: Masaki Kozuki <[email protected]>
Signed-off-by: Masaki Kozuki <[email protected]>
Signed-off-by: Masaki Kozuki <[email protected]>

revert wrong patch

Signed-off-by: Masaki Kozuki <[email protected]>

supply unpacks with traces generated within the lookaside

Signed-off-by: Masaki Kozuki <[email protected]>
Signed-off-by: Masaki Kozuki <[email protected]>
Signed-off-by: Masaki Kozuki <[email protected]>
Signed-off-by: Masaki Kozuki <[email protected]>
Signed-off-by: Masaki Kozuki <[email protected]>
crcrpar and others added 28 commits November 28, 2024 21:31
Signed-off-by: Masaki Kozuki <[email protected]>
Signed-off-by: Masaki Kozuki <[email protected]>
Signed-off-by: Masaki Kozuki <[email protected]>
Signed-off-by: Masaki Kozuki <[email protected]>
as torchao float8 ops table does not include `permute` but `transpose`.

Signed-off-by: Masaki Kozuki <[email protected]>
Signed-off-by: Masaki Kozuki <[email protected]>
Signed-off-by: Masaki Kozuki <[email protected]>
```
E               RuntimeError: While trying to flatten the following BoundSymbol:
E               t165 = manual_float8_matmul_with_args_in_float8_127532775308928_2(input_fp8, t164)  # t165: "cuda:0 f32[16, 64]"
E                 # t102 = ltorch.reshape(input_fp8, -1, 32)  # t102: "cuda:0 f32[16, 32]"
E                   # t102 = prims.reshape(input_fp8, (16, 32))  # t102: "cuda:0 f32[16, 32]"
E                 # t103 = ltorch.spmm(t102, t164)  # t103: "cuda:0 f32[16, 64]"
E                 # t165 = prims.shallow_copy(t103)  # t165: "cuda:0 f32[16, 64]"
E               Unsupported op of torch._scaled_mm found from
E               class <lambda>(torch.nn.Module):
E                   def forward(self, arg0, arg1, arg2, arg3, arg4, arg5):
E                       arg0_1: "f8e4m3fn[16, 32]"; arg1_1: "f32[]"; arg3_1: "f8e4m3fn[32, 64]"; arg4_1: "f32[]";
E
E                       arg0_1, arg1_1, arg2_1, arg2_2, arg2_3, arg2_4, arg2_5, arg2_6, arg2_7, arg2_8, arg2_9, arg2_10, arg2_11, arg2_12, arg2_13, arg2_14, arg2_15, arg3_1, arg4_1, arg5_1, arg5_2, arg5_3, arg5_4, arg5_5, arg5_6, arg5_7, arg5_8, arg5_9, arg5_10, arg5_11, arg5_12, arg5_13, arg5_14, arg5_15, = fx_pytree.tree_flatten_spec([arg0, arg1, arg2, arg3, arg4, arg5], self._in_spec)
E                       # No stacktrace found for following nodes
E                       view: "f8e4m3fn[16, 32]" = torch.ops.aten.view.default(arg0_1, [-1, 32]);  arg0_1 = None
E                       t: "f8e4m3fn[64, 32]" = torch.ops.aten.t.default(arg3_1);  arg3_1 = None
E                       clone: "f8e4m3fn[64, 32]" = torch.ops.aten.clone.default(t, memory_format = torch.contiguous_format);  t = None
E                       t_1: "f8e4m3fn[32, 64]" = torch.ops.aten.t.default(clone);  clone = None
E                       reciprocal: "f32[]" = torch.ops.aten.reciprocal.default(arg1_1);  arg1_1 = None
E                       reciprocal_1: "f32[]" = torch.ops.aten.reciprocal.default(arg4_1);  arg4_1 = None
E                       _scaled_mm: "f32[16, 64]" = torch.ops.aten._scaled_mm.default(view, t_1, reciprocal, reciprocal_1, None, None, torch.float32, True);  view = t_1 = reciprocal = reciprocal_1 = None
E                       return pytree.tree_unflatten([_scaled_mm, None], self._out_spec)

thunder/transforms/tensor_subclasses.py:299: RuntimeError
```

Signed-off-by: Masaki Kozuki <[email protected]>
Signed-off-by: Masaki Kozuki <[email protected]>
still failing as `_scaled_mm` requires the secomd matrix to be column
major:

```
E               NotImplementedError: Failing to map `torch._scaled_mm` to `thunder.torch` op of [Symbol name=_scaled_mm] with args of [<TensorProxy(name="t166", dtype=thunder.dtypes.float8_e4m3fn, shape=(16, 32))>, <TensorProxy(name="t169", dtype=thunder.dtypes.float8_e4m3fn, shape=(32, 64))>, <TensorProxy(name="t170", dtype=thunder.dtypes.float32, shape=())>, <TensorProxy(name="t171", dtype=thunder.dtypes.float32, shape=())>, None, None, torch.float32, True]
E               BoundSymbol in question is
E               ```python
E               t165 = manual_float8_matmul_with_args_in_float8_127377658692416_2(input_fp8, t164)  # t165: "cuda:0 f32[16, 64]"
E                 # t102 = ltorch.reshape(input_fp8, -1, 32)  # t102: "cuda:0 f32[16, 32]"
E                   # t102 = prims.reshape(input_fp8, (16, 32))  # t102: "cuda:0 f32[16, 32]"
E                 # t103 = ltorch.spmm(t102, t164)  # t103: "cuda:0 f32[16, 64]"
E                 # t165 = prims.shallow_copy(t103)  # t165: "cuda:0 f32[16, 64]"
E               ```
E               Corresponding torch.fx Graph is
E               ```python
E               class <lambda>(torch.nn.Module):
E                   def forward(self, arg0, arg1, arg2, arg3, arg4, arg5):
E                       arg0_1: "f8e4m3fn[16, 32]"; arg1_1: "f32[]"; arg3_1: "f8e4m3fn[32, 64]"; arg4_1: "f32[]";
E
E                       arg0_1, arg1_1, arg2_1, arg2_2, arg2_3, arg2_4, arg2_5, arg2_6, arg2_7, arg2_8, arg2_9, arg2_10, arg2_11, arg2_12, arg2_13, arg2_14, arg2_15, arg3_1, arg4_1, arg5_1, arg5_2, arg5_3, arg5_4, arg5_5, arg5_6, arg5_7, arg5_8, arg5_9, arg5_10, arg5_11, arg5_12, arg5_13, arg5_14, arg5_15, = fx_pytree.tree_flatten_spec([arg0, arg1, arg2, arg3, arg4, arg5], self._in_spec)
E                       # No stacktrace found for following nodes
E                       view: "f8e4m3fn[16, 32]" = torch.ops.aten.view.default(arg0_1, [-1, 32]);  arg0_1 = None
E                       t: "f8e4m3fn[64, 32]" = torch.ops.aten.t.default(arg3_1);  arg3_1 = None
E                       clone: "f8e4m3fn[64, 32]" = torch.ops.aten.clone.default(t, memory_format = torch.contiguous_format);  t = None
E                       t_1: "f8e4m3fn[32, 64]" = torch.ops.aten.t.default(clone);  clone = None
E                       reciprocal: "f32[]" = torch.ops.aten.reciprocal.default(arg1_1);  arg1_1 = None
E                       reciprocal_1: "f32[]" = torch.ops.aten.reciprocal.default(arg4_1);  arg4_1 = None
E                       _scaled_mm: "f32[16, 64]" = torch.ops.aten._scaled_mm.default(view, t_1, reciprocal, reciprocal_1, None, None, torch.float32, True);  view = t_1 = reciprocal = reciprocal_1 = None
E                       return pytree.tree_unflatten([_scaled_mm, None], self._out_spec)
E
E               ```
E               Original error is Exception encountered when doing automatic registration for _scaled_mm, please use manual registration: RuntimeError('mat2 must be col_major')
```

Signed-off-by: Masaki Kozuki <[email protected]>
Signed-off-by: Masaki Kozuki <[email protected]>
Signed-off-by: Masaki Kozuki <[email protected]>
Signed-off-by: Masaki Kozuki <[email protected]>
Signed-off-by: Masaki Kozuki <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Signed-off-by: Masaki Kozuki <[email protected]>
@crcrpar crcrpar force-pushed the crpa/subclass-torchao_float8tensor branch from 04d528a to 804bc99 Compare November 28, 2024 12:32
Comment on lines +275 to +277
if executor == DynamoThunderExecutor:
with pytest.raises(AssertionError):
torch.testing.assert_close(actual, expected)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This failure doesn't feel easy to fix to me. So I made this into a script:

import torch
import torch.nn as nn
from torchao.float8 import convert_to_float8_training
from thunder.dynamo import ThunderCompiler
from thunder.dynamo.splitter import SubgraphInfo
from thunder.tests.make_tensor import make_tensor


def main():
    batch_size, in_features, out_features = 16, 32, 64

    device = torch.device("cuda")
    dtype = torch.float32

    model = nn.Linear(in_features, out_features, bias=False, device=device, dtype=dtype)
    fp8_model = convert_to_float8_training(model)
    x = make_tensor((batch_size, in_features), device=device, dtype=dtype)
    expected = fp8_model(x)

    backend = ThunderCompiler()
    jitted = torch.compile(fp8_model, backend=backend)
    actual = jitted(x)

    backend.save_reproducer_to_folder("./debug_torchao_with_thunderfx", use_pytest_benchmark=True)
    print(f"{len(backend.subgraph_infos) = }")
    subgraph: SubgraphInfo
    for subgraph in backend.subgraph_infos:
        print(f"# {len(subgraph.thunder_compiled_fns) = }")

    torch.testing.assert_close(actual, expected)


if __name__ == "__main__":
    main()

note that pytorch/ao#1339 is needed at the moment.

Below, I put the console output of the script above:

% python debug_thunderfx_torchao_fp8.py
/home/mkozuki/ghq/github.com/Lightning-AI/lightning-thunder/thunder/dynamo/compiler.py:21: UserWarning: The ThunderCompiler is in active development and may not work as expected. Please report any issues you encounter to the Lightning Thunder team.
  warnings.warn(
len(backend.subgraph_infos) = 1
# len(subgraph.thunder_compiled_fns) = 0
Traceback (most recent call last):
  File "/home/mkozuki/ghq/github.com/Lightning-AI/lightning-thunder/debug_thunderfx_torchao_fp8.py", line 34, in <module>
    main()
  File "/home/mkozuki/ghq/github.com/Lightning-AI/lightning-thunder/debug_thunderfx_torchao_fp8.py", line 30, in main
    torch.testing.assert_close(actual, expected)
  File "/home/mkozuki/ghq/github.com/crcrpar/pytorch/torch/testing/_comparison.py", line 1530, in assert_close
    raise error_metas[0].to_error(msg)
AssertionError: Tensor-likes are not close!

Mismatched elements: 388 / 1024 (37.9%)
Greatest absolute difference: 0.18639898300170898 at index (1, 61) (up to 1e-05 allowed)
Greatest relative difference: 1.9664803743362427 at index (10, 33) (up to 1.3e-06 allowed)

So it seems that thunder.jit isn't used for this program but the numeric is diverging.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you check the result to see if they stay the same between different invocations. (Maybe due to low precision, the results could be different).

expected = fp8_model(x)
actual = fp8_model(x)
torch.testing.assert_close(actual, expected)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But please add a comment why expected and actual are both from calling the same model rather than one model and a reference.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants