Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Slow transpose convolutions (both cpu and cuda backends) #23783

Open
psmaragdis opened this issue Sep 20, 2024 · 1 comment
Open

Slow transpose convolutions (both cpu and cuda backends) #23783

psmaragdis opened this issue Sep 20, 2024 · 1 comment
Assignees
Labels
bug Something isn't working CPU Issues related to the CPU compiler/runtime NVIDIA GPU Issues specific to NVIDIA GPUs

Comments

@psmaragdis
Copy link

Description

Transpose convolutions are orders of magnitude slower than their complementary regular convolutions and their counterparts in torch (at least for the sizes in the example below). This problem is consistent across both cpu and cuda backends (so I wouldn't point a finger to CUDA here).

Notebook with timings on Colab is here: https://colab.research.google.com/drive/19g_VmTrK0bScC6p5sqbuND7n0FVi4GqW?usp=sharing

I'm also attaching a .py version of the code at the end, its output on my M1 laptop is:

Using jax 0.4.33 on cpu
Jax 1d conv :	 1,632 iterations in 5.09 seconds
Jax 1d convt:	    17 iterations in 229.14 seconds
Using torch 2.4.1 on cpu
Torch 1d conv :	 1,548 iterations in 5.00 seconds
Torch 1d convt:	 2,211 iterations in 5.00 seconds

And on an Ubuntu machine with an RTX4090:

Using jax 0.4.33 on cuda
Jax 1d conv :	73,774 iterations in 5.00 seconds
Jax 1d convt:	   388 iterations in 5.22 seconds
Using torch 2.4.1+cu124 on cuda
Torch 1d conv :	121,147 iterations in 5.02 seconds
Torch 1d convt:	168,121 iterations in 5.01 seconds

Here is the standalone code. Change the dev parameter to either 'cpu' or 'cuda' accordingly.

# -*- coding: utf-8 -*-
"""Slow Jax ConvT.ipynb

Automatically generated by Colab.

Original file is located at
    https://colab.research.google.com/drive/19g_VmTrK0bScC6p5sqbuND7n0FVi4GqW

# Setup and benchmarking routine
"""

import jax
import torch
torch.backends.cudnn.benchmark = True

sz = 256
hp = sz//4
dev = 'cuda' # can also change to 'cpu', same thing holds

# Block until CUDA is done
def block( y):
    if type( y) == type( jax.numpy.array([])):
        y.block_until_ready()
    elif dev == 'cuda':
        torch.cuda.synchronize()

# Timing routine
def time_it( f):
    from time import time

    # Warmup
    for _ in range( 3):
        y = f()
    block( y)

    # Count how many passes we can queue in 5 sec
    c = 0
    t0 = time()
    while time()-t0 < 5:
        y += f()
        c += 1
    block( y)
    print( f'\t{c:6,d} iterations in {time()-t0:.2f} seconds')

"""# Jax convolutions

Note how the transpose convolution is orders of magnitude slower
"""

from functools import partial

print( 'Using jax', jax.__version__, 'on', dev)

# Jax regular 1d conv
@partial( jax.jit, backend=dev)
def jconvf( x, F):
    return jax.lax.conv_general_dilated( lhs=x, rhs=F,
        window_strides=(hp,), padding=((sz,sz),),
        dimension_numbers=('NCT','OIT','NCT'))

# Jax transpose 1d conv
@partial( jax.jit, backend=dev)
def jconvt( f, F):
    return jax.lax.conv_general_dilated( lhs=f, rhs=F,
        window_strides=(1,), lhs_dilation=(hp,), padding=((sz-1,sz-1),),
        dimension_numbers=('NCT','IOT','NCT'))

# Time them
x = jax.numpy.ones( (16, 1, sz*100))
F = jax.numpy.ones( (sz//2+1, 1, sz))
print( 'Jax 1d conv :', end='')
time_it( lambda: jconvf( x, F))

f = jax.numpy.ones( (16, sz//2+1, sz*100//hp))
F = jax.numpy.ones( (sz//2+1, 1, sz))
print( 'Jax 1d convt:', end='')
time_it( lambda: jconvt( f, F))

"""# Torch convolutions

Both convolutions types have comparable runtimes.  Regular convolution is on par with Jax, transpose is way faster than Jax.
"""

print( 'Using torch', torch.__version__, 'on', dev)

# Torch regular 1d conv
def tconvf( x, F):
    return torch.nn.functional.conv1d( x, F, stride=hp, padding=sz)

# Torch transpose 1d conv
def tconvt( f, F):
    return torch.nn.functional.conv_transpose1d( f, F, stride=hp)

# Time them
x = torch.ones( (16, 1, sz*100), device=dev)
F = torch.ones( (sz//2+1, 1, sz), device=dev)
print( 'Torch 1d conv :', end='')
time_it( lambda: tconvf( x, F))

f = torch.ones( (16, sz//2+1, sz*100//hp), device=dev)
F = torch.ones( (sz//2+1, 1, sz), device=dev)
print( 'Torch 1d convt:', end='')
time_it( lambda: tconvt( f, F))

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.4.26
jaxlib: 0.4.26
numpy:  1.26.4
python: 3.10.12 (main, Sep 11 2024, 15:47:36) [GCC 11.4.0]
jax.devices (1 total, 1 local): [cuda(id=0)]
process_count: 1
platform: uname_result(system='Linux', node='f4a29f286e8e', release='6.1.85+', version='#1 SMP PREEMPT_DYNAMIC Thu Jun 27 21:05:47 UTC 2024', machine='x86_64')


$ nvidia-smi
Fri Sep 20 00:07:14 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.104.05             Driver Version: 535.104.05   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|=========================================+======================+======================|
|   0  Tesla T4                       Off | 00000000:00:04.0 Off |                    0 |
| N/A   66C    P0              31W /  70W |  11493MiB / 15360MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
                                                                                         
+---------------------------------------------------------------------------------------+
| Processes:                                                                            |
|  GPU   GI   CI        PID   Type   Process name                            GPU Memory |
|        ID   ID                                                             Usage      |
|=======================================================================================|
+---------------------------------------------------------------------------------------+
@psmaragdis psmaragdis added the bug Something isn't working label Sep 20, 2024
@hawkinsp hawkinsp added the CPU Issues related to the CPU compiler/runtime label Sep 20, 2024
@hawkinsp
Copy link
Collaborator

Assigning @penpornk for the CPU part.

(The CUDA part probably should receive a look also, but the CPU problem is much worse. It probably means we're falling back to a naive implementation rather than using an optimized kernel.)

@hawkinsp hawkinsp added the NVIDIA GPU Issues specific to NVIDIA GPUs label Sep 20, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working CPU Issues related to the CPU compiler/runtime NVIDIA GPU Issues specific to NVIDIA GPUs
Projects
None yet
Development

No branches or pull requests

3 participants