Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ops.image.affine_transform() does not work as a layer in GPU #20191

Open
kwchan7 opened this issue Aug 30, 2024 · 5 comments
Open

ops.image.affine_transform() does not work as a layer in GPU #20191

kwchan7 opened this issue Aug 30, 2024 · 5 comments
Assignees
Labels
stat:awaiting keras-eng Awaiting response from Keras engineer type:Bug

Comments

@kwchan7
Copy link

kwchan7 commented Aug 30, 2024

Hi,

I notice ops.image.affine_transform() does not work as part of a model in GPU
TF version: 2.16.1
keras version: 3.5.0

Some observations from some testing

  1. model.predict_step() works, but model.predict() does not
  2. it works if using CPU only
  3. other similar functions such as ops.image.resize(), ops.image.pad_images() work ok.

Samples code as below

import tensorflow as tf
import keras
from keras import layers
from keras import ops
import numpy as np

print('TF version: {:s}'.format(tf.__version__))
print('keras version: {:s}'.format(keras.__version__))

shape = (20,18,1)
inputs = layers.Input(shape)
transform = ops.stack([1, 0, 0, 0, 1, 0, 0, 0], axis = 0)

if 1:
    img = ops.image.affine_transform(
        inputs,
        transform,
        )
else:
    img = ops.image.resize(
        inputs,
        (10,9),
    )

y = layers.Flatten()(img)
outputs = layers.Dense(1)(y)
model = keras.Model(inputs,outputs)
model.summary()

x = np.random.uniform(-1,1,(10000,*shape))
yp = model.predict_step(x)
print(yp)
yp = model.predict(x)
print(yp)
@ghsanti
Copy link
Contributor

ghsanti commented Aug 31, 2024

Just adding extra info.

  • similar issue
  • It seems that GPU-XLA-JIT struggles to optimise the affine-transform operation, being un-registered (ig this means: tf can't turn this op to run on GPU-XLA-JIT):

InvalidArgumentError: Graph execution error:
Detected unsupported operations when trying to compile graph (...) on XLA_GPU_JIT: ImageProjectiveTransformV3 (No registered 'ImageProjectiveTransformV3' OpKernel for XLA_GPU_JIT ...

  • Disabling JIT works, but may not be desired

model.compile(jit_compile=False)

  • Somewhere they recommend to use bilinear interpolation but that doesn't help when jit_compile is done.

You could also try jit_scope, but it seems more complicated.


Edit:

@fchollet
Copy link
Member

Can you try JAX? I'd like to see if this is an XLA issue or a TF issue.

@ghsanti
Copy link
Contributor

ghsanti commented Sep 10, 2024

Done several tests (gist), I'll try to summarise below @fchollet :

  • TF: table here also shows that the operation does not support GPU. There are tickets asking for it since 2021.

  • Torch CPU & GPU:

--> 365     transform = torch.reshape(transform, (batch_size, 3, 3))
    366     offset = transform[:, 0:2, 2].clone()
    367     offset = torch.nn.functional.pad(offset, pad=[0, 1, 0, 0])

RuntimeError: shape '[10000, 3, 3]' is invalid for input of size 9
  • JAX

Same error than TORCH GPU for both CPU & GPU (see observation below.)


Using transform = torch.reshape(transform, (1, 3, 3)) instead of transform, (batch_size, 3, 3)) seems to run fine in jax (cpu,gpu), torch (cpu).
Same by replacing batch_size=10000 by batch_size=1.

Torch GPU returns RuntimeError: "baddbmm_cuda" not implemented for 'Int'

@fchollet
Copy link
Member

Torch GPU returns RuntimeError: "baddbmm_cuda" not implemented for 'Int'

For that one, you can simply cast your input to float32 (you can cast it back to int afterwards if you need ints)

Same error than TORCH GPU for both CPU & GPU (see observation below.)

What is the JAX error message? I don't understand the PyTorch error message (as is often the case with those).

@ghsanti
Copy link
Contributor

ghsanti commented Sep 10, 2024

@fchollet

cast your input to float32

should that be done automatically within the layer?

What is the JAX error message? I don't understand the PyTorch error message (as is often the case with those).

The line ( in this case just seems misuse of the fn since it takes a state. ):

Error one (line and screenshot)

yp = model.predict_step(x) 

Screenshot from 2024-09-10 18-11-56

Error two (line and screenshot)

And commenting that one out, the next error at model.predict(x) is:

Screenshot from 2024-09-10 18-12-16


Last error seems == to Torch's, just batch_size = 32

(note that im not OP just reading out of curiosity.)

@sachinprasadhs sachinprasadhs added stat:awaiting keras-eng Awaiting response from Keras engineer type:Bug labels Sep 10, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
stat:awaiting keras-eng Awaiting response from Keras engineer type:Bug
Projects
None yet
Development

No branches or pull requests

5 participants
@fchollet @sachinprasadhs @ghsanti @kwchan7 and others