Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Fix summarization postprocess backbone (#296)
Browse files Browse the repository at this point in the history
* Fix summarization postprocess backbone

* Update CHANGELOG.md

* Update
  • Loading branch information
ethanwharris authored May 14, 2021
1 parent 1f50b3f commit fa2e62f
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 8 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
### Fixed

- Fix DataPipeline resolution in Task ([#212](https://github.com/PyTorchLightning/lightning-flash/pull/212))
- Fixed a bug where the backbone used in summarization was not correctly passed to the postprocess ([#296](https://github.com/PyTorchLightning/lightning-flash/pull/296))


## [0.2.3] - 2021-04-17
Expand Down
13 changes: 13 additions & 0 deletions flash/text/seq2seq/core/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from dataclasses import dataclass
from functools import partial
from typing import Any, Callable, Dict, List, Optional, Union

Expand All @@ -24,6 +25,7 @@
from flash.data.data_module import DataModule
from flash.data.data_source import DataSource, DefaultDataSources
from flash.data.process import Preprocess
from flash.data.properties import ProcessState


class Seq2SeqDataSource(DataSource):
Expand Down Expand Up @@ -158,6 +160,15 @@ def load_data(
return [self._tokenize_fn(s) for s in data]


@dataclass(unsafe_hash=True, frozen=True)
class Seq2SeqBackboneState(ProcessState):
"""The ``Seq2SeqBackboneState`` stores the backbone in use by the
:class:`~flash.text.seq2seq.core.data.Seq2SeqPreprocess`
"""

backbone: str


class Seq2SeqPreprocess(Preprocess):

def __init__(
Expand Down Expand Up @@ -204,6 +215,8 @@ def __init__(
default_data_source="sentences",
)

self.set_state(Seq2SeqBackboneState(self.backbone))

def get_state_dict(self) -> Dict[str, Any]:
return {
**self.transforms,
Expand Down
24 changes: 17 additions & 7 deletions flash/text/seq2seq/summarization/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,19 +16,29 @@
from transformers import AutoTokenizer

from flash.data.process import Postprocess
from flash.text.seq2seq.core.data import Seq2SeqData, Seq2SeqPreprocess
from flash.text.seq2seq.core.data import Seq2SeqBackboneState, Seq2SeqData, Seq2SeqPreprocess


class SummarizationPostprocess(Postprocess):

def __init__(
self,
backbone: str = "sshleifer/tiny-mbart",
):
def __init__(self):
super().__init__()

# TODO: Should share the backbone or tokenizer over state
self.tokenizer = AutoTokenizer.from_pretrained(backbone, use_fast=True)
self._backbone = None
self._tokenizer = None

@property
def backbone(self):
backbone_state = self.get_state(Seq2SeqBackboneState)
if backbone_state is not None:
return backbone_state.backbone

@property
def tokenizer(self):
if self.backbone is not None and self.backbone != self._backbone:
self._tokenizer = AutoTokenizer.from_pretrained(self.backbone, use_fast=True)
self._backbone = self.backbone
return self._tokenizer

def uncollate(self, generated_tokens: Any) -> Any:
pred_str = self.tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
Expand Down
21 changes: 20 additions & 1 deletion tests/text/summarization/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def test_from_files(tmpdir):
train_file=csv_path,
val_file=csv_path,
test_file=csv_path,
batch_size=1
batch_size=1,
)
batch = next(iter(dm.val_dataloader()))
assert "labels" in batch
Expand All @@ -75,6 +75,25 @@ def test_from_files(tmpdir):
assert "input_ids" in batch


def test_postprocess_tokenizer(tmpdir):
"""Tests that the tokenizer property in ``SummarizationPostprocess`` resolves correctly when a different backbone is
used.
"""
backbone = "sshleifer/bart-tiny-random"
csv_path = csv_data(tmpdir)
dm = SummarizationData.from_csv(
"input",
"target",
backbone=backbone,
train_file=csv_path,
batch_size=1,
)
pipeline = dm.data_pipeline
pipeline.initialize()
assert pipeline._postprocess_pipeline.backbone == backbone
assert pipeline._postprocess_pipeline.tokenizer is not None


@pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows")
def test_from_json(tmpdir):
json_path = json_data(tmpdir)
Expand Down

0 comments on commit fa2e62f

Please sign in to comment.