From cdeae90772188fa276bf6f0f014cc02cb4c422b7 Mon Sep 17 00:00:00 2001 From: Sai Kaushik Date: Fri, 24 Jun 2022 20:51:44 +0530 Subject: [PATCH] longt5 --- simplet5/simplet5.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/simplet5/simplet5.py b/simplet5/simplet5.py index 4a9eb21..74789a5 100644 --- a/simplet5/simplet5.py +++ b/simplet5/simplet5.py @@ -8,8 +8,9 @@ PreTrainedTokenizer, T5TokenizerFast as T5Tokenizer, MT5TokenizerFast as MT5Tokenizer, + AutoTokenizer, + LongT5ForConditionalGeneration ) -from transformers import AutoTokenizer from torch.optim import AdamW from torch.utils.data import Dataset, DataLoader from transformers import AutoModelWithLMHead, AutoTokenizer @@ -311,6 +312,9 @@ def from_pretrained(self, model_type="t5", model_name="t5-base") -> None: self.model = T5ForConditionalGeneration.from_pretrained( f"{model_name}", return_dict=True ) + elif model_type == 'longt5': + self.model = LongT5ForConditionalGeneration.from_pretrained(f"{model_name}") + self.tokenizer = AutoTokenizer.from_pretrained(f"{model_name}") def train( self, @@ -413,6 +417,9 @@ def load_model( elif model_type == "byt5": self.model = T5ForConditionalGeneration.from_pretrained(f"{model_dir}") self.tokenizer = ByT5Tokenizer.from_pretrained(f"{model_dir}") + elif model_type == 'longt5': + self.model = LongT5ForConditionalGeneration.from_pretrained(f"{model_dir}") + self.tokenizer = AutoTokenizer.from_pretrained(f"{model_dir}") if use_gpu: if torch.cuda.is_available():