This repository has been archived by the owner on Nov 10, 2022. It is now read-only.
forked from CityofSantaMonica/mds-provider
-
Notifications
You must be signed in to change notification settings - Fork 0
/
db.py
329 lines (240 loc) · 11.2 KB
/
db.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
"""
Work with MDS Provider database backends.
"""
import json
import sqlalchemy
from ..db import loaders
from ..providers import Provider
from ..schemas import STATUS_CHANGES, TRIPS
from ..versions import UnsupportedVersionError, Version
def data_engine(uri=None, **kwargs):
"""
Create an engine for connections to a database backend.
Parameters:
uri: str, optional
A DBAPI-compatible connection URI.
e.g. for a PostgreSQL backend: postgresql://user:password@host:port/db
Required if any of (user, password, host, db) are not provided.
backend: str, optional
The type of the database backend. By default, postgresql.
user: str, optional
The user account for the database backend.
password: str, optional
The password for the user account.
host: str, optional
The host name of the database backend.
port: int, optional
The database backend connection port. By default, 5432 (postgres).
db: str, optional
The name of the database to connect to.
Return:
sqlalchemy.engine.Engine
"""
if uri is None and all(k in kwargs for k in ["user", "password", "host", "db"]):
backend = kwargs.pop("backend", "postgresql")
user, password, host, port, db = kwargs["user"], kwargs["password"], kwargs["host"], kwargs.get("port", 5432), kwargs["db"]
uri = f"{backend}://{user}:{password}@{host}:{port}/{db}"
elif uri is None:
raise KeyError("Provide either uri or ([backend], user, password, host, [port], db).")
return sqlalchemy.create_engine(uri)
class Database():
"""
Work with MDS Provider data in a database backend.
"""
def __init__(self, uri=None, **kwargs):
"""
Initialize a new ProviderDataLoader using a number of connection methods.
Parameters:
uri: str, optional
A DBAPI-compatible connection URI.
e.g. for a PostgreSQL backend: postgresql://user:password@host:port/db
Required if engine or any of (user, password, host, db) are not provided.
backend: str, optional
The type of the database backend. By default, postgresql.
user: str, optional
The user account for the database backend.
password: str, optional
The password for the user account.
host: str, optional
The host name of the database backend.
port: int, optional
The database backend connection port. By default, 5432 (postgres).
db: str, optional
The name of the database to connect to.
stage_first: bool, int, optional
True (default) to stage data in a temp table before upserting to the final table.
False to load directly into the target table.
Given an int greater than 0, determines the degrees of randomness when creating the
temp table, e.g.
stage_first=3
stages to a random temp table with 26*26*26 possible naming choices.
version: str, Version, optional
The MDS version to target. By default, Version.mds_lower().
Raise:
UnsupportedVersionError
When an unsupported MDS version is specified.
"""
self.version = Version(kwargs.pop("version", Version.mds_lower()))
if self.version.unsupported:
raise UnsupportedVersionError(self.version)
self.stage_first = kwargs.pop("stage_first", True)
self.engine = kwargs.pop("engine", data_engine(uri=uri, **kwargs))
def __repr__(self):
return f"<mds.db.Database ('{self.version}')>"
def load(self, source, record_type, table, **kwargs):
"""
Load MDS data from a variety of file path or object sources.
Parameters:
source: dict, list, str, Path, pandas.DataFrame
The data source to load, which could be any of:
* an MDS payload dict:
{
"version": "x.y.z",
"data": {
"record_type": [
//records here
]
}
}
* a list of MDS payload dicts
* one or more MDS data records, e.g. payload["data"][record_type]
* one or more file paths to MDS payload JSON files
* a pandas.DataFrame containing MDS data records
record_type: str
The type of MDS data, e.g. status_changes or trips
record_type: str
The type of MDS data ("status_changes" or "trips").
table: str
The name of the database table to insert this data into.
before_load: callable(df=DataFrame, version=Version): DataFrame, optional
Callback executed on an incoming DataFrame and Version.
Should return the final DataFrame for loading.
on_conflict_update: tuple (condition: str, actions: list), optional
Generate an "ON CONFLICT condition DO UPDATE SET actions" statement.
Only applies when stage_first evaluates True.
stage_first: bool, int, optional
True (default) to stage data in a temp table before upserting to the final table.
False to load directly into the target table.
Given an int greater than 0, determines the degrees of randomness when creating the
temp table, e.g.
stage_first=3
stages to a random temp table with 26*26*26 possible naming choices.
version: str, Version, optional
The MDS version to target. By default, Version.mds_lower().
Raise:
TypeError
When a loader for the type of source could not be found.
UnsupportedVersionError
When an unsupported MDS version is specified.
Return:
Database
self
"""
version = Version(kwargs.pop("version", self.version))
if version.unsupported:
raise UnsupportedVersionError(version)
if "stage_first" not in kwargs:
kwargs["stage_first"] = self.stage_first
loader_kwargs = {
**dict(record_type=record_type, table=table, engine=self.engine, version=version),
**kwargs
}
for loader in loaders.data_loaders():
if loader.can_load(source):
loader().load(source, **loader_kwargs)
return self
raise TypeError(f"Unrecognized type for source: {type(source)}")
def load_status_changes(self, source, **kwargs):
"""
Load MDS status_changes data.
Parameters:
source: dict, list, str, Path, pandas.DataFrame
See load() for supported source types.
table: str, optional
The name of the table to load data to. By default "status_changes".
before_load: callable(df=DataFrame, version=Version): DataFrame, optional
Callback executed on the incoming DataFrame and Version.
Should return the final DataFrame for loading.
drop_duplicates: list, optional
List of column names used to drop duplicate records before load.
version: str, Version, optional
The MDS version to target.
Additional keyword arguments are passed-through to load().
Return:
Database
self
"""
table = kwargs.pop("table", STATUS_CHANGES)
before_load = kwargs.pop("before_load", lambda df,v: df)
drop_duplicates = kwargs.pop("drop_duplicates", None)
def _before_load(df, version):
"""
Helper converts JSON cols and ensures optional cols exist
"""
if drop_duplicates:
df.drop_duplicates(subset=drop_duplicates, keep="last", inplace=True)
self._json_cols_tostring(df, ["event_location"])
null_cols = ["battery_pct"]
# version-depenedent association column
association_col = "associated_trips" if version < Version("0.3.0") else "associated_trip"
null_cols.append(association_col)
if version >= Version("0.3.0"):
null_cols.append("publication_time")
df = self._add_missing_cols(df, null_cols)
# coerce to object column
df[[association_col]] = df[[association_col]].astype("object")
if version < Version("0.3.0"):
# empty list by default
df[association_col] = df[association_col].apply(lambda d: d if isinstance(d, list) else [])
return before_load(df, version)
return self.load(source, STATUS_CHANGES, table, before_load=_before_load, **kwargs)
def load_trips(self, source, **kwargs):
"""
Load MDS trips data.
Parameters:
source: dict, list, str, Path, pandas.DataFrame
See load() for supported source types.
table: str, optional
The name of the table to load data to, by default trips.
before_load: callable(df=DataFrame, version=Version): DataFrame, optional
Callback executed on the incoming DataFrame and Version.
Should return the final DataFrame for loading.
drop_duplicates: list, optional
List of column names used to drop duplicate records before load.
By default, ["provider_id", "trip_id"]
Additional keyword arguments are passed-through to load().
Return:
Database
self
"""
table = kwargs.pop("table", TRIPS)
before_load = kwargs.pop("before_load", lambda df,v: df)
drop_duplicates = kwargs.pop("drop_duplicates", ["provider_id", "trip_id"])
def _before_load(df, version):
"""
Helper converts JSON cols and ensures optional cols exist
"""
if drop_duplicates:
df.drop_duplicates(subset=drop_duplicates, keep="last", inplace=True)
self._json_cols_tostring(df, ["route"])
null_cols = ["parking_verification_url", "standard_cost", "actual_cost"]
if version >= Version("0.3.0"):
null_cols.append("publication_time")
df = self._add_missing_cols(df, null_cols)
return before_load(df, version)
return self.load(source, TRIPS, table, before_load=_before_load, **kwargs)
@staticmethod
def _json_cols_tostring(df, cols):
"""
For each cols in the df, convert to a JSON string.
"""
for col in [c for c in cols if c in df]:
df[col] = df[col].apply(json.dumps)
return df
@staticmethod
def _add_missing_cols(df, cols):
"""
For each cols not in the df, add as an empty col.
"""
new_cols = set(df.columns.tolist() + cols)
return df.reindex(columns=new_cols)