Skip to content

Commit

Permalink
simplify ProviderDatabase API
Browse files Browse the repository at this point in the history
refactor individual loaders out,
search for one that works based on input.

similar approach to mds.api.auth types
  • Loading branch information
thekaveman committed May 9, 2019
1 parent aab897a commit 11a3748
Show file tree
Hide file tree
Showing 4 changed files with 351 additions and 223 deletions.
10 changes: 6 additions & 4 deletions mds/api/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ class AuthorizationToken():
@classmethod
can_auth(cls, provider): bool
return True if the auth type can be used on the provider.
Return True if the auth type can be used on the provider.
See OAuthClientCredentialsAuth for an example implementation.
"""
Expand Down Expand Up @@ -119,7 +119,9 @@ def auth_types():
"""
Return a list of all supported authentication types.
"""
types = AuthorizationToken.__subclasses__()
types.append(AuthorizationToken)
def all_subs(cls):
return set(cls.__subclasses__()).union(
[s for c in cls.__subclasses__() for s in all_subs(c)]
).union([cls])

return types
return all_subs(AuthorizationToken)
262 changes: 47 additions & 215 deletions mds/db/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import pandas as pd
import sqlalchemy

from ..db import sql
from ..db import loaders, sql
from ..fake import util
from ..files import ProviderDataFiles
from ..providers import Provider
Expand Down Expand Up @@ -110,7 +110,7 @@ def __init__(self, uri=None, **kwargs):
version: str, Version, optional
The MDS version to target. By default, Version.mds_lower().
Raises:
Raise:
UnsupportedVersionError
When an unsupported MDS version is specified.
"""
Expand All @@ -124,24 +124,44 @@ def __init__(self, uri=None, **kwargs):
def __repr__(self):
return f"<mds.db.ProviderDatabase ('{self.version}')>"

def load_from_df(self, df, record_type, table, **kwargs):
def load(self, source, record_type, table, **kwargs):
"""
Inserts MDS data from a DataFrame.
Load MDS data from a variety of file path or object sources.
Parameters:
df: DataFrame
Data of type record_type to insert.
source: dict, list, str, Path, pandas.DataFrame
The data source to load, which could be any of:
* an MDS payload dict:
{
"version": "x.y.z",
"data": {
"record_type": [
//records here
]
}
}
* a list of MDS payload dicts
* one or more MDS data records, e.g. payload["data"][record_type]
* one or more file paths to MDS payload JSON files
* a pandas.DataFrame containing MDS data records
record_type: str
The type of MDS data, e.g. status_changes or trips
record_type: str
The type of MDS data ("status_changes" or "trips").
table: str
The name of the database table to insert this data into.
before_load: callable(df=DataFrame, version=Version): DataFrame, optional
Callback executed on the incoming DataFrame and Version.
Callback executed on an incoming DataFrame and Version.
Should return the final DataFrame for loading.
on_conflict_update: tuple (condition: str, actions: list), optional
Generate an "ON CONFLICT condition DO UPDATE SET actions" statement.
Only applies when stage_first evaluates True.
stage_first: bool, int, optional
True (default) to stage data in a temp table before upserting to the final table.
False to load directly into the target table.
Expand All @@ -153,225 +173,41 @@ def load_from_df(self, df, record_type, table, **kwargs):
stages to a random temp table with 26*26*26 possible naming choices.
on_conflict_update: tuple (condition: str, actions: list), optional
Generate an "ON CONFLICT condition DO UPDATE SET actions" statement.
Only applies when stage_first evaluates True.
version: str, Version, optional
The MDS version to target.
Raises:
UnsupportedVersionError
When an unsupported MDS version is specified.
Return:
ProviderDataLoader
self
"""
version = Version(kwargs.get("version", self.version))
if version.unsupported:
raise UnsupportedVersionError(version)

before_load = kwargs.get("before_load", None)
stage_first = kwargs.get("stage_first", self.stage_first)
on_conflict_update = kwargs.get("on_conflict_update", None)

