Skip to content

Commit a3ec504

Browse files
seayoung1112facebook-github-bot
authored andcommitted
Added PandasDataSource (facebookresearch#1098)
Summary: Pull Request resolved: facebookresearch#1098 DataSource implementation which is a simple wrapper of Pandas DataFrame Reviewed By: chenyangyu1988 Differential Revision: D18266156 fbshipit-source-id: 3a7bce83ec7aa07309074a44eb44d97dd613408a
1 parent 8008a28 commit a3ec504

File tree

4 files changed

+105
-1
lines changed

4 files changed

+105
-1
lines changed

docs_requirements.txt

+1
Original file line numberDiff line numberDiff line change
@@ -10,3 +10,4 @@ pytorch-pretrained-bert
1010
requests
1111
torchtext
1212
tensorboard==1.14
13+
pandas

pytext/data/sources/__init__.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,15 @@
22
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
33

44
from .data_source import DataSource, RawExample
5+
from .pandas import PandasDataSource
56
from .squad import SquadDataSource
67
from .tsv import TSVDataSource
78

89

9-
__all__ = ["DataSource", "RawExample", "SquadDataSource", "TSVDataSource"]
10+
__all__ = [
11+
"DataSource",
12+
"RawExample",
13+
"SquadDataSource",
14+
"TSVDataSource",
15+
"PandasDataSource",
16+
]

pytext/data/sources/pandas.py

+70
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
3+
4+
from typing import Optional
5+
6+
from pandas import DataFrame
7+
from pytext.data.sources.data_source import RootDataSource
8+
9+
10+
class PandasDataSource(RootDataSource):
11+
"""
12+
DataSource which loads data from a pandas DataFrame.
13+
14+
Inputs:
15+
16+
train_df: DataFrame for training
17+
eval_df: DataFrame for evalu
18+
test_df: DataFrame for test
19+
schema: same as base DataSource, define the list of output values with their types
20+
column_mapping: maps the column names in DataFrame to the name defined in schema
21+
22+
Example:
23+
24+
ds = PandasDataSource(
25+
train_df=pd.DataFrame({"c1": [10, 20, 30], "c2": [40, 50, 60]}),
26+
eval_df=pd.DataFrame({"c1": [11, 21, 31], "c2": [41, 51, 61]}),
27+
test_df=pd.DataFrame({"c1": [12, 22, 32], "c2": [42, 52, 62]}),
28+
schema={"feature1": float, "feature2": float},
29+
column_mapping={"c1": "feature1", "c2": "feature2"},
30+
)
31+
32+
for row in ds.train:
33+
print(row)
34+
35+
will print out:
36+
37+
{"feature1": 10, "feature2": 40}
38+
{"feature1": 20, "feature2": 50}
39+
{"feature1": 30, "feature2": 60}
40+
41+
"""
42+
43+
def __init__(
44+
self,
45+
train_df: Optional[DataFrame] = None,
46+
eval_df: Optional[DataFrame] = None,
47+
test_df: Optional[DataFrame] = None,
48+
**kwargs
49+
):
50+
super().__init__(**kwargs)
51+
self.train_df = train_df
52+
self.eval_df = eval_df
53+
self.test_df = test_df
54+
55+
@staticmethod
56+
def raw_generator(df: Optional[DataFrame]):
57+
if df is None:
58+
yield from ()
59+
else:
60+
for _, row in df.iterrows():
61+
yield row
62+
63+
def raw_train_data_generator(self):
64+
return self.raw_generator(self.train_df)
65+
66+
def raw_eval_data_generator(self):
67+
return self.raw_generator(self.eval_df)
68+
69+
def raw_test_data_generator(self):
70+
return self.raw_generator(self.test_df)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
3+
4+
import unittest
5+
6+
import pandas as pd
7+
from pytext.data.sources import PandasDataSource
8+
9+
10+
class PandasDataSourceTest(unittest.TestCase):
11+
def test_create_data_source(self):
12+
ds = PandasDataSource(
13+
train_df=pd.DataFrame({"c1": [10, 20, 30], "c2": [40, 50, 60]}),
14+
eval_df=pd.DataFrame({"c1": [11, 21, 31], "c2": [41, 51, 61]}),
15+
test_df=pd.DataFrame({"c1": [12, 22, 32], "c2": [42, 52, 62]}),
16+
schema={"feature1": float, "feature2": float},
17+
column_mapping={"c1": "feature1", "c2": "feature2"},
18+
)
19+
self.assertEqual({"feature1": 10, "feature2": 40}, next(iter(ds.train)))
20+
self.assertEqual({"feature1": 11, "feature2": 41}, next(iter(ds.eval)))
21+
self.assertEqual({"feature1": 12, "feature2": 42}, next(iter(ds.test)))
22+
self.assertEqual(3, len(list(ds.train)))
23+
24+
def test_empty_data(self):
25+
ds = PandasDataSource(schema={"feature1": float, "feature2": float})
26+
self.assertEqual(0, len(list(ds.train)))

0 commit comments

Comments
 (0)