You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
import torch
import torch_xla
import torch_xla.core.xla_model as xm
class SimpleModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv_in = torch.nn.Conv2d(3, 64, kernel_size=3, padding=1)
def forward(self, sample):
return self.conv_in(sample)
def trace_model(device):
model = SimpleModel().to(device)
model.eval() # Set the model to evaluation mode
# Create a sample input
sample = torch.randn(1, 3, 32, 32, device=device)
# Attempt to trace the model
try:
with torch.no_grad(): # Disable gradient computation
traced_model = torch.jit.trace(model, sample)
print(f"Tracing successful on {device}")
# Test the traced model
test_output = traced_model(sample)
print(f"Test output shape: {test_output.shape}")
except Exception as e:
print(f"Tracing failed on {device}: {str(e)}")
def main():
# Test on CPU
print("Testing on CPU:")
trace_model(torch.device("cpu"))
# Test on XLA device
print("\nTesting on XLA device:")
xla_device = xm.xla_device()
trace_model(xla_device)
if __name__ == "__main__":
main()
Works fine on CPU but fails with [ XLAFloatType{1,64,32,32} ]) of traced region did not have observable data dependence with trace inputs; this probably indicates your program cannot be understood by the tracer.
Tested on a TPU v4 pod.
Especially annoying because tracing on CPU and loading it onto the xla device isn't possible because of pytorch/pytorch#96448
The text was updated successfully, but these errors were encountered:
Minimal reproduction example:
Works fine on CPU but fails with
[ XLAFloatType{1,64,32,32} ]) of traced region did not have observable data dependence with trace inputs; this probably indicates your program cannot be understood by the tracer.
Tested on a TPU v4 pod.
Especially annoying because tracing on CPU and loading it onto the xla device isn't possible because of pytorch/pytorch#96448
The text was updated successfully, but these errors were encountered: