forked from kennethreitz/records
-
Notifications
You must be signed in to change notification settings - Fork 0
/
records.py
532 lines (409 loc) · 16.1 KB
/
records.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
# -*- coding: utf-8 -*-
import os
from sys import stdout
from collections import OrderedDict
from contextlib import contextmanager
from inspect import isclass
import tablib
from docopt import docopt
from sqlalchemy import create_engine, exc, inspect, text
def isexception(obj):
"""Given an object, return a boolean indicating whether it is an instance
or subclass of :py:class:`Exception`.
"""
if isinstance(obj, Exception):
return True
if isclass(obj) and issubclass(obj, Exception):
return True
return False
class Record(object):
"""A row, from a query, from a database."""
__slots__ = ('_keys', '_values')
def __init__(self, keys, values):
self._keys = keys
self._values = values
# Ensure that lengths match properly.
assert len(self._keys) == len(self._values)
def keys(self):
"""Returns the list of column names from the query."""
return self._keys
def values(self):
"""Returns the list of values from the query."""
return self._values
def __repr__(self):
return '<Record {}>'.format(self.export('json')[1:-1])
def __getitem__(self, key):
# Support for index-based lookup.
if isinstance(key, int):
return self.values()[key]
# Support for string-based lookup.
if key in self.keys():
i = self.keys().index(key)
if self.keys().count(key) > 1:
raise KeyError("Record contains multiple '{}' fields.".format(key))
return self.values()[i]
raise KeyError("Record contains no '{}' field.".format(key))
def __getattr__(self, key):
try:
return self[key]
except KeyError as e:
raise AttributeError(e)
def __dir__(self):
standard = dir(super(Record, self))
# Merge standard attrs with generated ones (from column names).
return sorted(standard + [str(k) for k in self.keys()])
def get(self, key, default=None):
"""Returns the value for a given key, or default."""
try:
return self[key]
except KeyError:
return default
def as_dict(self, ordered=False):
"""Returns the row as a dictionary, as ordered."""
items = zip(self.keys(), self.values())
return OrderedDict(items) if ordered else dict(items)
@property
def dataset(self):
"""A Tablib Dataset containing the row."""
data = tablib.Dataset()
data.headers = self.keys()
row = _reduce_datetimes(self.values())
data.append(row)
return data
def export(self, format, **kwargs):
"""Exports the row to the given format."""
return self.dataset.export(format, **kwargs)
class RecordCollection(object):
"""A set of excellent Records from a query."""
def __init__(self, rows):
self._rows = rows
self._all_rows = []
self.pending = True
def __repr__(self):
return '<RecordCollection size={} pending={}>'.format(len(self), self.pending)
def __iter__(self):
"""Iterate over all rows, consuming the underlying generator
only when necessary."""
i = 0
while True:
# Other code may have iterated between yields,
# so always check the cache.
if i < len(self):
yield self[i]
else:
# Throws StopIteration when done.
# Prevent StopIteration bubbling from generator, following https://www.python.org/dev/peps/pep-0479/
try:
yield next(self)
except StopIteration:
return
i += 1
def next(self):
return self.__next__()
def __next__(self):
try:
nextrow = next(self._rows)
self._all_rows.append(nextrow)
return nextrow
except StopIteration:
self.pending = False
raise StopIteration('RecordCollection contains no more rows.')
def __getitem__(self, key):
is_int = isinstance(key, int)
# Convert RecordCollection[1] into slice.
if is_int:
key = slice(key, key + 1)
while len(self) < key.stop or key.stop is None:
try:
next(self)
except StopIteration:
break
rows = self._all_rows[key]
if is_int:
return rows[0]
else:
return RecordCollection(iter(rows))
def __len__(self):
return len(self._all_rows)
def export(self, format, **kwargs):
"""Export the RecordCollection to a given format (courtesy of Tablib)."""
return self.dataset.export(format, **kwargs)
@property
def dataset(self):
"""A Tablib Dataset representation of the RecordCollection."""
# Create a new Tablib Dataset.
data = tablib.Dataset()
# If the RecordCollection is empty, just return the empty set
# Check number of rows by typecasting to list
if len(list(self)) == 0:
return data
# Set the column names as headers on Tablib Dataset.
first = self[0]
data.headers = first.keys()
for row in self.all():
row = _reduce_datetimes(row.values())
data.append(row)
return data
def all(self, as_dict=False, as_ordereddict=False):
"""Returns a list of all rows for the RecordCollection. If they haven't
been fetched yet, consume the iterator and cache the results."""
# By calling list it calls the __iter__ method
rows = list(self)
if as_dict:
return [r.as_dict() for r in rows]
elif as_ordereddict:
return [r.as_dict(ordered=True) for r in rows]
return rows
def as_dict(self, ordered=False):
return self.all(as_dict=not(ordered), as_ordereddict=ordered)
def first(self, default=None, as_dict=False, as_ordereddict=False):
"""Returns a single record for the RecordCollection, or `default`. If
`default` is an instance or subclass of Exception, then raise it
instead of returning it."""
# Try to get a record, or return/raise default.
try:
record = self[0]
except IndexError:
if isexception(default):
raise default
return default
# Cast and return.
if as_dict:
return record.as_dict()
elif as_ordereddict:
return record.as_dict(ordered=True)
else:
return record
def one(self, default=None, as_dict=False, as_ordereddict=False):
"""Returns a single record for the RecordCollection, ensuring that it
is the only record, or returns `default`. If `default` is an instance
or subclass of Exception, then raise it instead of returning it."""
# Ensure that we don't have more than one row.
try:
self[1]
except IndexError:
return self.first(default=default, as_dict=as_dict, as_ordereddict=as_ordereddict)
else:
raise ValueError('RecordCollection contained more than one row. '
'Expects only one row when using '
'RecordCollection.one')
def scalar(self, default=None):
"""Returns the first column of the first row, or `default`."""
row = self.one()
return row[0] if row else default
class Database(object):
"""A Database. Encapsulates a url and an SQLAlchemy engine with a pool of
connections.
"""
def __init__(self, db_url=None, **kwargs):
# If no db_url was provided, fallback to $DATABASE_URL.
self.db_url = db_url or os.environ.get('DATABASE_URL')
if not self.db_url:
raise ValueError('You must provide a db_url.')
# Create an engine.
self._engine = create_engine(self.db_url, **kwargs)
self.open = True
def close(self):
"""Closes the Database."""
self._engine.dispose()
self.open = False
def __enter__(self):
return self
def __exit__(self, exc, val, traceback):
self.close()
def __repr__(self):
return '<Database open={}>'.format(self.open)
def get_table_names(self, internal=False):
"""Returns a list of table names for the connected database."""
# Setup SQLAlchemy for Database inspection.
return inspect(self._engine).get_table_names()
def get_connection(self):
"""Get a connection to this Database. Connections are retrieved from a
pool.
"""
if not self.open:
raise exc.ResourceClosedError('Database closed.')
return Connection(self._engine.connect())
def query(self, query, fetchall=False, **params):
"""Executes the given SQL query against the Database. Parameters can,
optionally, be provided. Returns a RecordCollection, which can be
iterated over to get result rows as dictionaries.
"""
with self.get_connection() as conn:
return conn.query(query, fetchall, **params)
def bulk_query(self, query, *multiparams):
"""Bulk insert or update."""
with self.get_connection() as conn:
conn.bulk_query(query, *multiparams)
def query_file(self, path, fetchall=False, **params):
"""Like Database.query, but takes a filename to load a query from."""
with self.get_connection() as conn:
return conn.query_file(path, fetchall, **params)
def bulk_query_file(self, path, *multiparams):
"""Like Database.bulk_query, but takes a filename to load a query from."""
with self.get_connection() as conn:
conn.bulk_query_file(path, *multiparams)
@contextmanager
def transaction(self):
"""A context manager for executing a transaction on this Database."""
conn = self.get_connection()
tx = conn.transaction()
try:
yield conn
tx.commit()
except:
tx.rollback()
finally:
conn.close()
class Connection(object):
"""A Database connection."""
def __init__(self, connection):
self._conn = connection
self.open = not connection.closed
def close(self):
self._conn.close()
self.open = False
def __enter__(self):
return self
def __exit__(self, exc, val, traceback):
self.close()
def __repr__(self):
return '<Connection open={}>'.format(self.open)
def query(self, query, fetchall=False, **params):
"""Executes the given SQL query against the connected Database.
Parameters can, optionally, be provided. Returns a RecordCollection,
which can be iterated over to get result rows as dictionaries.
"""
# Execute the given query.
cursor = self._conn.execute(text(query), **params) # TODO: PARAMS GO HERE
# Row-by-row Record generator.
row_gen = (Record(cursor.keys(), row) for row in cursor)
# Convert psycopg2 results to RecordCollection.
results = RecordCollection(row_gen)
# Fetch all results if desired.
if fetchall:
results.all()
return results
def bulk_query(self, query, *multiparams):
"""Bulk insert or update."""
self._conn.execute(text(query), *multiparams)
def query_file(self, path, fetchall=False, **params):
"""Like Connection.query, but takes a filename to load a query from."""
# If path doesn't exists
if not os.path.exists(path):
raise IOError("File '{}' not found!".format(path))
# If it's a directory
if os.path.isdir(path):
raise IOError("'{}' is a directory!".format(path))
# Read the given .sql file into memory.
with open(path) as f:
query = f.read()
# Defer processing to self.query method.
return self.query(query=query, fetchall=fetchall, **params)
def bulk_query_file(self, path, *multiparams):
"""Like Connection.bulk_query, but takes a filename to load a query
from.
"""
# If path doesn't exists
if not os.path.exists(path):
raise IOError("File '{}'' not found!".format(path))
# If it's a directory
if os.path.isdir(path):
raise IOError("'{}' is a directory!".format(path))
# Read the given .sql file into memory.
with open(path) as f:
query = f.read()
self._conn.execute(text(query), *multiparams)
def transaction(self):
"""Returns a transaction object. Call ``commit`` or ``rollback``
on the returned object as appropriate."""
return self._conn.begin()
def _reduce_datetimes(row):
"""Receives a row, converts datetimes to strings."""
row = list(row)
for i in range(len(row)):
if hasattr(row[i], 'isoformat'):
row[i] = row[i].isoformat()
return tuple(row)
def cli():
supported_formats = 'csv tsv json yaml html xls xlsx dbf latex ods'.split()
formats_lst=", ".join(supported_formats)
cli_docs ="""Records: SQL for Humans™
A Kenneth Reitz project.
Usage:
records <query> [<format>] [<params>...] [--url=<url>]
records (-h | --help)
Options:
-h --help Show this screen.
--url=<url> The database URL to use. Defaults to $DATABASE_URL.
Supported Formats:
%(formats_lst)s
Note: xls, xlsx, dbf, and ods formats are binary, and should only be
used with redirected output e.g. '$ records sql xls > sql.xls'.
Query Parameters:
Query parameters can be specified in key=value format, and injected
into your query in :key format e.g.:
$ records 'select * from repos where language ~= :lang' lang=python
Notes:
- While you may specify a database connection string with --url, records
will automatically default to the value of $DATABASE_URL, if available.
- Query is intended to be the path of a SQL file, however a query string
can be provided instead. Use this feature discernfully; it's dangerous.
- Records is intended for report-style exports of database queries, and
has not yet been optimized for extremely large data dumps.
""" % dict(formats_lst=formats_lst)
# Parse the command-line arguments.
arguments = docopt(cli_docs)
query = arguments['<query>']
params = arguments['<params>']
format = arguments.get('<format>')
if format and "=" in format:
del arguments['<format>']
arguments['<params>'].append(format)
format = None
if format and format not in supported_formats:
print('%s format not supported.' % format)
print('Supported formats are %s.' % formats_lst)
exit(62)
# Can't send an empty list if params aren't expected.
try:
params = dict([i.split('=') for i in params])
except ValueError:
print('Parameters must be given in key=value format.')
exit(64)
# Be ready to fail on missing packages
try:
# Create the Database.
db = Database(arguments['--url'])
# Execute the query, if it is a found file.
if os.path.isfile(query):
rows = db.query_file(query, **params)
# Execute the query, if it appears to be a query string.
elif len(query.split()) > 2:
rows = db.query(query, **params)
# Otherwise, say the file wasn't found.
else:
print('The given query could not be found.')
exit(66)
# Print results in desired format.
if format:
content = rows.export(format)
if isinstance(content, bytes):
print_bytes(content)
else:
print(content)
else:
print(rows.dataset)
except ImportError as impexc:
print(impexc.msg)
print("Used database or format require a package, which is missing.")
print("Try to install missing packages.")
exit(60)
def print_bytes(content):
try:
stdout.buffer.write(content)
except AttributeError:
stdout.write(content)
# Run the CLI when executed directly.
if __name__ == '__main__':
cli()