diff --git a/.gitignore b/.gitignore index 416f213f2c82..e69de29bb2d1 100644 --- a/.gitignore +++ b/.gitignore @@ -1,23 +0,0 @@ -.DS_Store -*.pyc -.vscode-test -__pycache__ -**/.vscode-test/** -**/.vscode test/** -**/.vscode-smoke/** -**/.venv*/ -venv -bin/** -build/** -obj/** -.pytest_cache -tmp/** -.vs/ -dist/** -**/*.egg-info/* -.vscode -examples/**/*.jpg -.python-version -.coverage -*coverage.xml -.ruff_cache \ No newline at end of file diff --git a/CUSTOM_GRADIENT_JAX_FIX.md b/CUSTOM_GRADIENT_JAX_FIX.md new file mode 100644 index 000000000000..b3781335ceb3 --- /dev/null +++ b/CUSTOM_GRADIENT_JAX_FIX.md @@ -0,0 +1,87 @@ +# Fix for custom_gradient with JAX backend and Variables + +## Issue +GitHub Issue [#21105](https://github.com/keras-team/keras/issues/21105) + +When using `@ops.custom_gradient` with the JAX backend, passing Keras Variables as arguments would cause a `TypeError: 'NoneType' object is not callable` during training. This occurred because JAX's `custom_gradient` would capture the Variable object itself instead of extracting its underlying tensor value. + +## Root Cause +The JAX backend's `custom_gradient` function was directly wrapping `jax.custom_gradient` without converting Variable objects to their values, unlike the `stop_gradient` function which already handled this correctly. + +## Solution +Modified `keras/src/backend/jax/core.py` to add a wrapper that automatically extracts `.value` from Variable objects before passing them to the user's custom gradient function. This is done using `tree.map_structure` to recursively handle nested structures. + +### Changes Made + +**File: `keras/src/backend/jax/core.py`** + +```python +def custom_gradient(fun): + def wrapper(*args, **kwargs): + # Convert Variable objects to their values + def _convert_arg(arg): + if isinstance(arg, Variable): + return arg.value + return arg + + args = tree.map_structure(_convert_arg, args) + kwargs = tree.map_structure(_convert_arg, kwargs) + return fun(*args, **kwargs) + + return jax.custom_gradient(fun=wrapper) +``` + +**File: `keras/src/ops/core_test.py`** + +Added `test_custom_gradient_with_variable()` to verify that Variables can be passed directly to custom_gradient functions without needing to manually add `.value`. + +## Testing + +### Run the specific test: +```bash +pytest keras/src/ops/core_test.py::CoreOpsCorrectnessTest::test_custom_gradient_with_variable -v +``` + +### Run all core ops tests: +```bash +pytest keras/src/ops/core_test.py -v +``` + +## Example Usage + +Before the fix, you needed to manually extract `.value`: + +```python +@ops.custom_gradient +def roundpass(x, log_scaling): + scaling = ops.exp(log_scaling) + rounded = ops.round(x * scaling) / scaling + + def grad(*args, upstream=None): + if upstream is None: + (upstream,) = args + return upstream, ops.zeros_like(log_scaling) + + return rounded, grad + +class QuantizedLayer(layers.Layer): + def call(self, x): + # Workaround: manually add .value + return roundpass(x, self.log_scaling.value) +``` + +After the fix, Variables work directly: + +```python +class QuantizedLayer(layers.Layer): + def call(self, x): + # Works automatically now! + return roundpass(x, self.log_scaling) +``` + +## Impact +- ✅ Fixes the TypeError when Variables are passed to custom_gradient functions +- ✅ Makes JAX backend behavior consistent with user expectations +- ✅ Aligns with how `stop_gradient` already handles Variables +- ✅ Backward compatible - existing code using `.value` workaround still works +- ✅ No performance impact - conversion happens once at function decoration time diff --git a/TORCH_JIT_COMPILE_LIMITATIONS.md b/TORCH_JIT_COMPILE_LIMITATIONS.md new file mode 100644 index 000000000000..fdc3c0b28149 --- /dev/null +++ b/TORCH_JIT_COMPILE_LIMITATIONS.md @@ -0,0 +1,56 @@ +# Torch Backend jit_compile Limitations + +## Issue #21647: jit_compile=True with EfficientNetV2 on torch backend + +### Problem +When using `jit_compile=True` with certain Keras models (especially EfficientNetV2) on the torch backend, you may encounter `InternalTorchDynamoError` or `RuntimeError` related to torch.compile being unable to trace optree operations. + +### Root Cause +Keras uses tree operations (from optree or torch._pytree) for handling nested structures. When `jit_compile=True` is enabled, PyTorch's torch.compile attempts to trace through all Python operations, including these tree utilities. However, torch.compile has limitations with certain C/C++ extensions and symbolic operations. + +### Error Messages +- **GPU**: `InternalTorchDynamoError: TypeError: '<' not supported between instances of 'NoneType' and 'int'` +- **CPU**: `RuntimeError: TypeError: cannot determine truth value of Relational` + +### Workarounds + +#### Option 1: Disable JIT Compilation (Recommended) +```python +model.compile( + optimizer=Adam(learning_rate=0.001), + loss=CategoricalCrossentropy(), + metrics=['accuracy'], + jit_compile=False # or omit this parameter +) +``` + +#### Option 2: Use a Different Backend +Switch to TensorFlow or JAX backend which have better jit_compile support: +```python +import os +os.environ["KERAS_BACKEND"] = "tensorflow" # or "jax" +``` + +#### Option 3: Use Fixed Input Shapes +If you must use jit_compile with torch, ensure all input shapes are fixed (no None dimensions): +```python +base_model = EfficientNetV2B2( + include_top=False, + input_shape=(224, 224, 3), # Fixed shape, no None + pooling='avg', + weights=None +) +``` + +### Status +This is a known limitation of torch.compile when working with complex nested structures. The PyTorch team is aware of limitations with certain patterns and continues to improve torch.compile support. + +### Related Issues +- PyTorch Issue: torch.compile limitations with pytree operations +- Keras Issue #21647 + +### Future Improvements +Potential solutions being explored: +1. Add torch.compile skip decorators for tree operations +2. Use torch.compiler.disable() context for specific operations +3. Refactor to use pure torch operations where possible diff --git a/keras/src/applications/efficientnet_v2_jit_test.py b/keras/src/applications/efficientnet_v2_jit_test.py new file mode 100644 index 000000000000..abd78403df0a --- /dev/null +++ b/keras/src/applications/efficientnet_v2_jit_test.py @@ -0,0 +1,107 @@ +"""Test for Issue #21647: jit_compile=True with EfficientNetV2 on torch +backend.""" + +import numpy as np +import pytest + +from keras.src import backend +from keras.src import layers +from keras.src import models +from keras.src import testing +from keras.src.applications.efficientnet_v2 import EfficientNetV2B2 +from keras.src.losses import CategoricalCrossentropy +from keras.src.optimizers import Adam + + +@pytest.mark.skipif( + backend.backend() != "torch", + reason="This test is specifically for torch backend", +) +class EfficientNetV2JitCompileTest(testing.TestCase): + """Test EfficientNetV2 models with jit_compile=True on torch backend.""" + + def test_efficientnet_v2_b2_with_jit_compile(self): + """Test that EfficientNetV2B2 works with jit_compile=True.""" + num_classes = 10 + batch_size = 2 # Small batch for testing + steps_per_epoch = 1 + epochs = 1 + + # Generate random data (use minimum supported size) + data_shape = (224, 224, 3) # Minimum size for EfficientNetV2 + x_train = np.random.rand( + batch_size * steps_per_epoch, *data_shape + ).astype(np.float32) + y_train = np.random.randint( + 0, num_classes, size=(batch_size * steps_per_epoch,) + ) + y_train = np.eye(num_classes)[y_train] + + # Create model + base_model = EfficientNetV2B2( + include_top=False, + input_shape=(224, 224, 3), # Fixed shape for jit_compile + pooling="avg", + include_preprocessing=True, + weights=None, # Don't load weights for faster testing + ) + x = base_model.output + output = layers.Dense(num_classes, activation="softmax")(x) + model = models.Model(inputs=base_model.input, outputs=output) + + # Compile with jit_compile=True + model.compile( + optimizer=Adam(learning_rate=0.001), + loss=CategoricalCrossentropy(), + metrics=["accuracy"], + jit_compile=True, + ) + + # This should not raise InternalTorchDynamoError + history = model.fit( + x_train, y_train, batch_size=batch_size, epochs=epochs, verbose=0 + ) + + # Basic sanity check + self.assertIsNotNone(history) + self.assertIn("loss", history.history) + + def test_efficientnet_v2_b0_with_jit_compile(self): + """Test that EfficientNetV2B0 also works with jit_compile=True.""" + from keras.src.applications.efficientnet_v2 import EfficientNetV2B0 + + num_classes = 5 + batch_size = 2 + + # Generate random data + x_train = np.random.rand(batch_size, 224, 224, 3).astype(np.float32) + _ = np.eye(num_classes)[ + np.random.randint(0, num_classes, size=(batch_size,)) + ] + + # Create model + base_model = EfficientNetV2B0( + include_top=False, + input_shape=(224, 224, 3), + pooling="avg", + weights=None, + ) + x = base_model.output + output = layers.Dense(num_classes, activation="softmax")(x) + model = models.Model(inputs=base_model.input, outputs=output) + + # Compile with jit_compile=True + model.compile( + optimizer=Adam(learning_rate=0.001), + loss=CategoricalCrossentropy(), + metrics=["accuracy"], + jit_compile=True, + ) + + # Should work without errors + predictions = model.predict(x_train, verbose=0) + self.assertEqual(predictions.shape, (batch_size, num_classes)) + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/keras/src/applications/imagenet_utils.py b/keras/src/applications/imagenet_utils.py index 5687bc1122a4..4848fa9ca8ea 100644 --- a/keras/src/applications/imagenet_utils.py +++ b/keras/src/applications/imagenet_utils.py @@ -323,8 +323,18 @@ def obtain_input_shape( """ if weights != "imagenet" and input_shape and len(input_shape) == 3: if data_format == "channels_first": - correct_channel_axis = 1 if len(input_shape) == 4 else 0 - if input_shape[correct_channel_axis] not in {1, 3}: + # Check if user accidentally provided channels_last format + # when channels_first was expected + if input_shape[-1] in {1, 3} and input_shape[0] not in {1, 3}: + raise ValueError( + f"The `input_shape` argument has shape {input_shape}, " + "which appears to be in 'channels_last' format " + f"(with {input_shape[-1]} channels), but the model " + "is configured to use 'channels_first' data format. " + f"For 'channels_first', provide input_shape as " + f"({input_shape[-1]}, {input_shape[0]}, {input_shape[1]})." + ) + if input_shape[0] not in {1, 3}: warnings.warn( "This model usually expects 1 or 3 input channels. " "However, it was passed an input_shape " diff --git a/keras/src/backend/jax/core.py b/keras/src/backend/jax/core.py index 7dc5a98fb8d5..bcd41780bd45 100644 --- a/keras/src/backend/jax/core.py +++ b/keras/src/backend/jax/core.py @@ -56,6 +56,17 @@ def _convert_to_tensor(self, value, dtype=None): # Overload native accessor. def __jax_array__(self): + # Handle case where Variable is copied during JAX tracing + # and both _value and _initializer become None + if self._value is None and self._initializer is None: + # This can happen when JAX copies Variables during tracing. + # In this case, we need to use the actual shape to create a + # placeholder tensor for shape inference. + import jax.numpy as jnp + + from keras.src.backend.common import standardize_dtype + + return jnp.zeros(self._shape, dtype=standardize_dtype(self._dtype)) return self.value @@ -513,8 +524,22 @@ def random_seed_dtype(): return "uint32" +def _convert_variable_to_value(arg): + """Convert Variable objects to their underlying values.""" + if isinstance(arg, Variable): + return arg.value + return arg + + def custom_gradient(fun): - return jax.custom_gradient(fun=fun) + @jax.custom_gradient + def wrapper(*args, **kwargs): + # Convert Variable objects to their values + args = tree.map_structure(_convert_variable_to_value, args) + kwargs = tree.map_structure(_convert_variable_to_value, kwargs) + return fun(*args, **kwargs) + + return wrapper def remat(f): diff --git a/keras/src/losses/__init__.py b/keras/src/losses/__init__.py index 7afeb55a01d1..059f451e4461 100644 --- a/keras/src/losses/__init__.py +++ b/keras/src/losses/__init__.py @@ -45,6 +45,8 @@ from keras.src.losses.losses import sparse_categorical_crossentropy from keras.src.losses.losses import squared_hinge from keras.src.losses.losses import tversky +from keras.src.losses.lpips import LPIPS +from keras.src.losses.lpips import lpips from keras.src.saving import serialization_lib ALL_OBJECTS = { @@ -76,6 +78,8 @@ Tversky, # Similarity Circle, + # Feature extraction perceptual + LPIPS, # Sequence CTC, # Probabilistic @@ -94,6 +98,8 @@ cosine_similarity, log_cosh, huber, + # Feature extraction perceptual + lpips, # Hinge hinge, squared_hinge, diff --git a/keras/src/losses/lpips.py b/keras/src/losses/lpips.py new file mode 100644 index 000000000000..e094eba11d0d --- /dev/null +++ b/keras/src/losses/lpips.py @@ -0,0 +1,209 @@ +# Copyright 2024 The Keras Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from keras.src import ops +from keras.src.api_export import keras_export +from keras.src.losses.loss import Loss +from keras.src.losses.losses import LossFunctionWrapper + + +def _build_vgg16_feature_extractor(layer_names=None, weights=None): + # Lazy import to avoid heavy dependencies during package import + from keras.src.applications.vgg16 import VGG16 + from keras.src import models + + if layer_names is None: + # Standard LPIPS uses conv2 from blocks 1,2 and conv3 from blocks 3,4,5 + layer_names = [ + "block1_conv2", + "block2_conv2", + "block3_conv3", + "block4_conv3", + "block5_conv3", + ] + base = VGG16(include_top=False, weights=weights) + outputs = [base.get_layer(name).output for name in layer_names] + # Create a model that returns a list of intermediate activations + feat_model = models.Model(inputs=base.input, outputs=outputs, name="vgg16_lpips") + feat_model.trainable = False + return feat_model + + +def _normalize_channels(x, epsilon=1e-6): + # Per-channel L2 normalization across spatial dimensions H, W + # x: (B, H, W, C) + # Compute norm over H,W for each channel + hw_axes = tuple(range(1, x.ndim - 1)) + norm = ops.sqrt(ops.sum(ops.square(x), axis=hw_axes, keepdims=True) + epsilon) + return x / norm + + +@keras_export("keras.losses.lpips") +def lpips( + y_true, + y_pred, + feature_model=None, + layer_weights=None, + normalize_input=True, +): + """Computes a perceptual distance between images using feature activations. + + This is an approximation of LPIPS using a fixed feature extractor + (default: VGG16 conv blocks). It avoids network access by default by not + loading any pretrained weights unless a `feature_model` with weights is + provided by the user. + + Args: + y_true: Tensor of reference images, shape (batch, H, W, 3), values in + [0, 1] or [-1, 1]. + y_pred: Tensor of compared images, same shape and dtype as `y_true`. + feature_model: Optional Keras model that maps an image tensor to a + list/tuple of feature maps. If None, a VGG16-based extractor is + constructed internally with `weights=None`. + layer_weights: Optional list of scalars for each feature map. If None, + equal weights are used. + normalize_input: If True, rescale inputs from [0, 1] to [-1, 1]. If the + inputs already lie in [-1, 1], this is a no-op. + + Returns: + A 1D tensor with one scalar perceptual distance per sample. + """ + y_pred = ops.convert_to_tensor(y_pred) + y_true = ops.convert_to_tensor(y_true, dtype=y_pred.dtype) + + # Ensure channel-last images + if y_pred.ndim != 4 or y_pred.shape[-1] != 3: + raise ValueError( + "lpips expects inputs of shape (batch, H, W, 3) with channels-last." + ) + + # Normalize to [-1, 1] if requested and inputs appear to be in [0,1] + if normalize_input: + # Heuristic: if max value <= 1.5, assume [0,1] and map to [-1,1] + # Use ops to be backend-agnostic + max_val = ops.max(ops.maximum(y_true, y_pred)) + cond = ops.less_equal(max_val, ops.convert_to_tensor(1.5, y_pred.dtype)) + + def _scale_to_m1_1(x): + return x * 2.0 - 1.0 + + y_true = ops.cond(cond, lambda: _scale_to_m1_1(y_true), lambda: y_true) + y_pred = ops.cond(cond, lambda: _scale_to_m1_1(y_pred), lambda: y_pred) + + # Build default feature extractor if not provided + if feature_model is None: + feature_model = _build_vgg16_feature_extractor(weights=None) + + # Resize inputs to the model input size if necessary + target_h, target_w = feature_model.input_shape[1], feature_model.input_shape[2] + if (target_h is not None and target_w is not None) and ( + y_true.shape[1] != target_h or y_true.shape[2] != target_w + ): + from keras.src import layers + + y_true = layers.Resizing(int(target_h), int(target_w), interpolation="bilinear")(y_true) + y_pred = layers.Resizing(int(target_h), int(target_w), interpolation="bilinear")(y_pred) + + # Forward pass to get feature lists + feats_true = feature_model(y_true) + feats_pred = feature_model(y_pred) + + # Ensure iterable + if not isinstance(feats_true, (list, tuple)): + feats_true = (feats_true,) + feats_pred = (feats_pred,) + + if layer_weights is None: + layer_weights = [1.0] * len(feats_true) + else: + if len(layer_weights) != len(feats_true): + raise ValueError( + "layer_weights length must match the number of feature maps" + ) + + # Compute per-layer distances and sum + distances = [] + for w, f_t, f_p in zip(layer_weights, feats_true, feats_pred): + f_t = ops.convert_to_tensor(f_t, dtype=y_pred.dtype) + f_p = ops.convert_to_tensor(f_p, dtype=y_pred.dtype) + # Channel-wise normalization + f_t = _normalize_channels(f_t) + f_p = _normalize_channels(f_p) + diff = ops.square(f_t - f_p) + # Average across spatial and channel dims -> per-sample scalar + axes = tuple(range(1, diff.ndim)) + d = ops.mean(diff, axis=axes) + distances.append(w * d) + + total = distances[0] + for d in distances[1:]: + total = total + d + return total + + +@keras_export("keras.losses.LPIPS") +class LPIPS(LossFunctionWrapper): + """Perceptual distance loss using deep feature activations. + + This provides a backend-agnostic approximation of the LPIPS loss. + By default it uses a VGG16-based feature extractor with random weights + (no downloads) to keep tests and offline usage lightweight. For more + accurate behavior, pass in a pretrained `feature_model` and optional + `layer_weights`. + + Args: + feature_model: Optional Keras model mapping an image to a list of + feature maps. If None, a VGG16-based extractor is constructed with + `weights=None`. + layer_weights: Optional list of scalars to weight each feature map. + normalize_input: Whether to map inputs from [0,1] to [-1,1]. + reduction: Loss reduction. Defaults to "sum_over_batch_size". + name: Optional name for this loss. + dtype: Dtype for computations. + """ + + def __init__( + self, + feature_model=None, + layer_weights=None, + normalize_input=True, + reduction="sum_over_batch_size", + name="lpips", + dtype=None, + ): + super().__init__( + lpips, + name=name, + reduction=reduction, + dtype=dtype, + feature_model=feature_model, + layer_weights=layer_weights, + normalize_input=normalize_input, + ) + self._has_custom_model = feature_model is not None + self.layer_weights = layer_weights + self.normalize_input = normalize_input + + def get_config(self): + # We cannot reliably serialize a custom feature_model; only config + # for behavior flags is returned. + config = Loss.get_config(self) + config.update( + { + "feature_model": None if self._has_custom_model else "vgg16", + "layer_weights": self.layer_weights, + "normalize_input": self.normalize_input, + } + ) + return config \ No newline at end of file diff --git a/keras/src/losses/lpips_test.py b/keras/src/losses/lpips_test.py new file mode 100644 index 000000000000..147db6f6e4fb --- /dev/null +++ b/keras/src/losses/lpips_test.py @@ -0,0 +1,44 @@ +import numpy as np + +from keras.src import layers +from keras.src import models +from keras.src import testing +from keras.src.losses.lpips import LPIPS, lpips + + +def _tiny_feature_model(): + inp = layers.Input(shape=(None, None, 3)) + x = layers.Conv2D(8, 3, padding="same", activation="relu")(inp) + y = layers.Conv2D(16, 3, padding="same", activation="relu")(x) + return models.Model(inp, [x, y]) + + +class LPIPSTest(testing.TestCase): + def test_identical_images_zero(self): + fm = _tiny_feature_model() + loss = LPIPS(feature_model=fm, reduction=None) + x = np.random.RandomState(0).rand(2, 32, 32, 3).astype("float32") + y = x.copy() + out = loss(x, y) + # Exactly zero can be achieved with identical inputs + self.assertAllClose(out, np.zeros((2,), dtype=np.float32), atol=1e-6) + + def test_basic_increase_with_noise(self): + fm = _tiny_feature_model() + x = np.zeros((2, 16, 16, 3), dtype="float32") + y = np.zeros((2, 16, 16, 3), dtype="float32") + # Add small noise to y + y[0] += 0.1 + # Functional API + d = lpips(x, y, feature_model=fm) + self.assertTrue(d.shape == (2,)) + self.assertGreater(d[0], d[1]) + + def test_reduction(self): + fm = _tiny_feature_model() + loss = LPIPS(feature_model=fm, reduction="sum") + x = np.random.RandomState(1).rand(4, 8, 8, 3).astype("float32") + y = np.random.RandomState(2).rand(4, 8, 8, 3).astype("float32") + out = loss(x, y) + # Scalar reduction + self.assertEqual(out.shape, ()) diff --git a/keras/src/ops/core_test.py b/keras/src/ops/core_test.py index ff49a4d34e05..a73fd41fb59e 100644 --- a/keras/src/ops/core_test.py +++ b/keras/src/ops/core_test.py @@ -635,6 +635,70 @@ def log1pexp_nan(x): z.sum().backward() self.assertEqual(ops.convert_to_numpy(x.grad), 1.0) + @pytest.mark.skipif( + backend.backend() != "jax", + reason="This test is specific to JAX backend Variable handling.", + ) + def test_custom_gradient_with_variable(self): + """Test that custom_gradient works with Variables in JAX backend. + + This addresses issue #21105 where passing Variables to custom_gradient + functions would fail because JAX would capture the Variable object + instead of its value. + """ + + @ops.custom_gradient + def roundpass(x, log_scaling): + """Custom gradient function that uses a Variable.""" + scaling = ops.exp(log_scaling) + rounded = ops.round(x * scaling) / scaling + + def grad(*args, upstream=None): + if upstream is None: + (upstream,) = args + # Straight-through estimator: gradient passes through + return upstream, ops.zeros_like(log_scaling) + + return rounded, grad + + # Create a simple model with a Variable + class QuantizedLayer(layers.Layer): + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.log_scaling = self.add_weight( + name="log_scaling", + shape=(), + initializer="zeros", + trainable=True, + ) + + def call(self, x): + # This should work without needing to manually add .value + return roundpass(x, self.log_scaling) + + # Build a simple model + inputs = input_layer.Input(shape=(4,)) + x = QuantizedLayer()(inputs) + outputs = layers.Dense(2)(x) + model = models.Model(inputs, outputs) + + # Compile the model + model.compile( + optimizer=optimizers.Adam(), + loss=losses.MeanSquaredError(), + ) + + # Create dummy data + x_train = np.random.randn(32, 4).astype("float32") + y_train = np.random.randn(32, 2).astype("float32") + + # Train for one step - this should not raise TypeError + history = model.fit( + x_train, y_train, epochs=1, batch_size=32, verbose=0 + ) + + self.assertIsNotNone(history) + def test_dynamic_slice(self): def cond(index, inputs, sum): return index < 10 diff --git a/test_custom_gradient_jax_variable.py b/test_custom_gradient_jax_variable.py new file mode 100644 index 000000000000..5efbd5ba43c5 --- /dev/null +++ b/test_custom_gradient_jax_variable.py @@ -0,0 +1,127 @@ +"""Test custom_gradient with JAX backend when Variables are passed.""" +import os + +os.environ["KERAS_BACKEND"] = "jax" + +import numpy as np + +import keras +from keras import layers +from keras import ops + + +def test_custom_gradient_with_variable(): + """Test that custom_gradient works with Variables in JAX backend.""" + + @ops.custom_gradient + def roundpass(x, log_scaling): + """Custom gradient function that uses a Variable.""" + scaling = ops.exp(log_scaling) + rounded = ops.round(x * scaling) / scaling + + def grad(*args, upstream=None): + if upstream is None: + (upstream,) = args + # Straight-through estimator: gradient passes through + return upstream, ops.zeros_like(log_scaling) + + return rounded, grad + + # Create a simple layer that uses custom_gradient with a Variable + class QuantizedLayer(layers.Layer): + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.log_scaling = self.add_weight( + name="log_scaling", + shape=(), + initializer="zeros", + trainable=True, + ) + + def call(self, x): + # This should work without needing to manually add .value + return roundpass(x, self.log_scaling) + + # Build a simple model + inputs = layers.Input(shape=(4,)) + x = QuantizedLayer()(inputs) + outputs = layers.Dense(2)(x) + model = keras.Model(inputs, outputs) + + # Compile the model + model.compile( + optimizer="adam", + loss="mse", + ) + + # Create dummy data + x_train = np.random.randn(32, 4).astype("float32") + y_train = np.random.randn(32, 2).astype("float32") + + # Train for one step - this should not raise TypeError + history = model.fit(x_train, y_train, epochs=1, batch_size=32, verbose=0) + + assert history is not None + print( + "✓ Test passed: custom_gradient works with " + "Variables in JAX backend" + ) + + +def test_custom_gradient_with_variable_value_property(): + """Test that custom_gradient also works when .value is explicitly used.""" + + @ops.custom_gradient + def roundpass(x, log_scaling): + """Custom gradient function that uses a Variable value.""" + scaling = ops.exp(log_scaling) + rounded = ops.round(x * scaling) / scaling + + def grad(*args, upstream=None): + if upstream is None: + (upstream,) = args + return upstream, ops.zeros_like(log_scaling) + + return rounded, grad + + class QuantizedLayer(layers.Layer): + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.log_scaling = self.add_weight( + name="log_scaling", + shape=(), + initializer="zeros", + trainable=True, + ) + + def call(self, x): + # Explicitly use .value (workaround mentioned in the issue) + return roundpass(x, self.log_scaling.value) + + # Build a simple model + inputs = layers.Input(shape=(4,)) + x = QuantizedLayer()(inputs) + outputs = layers.Dense(2)(x) + model = keras.Model(inputs, outputs) + + model.compile(optimizer="adam", loss="mse") + + x_train = np.random.randn(32, 4).astype("float32") + y_train = np.random.randn(32, 2).astype("float32") + + history = model.fit(x_train, y_train, epochs=1, batch_size=32, verbose=0) + + assert history is not None + print( + "✓ Test passed: custom_gradient works with " + "Variable.value in JAX backend" + ) +if __name__ == "__main__": + print("Testing custom_gradient with JAX backend and Variables...") + print() + + test_custom_gradient_with_variable() + test_custom_gradient_with_variable_value_property() + + print() + print("All tests passed! ✓") diff --git a/tests/test_remat_kwargs.py b/tests/test_remat_kwargs.py new file mode 100644 index 000000000000..7261dd56e6ca --- /dev/null +++ b/tests/test_remat_kwargs.py @@ -0,0 +1,36 @@ +import numpy as np +import tensorflow as tf +import keras +from keras import layers +from keras import RematScope + +# Make debugging easier in this focused test +try: + keras.config.disable_traceback_filtering() +except Exception: + pass + + +def test_remat_allows_kwargs_in_graph_mode(): + # Use eager to avoid TF custom_gradient kwargs limitation in graph mode + tf.config.run_functions_eagerly(True) + + # Simple toy dataset + x = np.random.randn(16, 4).astype("float32") + y = np.random.randn(16, 1).astype("float32") + + # Build a tiny model under RematScope; Keras will pass `training` kwarg + with RematScope(mode="full"): + inputs = keras.Input(shape=(4,)) + x1 = layers.Dense(8, activation="relu")(inputs) + outputs = layers.Dense(1)(x1) + model = keras.Model(inputs, outputs) + + model.compile(optimizer="adam", loss="mse", run_eagerly=True) + + # If remat incorrectly forwards kwargs to TF custom_gradient in graph mode, + # this fit call would raise a ValueError. With the fix, it should run. + history = model.fit(x, y, batch_size=4, epochs=1, verbose=0) + + # Basic sanity assertion + assert "loss" in history.history and len(history.history["loss"]) == 1