diff --git a/keras_nlp/models/bert/bert_classifier.py b/keras_nlp/models/bert/bert_classifier.py index 46dd289124..8bdf9cd34a 100644 --- a/keras_nlp/models/bert/bert_classifier.py +++ b/keras_nlp/models/bert/bert_classifier.py @@ -23,6 +23,7 @@ from keras_nlp.models.bert.bert_presets import backbone_presets from keras_nlp.models.bert.bert_presets import classifier_presets from keras_nlp.models.task import Task +from keras_nlp.utils.keras_utils import is_xla_compatible from keras_nlp.utils.python_utils import classproperty @@ -190,6 +191,14 @@ def __init__( self.num_classes = num_classes self.dropout = dropout + # Default compilation + self.compile( + loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), + optimizer=keras.optimizers.Adam(5e-5), + metrics=keras.metrics.SparseCategoricalAccuracy(), + jit_compile=is_xla_compatible(self), + ) + def get_config(self): config = super().get_config() config.update( diff --git a/keras_nlp/models/bert/bert_classifier_test.py b/keras_nlp/models/bert/bert_classifier_test.py index 0878b3108d..a56cecba3c 100644 --- a/keras_nlp/models/bert/bert_classifier_test.py +++ b/keras_nlp/models/bert/bert_classifier_test.py @@ -84,6 +84,9 @@ def test_bert_classifier_predict_no_preprocessing(self, jit_compile): self.classifier_no_preprocessing.compile(jit_compile=jit_compile) self.classifier_no_preprocessing.predict(self.preprocessed_batch) + def test_bert_classifier_fit_default_compile(self): + self.classifier.fit(self.raw_dataset) + @parameterized.named_parameters( ("jit_compile_false", False), ("jit_compile_true", True) ) diff --git a/keras_nlp/models/roberta/roberta_classifier.py b/keras_nlp/models/roberta/roberta_classifier.py index 30e80cffd3..d7b6641cc0 100644 --- a/keras_nlp/models/roberta/roberta_classifier.py +++ b/keras_nlp/models/roberta/roberta_classifier.py @@ -22,6 +22,7 @@ from keras_nlp.models.roberta.roberta_preprocessor import RobertaPreprocessor from keras_nlp.models.roberta.roberta_presets import backbone_presets from keras_nlp.models.task import Task +from keras_nlp.utils.keras_utils import is_xla_compatible from keras_nlp.utils.python_utils import classproperty @@ -194,6 +195,14 @@ def __init__( self.hidden_dim = hidden_dim self.dropout = dropout + # Default compilation + self.compile( + loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), + optimizer=keras.optimizers.Adam(2e-5), + metrics=keras.metrics.SparseCategoricalAccuracy(), + jit_compile=is_xla_compatible(self), + ) + def get_config(self): config = super().get_config() config.update( diff --git a/keras_nlp/models/roberta/roberta_classifier_test.py b/keras_nlp/models/roberta/roberta_classifier_test.py index 8b655a9723..201b808bc8 100644 --- a/keras_nlp/models/roberta/roberta_classifier_test.py +++ b/keras_nlp/models/roberta/roberta_classifier_test.py @@ -101,6 +101,9 @@ def test_roberta_classifier_predict_no_preprocessing(self, jit_compile): self.classifier_no_preprocessing.compile(jit_compile=jit_compile) self.classifier_no_preprocessing.predict(self.preprocessed_batch) + def test_roberta_classifier_fit_default_compile(self): + self.classifier.fit(self.raw_dataset) + @parameterized.named_parameters( ("jit_compile_false", False), ("jit_compile_true", True) ) diff --git a/keras_nlp/utils/keras_utils.py b/keras_nlp/utils/keras_utils.py index 47e6079235..8d41681bfa 100644 --- a/keras_nlp/utils/keras_utils.py +++ b/keras_nlp/utils/keras_utils.py @@ -13,6 +13,8 @@ # limitations under the License. +import platform + import tensorflow as tf from tensorflow import keras @@ -96,3 +98,16 @@ def convert_inputs_to_list_of_tensor_segments(x): f"list of tensors. Received `x={x}`" ) return x + + +def is_xla_compatible(model): + """Determine if model and platform xla-compatible.""" + return not ( + platform.system() == "Darwin" and "arm" in platform.processor().lower() + ) and not isinstance( + model.distribute_strategy, + ( + tf.compat.v1.distribute.experimental.TPUStrategy, + tf.distribute.TPUStrategy, + ), + )