diff --git a/homeassistant/components/sensor/gtfs.py b/homeassistant/components/sensor/gtfs.py index eec08be093f038..07b80d8288d301 100644 --- a/homeassistant/components/sensor/gtfs.py +++ b/homeassistant/components/sensor/gtfs.py @@ -8,6 +8,7 @@ import logging import datetime import threading +from typing import Optional import voluptuous as vol @@ -24,6 +25,7 @@ CONF_DESTINATION = 'destination' CONF_ORIGIN = 'origin' CONF_OFFSET = 'offset' +CONF_TOMORROW = 'include_tomorrow' DEFAULT_NAME = 'GTFS Sensor' DEFAULT_PATH = 'gtfs' @@ -49,65 +51,135 @@ vol.Required(CONF_DATA): cv.string, vol.Optional(CONF_NAME): cv.string, vol.Optional(CONF_OFFSET, default=0): cv.time_period, + vol.Optional(CONF_TOMORROW, default=False): cv.boolean, }) -def get_next_departure(sched, start_station_id, end_station_id, offset): +def get_next_departure(sched, start_station_id, end_station_id, offset, + include_tomorrow=False) -> Optional[dict]: """Get the next departure for the given schedule.""" origin_station = sched.stops_by_id(start_station_id)[0] destination_station = sched.stops_by_id(end_station_id)[0] now = datetime.datetime.now() + offset - day_name = now.strftime('%A').lower() - now_str = now.strftime('%H:%M:%S') - today = now.strftime(DATE_FORMAT) + now_date = now.strftime(DATE_FORMAT) + yesterday = now - datetime.timedelta(days=1) + yesterday_date = yesterday.strftime(DATE_FORMAT) + tomorrow = now + datetime.timedelta(days=1) + tomorrow_date = tomorrow.strftime(DATE_FORMAT) from sqlalchemy.sql import text - sql_query = text(""" - SELECT trip.trip_id, trip.route_id, - time(origin_stop_time.arrival_time) AS origin_arrival_time, - time(origin_stop_time.departure_time) AS origin_depart_time, - origin_stop_time.drop_off_type AS origin_drop_off_type, - origin_stop_time.pickup_type AS origin_pickup_type, - origin_stop_time.shape_dist_traveled AS origin_dist_traveled, - origin_stop_time.stop_headsign AS origin_stop_headsign, - origin_stop_time.stop_sequence AS origin_stop_sequence, - time(destination_stop_time.arrival_time) AS dest_arrival_time, - time(destination_stop_time.departure_time) AS dest_depart_time, - destination_stop_time.drop_off_type AS dest_drop_off_type, - destination_stop_time.pickup_type AS dest_pickup_type, - destination_stop_time.shape_dist_traveled AS dest_dist_traveled, - destination_stop_time.stop_headsign AS dest_stop_headsign, - destination_stop_time.stop_sequence AS dest_stop_sequence - FROM trips trip - INNER JOIN calendar calendar - ON trip.service_id = calendar.service_id - INNER JOIN stop_times origin_stop_time - ON trip.trip_id = origin_stop_time.trip_id - INNER JOIN stops start_station - ON origin_stop_time.stop_id = start_station.stop_id - INNER JOIN stop_times destination_stop_time - ON trip.trip_id = destination_stop_time.trip_id - INNER JOIN stops end_station - ON destination_stop_time.stop_id = end_station.stop_id - WHERE calendar.{day_name} = 1 - AND origin_depart_time > time(:now_str) - AND start_station.stop_id = :origin_station_id - AND end_station.stop_id = :end_station_id - AND origin_stop_sequence < dest_stop_sequence - AND calendar.start_date <= :today - AND calendar.end_date >= :today - ORDER BY origin_stop_time.departure_time - LIMIT 1 - """.format(day_name=day_name)) - result = sched.engine.execute(sql_query, now_str=now_str, + # Fetch all departures for yesterday, today and optionally tomorrow, + # up to an overkill maximum in case of a departure every minute for those + # days. + limit = 24 * 60 * 60 * 2 + tomorrow_select = tomorrow_where = tomorrow_order = '' + if include_tomorrow: + limit = limit / 2 * 3 + tomorrow_name = tomorrow.strftime('%A').lower() + tomorrow_select = "calendar.{} AS tomorrow,".format(tomorrow_name) + tomorrow_where = "OR calendar.{} = 1".format(tomorrow_name) + tomorrow_order = "calendar.{} DESC,".format(tomorrow_name) + + sql_query = """ + SELECT trip.trip_id, trip.route_id, + time(origin_stop_time.arrival_time) AS origin_arrival_time, + time(origin_stop_time.departure_time) AS origin_depart_time, + date(origin_stop_time.departure_time) AS origin_departure_date, + origin_stop_time.drop_off_type AS origin_drop_off_type, + origin_stop_time.pickup_type AS origin_pickup_type, + origin_stop_time.shape_dist_traveled AS origin_dist_traveled, + origin_stop_time.stop_headsign AS origin_stop_headsign, + origin_stop_time.stop_sequence AS origin_stop_sequence, + time(destination_stop_time.arrival_time) AS dest_arrival_time, + time(destination_stop_time.departure_time) AS dest_depart_time, + destination_stop_time.drop_off_type AS dest_drop_off_type, + destination_stop_time.pickup_type AS dest_pickup_type, + destination_stop_time.shape_dist_traveled AS dest_dist_traveled, + destination_stop_time.stop_headsign AS dest_stop_headsign, + destination_stop_time.stop_sequence AS dest_stop_sequence, + calendar.{yesterday_name} AS yesterday, + calendar.{today_name} AS today, + {tomorrow_select} + calendar.start_date AS start_date, + calendar.end_date AS end_date + FROM trips trip + INNER JOIN calendar calendar + ON trip.service_id = calendar.service_id + INNER JOIN stop_times origin_stop_time + ON trip.trip_id = origin_stop_time.trip_id + INNER JOIN stops start_station + ON origin_stop_time.stop_id = start_station.stop_id + INNER JOIN stop_times destination_stop_time + ON trip.trip_id = destination_stop_time.trip_id + INNER JOIN stops end_station + ON destination_stop_time.stop_id = end_station.stop_id + WHERE (calendar.{yesterday_name} = 1 + OR calendar.{today_name} = 1 + {tomorrow_where} + ) + AND start_station.stop_id = :origin_station_id + AND end_station.stop_id = :end_station_id + AND origin_stop_sequence < dest_stop_sequence + AND calendar.start_date <= :today + AND calendar.end_date >= :today + ORDER BY calendar.{yesterday_name} DESC, + calendar.{today_name} DESC, + {tomorrow_order} + origin_stop_time.departure_time + LIMIT :limit + """.format(yesterday_name=yesterday.strftime('%A').lower(), + today_name=now.strftime('%A').lower(), + tomorrow_select=tomorrow_select, + tomorrow_where=tomorrow_where, + tomorrow_order=tomorrow_order) + result = sched.engine.execute(text(sql_query), origin_station_id=origin_station.id, end_station_id=destination_station.id, - today=today) - item = {} + today=now_date, + limit=limit) + + # Create lookup timetable for today and possibly tomorrow, taking into + # account any departures from yesterday scheduled after midnight, + # as long as all departures are within the calendar date range. + timetable = {} + yesterday_first = today_first = tomorrow_first = None for row in result: - item = row + if row['yesterday'] == 1 and yesterday_date >= row['start_date']: + if yesterday_first is None: + yesterday_first = row['origin_departure_date'] + if yesterday_first != row['origin_departure_date']: + idx = '{} {}'.format(now_date, + row['origin_depart_time']) + timetable[idx] = {**row, **{'day': 'yesterday'}} + if row['today'] == 1: + if today_first is None: + today_first = row['origin_departure_date'] + if today_first == row['origin_departure_date']: + idx_prefix = now_date + else: + idx_prefix = tomorrow_date + idx = '{} {}'.format(idx_prefix, row['origin_depart_time']) + timetable[idx] = {**row, **{'day': 'today'}} + if 'tomorrow' in row and row['tomorrow'] == 1 and tomorrow_date <= \ + row['end_date']: + if tomorrow_first is None: + tomorrow_first = row['origin_departure_date'] + if tomorrow_first == row['origin_departure_date']: + idx = '{} {}'.format(tomorrow_date, + row['origin_depart_time']) + timetable[idx] = {**row, **{'day': 'tomorrow'}} + + _LOGGER.debug("Timetable: %s", sorted(timetable.keys())) + + item = {} + for key in sorted(timetable.keys()): + if datetime.datetime.strptime(key, TIME_FORMAT) > now: + item = timetable[key] + _LOGGER.debug("Departure found for station %s @ %s -> %s", + start_station_id, key, item) + break if item == {}: return None @@ -120,7 +192,7 @@ def get_next_departure(sched, start_station_id, end_station_id, offset): origin_arrival_time = '{} {}'.format(origin_arrival.strftime(DATE_FORMAT), item['origin_arrival_time']) - origin_depart_time = '{} {}'.format(today, item['origin_depart_time']) + origin_depart_time = '{} {}'.format(now_date, item['origin_depart_time']) dest_arrival = now if item['dest_arrival_time'] < item['origin_depart_time']: @@ -164,6 +236,7 @@ def get_next_departure(sched, start_station_id, end_station_id, offset): return { 'trip_id': item['trip_id'], + 'day': item['day'], 'trip': sched.trips_by_id(item['trip_id'])[0], 'route': route, 'agency': sched.agencies_by_id(route.agency_id)[0], @@ -186,6 +259,7 @@ def setup_platform(hass, config, add_entities, discovery_info=None): destination = config.get(CONF_DESTINATION) name = config.get(CONF_NAME) offset = config.get(CONF_OFFSET) + include_tomorrow = config.get(CONF_TOMORROW) if not os.path.exists(gtfs_dir): os.makedirs(gtfs_dir) @@ -207,17 +281,20 @@ def setup_platform(hass, config, add_entities, discovery_info=None): pygtfs.append_feed(gtfs, os.path.join(gtfs_dir, data)) add_entities([ - GTFSDepartureSensor(gtfs, name, origin, destination, offset)]) + GTFSDepartureSensor(gtfs, name, origin, destination, offset, + include_tomorrow)]) class GTFSDepartureSensor(Entity): """Implementation of an GTFS departures sensor.""" - def __init__(self, pygtfs, name, origin, destination, offset): + def __init__(self, pygtfs, name, origin, destination, offset, + include_tomorrow) -> None: """Initialize the sensor.""" self._pygtfs = pygtfs self.origin = origin self.destination = destination + self._include_tomorrow = include_tomorrow self._offset = offset self._custom_name = name self._icon = ICON @@ -257,10 +334,13 @@ def update(self): """Get the latest data from GTFS and update the states.""" with self.lock: self._departure = get_next_departure( - self._pygtfs, self.origin, self.destination, self._offset) + self._pygtfs, self.origin, self.destination, self._offset, + self._include_tomorrow) if not self._departure: self._state = None - self._attributes = {'Info': 'No more departures today'} + self._attributes = {} + self._attributes['Info'] = "No more departures" if \ + self._include_tomorrow else "No more departures today" if self._name == '': self._name = (self._custom_name or DEFAULT_NAME) return @@ -281,12 +361,12 @@ def update(self): origin_station.stop_id, destination_station.stop_id)) + self._icon = ICONS.get(route.route_type, ICON) + # Build attributes - self._attributes = {} + self._attributes['day'] = self._departure['day'] self._attributes['offset'] = self._offset.seconds / 60 - self._icon = ICONS.get(route.route_type, ICON) - def dict_for_table(resource): """Return a dict for the SQLAlchemy resource given.""" return dict((col, getattr(resource, col))