Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

torch.jit.trace fails with very simple conv layer on tpu but succeeded on both mps and cpu #7606

Open
BitPhinix opened this issue Jul 2, 2024 · 0 comments

Comments

@BitPhinix
Copy link
Contributor

BitPhinix commented Jul 2, 2024

Minimal reproduction example:

import torch
import torch_xla
import torch_xla.core.xla_model as xm

class SimpleModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv_in = torch.nn.Conv2d(3, 64, kernel_size=3, padding=1)

    def forward(self, sample):
        return self.conv_in(sample)


def trace_model(device):
    model = SimpleModel().to(device)
    model.eval()  # Set the model to evaluation mode

    # Create a sample input
    sample = torch.randn(1, 3, 32, 32, device=device)

    # Attempt to trace the model
    try:
        with torch.no_grad():  # Disable gradient computation
            traced_model = torch.jit.trace(model, sample)
        print(f"Tracing successful on {device}")

        # Test the traced model
        test_output = traced_model(sample)
        print(f"Test output shape: {test_output.shape}")

    except Exception as e:
        print(f"Tracing failed on {device}: {str(e)}")


def main():
    # Test on CPU
    print("Testing on CPU:")
    trace_model(torch.device("cpu"))

    # Test on XLA device
    print("\nTesting on XLA device:")
    xla_device = xm.xla_device()
    trace_model(xla_device)


if __name__ == "__main__":
    main()

Works fine on CPU but fails with [ XLAFloatType{1,64,32,32} ]) of traced region did not have observable data dependence with trace inputs; this probably indicates your program cannot be understood by the tracer.

Tested on a TPU v4 pod.

Especially annoying because tracing on CPU and loading it onto the xla device isn't possible because of pytorch/pytorch#96448

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant