Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
felixdittrich92 committed Sep 2, 2024
1 parent cb51591 commit 7b714d8
Showing 1 changed file with 15 additions and 14 deletions.
29 changes: 15 additions & 14 deletions doctr/file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,20 @@
logging.info("Disabling PyTorch because USE_TF is set")
_torch_available = False

# Compatibility fix to make sure tf.keras stays at Keras 2
if "TF_USE_LEGACY_KERAS" not in os.environ:
os.environ["TF_USE_LEGACY_KERAS"] = "1"

elif os.environ["TF_USE_LEGACY_KERAS"] != "1":
raise ValueError(

Check warning on line 43 in doctr/file_utils.py

View check run for this annotation

Codecov / codecov/patch

doctr/file_utils.py#L42-L43

Added lines #L42 - L43 were not covered by tests
"docTR is only compatible with Keras 2, but you have explicitly set `TF_USE_LEGACY_KERAS` to `0`. "
)


def ensure_keras_v2() -> None: # pragma: no cover
if not os.environ.get("TF_USE_LEGACY_KERAS") == "1":
os.environ["TF_USE_LEGACY_KERAS"] = "1"


if USE_TF in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TORCH not in ENV_VARS_TRUE_VALUES:
_tf_available = importlib.util.find_spec("tensorflow") is not None
Expand Down Expand Up @@ -65,6 +79,7 @@
_tf_available = False
else:
logging.info(f"TensorFlow version {_tf_version} available.")
ensure_keras_v2()
import tensorflow as tf

# Enable eager execution - this is required for some models to work properly
Expand All @@ -80,20 +95,6 @@
" is installed and that either USE_TF or USE_TORCH is enabled."
)

# Compatibility fix to make sure tf.keras stays at Keras 2
if "TF_USE_LEGACY_KERAS" not in os.environ:
os.environ["TF_USE_LEGACY_KERAS"] = "1"

elif os.environ["TF_USE_LEGACY_KERAS"] != "1":
raise ValueError(
"docTR is only compatible with Keras 2, but you have explicitly set `TF_USE_LEGACY_KERAS` to `0`. "
)


def ensure_keras_v2() -> None: # pragma: no cover
if not os.environ.get("TF_USE_LEGACY_KERAS") == "1":
os.environ["TF_USE_LEGACY_KERAS"] = "1"


def requires_package(name: str, extra_message: Optional[str] = None) -> None: # pragma: no cover
"""
Expand Down

0 comments on commit 7b714d8

Please sign in to comment.