This repository has been archived by the owner on Nov 22, 2022. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 799
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Summary: Pull Request resolved: #1098 DataSource implementation which is a simple wrapper of Pandas DataFrame Reviewed By: chenyangyu1988 Differential Revision: D18266156 fbshipit-source-id: 9a6d0cb7663993e23e86bd9694da7382d76cc02f
- Loading branch information
1 parent
8008a28
commit 7d78090
Showing
4 changed files
with
89 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,3 +10,4 @@ pytorch-pretrained-bert | |
requests | ||
torchtext | ||
tensorboard==1.14 | ||
pandas |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
#!/usr/bin/env python3 | ||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved | ||
|
||
from typing import Optional | ||
|
||
from pandas import DataFrame | ||
from pytext.data.sources.data_source import RootDataSource | ||
|
||
|
||
class PandasDataSource(RootDataSource): | ||
""" | ||
DataSource which loads data from a pandas DataFrame. | ||
Inputs: | ||
train_df: DataFrame for training | ||
eval_df: DataFrame for evalu | ||
test_df: DataFrame for test | ||
schema: same as base DataSource, define the list of output values with their types | ||
column_mapping: maps the column names in DataFrame to the name defined in schema | ||
""" | ||
|
||
def __init__( | ||
self, | ||
train_df: Optional[DataFrame] = None, | ||
eval_df: Optional[DataFrame] = None, | ||
test_df: Optional[DataFrame] = None, | ||
**kwargs | ||
): | ||
super().__init__(**kwargs) | ||
self.train_df = train_df | ||
self.eval_df = eval_df | ||
self.test_df = test_df | ||
|
||
@staticmethod | ||
def raw_generator(df: Optional[DataFrame]): | ||
if df is None: | ||
yield from () | ||
else: | ||
for _, row in df.iterrows(): | ||
yield row | ||
|
||
def raw_train_data_generator(self): | ||
return self.raw_generator(self.train_df) | ||
|
||
def raw_eval_data_generator(self): | ||
return self.raw_generator(self.eval_df) | ||
|
||
def raw_test_data_generator(self): | ||
return self.raw_generator(self.test_df) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
#!/usr/bin/env python3 | ||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved | ||
|
||
import unittest | ||
|
||
import pandas as pd | ||
from pytext.data.sources import PandasDataSource | ||
|
||
|
||
class PandasDataSourceTest(unittest.TestCase): | ||
def test_create_data_source(self): | ||
ds = PandasDataSource( | ||
train_df=pd.DataFrame({"c1": [10, 20, 30], "c2": [40, 50, 60]}), | ||
eval_df=pd.DataFrame({"c1": [11, 21, 31], "c2": [41, 51, 61]}), | ||
test_df=pd.DataFrame({"c1": [12, 22, 32], "c2": [42, 52, 62]}), | ||
schema={"feature1": float, "feature2": float}, | ||
column_mapping={"c1": "feature1", "c2": "feature2"}, | ||
) | ||
self.assertEqual({"feature1": 10, "feature2": 40}, next(iter(ds.train))) | ||
self.assertEqual({"feature1": 11, "feature2": 41}, next(iter(ds.eval))) | ||
self.assertEqual({"feature1": 12, "feature2": 42}, next(iter(ds.test))) | ||
self.assertEqual(3, len(list(ds.train))) | ||
|
||
def test_empty_data(self): | ||
ds = PandasDataSource(schema={"feature1": float, "feature2": float}) | ||
self.assertEqual(0, len(list(ds.train))) |