Skip to content

Commit e5c7e66

Browse files
committed
Turn off torch compilation
1 parent d66d880 commit e5c7e66

File tree

3 files changed

+6
-2
lines changed

3 files changed

+6
-2
lines changed

keras_nlp/models/bart/bart_seq_2_seq_lm.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,6 @@ def __init__(
207207
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
208208
optimizer=keras.optimizers.Adam(2e-5),
209209
metrics=[keras.metrics.SparseCategoricalAccuracy()],
210-
jit_compile=True,
211210
)
212211

213212
@classproperty

keras_nlp/models/task.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,9 @@ def _check_for_loss_mismatch(self, loss):
7979
)
8080

8181
def compile(self, optimizer="rmsprop", loss=None, **kwargs):
82+
# Temporarily disable jit compilation on torch.
83+
if config.backend() == "torch":
84+
kwargs["jit_compile"] = False
8285
self._check_for_loss_mismatch(loss)
8386
super().compile(optimizer=optimizer, loss=loss, **kwargs)
8487

keras_nlp/tests/test_case.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,9 @@ def call(self, x):
141141
return self.layer(x)
142142

143143
model = TestModel(layer)
144-
model.compile(optimizer="sgd", loss="mse", jit_compile=True)
144+
# Temporarily disable jit compilation on torch backend.
145+
jit_compile = config.backend() != "torch"
146+
model.compile(optimizer="sgd", loss="mse", jit_compile=jit_compile)
145147
model.fit(input_data, output_data, verbose=0)
146148

147149
if config.multi_backend():

0 commit comments

Comments
 (0)