# run any pre-processors to transform the df
if before_load is not None:
new_df = before_load(df, version)
df = new_df if new_df is not None else df

if not stage_first:
# append the data to an existing table
df.to_sql(table, self.engine, if_exists="append", index=False)
else:
# insert this DataFrame into a fresh temp table
factor = stage_first if isinstance(stage_first, int) else 1
temp = f"{table}_tmp_{util.random_string(factor, chars=string.ascii_lowercase)}"
df.to_sql(temp, self.engine, if_exists="replace", index=False)

# now insert from the temp table to the actual table
with self.engine.begin() as conn:
if record_type == STATUS_CHANGES:
query = sql.insert_status_changes_from(temp, table, on_conflict_update=on_conflict_update, version=version)
elif record_type == TRIPS:
query = sql.insert_trips_from(temp, table, on_conflict_update=on_conflict_update, version=version)
if query is not None:
conn.execute(query)
# delete temp table (not a true TEMPORARY table)
conn.execute(f"DROP TABLE {temp}")
return self

def load_from_file(self, src, record_type, table, **kwargs):
"""
Load MDS data from a file source.
Parameters:
src: str
An mds.json.files_to_df() compatible JSON file path.
record_type: str
The type of MDS data, e.g. status_changes or trips
table: str
The name of the table to load data to.
version: str, Version, optional
The MDS version to target.
Additional keyword arguments are passed-through to load_from_df().
Raises:
UnexpectedVersionError
When data is parsed with a version different from what was expected.
UnsupportedVersionError
When an unsupported MDS version is specified.
Return:
ProviderDataLoader
self
"""
version = Version(kwargs.get("version", self.version))
if version.unsupported:
raise UnsupportedVersionError(version)

# read the data file
_version, df = ProviderDataFiles(src).load_dataframe(record_type)

if _version != version:
raise UnexpectedVersionError(_version, version)

return self.load_from_df(df, record_type, table, **kwargs)

def load_from_records(self, records, record_type, table, **kwargs):
"""
Load MDS data from a list of records.
Parameters:
records: list
A list of dicts of type record_type.
record_type: str
The type of MDS data, e.g. status_changes or trips
table: str
The name of the table to load data to.
Additional keyword arguments are passed-through to load_from_df().
Raises:
TypeError
When records is not a list of dicts.
Return:
ProviderDataLoader
self
"""
if isinstance(records, list) and len(records) > 0 and all([isinstance(d, dict) for d in records]):
df = pd.DataFrame.from_records(records)
self.load_from_df(df, record_type, table, **kwargs)
return self

raise TypeError(f"Unknown type for records: {type(records)}")

def load_from_source(self, source, record_type, table, **kwargs):
"""
Load MDS data from a variety of file path or object sources.
Parameters:
source: dict, list, str, Path
The data source to load, which could be any of:
- an MDS payload dict, e.g.
{
"version": "x.y.z",
"data": {
"record_type": [{
"device_id": "1",
...
},
{
"device_id": "2",
...
}]
}
}
- a list of MDS payloads
- a list of MDS data records, e.g. payload["data"][record_type]
- a [list of] MDS payload JSON file paths
record_type: str
The type of MDS data, e.g. status_changes or trips
table: str
The name of the table to load data to.
version: str, Version, optional
The MDS version to target.
Additional keyword arguments are passed-through to load_from_df().
The MDS version to target. By default, Version.mds_lower().
Raises:
Raise:
TypeError
When the type of source is not recognized.
UnexpectedVersionError
When data is parsed with a version different from what was expected.
When a loader for the type of source could not be found.
UnsupportedVersionError
When an unsupported MDS version is specified.
Return:
ProviderDataLoader
ProviderDatabase
self
"""
version = Version(kwargs.get("version", self.version))
if version.unsupported:
raise UnsupportedVersionError(version)

def _valid_path(p):
"""
Check for a valid path reference
"""
return (isinstance(p, str) and os.path.exists(p)) or (isinstance(p, Path) and p.exists())

# source is a single data page
if isinstance(source, dict) and "data" in source and record_type in source["data"]:
_version, records = Version(source["version"]), source["data"][record_type]
if _version != version:
raise UnexpectedVersionError(_version, version)
self.load_from_records(records, record_type, table, **kwargs)

# source is a list of data pages
elif isinstance(source, list) and all([isinstance(s, dict) and "data" in s for s in source]):
for page in source:
self.load_from_source(page, record_type, table, **kwargs)

# source is a list of file paths, load only the valid paths
elif isinstance(source, list) and any([_valid_path(p) for p in source]):
for path in [p for p in source if _valid_path(p)]:
self.load_from_source(path, record_type, table, **kwargs)
for loader in loaders.data_loaders():
if loader.can_load(source):
loader().load(source, record_type=record_type, table=table, engine=self.engine, **kwargs)
return self

# source is a single (valid) file path
elif _valid_path(source):
self.load_from_file(source, record_type, table, **kwargs)

# source is something else we can't handle
else:
raise TypeError(f"Unrecognized type for source: {type(source)}")

return self
raise TypeError(f"Unrecognized type for source: {type(source)}")

def load_status_changes(self, source, **kwargs):
"""
Load MDS status_changes data.
Parameters:
source: dict, list, str, Path
See load_from_sources for supported source types.
source: dict, list, str, Path, pandas.DataFrame
See load() for supported source types.
table: str, optional
The name of the table to load data to, by default status_changes.
The name of the table to load data to. By default "status_changes".
before_load: callable(df=DataFrame, version=Version): DataFrame, optional
Callback executed on the incoming DataFrame and Version.
Expand All @@ -383,16 +219,12 @@ def load_status_changes(self, source, **kwargs):
version: str, Version, optional
The MDS version to target.
Additional keyword arguments are passed-through to load_from_df().
Additional keyword arguments are passed-through to load().
Return:
ProviderDataLoader
ProviderDatabase
self
"""
version = Version(kwargs.get("version", self.version))
if version.unsupported:
raise UnsupportedVersionError(version)

table = kwargs.pop("table", STATUS_CHANGES)
before_load = kwargs.pop("before_load", lambda df,v: df)
drop_duplicates = kwargs.pop("drop_duplicates", None)
Expand All @@ -408,15 +240,15 @@ def _before_load(df,v):
df = self._add_missing_cols(df, missing_cols)
return before_load(df,v)

return self.load_from_source(source, STATUS_CHANGES, table, before_load=_before_load, **kwargs)
return self.load(source, STATUS_CHANGES, table, before_load=_before_load, **kwargs)

def load_trips(self, source, **kwargs):
"""
Load MDS trips data.
Parameters:
source: dict, list, str, Path
See load_from_sources for supported source types.
source: dict, list, str, Path, pandas.DataFrame
See load() for supported source types.
table: str, optional
The name of the table to load data to, by default trips.
Expand All @@ -429,10 +261,10 @@ def load_trips(self, source, **kwargs):
List of column names used to drop duplicate records before load.
By default, ["provider_id", "trip_id"]
Additional keyword arguments are passed-through to load_from_df().
Additional keyword arguments are passed-through to load().
Return:
ProviderDataLoader
ProviderDatabase
self
"""
table = kwargs.pop("table", TRIPS)
Expand All @@ -448,7 +280,7 @@ def _before_load(df,v):
df = self._add_missing_cols(df, ["parking_verification_url", "standard_cost", "actual_cost"])
return before_load(df,v)

return self.load_from_source(source, TRIPS, table, before_load=_before_load, **kwargs)
return self.load(source, TRIPS, table, before_load=_before_load, **kwargs)

@staticmethod
def _json_cols_tostring(df, cols):
Expand Down
Loading

0 comments on commit 11a3748

Please sign in to comment.