Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Obscure validation failure due to _use_cached_eval_dataset #20177

Open
DLumi opened this issue Aug 28, 2024 · 5 comments
Open

Obscure validation failure due to _use_cached_eval_dataset #20177

DLumi opened this issue Aug 28, 2024 · 5 comments

Comments

@DLumi
Copy link

DLumi commented Aug 28, 2024

I'll preface by saying that I encountered this issue with tf_keras == 2.15, but the source code regarding evaluation is hardly different from v2.15, I feel that it's still applicable here.

The issue is that no matter what fit forces evaluate to use stored dataset object for the validation step instead of whatever object you supply to fit. This is super obscure, but it's probably done for some performance reasons, so whatever.
Why is this an issue?
If you change something about your dataset (like, initially you forgot to turn on .ignore_errors()) mid training, and then you pass the new DS instance to fit, it completely ignores this fact. And in this particular case, it would fail if any errors arise on the DS preprocessing steps.

Yes, you can cure it by model._eval_data_handler = None, which in turn forces evaluate to cache the new object, but to figure this out, you have to spend some time on diving into the source code.

So what I propose is:

  1. a mention about said functionality in fit's documentation
  2. some actual public API for either cleaning cached validation objects, or disabling caching behavior entirely

P.S. I'd provide a colab link, but it turns out that making a tf.Dataset that randomly fails when I want it to is actually way harder than it seems

@sachinprasadhs sachinprasadhs added keras-team-review-pending Pending review by a Keras team member. type:Bug labels Aug 28, 2024
@mattdangerw
Copy link
Member

@DLumi I suspect this is not an issue on Keras 3 actually. Keras 2 actually caches an attribute on self, which totally makes sense that it might mess up in a fit() call fails in the middle. But Keras 3 just passes and additional kwarg, so there shouldn't be anything stateful to mess up.

# TODO: respect compiled trainable state
use_cached_eval_dataset = kwargs.pop("_use_cached_eval_dataset", False)
if kwargs:
raise ValueError(f"Arguments not recognized: {kwargs}")
if use_cached_eval_dataset:
epoch_iterator = self._eval_epoch_iterator

If you think we could apply the same approach to tf-keras, you are welcome to open up a PR there. Otherwise we will probably stick to this being fixed on Keras 3. I will close this for now, but if you can recreate a bug on Keras 3 please re-open!

(Also, as to why this exists, yes it's to avoid some overhead the creating the dataset iterator. But Keras 3 handles this much more elegantly than Keras 2)

Copy link

Are you satisfied with the resolution of your issue?
Yes
No

@DLumi
Copy link
Author

DLumi commented Aug 29, 2024

But Keras 3 just passes and additional kwarg, so there shouldn't be anything stateful to mess up.

Uh, I'm pretty sure it's functionally exactly the same as in Keras 2, as I see little to no change in actual code. Here's Keras 2 code for comparison:
https://github.com/keras-team/tf-keras/blob/c5f97730b2e495f5f56fc2267d22504075e46337/tf_keras/engine/training.py#L2236C1-L2241C55

Maybe I am missing something here, though?

Anyways, it would greatly help if I knew how to recreate first-working-then-failing tf.Dataset on the toy scale.
This way I could recreate the setup, and potentially give you a definite way to reproduce this with Keras 3.

@mattdangerw mattdangerw reopened this Aug 29, 2024
@mattdangerw
Copy link
Member

Ah my bad, I misread the code. Still there is a key difference between Keras 2 and Keras 3 here. This line

self._eval_epoch_iterator = None

In Keras 3, we always clear the cached dataset at the beginning of fit. Which is not true in Keras 2. So I see how a crashing fit could cause an issue in Keras 2, but not in Keras 3.

As for a crashing a dataset, maybe something like this.

import tensorflow as tf

ds = tf.data.Dataset.from_tensor_slices(tf.range(100))

@tf.py_function(Tout=tf.int32)
def crasher(x):
    if x > 50:
        raise ValueError
    return x

ds = ds.map(crasher)

for x in ds:
    print(x)

@sachinprasadhs sachinprasadhs added stat:awaiting response from contributor and removed keras-team-review-pending Pending review by a Keras team member. labels Aug 29, 2024
Copy link

This issue is stale because it has been open for 14 days with no activity. It will be closed if no further activity occurs. Thank you.

@github-actions github-actions bot added the stale label Sep 13, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

3 participants