Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Fix summarization postprocess backbone #296

Merged
merged 3 commits into from
May 14, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
### Fixed

- Fix DataPipeline resolution in Task ([#212](https://github.com/PyTorchLightning/lightning-flash/pull/212))
- Fixed a bug where the backbone used in summarization was not correctly passed to the postprocess ([#296](https://github.com/PyTorchLightning/lightning-flash/pull/296))


## [0.2.3] - 2021-04-17
Expand Down
13 changes: 13 additions & 0 deletions flash/text/seq2seq/core/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from dataclasses import dataclass
from functools import partial
from typing import Any, Callable, Dict, List, Optional, Union

Expand All @@ -24,6 +25,7 @@
from flash.data.data_module import DataModule
from flash.data.data_source import DataSource, DefaultDataSources
from flash.data.process import Preprocess
from flash.data.properties import ProcessState


class Seq2SeqDataSource(DataSource):
Expand Down Expand Up @@ -158,6 +160,15 @@ def load_data(
return [self._tokenize_fn(s) for s in data]


@dataclass(unsafe_hash=True, frozen=True)
class Seq2SeqBackboneState(ProcessState):
"""The ``Seq2SeqBackboneState`` stores the backbone in use by the
:class:`~flash.text.seq2seq.core.data.Seq2SeqPreprocess`
"""

backbone: str


class Seq2SeqPreprocess(Preprocess):

def __init__(
Expand Down Expand Up @@ -204,6 +215,8 @@ def __init__(
default_data_source="sentences",
)

self.set_state(Seq2SeqBackboneState(self.backbone))

def get_state_dict(self) -> Dict[str, Any]:
return {
**self.transforms,
Expand Down
24 changes: 17 additions & 7 deletions flash/text/seq2seq/summarization/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,19 +16,29 @@
from transformers import AutoTokenizer

from flash.data.process import Postprocess
from flash.text.seq2seq.core.data import Seq2SeqData, Seq2SeqPreprocess
from flash.text.seq2seq.core.data import Seq2SeqBackboneState, Seq2SeqData, Seq2SeqPreprocess


class SummarizationPostprocess(Postprocess):

def __init__(
self,
backbone: str = "sshleifer/tiny-mbart",
):
def __init__(self):
super().__init__()

# TODO: Should share the backbone or tokenizer over state
self.tokenizer = AutoTokenizer.from_pretrained(backbone, use_fast=True)
self._backbone = None
self._tokenizer = None

@property
def backbone(self):
backbone_state = self.get_state(Seq2SeqBackboneState)
if backbone_state is not None:
return backbone_state.backbone

@property
def tokenizer(self):
if self.backbone is not None and self.backbone != self._backbone:
self._tokenizer = AutoTokenizer.from_pretrained(self.backbone, use_fast=True)
self._backbone = self.backbone
return self._tokenizer

def uncollate(self, generated_tokens: Any) -> Any:
pred_str = self.tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
Expand Down
21 changes: 20 additions & 1 deletion tests/text/summarization/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def test_from_files(tmpdir):
train_file=csv_path,
val_file=csv_path,
test_file=csv_path,
batch_size=1
batch_size=1,
)
batch = next(iter(dm.val_dataloader()))
assert "labels" in batch
Expand All @@ -75,6 +75,25 @@ def test_from_files(tmpdir):
assert "input_ids" in batch


def test_postprocess_tokenizer(tmpdir):
"""Tests that the tokenizer property in ``SummarizationPostprocess`` resolves correctly when a different backbone is
used.
"""
backbone = "sshleifer/bart-tiny-random"
csv_path = csv_data(tmpdir)
dm = SummarizationData.from_csv(
"input",
"target",
backbone=backbone,
train_file=csv_path,
batch_size=1,
)
pipeline = dm.data_pipeline
pipeline.initialize()
assert pipeline._postprocess_pipeline.backbone == backbone
assert pipeline._postprocess_pipeline.tokenizer is not None


@pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows")
def test_from_json(tmpdir):
json_path = json_data(tmpdir)
Expand Down