Skip to content

Commit 8581352

Browse files
authored
Default compilation for BERT/RoBERTa classifiers (#695)
* Default compilation for BERT/RoBERTa classifiers * Check xla compatibilty before setting jit_compile
1 parent 9a9be88 commit 8581352

File tree

5 files changed

+39
-0
lines changed

5 files changed

+39
-0
lines changed

keras_nlp/models/bert/bert_classifier.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from keras_nlp.models.bert.bert_presets import backbone_presets
2424
from keras_nlp.models.bert.bert_presets import classifier_presets
2525
from keras_nlp.models.task import Task
26+
from keras_nlp.utils.keras_utils import is_xla_compatible
2627
from keras_nlp.utils.python_utils import classproperty
2728

2829

@@ -190,6 +191,14 @@ def __init__(
190191
self.num_classes = num_classes
191192
self.dropout = dropout
192193

194+
# Default compilation
195+
self.compile(
196+
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
197+
optimizer=keras.optimizers.Adam(5e-5),
198+
metrics=keras.metrics.SparseCategoricalAccuracy(),
199+
jit_compile=is_xla_compatible(self),
200+
)
201+
193202
def get_config(self):
194203
config = super().get_config()
195204
config.update(

keras_nlp/models/bert/bert_classifier_test.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,9 @@ def test_bert_classifier_predict_no_preprocessing(self, jit_compile):
8484
self.classifier_no_preprocessing.compile(jit_compile=jit_compile)
8585
self.classifier_no_preprocessing.predict(self.preprocessed_batch)
8686

87+
def test_bert_classifier_fit_default_compile(self):
88+
self.classifier.fit(self.raw_dataset)
89+
8790
@parameterized.named_parameters(
8891
("jit_compile_false", False), ("jit_compile_true", True)
8992
)

keras_nlp/models/roberta/roberta_classifier.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from keras_nlp.models.roberta.roberta_preprocessor import RobertaPreprocessor
2323
from keras_nlp.models.roberta.roberta_presets import backbone_presets
2424
from keras_nlp.models.task import Task
25+
from keras_nlp.utils.keras_utils import is_xla_compatible
2526
from keras_nlp.utils.python_utils import classproperty
2627

2728

@@ -194,6 +195,14 @@ def __init__(
194195
self.hidden_dim = hidden_dim
195196
self.dropout = dropout
196197

198+
# Default compilation
199+
self.compile(
200+
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
201+
optimizer=keras.optimizers.Adam(2e-5),
202+
metrics=keras.metrics.SparseCategoricalAccuracy(),
203+
jit_compile=is_xla_compatible(self),
204+
)
205+
197206
def get_config(self):
198207
config = super().get_config()
199208
config.update(

keras_nlp/models/roberta/roberta_classifier_test.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,9 @@ def test_roberta_classifier_predict_no_preprocessing(self, jit_compile):
102102
self.classifier_no_preprocessing.compile(jit_compile=jit_compile)
103103
self.classifier_no_preprocessing.predict(self.preprocessed_batch)
104104

105+
def test_roberta_classifier_fit_default_compile(self):
106+
self.classifier.fit(self.raw_dataset)
107+
105108
@parameterized.named_parameters(
106109
("jit_compile_false", False), ("jit_compile_true", True)
107110
)

keras_nlp/utils/keras_utils.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
# limitations under the License.
1414

1515

16+
import platform
17+
1618
import tensorflow as tf
1719
from tensorflow import keras
1820

@@ -96,3 +98,16 @@ def convert_inputs_to_list_of_tensor_segments(x):
9698
f"list of tensors. Received `x={x}`"
9799
)
98100
return x
101+
102+
103+
def is_xla_compatible(model):
104+
"""Determine if model and platform xla-compatible."""
105+
return not (
106+
platform.system() == "Darwin" and "arm" in platform.processor().lower()
107+
) and not isinstance(
108+
model.distribute_strategy,
109+
(
110+
tf.compat.v1.distribute.experimental.TPUStrategy,
111+
tf.distribute.TPUStrategy,
112+
),
113+
)

0 commit comments

Comments
 (0)