Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pallas on CPU #7599

Open
johnsutor opened this issue Jul 1, 2024 · 11 comments
Open

Pallas on CPU #7599

johnsutor opened this issue Jul 1, 2024 · 11 comments

Comments

@johnsutor
Copy link

馃悰 Bug

I am attempting to implement custom Pallas kernels locally on a CPU for use with a TPU. I'm attempting to follow the official example here, with the minor modification being that I run the script on a CPU using interpret mode. After investigating, it appears that the main branch's latest code for a custom kernel should fix any issues with this error.

To Reproduce

Please use the colab here:

Steps to reproduce the behavior:

  1. Run the colab
  2. Observe errors in the last two cells

Expected behavior

It should execute the code without any errors

Environment

  • Reproducible on XLA backend [CPU/TPU/CUDA]: CPU
  • torch_xla version: ~=2.3.0

Additional context

N/A

@JackCaoG
Copy link
Collaborator

JackCaoG commented Jul 1, 2024

@alanwaketan can you take a look? I am actually not entirely sure if Pallas will works on XLA:CPU, but the failure seems to be happen in our python code.

@alanwaketan
Copy link
Collaborator

@johnsutor Can you try nightly?

@johnsutor
Copy link
Author

@alanwaketan I installed nightly using

! pip install https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-nightly-cp310-cp310-linux_x86_64.whl

but then I get the following error when I attempt to run the code block

import jax
from jax.experimental import pallas as pl
import jax.numpy as jnp
import os 
from torch_xla.experimental.custom_kernel import make_kernel_from_pallas
import torch
os.environ["PJRT_DEVICE"] = "CPU"
---------------------------------------------------------------------------
ImportError                               Traceback (most recent call last)
[<ipython-input-3-193790483d72>](https://localhost:8080/#) in <cell line: 5>()
      3 import jax.numpy as jnp
      4 import os
----> 5 from torch_xla.experimental.custom_kernel import make_kernel_from_pallas
      6 import torch
      7 os.environ["PJRT_DEVICE"] = "CPU"

[/usr/local/lib/python3.10/dist-packages/torch_xla/__init__.py](https://localhost:8080/#) in <module>
      6 
      7 import torch
----> 8 import _XLAC
      9 from ._internal import tpu
     10 from .version import __version__

ImportError: /usr/local/lib/python3.10/dist-packages/_XLAC.cpython-310-x86_64-linux-gnu.so: undefined symbol: _ZNK3c105Error4whatEv

---------------------------------------------------------------------------
NOTE: If your import is failing due to a missing package, you can
manually install dependencies using either !pip or !apt.

To view examples of installing some common dependencies, click the
"Open Examples" button below.
---------------------------------------------------------------------------

@bhavya01
Copy link
Collaborator

bhavya01 commented Jul 2, 2024

@johnsutor This error happens usually when there's a mismatch between torch and torch_xla versions installed. Maybe also update the pytorch to nightly version.

pip3 install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cpu

@johnsutor
Copy link
Author

@bhavya01 That unfortunately does not work, as I receive an issue from the PyTorch end when attempting to install in Colab.

Looking in indexes: https://download.pytorch.org/whl/nightly/cpu
Requirement already satisfied: torch in /usr/local/lib/python3.10/dist-packages (2.3.0+cu121)
Requirement already satisfied: torchvision in /usr/local/lib/python3.10/dist-packages (0.18.0+cu121)
Requirement already satisfied: torchaudio in /usr/local/lib/python3.10/dist-packages (2.3.0+cu121)
Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from torch) (3.15.4)
Requirement already satisfied: typing-extensions>=4.8.0 in /usr/local/lib/python3.10/dist-packages (from torch) (4.12.2)
Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch) (1.12.1)
Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch) (3.3)
Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch) (3.1.4)
Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from torch) (2023.6.0)
Collecting nvidia-cuda-nvrtc-cu12==12.1.105 (from torch)
  ERROR: HTTP error 403 while getting https://download.pytorch.org/whl/nightly/nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (from https://download.pytorch.org/whl/nightly/cpu/nvidia-cuda-nvrtc-cu12/)
ERROR: Could not install requirement nvidia-cuda-nvrtc-cu12==12.1.105 from https://download.pytorch.org/whl/nightly/nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (from torch) because of HTTP error 403 Client Error: Forbidden for url: https://download.pytorch.org/whl/nightly/nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl for URL https://download.pytorch.org/whl/nightly/nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (from https://download.pytorch.org/whl/nightly/cpu/nvidia-cuda-nvrtc-cu12/)

however, PyTorch nightly install works fine on my M2 Mac, but then I can't install torch_xla.

@alanwaketan
Copy link
Collaborator

alanwaketan commented Jul 2, 2024

@johnsutor Do you include torchvision torchaudio? If so, we can remove them from the command.

@johnsutor
Copy link
Author

@alanwaketan On Colab, I had to uninstall torch to get it to work before installing torch nightly. However, I now get the following error when I attempt to run this code (you can see the results in the notebook here)

ef add_vectors_kernel(x_ref, y_ref, o_ref):
  x, y = x_ref[...], y_ref[...]
  o_ref[...] = x + y

@jax.jit
def add_vectors(x: jax.Array, y: jax.Array) -> jax.Array:
  return pl.pallas_call(add_vectors_kernel,
                        out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype),
                        interpret=True
                        )(x, y)

q = torch.randn(3, 2, 128, 4).to("xla")
k = torch.randn(3, 2, 128, 4).to("xla")
v = torch.randn(3, 2, 128, 4).to("xla")

# From the tutorial
pt_kernel = make_kernel_from_pallas(add_vectors, lambda x, y: [(x.shape, x.dtype)])
output = pt_kernel(q, k)
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
[<ipython-input-4-d7393bd6c1ca>](https://localhost:8080/#) in <cell line: 3>()
      1 # From the tutorial
      2 pt_kernel = make_kernel_from_pallas(add_vectors, lambda x, y: [(x.shape, x.dtype)])
----> 3 output = pt_kernel(q, k)

[/usr/local/lib/python3.10/dist-packages/torch_xla/experimental/custom_kernel.py](https://localhost:8080/#) in wrapped_kernel(kernel, output_shape_dtype_fn, static_argnums, static_argnames, *args, **kwargs)
    143     output_shapes = [shape for shape, _ in output_shape_dtype]
    144     output_dtypes = [dtype for _, dtype in output_shape_dtype]
--> 145     outputs = torch_xla._XLAC._xla_tpu_custom_call(tensor_args, payload,
    146                                                    output_shapes, output_dtypes)
    147 

TypeError: _xla_tpu_custom_call(): incompatible function arguments. The following argument types are supported:
    1. (arg0: list[torch.Tensor], arg1: str, arg2: list[list[int]], arg3: list[object]) -> list[torch.Tensor]

Invoked with: [tensor([[[[ 0.9387, -0.0999,  1.2256,  1.0983],
          [ 0.1434,  0.6569, -1.0381, -1.5455],
          [ 0.6816,  1.0613, -1.0494, -0.1487],
          ...,
          [-0.7309, -0.7079,  0.9284,  0.8607],
          [ 0.0899,  1.5027, -0.5658, -1.0447],
          [ 1.3098,  0.5118,  0.0765, -0.2424]],

         [[ 0.0088,  1.8775,  0.0542, -1.1126],
          [-0.6061, -0.3355, -0.7491, -1.8286],
          [ 0.7737,  0.0899, -0.2609,  0.4641],
          ...,
          [ 0.7544,  1.3189,  0.9427,  0.6183],
          [ 0.9660, -0.5893, -0.3516,  1.0709],
          [-2.3094, -0.2950, -1.1045, -0.0845]]],


        [[[-0.4123,  0.7556, -0.4119,  0.6650],
          [-0.4744,  0.9115,  0.1186, -0.7852],
          [-0.2252, -0.6484,  1.5036,  0.7215],
          ...,
          [-0.1553,  1.5585,  1.1157,  0.8698],
          [ 0.5997,  0.5789,  0.3054, -1.8421],
          [-0.5578, -0.8656,  0.1356,  1.1475]],

         [[ 2.0097,  0.7483, -0.5908,  0.0702],
          [-1.0810, -0.6120, -0.7814,  0.5367],
          [-0.9203, -0.9630, -1.7621,  1.3503],
          ...,
          [ 0.2230, -0.2255,  1.2624, -0.7935],
          [-1.0775,  0.5843,  0.5457, -0.1265],
          [-2.4482, -1.0382, -0.9038, -0.9088]]],


        [[[ 0.6542,  2.3457,  0.0888, -0.2082],
          [-0.2973, -0.4685,  0.8633,  1.2241],
          [ 0.1258,  0.1412,  0.9298, -1.0842],
          ...,
          [-0.6876, -1.5594, -1.0357,  0.3485],
          [ 1.1975, -0.1514, -1.2257,  0.9857],
          [ 1.7342, -1.5681,  0.4157,  0.9439]],

         [[ 1.1967,  0.2086, -0.5509, -1.1779],
          [ 0.4936, -0.8626, -0.6094, -0.7941],
          [ 0.0440, -0.5978,  1.2477,  1.2164],
          ...,
          [-1.8150,  1.0365,  0.5270, -0.4706],
          [ 0.5347, -1.1803, -0.2394, -0.1587],
          [-0.9638, -1.0259, -1.2330, -0.2761]]]], device='xla:0'), tensor([[[[ 6.7709e-01,  5.0651e-01, -3.9015e-01,  1.1769e+00],
          [-5.7469e-01,  6.5236e-01, -6.9628e-01,  3.3803e-02],
          [-2.8044e-01,  7.1211e-01,  2.3748e-01,  3.7293e-01],
          ...,
          [ 2.8838e+00,  8.6530e-01, -1.5567e-01,  2.1392e-01],
          [-5.7115e-01, -2.6569e+00,  1.2452e+00,  1.0137e-01],
          [ 8.7078e-01, -8.3965e-01, -9.3462e-01,  5.8777e-01]],

         [[-3.5002e-01, -1.0575e+00, -1.4964e+00,  9.9756e-01],
          [ 7.8972e-01, -4.1112e-02, -1.2023e+00, -5.3902e-01],
          [-5.9894e-01, -8.5050e-01, -3.6425e-01, -9.7505e-01],
          ...,
          [ 1.2945e-02, -3.0388e-01, -1.3666e+00, -8.1373e-01],
          [-1.2614e+00,  1.3913e-01, -5.6531e-01, -4.5330e-01],
          [-1.1217e+00, -9.0676e-01, -1.0731e+00, -1.9240e-01]]],


        [[[ 4.0985e-01,  1.1629e+00, -5.1721e-01,  1.6515e-01],
          [-3.1879e-01,  7.2867e-01, -1.5622e+00, -6.3426e-01],
          [-6.6151e-01, -3.2032e-01,  2.1753e+00, -8.9741e-01],
          ...,
          [-3.5038e-01,  2.1497e-01, -2.1903e-01,  3.8987e-01],
          [ 5.4283e-01, -5.4239e-01, -8.3459e-01,  4.8928e-01],
          [ 1.2570e+00, -1.4615e+00,  4.1475e-01,  1.5395e+00]],

         [[-8.7543e-01, -3.6893e-01, -6.6030e-01, -4.0877e-01],
          [-2.3046e-02, -6.1282e-01,  1.8114e-01, -5.9609e-01],
          [-1.8128e-01,  1.1691e+00, -5.3699e-01, -1.2312e-01],
          ...,
          [ 2.0114e-01, -7.7060e-01,  1.1129e+00, -2.0385e-01],
          [ 1.0480e+00,  4.0939e-01, -5.2975e-01, -2.1745e-01],
          [-9.5069e-01, -9.6135e-01, -1.1307e+00,  1.1766e+00]]],


        [[[ 3.3030e+00,  6.5805e-01, -1.7184e+00,  3.5029e-01],
          [-1.2511e-01,  4.9624e-01, -1.3216e+00, -6.8949e-01],
          [ 1.5174e+00, -6.8108e-01,  1.5536e-01, -7.8465e-01],
          ...,
          [-9.0042e-01,  6.8886e-01,  5.7894e-01,  1.2891e-01],
          [ 1.1717e-01, -2.0201e-01, -3.0778e-01, -1.8447e+00],
          [ 3.1883e-01,  6.5378e-01,  1.3329e+00,  2.5843e-01]],

         [[-2.1050e-01, -5.7218e-01, -5.8443e-01, -1.4757e+00],
          [-1.6935e+00, -3.2765e-02,  9.5702e-01,  8.3929e-01],
          [ 1.6788e-01, -1.0459e+00, -2.0357e-01,  3.7145e-02],
          ...,
          [-9.4363e-01,  9.8749e-01, -4.5407e-01, -9.5364e-01],
          [-1.3861e-01, -2.1635e-01,  9.0047e-01, -2.7273e-02],
          [ 2.2375e+00,  2.0899e-03,  1.3707e+00, -6.9060e-01]]]],
       device='xla:0')], None, [torch.Size([3, 2, 128, 4])], [torch.float32]

@alanwaketan
Copy link
Collaborator

Okay, the payload is None which means no IR being traced from the kernel. I think we just don't support the interpret mode.

@alanwaketan
Copy link
Collaborator

Can you tell me more on why you are trying to use interpret mode? Maybe we can add that support in the future.

@johnsutor
Copy link
Author

@alanwaketan I'm trying to use interpret mode so that I can develop kernels locally before I test them on TPUs (since it can be tough/expensive to write code and debug on TPUs)

@alanwaketan
Copy link
Collaborator

Got it. It seems a good feature to add. Unfortunately, the best way for you to develop the kernel now is to use Jax and then port the developed kernel into PyTorch/XLA. @johnsutor

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants