Skip to content
This repository has been archived by the owner on Nov 22, 2022. It is now read-only.

Commit

Permalink
Add option to pad missing label in LabelListTensorizer (#1269)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #1269

Unlike single LabelTensorizer labels are much more likely missing in a sequence. In this change we add an option in the tensorizer to pad these missing labels.

Reviewed By: hudeven

Differential Revision: D20187694

fbshipit-source-id: a38f01caeff8a591ace39a3121d8c31326576e38
  • Loading branch information
seayoung1112 authored and facebook-github-bot committed Mar 6, 2020
1 parent 06c11e1 commit 8b99950
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 6 deletions.
41 changes: 36 additions & 5 deletions pytext/data/tensorizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -815,18 +815,49 @@ class LabelListTensorizer(LabelTensorizer):
"""

class Config(LabelTensorizer.Config):
pass
# pad missing label in the list, including None and empty
pad_missing: bool = True

def __init__(self, label_column: str = "label", *args, **kwargs):
super().__init__(label_column, *args, **kwargs)
@classmethod
def from_config(cls, config: Config):
return cls(
config.pad_missing,
config.column,
config.allow_unknown,
config.pad_in_vocab,
config.label_vocab,
config.is_input,
)

def __init__(self, pad_missing: bool = True, *args, **kwargs):
super().__init__(*args, **kwargs)
self.pad_missing = pad_missing

def __setstate__(self, newstate):
# for backward compatibility
if "pad_missing" not in newstate:
newstate["pad_missing"] = True
self.__dict__.update(newstate)

@property
def column_schema(self):
return [(self.label_column, List[str])]

def numberize(self, row):
labels = super().numberize(row)
return labels, len(labels)
label_idx_list = []
for label in row[self.label_column]:
# Only None and empty is viewed as missing data, values like "False" is legit
if label in [None, ""]:
if self.pad_missing:
label_idx_list.append(self.pad_idx)
else:
raise Exception(
"Found none or empty value in the list,"
+ " while pad_missing is disabled"
)
else:
label_idx_list.append(self.vocab.lookup_all(label))
return label_idx_list, len(label_idx_list)

def tensorize(self, batch):
labels, labels_len = zip(*batch)
Expand Down
29 changes: 29 additions & 0 deletions pytext/data/test/tensorizers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import List

import numpy as np
import pandas as pd
import torch
from pytext.data.bert_tensorizer import BERTTensorizer, BERTTensorizerScriptImpl
from pytext.data.roberta_tensorizer import (
Expand All @@ -13,6 +14,7 @@
)
from pytext.data.sources import SquadDataSource
from pytext.data.sources.data_source import Gazetteer, SafeFileWrapper, load_float_list
from pytext.data.sources.pandas import SessionPandasDataSource
from pytext.data.sources.tsv import SessionTSVDataSource, TSVDataSource
from pytext.data.squad_for_bert_tensorizer import (
SquadForBERTTensorizer,
Expand Down Expand Up @@ -127,6 +129,33 @@ def test_label_list_tensors_no_pad_in_vocab(self):
tensors.detach().numpy(),
)

def test_label_list_tensors_pad_missing(self):
ds = SessionPandasDataSource(
test_df=pd.DataFrame(
# test None and empty case
{
"session_id": [1, 1, 1, 1],
"label": ["positive", "negative", None, ""],
}
),
schema={"label": List[str]},
id_col="session_id",
)
tensorizers = {
"label": LabelListTensorizer(
label_column="label", pad_in_vocab=False, allow_unknown=False
)
}
initialize_tensorizers(tensorizers, ds.test)
self.assertEqual(2, len(tensorizers["label"].vocab))
# only one row in test data
label_idx_list, lens = tensorizers["label"].numberize(next(iter(ds.test)))
self.assertEqual([0, 1, -1, -1], label_idx_list)

tensorizers["label"].pad_missing = False
with self.assertRaises(Exception):
tensorizers["label"].numberize(next(iter(ds.test)))


# fmt: off
EXPECTED_ACTIONS = [
Expand Down
4 changes: 3 additions & 1 deletion pytext/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,9 @@ def add_all(self, values) -> None:
for value in values:
self.add_all(value)
else:
self.add(values)
# Don't add None or empty
if values:
self.add(values)

def add(self, value) -> None:
"""Count a single value in the vocabulary."""
Expand Down

0 comments on commit 8b99950

Please sign in to comment.