-
Notifications
You must be signed in to change notification settings - Fork 179
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
check scale.ndim
before applying t
/transpose
#1339
base: main
Are you sure you want to change the base?
check scale.ndim
before applying t
/transpose
#1339
Conversation
because (a) `scale` could be 0D/1D and `transpose` and (b) the args and kwargs of `torch.ops.aten.transpose.int` would supply `dim0` and `dim1`, leading to cause dim canonicalization to fail. e.g. [`torch._prims_common.canonicalize_dims`](https://github.com/pytorch/pytorch/blob/07906f2/torch/_prims_common/__init__.py#L704) Signed-off-by: Masaki Kozuki <[email protected]>
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/1339
Note: Links to docs will display an error until the docs builds have been completed. ❌ 1 New FailureAs of commit fdf4da5 with merge base b2e42ff (): NEW FAILURE - The following job has failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
just curious, what's the functionality you are trying to enable? |
I'm doing something funny in lightning-thunder to support torchao float8 in Lightning-AI/lightning-thunder#1415 where I create a python function from thunder IR and evaluate that function using torch.fx. |
sgtm. Can we add just add a corresponding test to |
Signed-off-by: Masaki Kozuki <[email protected]>
torchao/float8/float8_ops.py
Outdated
@@ -85,7 +85,10 @@ def float8_desugar_data_and_scale_op(aten_op, args, kwargs=None): | |||
) | |||
def float8_transpose(aten_op, args, kwargs=None): | |||
new_data = aten_op(args[0]._data, *args[1:], **kwargs) | |||
new_scale = aten_op(args[0]._scale, *args[1:], **kwargs) | |||
if args[0]._scale.ndim == 2: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should this be >= 2 to handle tensors of rank 3+?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I thought for FP8 mostly scale would be 0D / 1D unless it's MX fp8. Anyway cases I'm interested in would be solved with if-else so I updated the condition
Signed-off-by: Masaki Kozuki <[email protected]>
This check is for
torch.ops.aten.transpose.int
.This is because (a)
scale
could be 0D/1D and (b) the args and kwargs oftorch.ops.aten.transpose.int
would supplydim0
anddim1
, which is not appropriate for <2D tensor and dim canonicalization would fail liketorch._prims_common.canonicalize_dims