diff --git a/CHANGELOG.md b/CHANGELOG.md index 3eeee3aa58..7f383676bf 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -31,6 +31,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Fixed - Fix DataPipeline resolution in Task ([#212](https://github.com/PyTorchLightning/lightning-flash/pull/212)) +- Fixed a bug where the backbone used in summarization was not correctly passed to the postprocess ([#296](https://github.com/PyTorchLightning/lightning-flash/pull/296)) ## [0.2.3] - 2021-04-17 diff --git a/flash/text/seq2seq/core/data.py b/flash/text/seq2seq/core/data.py index f7968ee4a7..9735e8750d 100644 --- a/flash/text/seq2seq/core/data.py +++ b/flash/text/seq2seq/core/data.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import os +from dataclasses import dataclass from functools import partial from typing import Any, Callable, Dict, List, Optional, Union @@ -24,6 +25,7 @@ from flash.data.data_module import DataModule from flash.data.data_source import DataSource, DefaultDataSources from flash.data.process import Preprocess +from flash.data.properties import ProcessState class Seq2SeqDataSource(DataSource): @@ -158,6 +160,15 @@ def load_data( return [self._tokenize_fn(s) for s in data] +@dataclass(unsafe_hash=True, frozen=True) +class Seq2SeqBackboneState(ProcessState): + """The ``Seq2SeqBackboneState`` stores the backbone in use by the + :class:`~flash.text.seq2seq.core.data.Seq2SeqPreprocess` + """ + + backbone: str + + class Seq2SeqPreprocess(Preprocess): def __init__( @@ -204,6 +215,8 @@ def __init__( default_data_source="sentences", ) + self.set_state(Seq2SeqBackboneState(self.backbone)) + def get_state_dict(self) -> Dict[str, Any]: return { **self.transforms, diff --git a/flash/text/seq2seq/summarization/data.py b/flash/text/seq2seq/summarization/data.py index 791c98a32f..e9f958b0e0 100644 --- a/flash/text/seq2seq/summarization/data.py +++ b/flash/text/seq2seq/summarization/data.py @@ -16,19 +16,29 @@ from transformers import AutoTokenizer from flash.data.process import Postprocess -from flash.text.seq2seq.core.data import Seq2SeqData, Seq2SeqPreprocess +from flash.text.seq2seq.core.data import Seq2SeqBackboneState, Seq2SeqData, Seq2SeqPreprocess class SummarizationPostprocess(Postprocess): - def __init__( - self, - backbone: str = "sshleifer/tiny-mbart", - ): + def __init__(self): super().__init__() - # TODO: Should share the backbone or tokenizer over state - self.tokenizer = AutoTokenizer.from_pretrained(backbone, use_fast=True) + self._backbone = None + self._tokenizer = None + + @property + def backbone(self): + backbone_state = self.get_state(Seq2SeqBackboneState) + if backbone_state is not None: + return backbone_state.backbone + + @property + def tokenizer(self): + if self.backbone is not None and self.backbone != self._backbone: + self._tokenizer = AutoTokenizer.from_pretrained(self.backbone, use_fast=True) + self._backbone = self.backbone + return self._tokenizer def uncollate(self, generated_tokens: Any) -> Any: pred_str = self.tokenizer.batch_decode(generated_tokens, skip_special_tokens=True) diff --git a/tests/text/summarization/test_data.py b/tests/text/summarization/test_data.py index 67b88bc937..cd865afe13 100644 --- a/tests/text/summarization/test_data.py +++ b/tests/text/summarization/test_data.py @@ -64,7 +64,7 @@ def test_from_files(tmpdir): train_file=csv_path, val_file=csv_path, test_file=csv_path, - batch_size=1 + batch_size=1, ) batch = next(iter(dm.val_dataloader())) assert "labels" in batch @@ -75,6 +75,25 @@ def test_from_files(tmpdir): assert "input_ids" in batch +def test_postprocess_tokenizer(tmpdir): + """Tests that the tokenizer property in ``SummarizationPostprocess`` resolves correctly when a different backbone is + used. + """ + backbone = "sshleifer/bart-tiny-random" + csv_path = csv_data(tmpdir) + dm = SummarizationData.from_csv( + "input", + "target", + backbone=backbone, + train_file=csv_path, + batch_size=1, + ) + pipeline = dm.data_pipeline + pipeline.initialize() + assert pipeline._postprocess_pipeline.backbone == backbone + assert pipeline._postprocess_pipeline.tokenizer is not None + + @pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows") def test_from_json(tmpdir): json_path = json_data(tmpdir)