Skip to content

Commit 270991a

Browse files
committed
Address comments
1 parent 9faec7f commit 270991a

File tree

3 files changed

+22
-57
lines changed

3 files changed

+22
-57
lines changed

keras_nlp/models/roberta/roberta_masked_lm.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
"""RoBERTa classification model."""
14+
"""RoBERTa masked lm model."""
1515

1616
import copy
1717

@@ -63,13 +63,13 @@ class RobertaMaskedLM(Task):
6363
6464
# Create a RobertaMaskedLM with a pretrained backbone and further train
6565
# on an MLM task.
66-
classifier = keras_nlp.models.RobertaMaskedLM.from_preset(
66+
masked_lm = keras_nlp.models.RobertaMaskedLM.from_preset(
6767
"roberta_base_en",
6868
)
69-
classifier.compile(
69+
masked_lm.compile(
7070
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
7171
)
72-
classifier.fit(x=features, batch_size=2)
72+
masked_lm.fit(x=features, batch_size=2)
7373
```
7474
7575
Preprocessed inputs and custom backbone.

keras_nlp/models/roberta/roberta_masked_lm_preprocessor.py

Lines changed: 18 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -14,26 +14,16 @@
1414

1515
"""RoBERTa masked language model preprocessor layer."""
1616

17-
import copy
18-
17+
from absl import logging
1918
from tensorflow import keras
2019

2120
from keras_nlp.layers.masked_lm_mask_generator import MaskedLMMaskGenerator
22-
from keras_nlp.models.preprocessor import Preprocessor
23-
from keras_nlp.models.roberta.roberta_multi_segment_packer import (
24-
RobertaMultiSegmentPacker,
25-
)
26-
from keras_nlp.models.roberta.roberta_presets import backbone_presets
27-
from keras_nlp.models.roberta.roberta_tokenizer import RobertaTokenizer
28-
from keras_nlp.utils.keras_utils import (
29-
convert_inputs_to_list_of_tensor_segments,
30-
)
21+
from keras_nlp.models.roberta.roberta_preprocessor import RobertaPreprocessor
3122
from keras_nlp.utils.keras_utils import pack_x_y_sample_weight
32-
from keras_nlp.utils.python_utils import classproperty
3323

3424

3525
@keras.utils.register_keras_serializable(package="keras_nlp")
36-
class RobertaMaskedLMPreprocessor(Preprocessor):
26+
class RobertaMaskedLMPreprocessor(RobertaPreprocessor):
3727
"""RoBERTa preprocessing for the masked language modeling task.
3828
3929
This preprocessing layer will prepare inputs for a masked language modeling
@@ -114,21 +104,18 @@ def __init__(
114104
self,
115105
tokenizer,
116106
sequence_length=512,
107+
truncate="round_robin",
117108
mask_selection_rate=0.15,
118109
mask_selection_length=96,
119-
truncate="round_robin",
120110
**kwargs,
121111
):
122-
super().__init__(**kwargs)
123-
124-
self._tokenizer = tokenizer
125-
self.packer = RobertaMultiSegmentPacker(
126-
start_value=tokenizer.start_token_id,
127-
end_value=tokenizer.end_token_id,
128-
pad_value=tokenizer.pad_token_id,
129-
truncate=truncate,
112+
super().__init__(
113+
tokenizer,
130114
sequence_length=sequence_length,
115+
truncate=truncate,
116+
**kwargs,
131117
)
118+
132119
self.masker = MaskedLMMaskGenerator(
133120
mask_selection_rate=mask_selection_rate,
134121
mask_selection_length=mask_selection_length,
@@ -145,40 +132,29 @@ def get_config(self):
145132
config = super().get_config()
146133
config.update(
147134
{
148-
"sequence_length": self.packer.sequence_length,
149135
"mask_selection_rate": self.masker.mask_selection_rate,
150136
"mask_selection_length": self.masker.mask_selection_length,
151-
"truncate": self.packer.truncate,
152137
}
153138
)
154139
return config
155140

156141
def call(self, x, y=None, sample_weight=None):
157-
if y is not None:
158-
raise ValueError(
159-
"`RobertaMaskedLMPreprocessor` received labeled data (`y` is "
160-
"not `None`). No labels should be passed in as "
161-
"this layer generates training labels dynamically from raw "
162-
"text features passed as `x`. Received: `y={y}`."
142+
if y is not None or sample_weight is not None:
143+
logging.warning(
144+
f"{self.__class__.__name__} generates `y` and `sample_weight` "
145+
"based on your input data, but your data already contains `y` "
146+
"or `sample_weight`. Your `y` and `sample_weight` will be "
147+
"ignored."
163148
)
164149

165-
x = convert_inputs_to_list_of_tensor_segments(x)
166-
x = [self.tokenizer(segment) for segment in x]
167-
token_ids = self.packer(x)
150+
x = super().call(x)
151+
token_ids, padding_mask = x["token_ids"], x["padding_mask"]
168152
masker_outputs = self.masker(token_ids)
169153
x = {
170154
"token_ids": masker_outputs["token_ids"],
171-
"padding_mask": token_ids != self.tokenizer.pad_token_id,
155+
"padding_mask": padding_mask,
172156
"mask_positions": masker_outputs["mask_positions"],
173157
}
174158
y = masker_outputs["mask_ids"]
175159
sample_weight = masker_outputs["mask_weights"]
176160
return pack_x_y_sample_weight(x, y, sample_weight)
177-
178-
@classproperty
179-
def tokenizer_cls(cls):
180-
return RobertaTokenizer
181-
182-
@classproperty
183-
def presets(cls):
184-
return copy.deepcopy(backbone_presets)

keras_nlp/models/roberta/roberta_masked_lm_preprocessor_test.py

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -110,12 +110,6 @@ def test_tokenize_list_of_strings(self):
110110
)
111111
self.assertAllEqual(sw, [[1.0, 1.0, 0.0, 0.0]] * 4)
112112

113-
def test_tokenize_labeled_errors(self):
114-
x = tf.constant([" airplane at airport"] * 4)
115-
y = tf.constant([1] * 4)
116-
with self.assertRaises(ValueError):
117-
self.preprocessor(x, y)
118-
119113
def test_tokenize_dataset(self):
120114
sentences = tf.constant([" airplane at airport"] * 4)
121115
ds = tf.data.Dataset.from_tensor_slices(sentences)
@@ -168,11 +162,6 @@ def test_mask_multiple_sentences(self):
168162
self.assertAllEqual(y, [3, 5, 10, 0])
169163
self.assertAllEqual(sw, [1.0, 1.0, 1.0, 0.0])
170164

171-
def test_errors_for_2d_list_input(self):
172-
ambiguous_input = [["one", "two"], ["three", "four"]]
173-
with self.assertRaises(ValueError):
174-
self.preprocessor(ambiguous_input)
175-
176165
@parameterized.named_parameters(
177166
("tf_format", "tf", "model"),
178167
("keras_format", "keras_v3", "model.keras"),

0 commit comments

Comments
 (0)