Slow transpose convolutions (both cpu and cuda backends) #23783
Labels
bug
Something isn't working
CPU
Issues related to the CPU compiler/runtime
NVIDIA GPU
Issues specific to NVIDIA GPUs
Description
Transpose convolutions are orders of magnitude slower than their complementary regular convolutions and their counterparts in torch (at least for the sizes in the example below). This problem is consistent across both cpu and cuda backends (so I wouldn't point a finger to CUDA here).
Notebook with timings on Colab is here: https://colab.research.google.com/drive/19g_VmTrK0bScC6p5sqbuND7n0FVi4GqW?usp=sharing
I'm also attaching a .py version of the code at the end, its output on my M1 laptop is:
And on an Ubuntu machine with an RTX4090:
Here is the standalone code. Change the
dev
parameter to either'cpu'
or'cuda'
accordingly.System info (python version, jaxlib version, accelerator, etc.)
The text was updated successfully, but these errors were encountered: