diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index 042798fbe8c1..43a4477fb49a 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -3868,6 +3868,13 @@ def __init__(self, *args, **kwargs): SPLINTER_PRETRAINED_MODEL_ARCHIVE_LIST = None +class SplinterForPreTraining(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + class SplinterForQuestionAnswering(metaclass=DummyObject): _backends = ["torch"]