Skip to content
This repository has been archived by the owner on Nov 22, 2022. It is now read-only.

Added PandasDataSource #1098

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs_requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,4 @@ pytorch-pretrained-bert
requests
torchtext
tensorboard==1.14
pandas
9 changes: 8 additions & 1 deletion pytext/data/sources/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,15 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved

from .data_source import DataSource, RawExample
from .pandas import PandasDataSource
from .squad import SquadDataSource
from .tsv import TSVDataSource


__all__ = ["DataSource", "RawExample", "SquadDataSource", "TSVDataSource"]
__all__ = [
"DataSource",
"RawExample",
"SquadDataSource",
"TSVDataSource",
"PandasDataSource",
]
54 changes: 54 additions & 0 deletions pytext/data/sources/pandas.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved

from typing import Optional

from pandas import DataFrame
from pytext.data.sources.data_source import RootDataSource


class PandasDataSource(RootDataSource):
"""
DataSource which loads data from a pandas DataFrame.

Inputs:
train_df: DataFrame for training

eval_df: DataFrame for evalu

test_df: DataFrame for test

schema: same as base DataSource, define the list of output values with their types

column_mapping: maps the column names in DataFrame to the name defined in schema

"""

def __init__(
self,
train_df: Optional[DataFrame] = None,
eval_df: Optional[DataFrame] = None,
test_df: Optional[DataFrame] = None,
**kwargs
):
super().__init__(**kwargs)
self.train_df = train_df
self.eval_df = eval_df
self.test_df = test_df

@staticmethod
def raw_generator(df: Optional[DataFrame]):
if df is None:
yield from ()
else:
for _, row in df.iterrows():
yield row

def raw_train_data_generator(self):
return self.raw_generator(self.train_df)

def raw_eval_data_generator(self):
return self.raw_generator(self.eval_df)

def raw_test_data_generator(self):
return self.raw_generator(self.test_df)
26 changes: 26 additions & 0 deletions pytext/data/test/pandas_data_source_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved

import unittest

import pandas as pd
from pytext.data.sources import PandasDataSource


class PandasDataSourceTest(unittest.TestCase):
def test_create_data_source(self):
ds = PandasDataSource(
train_df=pd.DataFrame({"c1": [10, 20, 30], "c2": [40, 50, 60]}),
eval_df=pd.DataFrame({"c1": [11, 21, 31], "c2": [41, 51, 61]}),
test_df=pd.DataFrame({"c1": [12, 22, 32], "c2": [42, 52, 62]}),
schema={"feature1": float, "feature2": float},
column_mapping={"c1": "feature1", "c2": "feature2"},
)
self.assertEqual({"feature1": 10, "feature2": 40}, next(iter(ds.train)))
self.assertEqual({"feature1": 11, "feature2": 41}, next(iter(ds.eval)))
self.assertEqual({"feature1": 12, "feature2": 42}, next(iter(ds.test)))
self.assertEqual(3, len(list(ds.train)))

def test_empty_data(self):
ds = PandasDataSource(schema={"feature1": float, "feature2": float})
self.assertEqual(0, len(list(ds.train)))