Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Id column converter #63

Open
wants to merge 9 commits into
base: master
Choose a base branch
from
58 changes: 50 additions & 8 deletions partridge/gtfs.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
from threading import RLock
from typing import Dict, Optional, Union
from warnings import warn

import networkx as nx
import numpy as np
Expand Down Expand Up @@ -34,9 +35,13 @@ def __init__(
self._locks: Dict[str, RLock] = {}
if isinstance(source, self.__class__):
self._read = source.get
self._proxy_feed = bool(self._view)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would prefer the value of _proxy_feed not to depend on whether in feed is initialized from a path or another feed object. Is that possible?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it could probably just be bool(self.view) out side the of if block

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried that and it didn't work.
Do you prefer passing proxy as a parameter to feed.init ?

elif isinstance(source, str) and os.path.isdir(source):
self._read = self._read_csv
self._bootstrap(source)
self._proxy_feed = True
# Validate the configuration and raise warning if needed
self._validate_dependencies_conversion()
else:
raise ValueError("Invalid source")

Expand All @@ -46,11 +51,15 @@ def get(self, filename: str) -> pd.DataFrame:
df = self._cache.get(filename)
if df is None:
df = self._read(filename)
df = self._filter(filename, df)
df = self._prune(filename, df)
self._convert_types(filename, df)
df = df.reset_index(drop=True)
df = self._transform(filename, df)
if self._proxy_feed:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure the choice should be to filter+prune OR convert+transform. Can you tell me a bit about how you are thinking about this behavior? I will need to think through the logic.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As I see it for each table you filter you create a feed, and each feed is the source of the next one. except for the last layer the feeds (the proxy feeds) are only responsible for filtering the table according to the filter and the already filtered table (pruning). That's why those proxy feed you only need to filter and prune, before my change you did that by removing the transform and convert data from the configuration.
The last feed layer doesn't need to deal with the pruning and filtering since it doesn't even get a view as a parameter, and the lower level feeds are doing the pruning already.

Tell me if I missed something

# files feed responsible for file access
df = self._filter(filename, df)
df = self._prune(filename, df)
df = df.reset_index(drop=True)
else:
# proxy feed responsible for data conversion
self._convert_types(filename, df)
df = self._transform(filename, df)
self.set(filename, df)
return self._cache[filename]

Expand Down Expand Up @@ -95,7 +104,7 @@ def _read_csv(self, filename: str) -> pd.DataFrame:
# DataFrame containing any required columns.
return empty_df(columns)

# If the file isn't in the zip, return an empty DataFrame.
# Read file encoding
with open(path, "rb") as f:
encoding = detect_encoding(f)

Expand All @@ -121,7 +130,6 @@ def _filter(self, filename: str, df: pd.DataFrame) -> pd.DataFrame:
# If applicable, filter this dataframe by the given set of values
if col in df.columns:
df = df[df[col].isin(setwrap(values))]

return df

def _prune(self, filename: str, df: pd.DataFrame) -> pd.DataFrame:
Expand All @@ -147,10 +155,44 @@ def _prune(self, filename: str, df: pd.DataFrame) -> pd.DataFrame:
depcol = deps[depfile]
# If applicable, prune this dataframe by the other
if col in df.columns and depcol in depdf.columns:
df = df[df[col].isin(depdf[depcol])]
converter = self._get_convert_function(filename, col)
# Convert the column before pruning since depdf is already converted
col_series = converter(df[col]) if converter else df[col]
df = df[col_series.isin(depdf[depcol])]

return df

def _get_convert_function(self, filename, colname):
"""return the convert function from the config
for a specific file and column"""
return self._config.nodes.get(filename, {}).get("converters", {}).get(colname)

def _validate_dependencies_conversion(self):
"""Validate that dependent columns in different files
has the same convert function if one exist.
"""

def check_column_pair(column_pair: dict) -> bool:
assert len(column_pair) == 2
convert_funcs = [
self._get_convert_function(filename, colname)
for filename, colname in column_pair.items()
]
if convert_funcs[0] != convert_funcs[1]:
return False
return True

for file_a, file_b, data in self._config.edges(data=True):
dependencies = data.get("dependencies", [])
for column_pair in dependencies:
if check_column_pair(column_pair):
continue
warn(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why produce a warning here as opposed to raising an exception?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought that it might be intentional, for example int8 and int16, or something like that.

f"Converters Mismatch: column `{column_pair[file_a]}` in {file_a} "
f"is dependant on column `{column_pair[file_b]}` in {file_b} "
f"but converted with different functions, which might cause merging problems."
)

def _convert_types(self, filename: str, df: pd.DataFrame) -> None:
"""
Apply type conversions
Expand Down
7 changes: 2 additions & 5 deletions partridge/readers.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,6 @@
from .gtfs import Feed
from .parsers import vparse_date
from .types import View
from .utilities import remove_node_attributes


DAY_NAMES = (
"monday",
Expand Down Expand Up @@ -105,10 +103,9 @@ def finalize() -> None:

def _load_feed(path: str, view: View, config: nx.DiGraph) -> Feed:
"""Multi-file feed filtering"""
config_ = remove_node_attributes(config, ["converters", "transformations"])
feed_ = Feed(path, view={}, config=config_)
feed_ = Feed(path, view={}, config=config)
for filename, column_filters in view.items():
config_ = reroot_graph(config_, filename)
config_ = reroot_graph(config, filename)
view_ = {filename: column_filters}
feed_ = Feed(feed_, view=view_, config=config_)
return Feed(feed_, config=config)
Expand Down
17 changes: 17 additions & 0 deletions tests/test_feed.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import datetime

import pandas as pd
import pytest

import partridge as ptg
Expand Down Expand Up @@ -225,3 +227,18 @@ def test_filtered_columns(path):

assert set(feed_full.trips.columns) == set(feed_view.trips.columns)
assert set(feed_full.trips.columns) == set(feed_null.trips.columns)


@pytest.mark.parametrize("path", [fixture("amazon-2017-08-06")])
def test_converted_id_column(path):
conf = default_config()
conf.nodes["routes.txt"]["converters"]["route_id"] = pd.to_numeric
with pytest.warns(UserWarning, match="Converters Mismatch:"):
ptg.load_feed(path, config=conf)
conf.nodes["trips.txt"]["converters"]["route_id"] = pd.to_numeric
# Just to prevent another warning
conf.nodes["fare_rules.txt"]["converters"] = {}
conf.nodes["fare_rules.txt"]["converters"]["route_id"] = pd.to_numeric
feed = ptg.load_feed(path, config=conf)
assert len(feed.trips) > 0
assert len(feed.routes) > 0