Skip to content
This repository has been archived by the owner on Nov 22, 2022. It is now read-only.

Commit

Permalink
Added PandasDataSource (#1098)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #1098

DataSource implementation which is a simple wrapper of Pandas DataFrame

Reviewed By: chenyangyu1988

Differential Revision: D18266156

fbshipit-source-id: 9a6d0cb7663993e23e86bd9694da7382d76cc02f
  • Loading branch information
seayoung1112 authored and facebook-github-bot committed Nov 2, 2019
1 parent 8008a28 commit 7d78090
Show file tree
Hide file tree
Showing 4 changed files with 89 additions and 1 deletion.
1 change: 1 addition & 0 deletions docs_requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,4 @@ pytorch-pretrained-bert
requests
torchtext
tensorboard==1.14
pandas
9 changes: 8 additions & 1 deletion pytext/data/sources/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,15 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved

from .data_source import DataSource, RawExample
from .pandas import PandasDataSource
from .squad import SquadDataSource
from .tsv import TSVDataSource


__all__ = ["DataSource", "RawExample", "SquadDataSource", "TSVDataSource"]
__all__ = [
"DataSource",
"RawExample",
"SquadDataSource",
"TSVDataSource",
"PandasDataSource",
]
54 changes: 54 additions & 0 deletions pytext/data/sources/pandas.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved

from typing import Optional

from pandas import DataFrame
from pytext.data.sources.data_source import RootDataSource


class PandasDataSource(RootDataSource):
"""
DataSource which loads data from a pandas DataFrame.
Inputs:
train_df: DataFrame for training
eval_df: DataFrame for evalu
test_df: DataFrame for test
schema: same as base DataSource, define the list of output values with their types
column_mapping: maps the column names in DataFrame to the name defined in schema
"""

def __init__(
self,
train_df: Optional[DataFrame] = None,
eval_df: Optional[DataFrame] = None,
test_df: Optional[DataFrame] = None,
**kwargs
):
super().__init__(**kwargs)
self.train_df = train_df
self.eval_df = eval_df
self.test_df = test_df

@staticmethod
def raw_generator(df: Optional[DataFrame]):
if df is None:
yield from ()
else:
for _, row in df.iterrows():
yield row

def raw_train_data_generator(self):
return self.raw_generator(self.train_df)

def raw_eval_data_generator(self):
return self.raw_generator(self.eval_df)

def raw_test_data_generator(self):
return self.raw_generator(self.test_df)
26 changes: 26 additions & 0 deletions pytext/data/test/pandas_data_source_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved

import unittest

import pandas as pd
from pytext.data.sources import PandasDataSource


class PandasDataSourceTest(unittest.TestCase):
def test_create_data_source(self):
ds = PandasDataSource(
train_df=pd.DataFrame({"c1": [10, 20, 30], "c2": [40, 50, 60]}),
eval_df=pd.DataFrame({"c1": [11, 21, 31], "c2": [41, 51, 61]}),
test_df=pd.DataFrame({"c1": [12, 22, 32], "c2": [42, 52, 62]}),
schema={"feature1": float, "feature2": float},
column_mapping={"c1": "feature1", "c2": "feature2"},
)
self.assertEqual({"feature1": 10, "feature2": 40}, next(iter(ds.train)))
self.assertEqual({"feature1": 11, "feature2": 41}, next(iter(ds.eval)))
self.assertEqual({"feature1": 12, "feature2": 42}, next(iter(ds.test)))
self.assertEqual(3, len(list(ds.train)))

def test_empty_data(self):
ds = PandasDataSource(schema={"feature1": float, "feature2": float})
self.assertEqual(0, len(list(ds.train)))

0 comments on commit 7d78090

Please sign in to comment.