Skip to content

Commit

Permalink
[To keep] -- [build] tf upgrade by keeping keras v2 (#1542)
Browse files Browse the repository at this point in the history
  • Loading branch information
felixdittrich92 authored Oct 7, 2024
1 parent da88880 commit 90c3fff
Show file tree
Hide file tree
Showing 41 changed files with 147 additions and 82 deletions.
2 changes: 1 addition & 1 deletion docs/source/using_doctr/using_model_export.rst
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ Advantages:
.. code:: python3
import tensorflow as tf
from keras import mixed_precision
from tensorflow.keras import mixed_precision
mixed_precision.set_global_policy('mixed_float16')
predictor = ocr_predictor(reco_arch="crnn_mobilenet_v3_small", det_arch="linknet_resnet34", pretrained=True)
Expand Down
19 changes: 19 additions & 0 deletions doctr/file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,20 @@
logging.info("Disabling PyTorch because USE_TF is set")
_torch_available = False

# Compatibility fix to make sure tensorflow.keras stays at Keras 2
if "TF_USE_LEGACY_KERAS" not in os.environ:
os.environ["TF_USE_LEGACY_KERAS"] = "1"

elif os.environ["TF_USE_LEGACY_KERAS"] != "1":
raise ValueError(
"docTR is only compatible with Keras 2, but you have explicitly set `TF_USE_LEGACY_KERAS` to `0`. "
)


def ensure_keras_v2() -> None: # pragma: no cover
if not os.environ.get("TF_USE_LEGACY_KERAS") == "1":
os.environ["TF_USE_LEGACY_KERAS"] = "1"


if USE_TF in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TORCH not in ENV_VARS_TRUE_VALUES:
_tf_available = importlib.util.find_spec("tensorflow") is not None
Expand Down Expand Up @@ -65,6 +79,11 @@
_tf_available = False
else:
logging.info(f"TensorFlow version {_tf_version} available.")
ensure_keras_v2()
import tensorflow as tf

# Enable eager execution - this is required for some models to work properly
tf.config.run_functions_eagerly(True)
else: # pragma: no cover
logging.info("Disabling Tensorflow because USE_TORCH is set")
_tf_available = False
Expand Down
2 changes: 1 addition & 1 deletion doctr/io/image/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@

import numpy as np
import tensorflow as tf
from keras.utils import img_to_array
from PIL import Image
from tensorflow.keras.utils import img_to_array

from doctr.utils.common_types import AbstractPath

Expand Down
4 changes: 2 additions & 2 deletions doctr/models/classification/magc_resnet/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
from typing import Any, Dict, List, Optional, Tuple

import tensorflow as tf
from keras import activations, layers
from keras.models import Sequential
from tensorflow.keras import activations, layers
from tensorflow.keras.models import Sequential

from doctr.datasets import VOCABS

Expand Down
4 changes: 2 additions & 2 deletions doctr/models/classification/mobilenet/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
from typing import Any, Dict, List, Optional, Tuple, Union

import tensorflow as tf
from keras import layers
from keras.models import Sequential
from tensorflow.keras import layers
from tensorflow.keras.models import Sequential

from ....datasets import VOCABS
from ...utils import conv_sequence, load_pretrained_params
Expand Down
2 changes: 1 addition & 1 deletion doctr/models/classification/predictor/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import numpy as np
import tensorflow as tf
from keras import Model
from tensorflow.keras import Model

from doctr.models.preprocessor import PreProcessor
from doctr.utils.repr import NestedObject
Expand Down
6 changes: 3 additions & 3 deletions doctr/models/classification/resnet/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@
from typing import Any, Callable, Dict, List, Optional, Tuple

import tensorflow as tf
from keras import layers
from keras.applications import ResNet50
from keras.models import Sequential
from tensorflow.keras import layers
from tensorflow.keras.applications import ResNet50
from tensorflow.keras.models import Sequential

from doctr.datasets import VOCABS

Expand Down
2 changes: 1 addition & 1 deletion doctr/models/classification/textnet/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from copy import deepcopy
from typing import Any, Dict, List, Optional, Tuple

from keras import Sequential, layers
from tensorflow.keras import Sequential, layers

from doctr.datasets import VOCABS

Expand Down
4 changes: 2 additions & 2 deletions doctr/models/classification/vgg/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
from copy import deepcopy
from typing import Any, Dict, List, Optional, Tuple

from keras import layers
from keras.models import Sequential
from tensorflow.keras import layers
from tensorflow.keras.models import Sequential

from doctr.datasets import VOCABS

Expand Down
2 changes: 1 addition & 1 deletion doctr/models/classification/vit/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from typing import Any, Dict, Optional, Tuple

import tensorflow as tf
from keras import Sequential, layers
from tensorflow.keras import Sequential, layers

from doctr.datasets import VOCABS
from doctr.models.modules.transformer import EncoderBlock
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@

import numpy as np
import tensorflow as tf
from keras import Model, Sequential, layers, losses
from keras.applications import ResNet50
from tensorflow.keras import Model, Sequential, layers, losses
from tensorflow.keras.applications import ResNet50

from doctr.file_utils import CLASS_NAME
from doctr.models.utils import IntermediateLayerGetter, _bf16_to_float32, conv_sequence, load_pretrained_params
Expand Down
2 changes: 1 addition & 1 deletion doctr/models/detection/fast/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

import numpy as np
import tensorflow as tf
from keras import Model, Sequential, layers
from tensorflow.keras import Model, Sequential, layers

from doctr.file_utils import CLASS_NAME
from doctr.models.utils import IntermediateLayerGetter, _bf16_to_float32, load_pretrained_params
Expand Down
2 changes: 1 addition & 1 deletion doctr/models/detection/linknet/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

import numpy as np
import tensorflow as tf
from keras import Model, Sequential, layers, losses
from tensorflow.keras import Model, Sequential, layers, losses

from doctr.file_utils import CLASS_NAME
from doctr.models.classification import resnet18, resnet34, resnet50
Expand Down
2 changes: 1 addition & 1 deletion doctr/models/detection/predictor/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import numpy as np
import tensorflow as tf
from keras import Model
from tensorflow.keras import Model

from doctr.models.detection._utils import _remove_padding
from doctr.models.preprocessor import PreProcessor
Expand Down
2 changes: 1 addition & 1 deletion doctr/models/modules/layers/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import numpy as np
import tensorflow as tf
from keras import layers
from tensorflow.keras import layers

from doctr.utils.repr import NestedObject

Expand Down
4 changes: 1 addition & 3 deletions doctr/models/modules/transformer/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,12 @@
from typing import Any, Callable, Optional, Tuple

import tensorflow as tf
from keras import layers
from tensorflow.keras import layers

from doctr.utils.repr import NestedObject

__all__ = ["Decoder", "PositionalEncoding", "EncoderBlock", "PositionwiseFeedForward", "MultiHeadAttention"]

tf.config.run_functions_eagerly(True)


class PositionalEncoding(layers.Layer, NestedObject):
"""Compute positional encoding"""
Expand Down
2 changes: 1 addition & 1 deletion doctr/models/modules/vision_transformer/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from typing import Any, Tuple

import tensorflow as tf
from keras import layers
from tensorflow.keras import layers

from doctr.utils.repr import NestedObject

Expand Down
4 changes: 2 additions & 2 deletions doctr/models/recognition/crnn/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
from typing import Any, Dict, List, Optional, Tuple, Union

import tensorflow as tf
from keras import layers
from keras.models import Model, Sequential
from tensorflow.keras import layers
from tensorflow.keras.models import Model, Sequential

from doctr.datasets import VOCABS

Expand Down
2 changes: 1 addition & 1 deletion doctr/models/recognition/master/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from typing import Any, Dict, List, Optional, Tuple

import tensorflow as tf
from keras import Model, layers
from tensorflow.keras import Model, layers

from doctr.datasets import VOCABS
from doctr.models.classification import magc_resnet31
Expand Down
5 changes: 1 addition & 4 deletions doctr/models/recognition/parseq/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

import numpy as np
import tensorflow as tf
from keras import Model, layers
from tensorflow.keras import Model, layers

from doctr.datasets import VOCABS
from doctr.models.modules.transformer import MultiHeadAttention, PositionwiseFeedForward
Expand Down Expand Up @@ -167,7 +167,6 @@ def __init__(

self.postprocessor = PARSeqPostProcessor(vocab=self.vocab)

@tf.function
def generate_permutations(self, seqlen: tf.Tensor) -> tf.Tensor:
# Generates permutations of the target sequence.
# Translated from https://github.com/baudm/parseq/blob/main/strhub/models/parseq/system.py
Expand Down Expand Up @@ -214,7 +213,6 @@ def generate_permutations(self, seqlen: tf.Tensor) -> tf.Tensor:
)
return combined

@tf.function
def generate_permutations_attention_masks(self, permutation: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor]:
# Generate source and target mask for the decoder attention.
sz = permutation.shape[0]
Expand All @@ -234,7 +232,6 @@ def generate_permutations_attention_masks(self, permutation: tf.Tensor) -> Tuple
target_mask = mask[1:, :-1]
return tf.cast(source_mask, dtype=tf.bool), tf.cast(target_mask, dtype=tf.bool)

@tf.function
def decode(
self,
target: tf.Tensor,
Expand Down
2 changes: 1 addition & 1 deletion doctr/models/recognition/sar/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from typing import Any, Dict, List, Optional, Tuple

import tensorflow as tf
from keras import Model, Sequential, layers
from tensorflow.keras import Model, Sequential, layers

from doctr.datasets import VOCABS
from doctr.utils.repr import NestedObject
Expand Down
2 changes: 1 addition & 1 deletion doctr/models/recognition/vitstr/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from typing import Any, Dict, List, Optional, Tuple

import tensorflow as tf
from keras import Model, layers
from tensorflow.keras import Model, layers

from doctr.datasets import VOCABS

Expand Down
6 changes: 3 additions & 3 deletions doctr/models/utils/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import tensorflow as tf
import tf2onnx
from keras import Model, layers
from tensorflow.keras import Model, layers

from doctr.utils.data import download_from_url

Expand Down Expand Up @@ -77,7 +77,7 @@ def conv_sequence(
) -> List[layers.Layer]:
"""Builds a convolutional-based layer sequence
>>> from keras import Sequential
>>> from tensorflow.keras import Sequential
>>> from doctr.models import conv_sequence
>>> module = Sequential(conv_sequence(32, 'relu', True, kernel_size=3, input_shape=[224, 224, 3]))
Expand Down Expand Up @@ -113,7 +113,7 @@ def conv_sequence(
class IntermediateLayerGetter(Model):
"""Implements an intermediate layer getter
>>> from keras.applications import ResNet50
>>> from tensorflow.keras.applications import ResNet50
>>> from doctr.models import IntermediateLayerGetter
>>> target_layers = ["conv2_block3_out", "conv3_block4_out", "conv4_block6_out", "conv5_block3_out"]
>>> feat_extractor = IntermediateLayerGetter(ResNet50(include_top=False, pooling=False), target_layers)
Expand Down
1 change: 0 additions & 1 deletion doctr/transforms/modules/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,7 +399,6 @@ def __init__(self, kernel_shape: Union[int, Iterable[int]], std: Tuple[float, fl
def extra_repr(self) -> str:
return f"kernel_shape={self.kernel_shape}, std={self.std}"

@tf.function
def __call__(self, img: tf.Tensor) -> tf.Tensor:
return tf.squeeze(
_gaussian_filter(
Expand Down
10 changes: 4 additions & 6 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -52,12 +52,10 @@ dependencies = [

[project.optional-dependencies]
tf = [
# cf. https://github.com/mindee/doctr/pull/1182
# cf. https://github.com/mindee/doctr/pull/1461
"tensorflow>=2.11.0,<2.16.0",
"tensorflow>=2.15.0,<3.0.0",
"tf-keras>=2.15.0,<3.0.0", # Keep keras 2 compatibility
"tf2onnx>=1.16.0,<2.0.0", # cf. https://github.com/onnx/tensorflow-onnx/releases/tag/v1.16.0
# TODO: This is a temporary fix until we can upgrade to a newer version of tensorflow
"numpy>=1.16.0,<2.0.0",
]
torch = [
"torch>=1.12.0,<3.0.0",
Expand Down Expand Up @@ -98,9 +96,9 @@ docs = [
]
dev = [
# Tensorflow
# cf. https://github.com/mindee/doctr/pull/1182
# cf. https://github.com/mindee/doctr/pull/1461
"tensorflow>=2.11.0,<2.16.0",
"tensorflow>=2.15.0,<3.0.0",
"tf-keras>=2.15.0,<3.0.0", # Keep keras 2 compatibility
"tf2onnx>=1.16.0,<2.0.0", # cf. https://github.com/onnx/tensorflow-onnx/releases/tag/v1.16.0
# PyTorch
"torch>=1.12.0,<3.0.0",
Expand Down
4 changes: 2 additions & 2 deletions references/classification/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,12 +60,12 @@ Feel free to inspect the multiple script option to customize your training to yo

Character classification:

```python
```shell
python references/classification/train_tensorflow_character.py --help
```

Orientation classification:

```python
```shell
python references/classification/train_tensorflow_orientation.py --help
```
8 changes: 6 additions & 2 deletions references/classification/latency_tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,16 @@
import os
import time

import numpy as np
import tensorflow as tf
from doctr.file_utils import ensure_keras_v2

ensure_keras_v2()

os.environ["USE_TF"] = "1"
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"

import numpy as np
import tensorflow as tf

from doctr.models import classification


Expand Down
Loading

0 comments on commit 90c3fff

Please sign in to comment.