Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

codeutils.to_printable does not seem capable of handling NamedTuple #1442

Closed
crcrpar opened this issue Nov 14, 2024 · 0 comments
Closed

codeutils.to_printable does not seem capable of handling NamedTuple #1442

crcrpar opened this issue Nov 14, 2024 · 0 comments
Assignees

Comments

@crcrpar
Copy link
Collaborator

crcrpar commented Nov 14, 2024

Note: If you have a model or program that is not supported yet but should be, please use the program coverage template.

🐛 Bug

codeutils.to_printable does not seem to be able to handle a NamedTuple, but thunder.jit looks happy with a function that takes a NamedTuple as one of its arguments.

I faced this in #1415 (comment).

To Reproduce

Code sample

from collections import namedtuple

import torch

import thunder
from thunder.core import codeutils
from thunder.core.trace import TraceCtx


MyTuple = namedtuple('MyTuple', ['x', 'y'])


def f(x: torch.Tensor, y: torch.Tensor, my_tuple: MyTuple):
    if my_tuple.x and my_tuple.y:
        return x - y
    return x + y


def main():

    my_tuple = MyTuple("abc", "def")

    x, y = torch.randn((2, 2)), torch.randn((2, 2))

    print(f"{thunder.jit(f)(x, y, my_tuple) = }")
    trace = TraceCtx()
    codeutils.to_printable(trace, my_tuple)


if __name__ == "__main__":
    main()
thunder.jit(f)(x, y, my_tuple) = tensor([[0.3599, 0.9889],
        [0.1036, 0.6854]])
Traceback (most recent call last):
  File "/home/mkozuki/ghq/github.com/Lightning-AI/lightning-thunder/c.py", line 31, in <module>
    main()
  File "/home/mkozuki/ghq/github.com/Lightning-AI/lightning-thunder/c.py", line 27, in main
    codeutils.to_printable(trace, my_tuple)
  File "/home/mkozuki/ghq/github.com/Lightning-AI/lightning-thunder/thunder/core/codeutils.py", line 144, in to_printable
    flat, spec = tree_flatten(x, namespace="")
  File "/home/mkozuki/ghq/github.com/Lightning-AI/lightning-thunder/thunder/core/pytree.py", line 73, in tree_flatten
    raise TypeError(f"tree_flatten of type {type(args)} is not supported.")
TypeError: tree_flatten of type <class '__main__.MyTuple'> is not supported.

Expected behavior

It just works.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants