Skip to content

Commit

Permalink
Implement tensorizor saving to snapshot files (facebookresearch#739)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: facebookresearch#739

As per design discussions, store tensorizers directly to snapshot file using pickle.

Reviewed By: neo315, seayoung1112

Differential Revision: D16045237

fbshipit-source-id: be6e6ebbbb6fc16c3220a4f68d3e9715d3f1589f
  • Loading branch information
bethebunny authored and facebook-github-bot committed Jul 24, 2019
1 parent 7ae0b2a commit daae0c3
Show file tree
Hide file tree
Showing 17 changed files with 368 additions and 125 deletions.
5 changes: 4 additions & 1 deletion demo/examples/tensorizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,13 @@ def from_config(cls, config: Config):
return cls(column=config.column)

def __init__(self, column):
super().__init__([(column, str)])
self.column = column
self.vocab = None

@property
def column_schema(self):
return [(self.column, str)]

def _tokenize(self, row):
raw_text = row[self.column]
return raw_text.split()
Expand Down
4 changes: 4 additions & 0 deletions pytext/config/pytext_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,10 @@ def __str__(self):
lines += f"{key}: {val}".split("\n")
return "\n ".join(lines)

def __eq__(self, other):
"""Mainly a convenience utility for unit testing."""
return type(self) == type(other) and self._asdict() == other._asdict()


class PlaceHolder:
pass
Expand Down
5 changes: 4 additions & 1 deletion pytext/data/bert_tensorizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,10 @@ def __init__(self, columns, **kwargs):
super().__init__(text_column=None, **kwargs)
self.columns = columns
# Manually initialize column_schema since we are sending None to TokenTensorizer
self.column_schema = [(column, str) for column in columns]

@property
def column_schema(self):
return [(column, str) for column in self.columns]

def numberize(self, row):
"""Tokenize, look up in vocabulary."""
Expand Down
28 changes: 24 additions & 4 deletions pytext/data/sources/data_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,17 @@

import json
import logging
from typing import Dict, List, TypeVar
import re
from typing import Dict, List, Type, TypeVar

from pytext.config.component import Component, ComponentType
from pytext.data.utils import shard
from pytext.utils.data import Slot, parse_slot_string


Schema = Dict[str, Type]


class RawExample(dict):
"""A wrapper class for a single example row with a dict interface.
This is here for any logic we want row objects to have that dicts don't do."""
Expand Down Expand Up @@ -127,7 +131,7 @@ class DataSource(Component):
__COMPONENT_TYPE__ = ComponentType.DATA_SOURCE
__EXPANSIBLE__ = True

def __init__(self, schema):
def __init__(self, schema: Schema):
self.schema = schema

@generator_property
Expand Down Expand Up @@ -191,7 +195,7 @@ class Config(Component.Config):
#: remap names from the raw data source to names in the schema.
column_mapping: Dict[str, str] = {}

def __init__(self, schema, column_mapping=()):
def __init__(self, schema: Schema, column_mapping: Dict[str, str] = ()):
super().__init__(schema)
self.column_mapping = dict(column_mapping)

Expand Down Expand Up @@ -299,13 +303,29 @@ def load_slots(s):


@RootDataSource.register_type(Gazetteer)
@RootDataSource.register_type(List[float])
@RootDataSource.register_type(List[str])
@RootDataSource.register_type(List[int])
def load_json(s):
return json.loads(s)


@RootDataSource.register_type(List[float])
def load_float_list(s):
# replace spaces between float numbers with commas (regex101.com/r/C2705x/1)
processed = re.sub(r"(?<=[\d.])\s*,?\s+(?=[+-]?[\d.])", ",", s)
# remove dot not followed with a digit (regex101.com/r/goSmuG/1/)
processed = re.sub(r"(?<=\d)\.(?![\d])", "", processed)
try:
parsed = json.loads(processed)
except json.decoder.JSONDecodeError as e:
raise ValueError(
f"Unable to parse float list `{s}` (normalized to `{processed}`)"
) from e
if not isinstance(parsed, list):
raise ValueError(f"Expected float list for float feature, got {parsed}")
return [float(f) for f in parsed]


@RootDataSource.register_type(JSONString)
def load_json_string(s):
parsed = json.loads(s)
Expand Down
109 changes: 68 additions & 41 deletions pytext/data/tensorizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,12 @@ class Tensorizer(Component):
class Config(Component.Config):
pass

def __init__(self, column_schema: List[Tuple[str, Type]]):
self.column_schema = column_schema
@property
def column_schema(self):
"""Generic types don't pickle well pre-3.7, so we don't actually want
to store the schema as an attribute. We're already storing all of the
columns anyway, so until there's a better solution, schema is a property."""
return []

def numberize(self, row):
raise NotImplementedError
Expand Down Expand Up @@ -157,7 +161,6 @@ def __init__(
vocab_config=None,
vocab=None,
):
super().__init__([(text_column, str)])
self.text_column = text_column
self.tokenizer = tokenizer or Tokenizer()
self.vocab = vocab
Expand All @@ -168,6 +171,10 @@ def __init__(
self.vocab_builder = None
self.vocab_config = vocab_config or VocabConfig()

@property
def column_schema(self):
return [(self.text_column, str)]

def _lookup_tokens(self, text=None, pre_tokenized=None):
tokenized = pre_tokenized or self.tokenizer.tokenize(text)[: self.max_seq_len]
if self.add_bos_token:
Expand Down Expand Up @@ -271,11 +278,14 @@ def from_config(cls, config: Config):
return cls(config.column, config.lower, config.max_seq_len)

def __init__(self, text_column, lower=True, max_seq_len=None):
super().__init__([(text_column, str)])
self.text_column = text_column
self.lower = lower
self.max_seq_len = max_seq_len

@property
def column_schema(self):
return [(self.text_column, str)]

def numberize(self, row):
"""Convert text to characters."""
text = row[self.text_column]
Expand Down Expand Up @@ -334,13 +344,16 @@ def __init__(
max_byte_len=Config.max_byte_len,
offset_for_non_padding=Config.offset_for_non_padding,
):
super().__init__([(text_column, str)])
self.text_column = text_column
self.tokenizer = tokenizer or Tokenizer()
self.max_seq_len = max_seq_len or 2 ** 30 # large number
self.max_byte_len = max_byte_len
self.offset_for_non_padding = offset_for_non_padding

@property
def column_schema(self):
return [(self.text_column, str)]

def numberize(self, row):
"""Convert text to bytes, pad batch."""
tokens = self.tokenizer.tokenize(row[self.text_column])[: self.max_seq_len]
Expand Down Expand Up @@ -437,7 +450,6 @@ def __init__(
pad_in_vocab: bool = False,
label_vocab: Optional[List[str]] = None,
):
super().__init__([(label_column, str)])
self.label_column = label_column
self.pad_in_vocab = pad_in_vocab
self.vocab_builder = VocabBuilder()
Expand All @@ -449,6 +461,10 @@ def __init__(
self.vocab_builder.add_all(label_vocab)
self.vocab, self.pad_idx = self._create_vocab()

@property
def column_schema(self):
return [(self.label_column, str)]

def initialize(self):
"""
Look through the dataset for all labels and create a vocab map for them.
Expand Down Expand Up @@ -487,7 +503,10 @@ class LabelListTensorizer(LabelTensorizer):

def __init__(self, label_column: str = "label", *args, **kwargs):
super().__init__(label_column, *args, **kwargs)
self.column_schema = [(label_column, List[str])]

@property
def column_schema(self):
return [(self.label_column, List[str])]

def numberize(self, row):
labels = super().numberize(row)
Expand Down Expand Up @@ -537,17 +556,19 @@ def __init__(
labels_column: str = "target_labels",
):
super().__init__(label_column, allow_unknown, pad_in_vocab, label_vocab)
column_schema = [
(label_column, str),
(probs_column, List[float]),
(logits_column, List[float]),
(labels_column, List[str]),
]
Tensorizer.__init__(self, column_schema)
self.probs_column = probs_column
self.logits_column = logits_column
self.labels_column = labels_column

@property
def column_schema(self):
return [
(self.label_column, str),
(self.probs_column, List[float]),
(self.logits_column, List[float]),
(self.labels_column, List[str]),
]

def numberize(self, row):
"""Numberize hard and soft labels"""
label = self.vocab.lookup_all(row[self.label_column])
Expand Down Expand Up @@ -584,13 +605,16 @@ def __init__(
label_column: str = Config.column,
rescale_range: Optional[List[float]] = Config.rescale_range,
):
super().__init__([(label_column, str)])
self.label_column = label_column
if rescale_range is not None:
assert len(rescale_range) == 2
assert rescale_range[0] < rescale_range[1]
self.rescale_range = rescale_range

@property
def column_schema(self):
return [(self.label_column, str)]

def numberize(self, row):
"""Numberize labels."""
label = float(row[self.label_column])
Expand Down Expand Up @@ -618,33 +642,22 @@ def from_config(cls, config: Config):
return cls(config.column, config.error_check, config.dim)

def __init__(self, column: str, error_check: bool, dim: Optional[int]):
super().__init__([(column, str)])
self.column = column
self.error_check = error_check
self.dim = dim
assert not self.error_check or self.dim is not None, "Error check requires dim"

@property
def column_schema(self):
return [(self.column, List[float])]

def numberize(self, row):
str = row[self.column]
# replace spaces between float numbers with commas (regex101.com/r/C2705x/1)
str = re.sub(r"(?<=[\d.])\s*,?\s+(?=[+-]?[\d.])", ",", str)
# remove dot not followed with a digit (regex101.com/r/goSmuG/1/)
str = re.sub(r"(?<=\d)\.(?![\d])", "", str)
try:
res = json.loads(str)
except json.decoder.JSONDecodeError as e:
raise Exception(
f"Unable to parse dense feature:{row[self.column]}," f" re output:{str}"
) from e
if type(res) is not list:
raise ValueError(f"{res} is not a valid float list")
dense = row[self.column]
if self.error_check:
assert len(res) == self.dim, (
f"Expected dimension:{self.dim}"
f", got:{len(res)}"
f", dense-feature:{res}"
)
return [float(n) for n in res]
assert (
len(dense) == self.dim
), f"Dense feature didn't match expected dimension {self.dim}: {dense}"
return dense

def tensorize(self, batch):
return pad_and_tensorize(batch, dtype=torch.float)
Expand Down Expand Up @@ -684,13 +697,16 @@ def __init__(
tokenizer: Tokenizer = None,
allow_unknown: bool = Config.allow_unknown,
):
super().__init__([(text_column, str), (slot_column, List[Slot])])
self.slot_column = slot_column
self.text_column = text_column
self.allow_unknown = allow_unknown
self.tokenizer = tokenizer or Tokenizer()
self.pad_idx = Padding.DEFAULT_LABEL_PAD_IDX

@property
def column_schema(self):
return [(self.text_column, str), (self.slot_column, List[Slot])]

def initialize(self):
"""Look through the dataset for all labels and create a vocab map for them."""
builder = VocabBuilder()
Expand Down Expand Up @@ -788,11 +804,14 @@ def __init__(
dict_column: str = Config.dict_column,
tokenizer: Tokenizer = None,
):
super().__init__([(text_column, str), (dict_column, Gazetteer)])
self.text_column = text_column
self.dict_column = dict_column
self.tokenizer = tokenizer or Tokenizer()

@property
def column_schema(self):
return [(self.text_column, str), (self.dict_column, Gazetteer)]

def initialize(self):
"""
Look through the dataset for all dict features to create vocab.
Expand Down Expand Up @@ -961,7 +980,6 @@ def __init__(
max_seq_len=Config.max_seq_len,
vocab=None,
):
super().__init__([(column, List[str])])
self.column = column
self.tokenizer = tokenizer or Tokenizer()
self.vocab = vocab
Expand All @@ -973,6 +991,10 @@ def __init__(
self.use_eol_token_for_bol = use_eol_token_for_bol
self.max_seq_len = max_seq_len or 2 ** 30 # large number

@property
def column_schema(self):
return [(self.column, List[str])]

def initialize(self, vocab_builder=None):
"""Build vocabulary based on training corpus."""
if self.vocab:
Expand Down Expand Up @@ -1063,10 +1085,13 @@ def from_config(cls, config: Config):
return cls(column=config.column)

def __init__(self, column: str = Config.column, vocab=None):
super().__init__([(column, str)])
self.column = column
self.vocab = vocab

@property
def column_schema(self):
return [(self.column, List[str])]

def initialize(self, vocab_builder=None):
"""Build vocabulary based on training corpus."""
if self.vocab:
Expand Down Expand Up @@ -1116,7 +1141,6 @@ def from_config(cls, config: Config):
return cls(config.names, config.indexes)

def __init__(self, names: List[str], indexes: List[int]):
super().__init__([])
self.names = names
self.indexes = indexes

Expand Down Expand Up @@ -1152,9 +1176,12 @@ def from_config(cls, config: Config):
return cls(config.column)

def __init__(self, column: str):
super().__init__([(column, float)])
self.column = column

@property
def column_schema(self):
return [(self.column, float)]

def numberize(self, row):
return row[self.column]

Expand Down
Loading

0 comments on commit daae0c3

Please sign in to comment.