-
Notifications
You must be signed in to change notification settings - Fork 80
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[torchao float8tensor] #1415
base: crpa/subclass-tensor-ops
Are you sure you want to change the base?
[torchao float8tensor] #1415
Conversation
This comment was marked as outdated.
This comment was marked as outdated.
3fa8e2d
to
d5fb9fe
Compare
abf0167
to
e7ca8b7
Compare
This comment was marked as outdated.
This comment was marked as outdated.
This comment was marked as outdated.
This comment was marked as outdated.
896b631
to
316327f
Compare
@crcrpar if you merge main, the pt nightly distributed ci tests should be fixed. |
d5fb9fe
to
15c8d12
Compare
c87a36c
to
0de44ee
Compare
thunder/__init__.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This change should be in #1394
@@ -637,7 +637,7 @@ def _convert_pytorchfunc_to_thundertrace( | |||
trace = TraceCtx() | |||
trace.bound_symbols.extend(active_jit_ctx.computation_trace.pop_scope()) | |||
func_result = unwrap(wrapped_func_result) | |||
if shallow_copy_output: | |||
if shallow_copy_output and not trace.bound_symbols: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
|
||
added_bsym: BoundSymbol = get_jit_ctx().computation_trace.scopes[-1][-1] | ||
import_ctx, call_ctx, object_ctx = {}, {}, {} | ||
for bsym in trace_of_fwd.bound_symbols: | ||
cur_import_ctx, cur_call_ctx, cur_object_ctx = bsym.gather_ctxs() | ||
import_ctx.update(cur_import_ctx) | ||
call_ctx.update(cur_call_ctx) | ||
object_ctx.update(cur_object_ctx) | ||
|
||
if import_ctx: | ||
added_bsym._import_ctx.update(import_ctx) | ||
if call_ctx: | ||
if added_bsym._call_ctx is not None: | ||
added_bsym._call_ctx.update(call_ctx) | ||
else: | ||
added_bsym._call_ctx = call_ctx | ||
if object_ctx: | ||
added_bsym._object_ctx.update(object_ctx) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should be in #1394
thunder/core/prims.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This change should also be in #1394
thunder/executors/torch_autograd.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should be in #1394
15c8d12
to
70dc6ba
Compare
Signed-off-by: Masaki Kozuki <[email protected]>
next, function with tensor creation in it Signed-off-by: Masaki Kozuki <[email protected]>
Signed-off-by: Masaki Kozuki <[email protected]>
Signed-off-by: Masaki Kozuki <[email protected]>
Signed-off-by: Masaki Kozuki <[email protected]>
Signed-off-by: Masaki Kozuki <[email protected]>
Signed-off-by: Masaki Kozuki <[email protected]>
Signed-off-by: Masaki Kozuki <[email protected]> revert wrong patch Signed-off-by: Masaki Kozuki <[email protected]> supply unpacks with traces generated within the lookaside Signed-off-by: Masaki Kozuki <[email protected]>
Signed-off-by: Masaki Kozuki <[email protected]>
Signed-off-by: Masaki Kozuki <[email protected]>
Signed-off-by: Masaki Kozuki <[email protected]>
Signed-off-by: Masaki Kozuki <[email protected]>
Signed-off-by: Masaki Kozuki <[email protected]>
Signed-off-by: Masaki Kozuki <[email protected]>
Signed-off-by: Masaki Kozuki <[email protected]>
Signed-off-by: Masaki Kozuki <[email protected]>
Signed-off-by: Masaki Kozuki <[email protected]>
Signed-off-by: Masaki Kozuki <[email protected]>
Signed-off-by: Masaki Kozuki <[email protected]>
Signed-off-by: Masaki Kozuki <[email protected]>
as torchao float8 ops table does not include `permute` but `transpose`. Signed-off-by: Masaki Kozuki <[email protected]>
Signed-off-by: Masaki Kozuki <[email protected]>
Signed-off-by: Masaki Kozuki <[email protected]>
Signed-off-by: Masaki Kozuki <[email protected]>
Signed-off-by: Masaki Kozuki <[email protected]>
Signed-off-by: Masaki Kozuki <[email protected]>
``` E RuntimeError: While trying to flatten the following BoundSymbol: E t165 = manual_float8_matmul_with_args_in_float8_127532775308928_2(input_fp8, t164) # t165: "cuda:0 f32[16, 64]" E # t102 = ltorch.reshape(input_fp8, -1, 32) # t102: "cuda:0 f32[16, 32]" E # t102 = prims.reshape(input_fp8, (16, 32)) # t102: "cuda:0 f32[16, 32]" E # t103 = ltorch.spmm(t102, t164) # t103: "cuda:0 f32[16, 64]" E # t165 = prims.shallow_copy(t103) # t165: "cuda:0 f32[16, 64]" E Unsupported op of torch._scaled_mm found from E class <lambda>(torch.nn.Module): E def forward(self, arg0, arg1, arg2, arg3, arg4, arg5): E arg0_1: "f8e4m3fn[16, 32]"; arg1_1: "f32[]"; arg3_1: "f8e4m3fn[32, 64]"; arg4_1: "f32[]"; E E arg0_1, arg1_1, arg2_1, arg2_2, arg2_3, arg2_4, arg2_5, arg2_6, arg2_7, arg2_8, arg2_9, arg2_10, arg2_11, arg2_12, arg2_13, arg2_14, arg2_15, arg3_1, arg4_1, arg5_1, arg5_2, arg5_3, arg5_4, arg5_5, arg5_6, arg5_7, arg5_8, arg5_9, arg5_10, arg5_11, arg5_12, arg5_13, arg5_14, arg5_15, = fx_pytree.tree_flatten_spec([arg0, arg1, arg2, arg3, arg4, arg5], self._in_spec) E # No stacktrace found for following nodes E view: "f8e4m3fn[16, 32]" = torch.ops.aten.view.default(arg0_1, [-1, 32]); arg0_1 = None E t: "f8e4m3fn[64, 32]" = torch.ops.aten.t.default(arg3_1); arg3_1 = None E clone: "f8e4m3fn[64, 32]" = torch.ops.aten.clone.default(t, memory_format = torch.contiguous_format); t = None E t_1: "f8e4m3fn[32, 64]" = torch.ops.aten.t.default(clone); clone = None E reciprocal: "f32[]" = torch.ops.aten.reciprocal.default(arg1_1); arg1_1 = None E reciprocal_1: "f32[]" = torch.ops.aten.reciprocal.default(arg4_1); arg4_1 = None E _scaled_mm: "f32[16, 64]" = torch.ops.aten._scaled_mm.default(view, t_1, reciprocal, reciprocal_1, None, None, torch.float32, True); view = t_1 = reciprocal = reciprocal_1 = None E return pytree.tree_unflatten([_scaled_mm, None], self._out_spec) thunder/transforms/tensor_subclasses.py:299: RuntimeError ``` Signed-off-by: Masaki Kozuki <[email protected]>
Signed-off-by: Masaki Kozuki <[email protected]>
Signed-off-by: Masaki Kozuki <[email protected]>
Signed-off-by: Masaki Kozuki <[email protected]>
Signed-off-by: Masaki Kozuki <[email protected]>
Signed-off-by: Masaki Kozuki <[email protected]>
still failing as `_scaled_mm` requires the secomd matrix to be column major: ``` E NotImplementedError: Failing to map `torch._scaled_mm` to `thunder.torch` op of [Symbol name=_scaled_mm] with args of [<TensorProxy(name="t166", dtype=thunder.dtypes.float8_e4m3fn, shape=(16, 32))>, <TensorProxy(name="t169", dtype=thunder.dtypes.float8_e4m3fn, shape=(32, 64))>, <TensorProxy(name="t170", dtype=thunder.dtypes.float32, shape=())>, <TensorProxy(name="t171", dtype=thunder.dtypes.float32, shape=())>, None, None, torch.float32, True] E BoundSymbol in question is E ```python E t165 = manual_float8_matmul_with_args_in_float8_127377658692416_2(input_fp8, t164) # t165: "cuda:0 f32[16, 64]" E # t102 = ltorch.reshape(input_fp8, -1, 32) # t102: "cuda:0 f32[16, 32]" E # t102 = prims.reshape(input_fp8, (16, 32)) # t102: "cuda:0 f32[16, 32]" E # t103 = ltorch.spmm(t102, t164) # t103: "cuda:0 f32[16, 64]" E # t165 = prims.shallow_copy(t103) # t165: "cuda:0 f32[16, 64]" E ``` E Corresponding torch.fx Graph is E ```python E class <lambda>(torch.nn.Module): E def forward(self, arg0, arg1, arg2, arg3, arg4, arg5): E arg0_1: "f8e4m3fn[16, 32]"; arg1_1: "f32[]"; arg3_1: "f8e4m3fn[32, 64]"; arg4_1: "f32[]"; E E arg0_1, arg1_1, arg2_1, arg2_2, arg2_3, arg2_4, arg2_5, arg2_6, arg2_7, arg2_8, arg2_9, arg2_10, arg2_11, arg2_12, arg2_13, arg2_14, arg2_15, arg3_1, arg4_1, arg5_1, arg5_2, arg5_3, arg5_4, arg5_5, arg5_6, arg5_7, arg5_8, arg5_9, arg5_10, arg5_11, arg5_12, arg5_13, arg5_14, arg5_15, = fx_pytree.tree_flatten_spec([arg0, arg1, arg2, arg3, arg4, arg5], self._in_spec) E # No stacktrace found for following nodes E view: "f8e4m3fn[16, 32]" = torch.ops.aten.view.default(arg0_1, [-1, 32]); arg0_1 = None E t: "f8e4m3fn[64, 32]" = torch.ops.aten.t.default(arg3_1); arg3_1 = None E clone: "f8e4m3fn[64, 32]" = torch.ops.aten.clone.default(t, memory_format = torch.contiguous_format); t = None E t_1: "f8e4m3fn[32, 64]" = torch.ops.aten.t.default(clone); clone = None E reciprocal: "f32[]" = torch.ops.aten.reciprocal.default(arg1_1); arg1_1 = None E reciprocal_1: "f32[]" = torch.ops.aten.reciprocal.default(arg4_1); arg4_1 = None E _scaled_mm: "f32[16, 64]" = torch.ops.aten._scaled_mm.default(view, t_1, reciprocal, reciprocal_1, None, None, torch.float32, True); view = t_1 = reciprocal = reciprocal_1 = None E return pytree.tree_unflatten([_scaled_mm, None], self._out_spec) E E ``` E Original error is Exception encountered when doing automatic registration for _scaled_mm, please use manual registration: RuntimeError('mat2 must be col_major') ``` Signed-off-by: Masaki Kozuki <[email protected]>
Signed-off-by: Masaki Kozuki <[email protected]>
Signed-off-by: Masaki Kozuki <[email protected]>
Signed-off-by: Masaki Kozuki <[email protected]>
…latten__` Signed-off-by: Masaki Kozuki <[email protected]>
for more information, see https://pre-commit.ci
Signed-off-by: Masaki Kozuki <[email protected]>
Signed-off-by: Masaki Kozuki <[email protected]>
Signed-off-by: Masaki Kozuki <[email protected]>
Signed-off-by: Masaki Kozuki <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Signed-off-by: Masaki Kozuki <[email protected]>
04d528a
to
804bc99
Compare
if executor == DynamoThunderExecutor: | ||
with pytest.raises(AssertionError): | ||
torch.testing.assert_close(actual, expected) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This failure doesn't feel easy to fix to me. So I made this into a script:
import torch
import torch.nn as nn
from torchao.float8 import convert_to_float8_training
from thunder.dynamo import ThunderCompiler
from thunder.dynamo.splitter import SubgraphInfo
from thunder.tests.make_tensor import make_tensor
def main():
batch_size, in_features, out_features = 16, 32, 64
device = torch.device("cuda")
dtype = torch.float32
model = nn.Linear(in_features, out_features, bias=False, device=device, dtype=dtype)
fp8_model = convert_to_float8_training(model)
x = make_tensor((batch_size, in_features), device=device, dtype=dtype)
expected = fp8_model(x)
backend = ThunderCompiler()
jitted = torch.compile(fp8_model, backend=backend)
actual = jitted(x)
backend.save_reproducer_to_folder("./debug_torchao_with_thunderfx", use_pytest_benchmark=True)
print(f"{len(backend.subgraph_infos) = }")
subgraph: SubgraphInfo
for subgraph in backend.subgraph_infos:
print(f"# {len(subgraph.thunder_compiled_fns) = }")
torch.testing.assert_close(actual, expected)
if __name__ == "__main__":
main()
note that pytorch/ao#1339 is needed at the moment.
Below, I put the console output of the script above:
% python debug_thunderfx_torchao_fp8.py
/home/mkozuki/ghq/github.com/Lightning-AI/lightning-thunder/thunder/dynamo/compiler.py:21: UserWarning: The ThunderCompiler is in active development and may not work as expected. Please report any issues you encounter to the Lightning Thunder team.
warnings.warn(
len(backend.subgraph_infos) = 1
# len(subgraph.thunder_compiled_fns) = 0
Traceback (most recent call last):
File "/home/mkozuki/ghq/github.com/Lightning-AI/lightning-thunder/debug_thunderfx_torchao_fp8.py", line 34, in <module>
main()
File "/home/mkozuki/ghq/github.com/Lightning-AI/lightning-thunder/debug_thunderfx_torchao_fp8.py", line 30, in main
torch.testing.assert_close(actual, expected)
File "/home/mkozuki/ghq/github.com/crcrpar/pytorch/torch/testing/_comparison.py", line 1530, in assert_close
raise error_metas[0].to_error(msg)
AssertionError: Tensor-likes are not close!
Mismatched elements: 388 / 1024 (37.9%)
Greatest absolute difference: 0.18639898300170898 at index (1, 61) (up to 1e-05 allowed)
Greatest relative difference: 1.9664803743362427 at index (10, 33) (up to 1.3e-06 allowed)
So it seems that thunder.jit
isn't used for this program but the numeric is diverging.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you check the result to see if they stay the same between different invocations. (Maybe due to low precision, the results could be different).
expected = fp8_model(x)
actual = fp8_model(x)
torch.testing.assert_close(actual, expected)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But please add a comment why expected and actual are both from calling the same model rather than one model and a reference.
What does this PR do?
Improve the tensor subclass support of #1394 for TorchAo float8.
note: pytorch/ao#1339 is needed
my environment