Skip to content

Commit

Permalink
fix #2134 (#2248)
Browse files Browse the repository at this point in the history
  • Loading branch information
divyashreepathihalli authored Dec 15, 2023
1 parent e2b627e commit b980f68
Showing 1 changed file with 15 additions and 11 deletions.
26 changes: 15 additions & 11 deletions keras_cv/models/object_detection/predict_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,23 @@
import tensorflow as tf

try:
# To-do: these imports need to fixed - Issue 2134
# https://github.com/keras-team/keras-cv/issues/2134
# from keras.src.engine.training import _minimum_control_deps
# from keras.src.engine.training import reduce_per_replica
from keras.src.utils import tf_utils
except ImportError:
# from keras.engine.training import _minimum_control_deps
# from keras.engine.training import reduce_per_replica
from keras.utils import tf_utils


def _minimum_control_deps(outputs):
"""Returns the minimum control dependencies to ensure step succeeded."""
if tf.executing_eagerly():
return [] # Control dependencies not needed.
outputs = tf.nest.flatten(outputs, expand_composites=True)
for out in outputs:
# Variables can't be control dependencies.
if not isinstance(out, tf.Variable):
return [out] # Return first Tensor or Op from outputs.
return [] # No viable Tensor or Op to use for control deps.


def make_predict_function(model, force=False):
if model.predict_function is not None and not force:
return model.predict_function
Expand All @@ -36,8 +42,8 @@ def step_function(iterator):
def run_step(data):
outputs = model.predict_step(data)
# Ensure counter is updated only if `test_step` succeeds.
# with tf.control_dependencies(_minimum_control_deps(outputs)):
model._predict_counter.assign_add(1)
with tf.control_dependencies(_minimum_control_deps(outputs)):
model._predict_counter.assign_add(1)
return outputs

if model._jit_compile:
Expand All @@ -47,9 +53,7 @@ def run_step(data):

data = next(iterator)
outputs = model.distribute_strategy.run(run_step, args=(data,))
# outputs = reduce_per_replica(
# outputs, model.distribute_strategy, reduction="concat"
# )
outputs = model.distribute_strategy.gather(outputs, axis=0)
# Note that this is the only deviation from the base keras.Model
# implementation. We add the decode_step inside of the computation
# graph but outside of the distribute_strategy (i.e on host CPU).
Expand Down

0 comments on commit b980f68

Please sign in to comment.