|  | 
|  | 1 | +# Copyright 2024 The KerasHub Authors | 
|  | 2 | +# | 
|  | 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); | 
|  | 4 | +# you may not use this file except in compliance with the License. | 
|  | 5 | +# You may obtain a copy of the License at | 
|  | 6 | +# | 
|  | 7 | +#     https://www.apache.org/licenses/LICENSE-2.0 | 
|  | 8 | +# | 
|  | 9 | +# Unless required by applicable law or agreed to in writing, software | 
|  | 10 | +# distributed under the License is distributed on an "AS IS" BASIS, | 
|  | 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | 
|  | 12 | +# See the License for the specific language governing permissions and | 
|  | 13 | +# limitations under the License. | 
|  | 14 | +import keras | 
|  | 15 | + | 
|  | 16 | +from keras_hub.src.api_export import keras_hub_export | 
|  | 17 | +from keras_hub.src.layers.preprocessing.start_end_packer import StartEndPacker | 
|  | 18 | +from keras_hub.src.models.clip.clip_tokenizer import CLIPTokenizer | 
|  | 19 | +from keras_hub.src.models.preprocessor import Preprocessor | 
|  | 20 | +from keras_hub.src.utils.tensor_utils import preprocessing_function | 
|  | 21 | + | 
|  | 22 | +try: | 
|  | 23 | +    import tensorflow as tf | 
|  | 24 | +except ImportError: | 
|  | 25 | +    tf = None | 
|  | 26 | + | 
|  | 27 | + | 
|  | 28 | +@keras_hub_export("keras_hub.models.CLIPPreprocessor") | 
|  | 29 | +class CLIPPreprocessor(Preprocessor): | 
|  | 30 | +    """CLIP preprocessing layer which tokenizes and packs inputs. | 
|  | 31 | +
 | 
|  | 32 | +    This preprocessing layer will do 2 things: | 
|  | 33 | +
 | 
|  | 34 | +    - Tokenize the inputs using the `tokenizer`. | 
|  | 35 | +    - Construct a dictionary with keys `"token_ids"`, `"padding_mask"`. | 
|  | 36 | +
 | 
|  | 37 | +    This layer can be used directly with `tf.data.Dataset.map` to preprocess | 
|  | 38 | +    string data in the `(x, y, sample_weight)` format used by | 
|  | 39 | +    `keras.Model.fit`. | 
|  | 40 | +
 | 
|  | 41 | +    The call method of this layer accepts three arguments, `x`, `y`, and | 
|  | 42 | +    `sample_weight`. `x` can be a python string or tensor representing a single | 
|  | 43 | +    segment, a list of python strings representing a batch of single segments, | 
|  | 44 | +    or a list of tensors representing multiple segments to be packed together. | 
|  | 45 | +    `y` and `sample_weight` are both optional, can have any format, and will be | 
|  | 46 | +    passed through unaltered. | 
|  | 47 | +
 | 
|  | 48 | +    `CLIPPreprocessor` forces the input to have only one segment, as CLIP is | 
|  | 49 | +    mainly used for generation tasks. For tasks having multi-segment inputs | 
|  | 50 | +    like "glue/mnli", please use a model designed for classification purposes | 
|  | 51 | +    such as BERT or RoBERTa. | 
|  | 52 | +
 | 
|  | 53 | +    Args: | 
|  | 54 | +        tokenizer: A `keras_hub.models.CLIPTokenizer` instance. | 
|  | 55 | +        sequence_length: The length of the packed inputs. | 
|  | 56 | +        add_start_token: If `True`, the preprocessor will prepend the tokenizer | 
|  | 57 | +            start token to each input sequence. | 
|  | 58 | +        add_end_token: If `True`, the preprocessor will append the tokenizer | 
|  | 59 | +            end token to each input sequence. | 
|  | 60 | +        to_lower: bool. Whether to lower the inputs. | 
|  | 61 | +
 | 
|  | 62 | +    Call arguments: | 
|  | 63 | +        x: A string, `tf.Tensor` or list of python strings. | 
|  | 64 | +        y: Any label data. Will be passed through unaltered. | 
|  | 65 | +        sample_weight: Any label weight data. Will be passed through unaltered. | 
|  | 66 | +        sequence_length: Pass to override the configured `sequence_length` of | 
|  | 67 | +            the layer. | 
|  | 68 | +    """ | 
|  | 69 | + | 
|  | 70 | +    # TODO: Add example once we have a CLIP model. | 
|  | 71 | + | 
|  | 72 | +    tokenizer_cls = CLIPTokenizer | 
|  | 73 | + | 
|  | 74 | +    def __init__( | 
|  | 75 | +        self, | 
|  | 76 | +        tokenizer, | 
|  | 77 | +        sequence_length=77, | 
|  | 78 | +        add_start_token=True, | 
|  | 79 | +        add_end_token=True, | 
|  | 80 | +        to_lower=True, | 
|  | 81 | +        **kwargs, | 
|  | 82 | +    ): | 
|  | 83 | +        super().__init__(**kwargs) | 
|  | 84 | +        self.tokenizer = tokenizer | 
|  | 85 | +        self.packer = None | 
|  | 86 | +        self.sequence_length = sequence_length | 
|  | 87 | +        self.add_start_token = add_start_token | 
|  | 88 | +        self.add_end_token = add_end_token | 
|  | 89 | +        self.to_lower = to_lower | 
|  | 90 | + | 
|  | 91 | +    def build(self, input_shape): | 
|  | 92 | +        # Defer packer creation to `build()` so that we can be sure tokenizer | 
|  | 93 | +        # assets have loaded when restoring a saved model. | 
|  | 94 | +        self.packer = StartEndPacker( | 
|  | 95 | +            start_value=self.tokenizer.start_token_id, | 
|  | 96 | +            end_value=self.tokenizer.end_token_id, | 
|  | 97 | +            pad_value=self.tokenizer.end_token_id, | 
|  | 98 | +            sequence_length=self.sequence_length, | 
|  | 99 | +            return_padding_mask=True, | 
|  | 100 | +        ) | 
|  | 101 | +        self.built = True | 
|  | 102 | + | 
|  | 103 | +    @preprocessing_function | 
|  | 104 | +    def call( | 
|  | 105 | +        self, | 
|  | 106 | +        x, | 
|  | 107 | +        y=None, | 
|  | 108 | +        sample_weight=None, | 
|  | 109 | +        sequence_length=None, | 
|  | 110 | +    ): | 
|  | 111 | +        sequence_length = sequence_length or self.sequence_length | 
|  | 112 | +        if self.to_lower: | 
|  | 113 | +            x = tf.strings.lower(x) | 
|  | 114 | +        token_ids, padding_mask = self.packer( | 
|  | 115 | +            self.tokenizer(x), | 
|  | 116 | +            sequence_length=sequence_length, | 
|  | 117 | +            add_start_value=self.add_start_token, | 
|  | 118 | +            add_end_value=self.add_end_token, | 
|  | 119 | +        ) | 
|  | 120 | +        x = { | 
|  | 121 | +            "token_ids": token_ids, | 
|  | 122 | +            "padding_mask": padding_mask, | 
|  | 123 | +        } | 
|  | 124 | +        return keras.utils.pack_x_y_sample_weight(x, y, sample_weight) | 
|  | 125 | + | 
|  | 126 | +    def get_config(self): | 
|  | 127 | +        config = super().get_config() | 
|  | 128 | +        config.update( | 
|  | 129 | +            { | 
|  | 130 | +                "sequence_length": self.sequence_length, | 
|  | 131 | +                "add_start_token": self.add_start_token, | 
|  | 132 | +                "add_end_token": self.add_end_token, | 
|  | 133 | +                "to_lower": self.to_lower, | 
|  | 134 | +            } | 
|  | 135 | +        ) | 
|  | 136 | +        return config | 
|  | 137 | + | 
|  | 138 | +    @property | 
|  | 139 | +    def sequence_length(self): | 
|  | 140 | +        """The padded length of model input sequences.""" | 
|  | 141 | +        return self._sequence_length | 
|  | 142 | + | 
|  | 143 | +    @sequence_length.setter | 
|  | 144 | +    def sequence_length(self, value): | 
|  | 145 | +        self._sequence_length = value | 
|  | 146 | +        if self.packer is not None: | 
|  | 147 | +            self.packer.sequence_length = value | 
0 commit comments