-
Notifications
You must be signed in to change notification settings - Fork 427
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.
Already on GitHub? Sign in to your account
Pallas on CPU #7599
Comments
@alanwaketan can you take a look? I am actually not entirely sure if Pallas will works on XLA:CPU, but the failure seems to be happen in our python code. |
@johnsutor Can you try nightly? |
@alanwaketan I installed nightly using ! pip install https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-nightly-cp310-cp310-linux_x86_64.whl but then I get the following error when I attempt to run the code block import jax
from jax.experimental import pallas as pl
import jax.numpy as jnp
import os
from torch_xla.experimental.custom_kernel import make_kernel_from_pallas
import torch
os.environ["PJRT_DEVICE"] = "CPU" ---------------------------------------------------------------------------
ImportError Traceback (most recent call last)
[<ipython-input-3-193790483d72>](https://localhost:8080/#) in <cell line: 5>()
3 import jax.numpy as jnp
4 import os
----> 5 from torch_xla.experimental.custom_kernel import make_kernel_from_pallas
6 import torch
7 os.environ["PJRT_DEVICE"] = "CPU"
[/usr/local/lib/python3.10/dist-packages/torch_xla/__init__.py](https://localhost:8080/#) in <module>
6
7 import torch
----> 8 import _XLAC
9 from ._internal import tpu
10 from .version import __version__
ImportError: /usr/local/lib/python3.10/dist-packages/_XLAC.cpython-310-x86_64-linux-gnu.so: undefined symbol: _ZNK3c105Error4whatEv
---------------------------------------------------------------------------
NOTE: If your import is failing due to a missing package, you can
manually install dependencies using either !pip or !apt.
To view examples of installing some common dependencies, click the
"Open Examples" button below.
--------------------------------------------------------------------------- |
@johnsutor This error happens usually when there's a mismatch between torch and torch_xla versions installed. Maybe also update the pytorch to nightly version.
|
@bhavya01 That unfortunately does not work, as I receive an issue from the PyTorch end when attempting to install in Colab. Looking in indexes: https://download.pytorch.org/whl/nightly/cpu
Requirement already satisfied: torch in /usr/local/lib/python3.10/dist-packages (2.3.0+cu121)
Requirement already satisfied: torchvision in /usr/local/lib/python3.10/dist-packages (0.18.0+cu121)
Requirement already satisfied: torchaudio in /usr/local/lib/python3.10/dist-packages (2.3.0+cu121)
Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from torch) (3.15.4)
Requirement already satisfied: typing-extensions>=4.8.0 in /usr/local/lib/python3.10/dist-packages (from torch) (4.12.2)
Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch) (1.12.1)
Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch) (3.3)
Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch) (3.1.4)
Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from torch) (2023.6.0)
Collecting nvidia-cuda-nvrtc-cu12==12.1.105 (from torch)
ERROR: HTTP error 403 while getting https://download.pytorch.org/whl/nightly/nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (from https://download.pytorch.org/whl/nightly/cpu/nvidia-cuda-nvrtc-cu12/)
ERROR: Could not install requirement nvidia-cuda-nvrtc-cu12==12.1.105 from https://download.pytorch.org/whl/nightly/nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (from torch) because of HTTP error 403 Client Error: Forbidden for url: https://download.pytorch.org/whl/nightly/nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl for URL https://download.pytorch.org/whl/nightly/nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (from https://download.pytorch.org/whl/nightly/cpu/nvidia-cuda-nvrtc-cu12/) however, PyTorch nightly install works fine on my M2 Mac, but then I can't install torch_xla. |
@johnsutor Do you include torchvision torchaudio? If so, we can remove them from the command. |
@alanwaketan On Colab, I had to uninstall torch to get it to work before installing torch nightly. However, I now get the following error when I attempt to run this code (you can see the results in the notebook here) ef add_vectors_kernel(x_ref, y_ref, o_ref):
x, y = x_ref[...], y_ref[...]
o_ref[...] = x + y
@jax.jit
def add_vectors(x: jax.Array, y: jax.Array) -> jax.Array:
return pl.pallas_call(add_vectors_kernel,
out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype),
interpret=True
)(x, y)
q = torch.randn(3, 2, 128, 4).to("xla")
k = torch.randn(3, 2, 128, 4).to("xla")
v = torch.randn(3, 2, 128, 4).to("xla")
# From the tutorial
pt_kernel = make_kernel_from_pallas(add_vectors, lambda x, y: [(x.shape, x.dtype)])
output = pt_kernel(q, k) ---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
[<ipython-input-4-d7393bd6c1ca>](https://localhost:8080/#) in <cell line: 3>()
1 # From the tutorial
2 pt_kernel = make_kernel_from_pallas(add_vectors, lambda x, y: [(x.shape, x.dtype)])
----> 3 output = pt_kernel(q, k)
[/usr/local/lib/python3.10/dist-packages/torch_xla/experimental/custom_kernel.py](https://localhost:8080/#) in wrapped_kernel(kernel, output_shape_dtype_fn, static_argnums, static_argnames, *args, **kwargs)
143 output_shapes = [shape for shape, _ in output_shape_dtype]
144 output_dtypes = [dtype for _, dtype in output_shape_dtype]
--> 145 outputs = torch_xla._XLAC._xla_tpu_custom_call(tensor_args, payload,
146 output_shapes, output_dtypes)
147
TypeError: _xla_tpu_custom_call(): incompatible function arguments. The following argument types are supported:
1. (arg0: list[torch.Tensor], arg1: str, arg2: list[list[int]], arg3: list[object]) -> list[torch.Tensor]
Invoked with: [tensor([[[[ 0.9387, -0.0999, 1.2256, 1.0983],
[ 0.1434, 0.6569, -1.0381, -1.5455],
[ 0.6816, 1.0613, -1.0494, -0.1487],
...,
[-0.7309, -0.7079, 0.9284, 0.8607],
[ 0.0899, 1.5027, -0.5658, -1.0447],
[ 1.3098, 0.5118, 0.0765, -0.2424]],
[[ 0.0088, 1.8775, 0.0542, -1.1126],
[-0.6061, -0.3355, -0.7491, -1.8286],
[ 0.7737, 0.0899, -0.2609, 0.4641],
...,
[ 0.7544, 1.3189, 0.9427, 0.6183],
[ 0.9660, -0.5893, -0.3516, 1.0709],
[-2.3094, -0.2950, -1.1045, -0.0845]]],
[[[-0.4123, 0.7556, -0.4119, 0.6650],
[-0.4744, 0.9115, 0.1186, -0.7852],
[-0.2252, -0.6484, 1.5036, 0.7215],
...,
[-0.1553, 1.5585, 1.1157, 0.8698],
[ 0.5997, 0.5789, 0.3054, -1.8421],
[-0.5578, -0.8656, 0.1356, 1.1475]],
[[ 2.0097, 0.7483, -0.5908, 0.0702],
[-1.0810, -0.6120, -0.7814, 0.5367],
[-0.9203, -0.9630, -1.7621, 1.3503],
...,
[ 0.2230, -0.2255, 1.2624, -0.7935],
[-1.0775, 0.5843, 0.5457, -0.1265],
[-2.4482, -1.0382, -0.9038, -0.9088]]],
[[[ 0.6542, 2.3457, 0.0888, -0.2082],
[-0.2973, -0.4685, 0.8633, 1.2241],
[ 0.1258, 0.1412, 0.9298, -1.0842],
...,
[-0.6876, -1.5594, -1.0357, 0.3485],
[ 1.1975, -0.1514, -1.2257, 0.9857],
[ 1.7342, -1.5681, 0.4157, 0.9439]],
[[ 1.1967, 0.2086, -0.5509, -1.1779],
[ 0.4936, -0.8626, -0.6094, -0.7941],
[ 0.0440, -0.5978, 1.2477, 1.2164],
...,
[-1.8150, 1.0365, 0.5270, -0.4706],
[ 0.5347, -1.1803, -0.2394, -0.1587],
[-0.9638, -1.0259, -1.2330, -0.2761]]]], device='xla:0'), tensor([[[[ 6.7709e-01, 5.0651e-01, -3.9015e-01, 1.1769e+00],
[-5.7469e-01, 6.5236e-01, -6.9628e-01, 3.3803e-02],
[-2.8044e-01, 7.1211e-01, 2.3748e-01, 3.7293e-01],
...,
[ 2.8838e+00, 8.6530e-01, -1.5567e-01, 2.1392e-01],
[-5.7115e-01, -2.6569e+00, 1.2452e+00, 1.0137e-01],
[ 8.7078e-01, -8.3965e-01, -9.3462e-01, 5.8777e-01]],
[[-3.5002e-01, -1.0575e+00, -1.4964e+00, 9.9756e-01],
[ 7.8972e-01, -4.1112e-02, -1.2023e+00, -5.3902e-01],
[-5.9894e-01, -8.5050e-01, -3.6425e-01, -9.7505e-01],
...,
[ 1.2945e-02, -3.0388e-01, -1.3666e+00, -8.1373e-01],
[-1.2614e+00, 1.3913e-01, -5.6531e-01, -4.5330e-01],
[-1.1217e+00, -9.0676e-01, -1.0731e+00, -1.9240e-01]]],
[[[ 4.0985e-01, 1.1629e+00, -5.1721e-01, 1.6515e-01],
[-3.1879e-01, 7.2867e-01, -1.5622e+00, -6.3426e-01],
[-6.6151e-01, -3.2032e-01, 2.1753e+00, -8.9741e-01],
...,
[-3.5038e-01, 2.1497e-01, -2.1903e-01, 3.8987e-01],
[ 5.4283e-01, -5.4239e-01, -8.3459e-01, 4.8928e-01],
[ 1.2570e+00, -1.4615e+00, 4.1475e-01, 1.5395e+00]],
[[-8.7543e-01, -3.6893e-01, -6.6030e-01, -4.0877e-01],
[-2.3046e-02, -6.1282e-01, 1.8114e-01, -5.9609e-01],
[-1.8128e-01, 1.1691e+00, -5.3699e-01, -1.2312e-01],
...,
[ 2.0114e-01, -7.7060e-01, 1.1129e+00, -2.0385e-01],
[ 1.0480e+00, 4.0939e-01, -5.2975e-01, -2.1745e-01],
[-9.5069e-01, -9.6135e-01, -1.1307e+00, 1.1766e+00]]],
[[[ 3.3030e+00, 6.5805e-01, -1.7184e+00, 3.5029e-01],
[-1.2511e-01, 4.9624e-01, -1.3216e+00, -6.8949e-01],
[ 1.5174e+00, -6.8108e-01, 1.5536e-01, -7.8465e-01],
...,
[-9.0042e-01, 6.8886e-01, 5.7894e-01, 1.2891e-01],
[ 1.1717e-01, -2.0201e-01, -3.0778e-01, -1.8447e+00],
[ 3.1883e-01, 6.5378e-01, 1.3329e+00, 2.5843e-01]],
[[-2.1050e-01, -5.7218e-01, -5.8443e-01, -1.4757e+00],
[-1.6935e+00, -3.2765e-02, 9.5702e-01, 8.3929e-01],
[ 1.6788e-01, -1.0459e+00, -2.0357e-01, 3.7145e-02],
...,
[-9.4363e-01, 9.8749e-01, -4.5407e-01, -9.5364e-01],
[-1.3861e-01, -2.1635e-01, 9.0047e-01, -2.7273e-02],
[ 2.2375e+00, 2.0899e-03, 1.3707e+00, -6.9060e-01]]]],
device='xla:0')], None, [torch.Size([3, 2, 128, 4])], [torch.float32] |
Okay, the payload is None which means no IR being traced from the kernel. I think we just don't support the interpret mode. |
Can you tell me more on why you are trying to use interpret mode? Maybe we can add that support in the future. |
@alanwaketan I'm trying to use interpret mode so that I can develop kernels locally before I test them on TPUs (since it can be tough/expensive to write code and debug on TPUs) |
Got it. It seems a good feature to add. Unfortunately, the best way for you to develop the kernel now is to use Jax and then port the developed kernel into PyTorch/XLA. @johnsutor |
馃悰 Bug
I am attempting to implement custom Pallas kernels locally on a CPU for use with a TPU. I'm attempting to follow the official example here, with the minor modification being that I run the script on a CPU using interpret mode. After investigating, it appears that the main branch's latest code for a custom kernel should fix any issues with this error.
To Reproduce
Please use the colab here:
Steps to reproduce the behavior:
Expected behavior
It should execute the code without any errors
Environment
Additional context
N/A
The text was updated successfully, but these errors were encountered: