You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I use the lax.scan function. Inside the scan_step, among other things, I call a jitted function which (among other things) calls lax.cond.
I observed a huge drop in speed (factor 9 in my case) after updating jax when running it on my GPU (I didn't try it on the CPU as it takes a while to run). I was able to trace it back to the update from version 0.4.35 to 0.4.36.
The code is quite extensive and I am not allowed to publish it yet. My quick attempts to create a short minimal example did not result in the same huge drop.
This might be related to #26162. However, I have also checked the newer versions of jax up to 0.5.2 and still observe the same drop in speed. I know the issue is still open, but changes have been made to the xla library. I was hoping that they were already released and so it would be fixed in newer versions, but that is not the case. I am not sure how to incorporate the changes in the xla library into my environment, which is why I opened a separate issue.
System info (python version, jaxlib version, accelerator, etc.)
jax: 0.4.36
jaxlib: 0.4.36
numpy: 2.2.3
python: 3.12.5 (main, Aug 14 2024, 05:08:31) [Clang 18.1.8 ]
device info: NVIDIA RTX 2000 Ada Generation Laptop GPU-1, 1 local devices"
process_count: 1
platform: uname_result(system='Linux', node='ThinkPad-P1-Gen-6', release='6.8.0-52-generic', version='#53~22.04.1-Ubuntu SMP PREEMPT_DYNAMIC Wed Jan 15 19:18:46 UTC 2', machine='x86_64')
$ nvidia-smi
Tue Mar 18 16:34:23 2025
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 555.42.06 Driver Version: 555.42.06 CUDA Version: 12.5 |
|-----------------------------------------+------------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+========================+======================|
| 0 NVIDIA RTX 2000 Ada Gene... Off | 00000000:01:00.0 Off | N/A |
| N/A 53C P3 12W / 35W | 121MiB / 8188MiB | 7% Default |
| | | N/A |
+-----------------------------------------+------------------------+----------------------+
+-----------------------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=========================================================================================|
| 0 N/A N/A 2814 G /usr/lib/xorg/Xorg 4MiB |
| 0 N/A N/A 155902 C python3 98MiB |
+-----------------------------------------------------------------------------------------+
The text was updated successfully, but these errors were encountered:
Description
I use the
lax.scan
function. Inside thescan_step
, among other things, I call a jitted function which (among other things) callslax.cond
.I observed a huge drop in speed (factor 9 in my case) after updating
jax
when running it on my GPU (I didn't try it on the CPU as it takes a while to run). I was able to trace it back to the update from version0.4.35
to0.4.36
.The code is quite extensive and I am not allowed to publish it yet. My quick attempts to create a short minimal example did not result in the same huge drop.
This might be related to #26162. However, I have also checked the newer versions of
jax
up to0.5.2
and still observe the same drop in speed. I know the issue is still open, but changes have been made to thexla
library. I was hoping that they were already released and so it would be fixed in newer versions, but that is not the case. I am not sure how to incorporate the changes in thexla
library into my environment, which is why I opened a separate issue.System info (python version, jaxlib version, accelerator, etc.)
The text was updated successfully, but these errors were encountered: