Skip to content

Commit

Permalink
Merge pull request #80 from bileto/iterator
Browse files Browse the repository at this point in the history
Use .iterator() to save RAM during export
  • Loading branch information
jwhitlock authored Mar 9, 2018
2 parents 7d30696 + b493b1d commit ca19ee6
Showing 1 changed file with 36 additions and 61 deletions.
97 changes: 36 additions & 61 deletions multigtfs/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
logger = getLogger(__name__)
re_point = re.compile(r'(?P<name>point)\[(?P<index>\d)\]')
batch_size = 1000
large_queryset_size = 100000
CSV_BOM = BOM_UTF8.decode('utf-8') if PY3 else BOM_UTF8


Expand Down Expand Up @@ -365,70 +364,46 @@ def export_txt(cls, feed):
cache[field_name][None] = u''
model_to_field_name[model_name] = field_name

# For large querysets, break up by the first field
if total < large_queryset_size:
querysets = [objects.order_by(*sort_fields)]
else: # pragma: no cover
field1_raw = sort_fields[0]
assert '__' in field1_raw
assert field1_raw in cache
field1 = field1_raw.split('__', 1)[0]
field1_id = field1 + '_id'

# Sort field1 ids by field1 values
val_to_id = dict((v, k) for k, v in cache[field1_raw].items())
assert len(val_to_id) == len(cache[field1_raw])
sorted_vals = sorted(val_to_id.keys())

querysets = []
for val in sorted_vals:
fid = val_to_id[val]
if fid:
qs = objects.filter(
**{field1_id: fid}).order_by(*sort_fields[1:])
querysets.append(qs)

# Assemble the rows, writing when we hit batch size
count = 0
rows = []
for queryset in querysets:
for item in queryset.order_by(*sort_fields):
row = []
for csv_name, field_name in column_map:
obj = item
point_match = re_point.match(field_name)
if '__' in field_name:
# Return relations from cache
local_field_name = field_name.split('__', 1)[0]
field_id = getattr(obj, local_field_name + '_id')
row.append(cache[field_name][field_id])
elif point_match:
# Get the lat or long from the point
name, index = point_match.groups()
field = getattr(obj, name)
row.append(field.coords[int(index)])
for item in objects.order_by(*sort_fields).iterator():
row = []
for csv_name, field_name in column_map:
obj = item
point_match = re_point.match(field_name)
if '__' in field_name:
# Return relations from cache
local_field_name = field_name.split('__', 1)[0]
field_id = getattr(obj, local_field_name + '_id')
row.append(cache[field_name][field_id])
elif point_match:
# Get the lat or long from the point
name, index = point_match.groups()
field = getattr(obj, name)
row.append(field.coords[int(index)])
else:
# Handle other field types
field = getattr(obj, field_name) if obj else ''
if isinstance(field, date):
formatted = field.strftime(u'%Y%m%d')
row.append(text_type(formatted))
elif isinstance(field, bool):
row.append(1 if field else 0)
elif field is None:
row.append(u'')
else:
# Handle other field types
field = getattr(obj, field_name) if obj else ''
if isinstance(field, date):
formatted = field.strftime(u'%Y%m%d')
row.append(text_type(formatted))
elif isinstance(field, bool):
row.append(1 if field else 0)
elif field is None:
row.append(u'')
else:
row.append(text_type(field))
for col in extra_columns:
row.append(obj.extra_data.get(col, u''))
rows.append(row)
if len(rows) % batch_size == 0: # pragma: no cover
write_text_rows(csv_writer, rows)
count += len(rows)
logger.info(
"Exported %d %s",
count, cls._meta.verbose_name_plural)
rows = []
row.append(text_type(field))
for col in extra_columns:
row.append(obj.extra_data.get(col, u''))
rows.append(row)
if len(rows) % batch_size == 0: # pragma: no cover
write_text_rows(csv_writer, rows)
count += len(rows)
logger.info(
"Exported %d %s",
count, cls._meta.verbose_name_plural)
rows = []

# Write rows smaller than batch size
write_text_rows(csv_writer, rows)
Expand Down

0 comments on commit ca19ee6

Please sign in to comment.