diff --git a/multigtfs/models/feed.py b/multigtfs/models/feed.py index 0aa7d95..1f45bd6 100644 --- a/multigtfs/models/feed.py +++ b/multigtfs/models/feed.py @@ -58,6 +58,20 @@ class Meta: db_table = 'feed' app_label = 'multigtfs' + _gtfs_order = [ + Agency, Stop, Route, Service, ServiceDate, ShapePoint, Trip, + StopTime, Frequency, Fare, FareRule, Transfer, FeedInfo, + ] + + @classmethod + def register_model(cls, klass): + cls._gtfs_order.append(klass) + return klass + + @classmethod + def unregister_model(cls, klass): + cls._gtfs_order.remove(klass) + def __str__(self): if self.name: return "%d %s" % (self.id, self.name) @@ -88,14 +102,10 @@ def import_gtfs(self, gtfs_obj): opener = opener_from_zipfile(zfile) filelist = zfile.namelist() - gtfs_order = ( - Agency, Stop, Route, Service, ServiceDate, ShapePoint, Trip, - StopTime, Frequency, Fare, FareRule, Transfer, FeedInfo, - ) post_save.disconnect(dispatch_uid='post_save_shapepoint') post_save.disconnect(dispatch_uid='post_save_stop') try: - for klass in gtfs_order: + for klass in self._gtfs_order: for f in filelist: if f.endswith(klass._filename): start_time = time.time() @@ -155,12 +165,7 @@ def export_gtfs(self, gtfs_file): total_start = time.time() z = open_writable_zipfile(gtfs_file) - gtfs_order = ( - Agency, Service, ServiceDate, Fare, FareRule, FeedInfo, Frequency, - Route, ShapePoint, StopTime, Stop, Transfer, Trip, - ) - - for klass in gtfs_order: + for klass in sorted(self._gtfs_order, key=lambda k: k._filename): start_time = time.time() content = klass.export_txt(self) if content: