|
| 1 | +#!/usr/bin/env python3 |
| 2 | +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved |
| 3 | + |
| 4 | +from typing import Optional |
| 5 | + |
| 6 | +from pandas import DataFrame |
| 7 | +from pytext.data.sources.data_source import RootDataSource |
| 8 | + |
| 9 | + |
| 10 | +class PandasDataSource(RootDataSource): |
| 11 | + """ |
| 12 | + DataSource which loads data from a pandas DataFrame. |
| 13 | +
|
| 14 | + Inputs: |
| 15 | +
|
| 16 | + train_df: DataFrame for training |
| 17 | + eval_df: DataFrame for evalu |
| 18 | + test_df: DataFrame for test |
| 19 | + schema: same as base DataSource, define the list of output values with their types |
| 20 | + column_mapping: maps the column names in DataFrame to the name defined in schema |
| 21 | +
|
| 22 | + Example: |
| 23 | +
|
| 24 | + ds = PandasDataSource( |
| 25 | + train_df=pd.DataFrame({"c1": [10, 20, 30], "c2": [40, 50, 60]}), |
| 26 | + eval_df=pd.DataFrame({"c1": [11, 21, 31], "c2": [41, 51, 61]}), |
| 27 | + test_df=pd.DataFrame({"c1": [12, 22, 32], "c2": [42, 52, 62]}), |
| 28 | + schema={"feature1": float, "feature2": float}, |
| 29 | + column_mapping={"c1": "feature1", "c2": "feature2"}, |
| 30 | + ) |
| 31 | +
|
| 32 | + for row in ds.train: |
| 33 | + print(row) |
| 34 | +
|
| 35 | + will print out: |
| 36 | +
|
| 37 | + {"feature1": 10, "feature2": 40} |
| 38 | + {"feature1": 20, "feature2": 50} |
| 39 | + {"feature1": 30, "feature2": 60} |
| 40 | +
|
| 41 | + """ |
| 42 | + |
| 43 | + def __init__( |
| 44 | + self, |
| 45 | + train_df: Optional[DataFrame] = None, |
| 46 | + eval_df: Optional[DataFrame] = None, |
| 47 | + test_df: Optional[DataFrame] = None, |
| 48 | + **kwargs |
| 49 | + ): |
| 50 | + super().__init__(**kwargs) |
| 51 | + self.train_df = train_df |
| 52 | + self.eval_df = eval_df |
| 53 | + self.test_df = test_df |
| 54 | + |
| 55 | + @staticmethod |
| 56 | + def raw_generator(df: Optional[DataFrame]): |
| 57 | + if df is None: |
| 58 | + yield from () |
| 59 | + else: |
| 60 | + for _, row in df.iterrows(): |
| 61 | + yield row |
| 62 | + |
| 63 | + def raw_train_data_generator(self): |
| 64 | + return self.raw_generator(self.train_df) |
| 65 | + |
| 66 | + def raw_eval_data_generator(self): |
| 67 | + return self.raw_generator(self.eval_df) |
| 68 | + |
| 69 | + def raw_test_data_generator(self): |
| 70 | + return self.raw_generator(self.test_df) |
0 commit comments