Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions keras_nlp/models/bert/bert_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,14 @@ def __init__(
self.num_classes = num_classes
self.dropout = dropout

# Default compilation
self.compile(
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
optimizer=keras.optimizers.Adam(5e-5),
metrics=keras.metrics.SparseCategoricalAccuracy(),
jit_compile=True,
)

def get_config(self):
config = super().get_config()
config.update(
Expand Down
3 changes: 3 additions & 0 deletions keras_nlp/models/bert/bert_classifier_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,9 @@ def test_bert_classifier_predict_no_preprocessing(self, jit_compile):
self.classifier_no_preprocessing.compile(jit_compile=jit_compile)
self.classifier_no_preprocessing.predict(self.preprocessed_batch)

def test_bert_classifier_fit_default_compile(self):
self.classifier.fit(self.raw_dataset)

@parameterized.named_parameters(
("jit_compile_false", False), ("jit_compile_true", True)
)
Expand Down
8 changes: 8 additions & 0 deletions keras_nlp/models/roberta/roberta_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,14 @@ def __init__(
self.hidden_dim = hidden_dim
self.dropout = dropout

# Default compilation
self.compile(
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
optimizer=keras.optimizers.Adam(2e-5),
metrics=keras.metrics.SparseCategoricalAccuracy(),
jit_compile=True,
)

def get_config(self):
config = super().get_config()
config.update(
Expand Down
3 changes: 3 additions & 0 deletions keras_nlp/models/roberta/roberta_classifier_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,9 @@ def test_roberta_classifier_predict_no_preprocessing(self, jit_compile):
self.classifier_no_preprocessing.compile(jit_compile=jit_compile)
self.classifier_no_preprocessing.predict(self.preprocessed_batch)

def test_roberta_classifier_fit_default_compile(self):
self.classifier.fit(self.raw_dataset)

@parameterized.named_parameters(
("jit_compile_false", False), ("jit_compile_true", True)
)
Expand Down