diff --git a/simplet5/simplet5.py b/simplet5/simplet5.py index 4a9eb21..74789a5 100644 --- a/simplet5/simplet5.py +++ b/simplet5/simplet5.py @@ -8,8 +8,9 @@ PreTrainedTokenizer, T5TokenizerFast as T5Tokenizer, MT5TokenizerFast as MT5Tokenizer, + AutoTokenizer, + LongT5ForConditionalGeneration ) -from transformers import AutoTokenizer from torch.optim import AdamW from torch.utils.data import Dataset, DataLoader from transformers import AutoModelWithLMHead, AutoTokenizer @@ -311,6 +312,9 @@ def from_pretrained(self, model_type="t5", model_name="t5-base") -> None: self.model = T5ForConditionalGeneration.from_pretrained( f"{model_name}", return_dict=True ) + elif model_type == 'longt5': + self.model = LongT5ForConditionalGeneration.from_pretrained(f"{model_name}") + self.tokenizer = AutoTokenizer.from_pretrained(f"{model_name}") def train( self, @@ -413,6 +417,9 @@ def load_model( elif model_type == "byt5": self.model = T5ForConditionalGeneration.from_pretrained(f"{model_dir}") self.tokenizer = ByT5Tokenizer.from_pretrained(f"{model_dir}") + elif model_type == 'longt5': + self.model = LongT5ForConditionalGeneration.from_pretrained(f"{model_dir}") + self.tokenizer = AutoTokenizer.from_pretrained(f"{model_dir}") if use_gpu: if torch.cuda.is_available():