diff --git a/mta_manager/__init__.py b/mta_manager/__init__.py index 1e01856..bfda163 100644 --- a/mta_manager/__init__.py +++ b/mta_manager/__init__.py @@ -1,3 +1,3 @@ from .mta import * -from .route import * +from .train import * from .stop import * diff --git a/mta_manager/mta.py b/mta_manager/mta.py index 5557118..82104ee 100644 --- a/mta_manager/mta.py +++ b/mta_manager/mta.py @@ -4,20 +4,23 @@ import json from google.transit import gtfs_realtime_pb2 from protobuf_to_dict import protobuf_to_dict -from .route import get_route_from_dict +from .train import get_train_from_dict from time import time class MTA(object): - def __init__(self, api_key: str, train_lines, station_ids, timing_callbacks=None, alert_callbacks=None, + # Create a data filter object. + # Then be able to update that object on the fly. + # This filter should return all possible trains and stations by default. + # If anyhting is added it gets filtered out. + def __init__(self, api_key: str, routes, station_ids, timing_callbacks=None, alert_callbacks=None, endpoints_file="./endpoints.json", callback_frequency=10, max_arrival_time=30): self.header = { "x-api-key": api_key } - self.train_lines = train_lines + self.routes = routes self.station_ids = station_ids self.timing_callbacks = timing_callbacks if timing_callbacks else [] - # self.alert_callbacks = alert_callbacks if alert_callbacks else [] self.is_running = False self.callback_frequency = callback_frequency self.max_arrival_time = max_arrival_time @@ -28,9 +31,9 @@ class MTA(object): def set_valid_endpoints(self): self.valid_endpoints = {} for key, value in self.endpoints.items(): - valid_lines = [x for x in self.train_lines if x in key] - if valid_lines: - self.valid_endpoints[value] = valid_lines + valid_routes = [x for x in self.routes if x in key] + if valid_routes: + self.valid_endpoints[value] = valid_routes print(self.valid_endpoints) def start_updates(self): @@ -43,53 +46,50 @@ class MTA(object): self.is_running = False async def get_data(self): - routes = [] + trains = [] for endpoint, valid_lines in self.valid_endpoints.items(): r = requests.get(endpoint, headers=self.header) feed = gtfs_realtime_pb2.FeedMessage() feed.ParseFromString(r.content) subway_feed = protobuf_to_dict(feed)['entity'] - routes.extend([x for x in [get_route_from_dict(x) for x in subway_feed] if x is not None]) - return routes + trains.extend([train for train in [get_train_from_dict(train_dict) for train_dict in subway_feed] if train is not None]) + return trains + @staticmethod - def valid_route(train_lines, station_ids, route, max_time): - if route.route_id not in train_lines: - return False - stops = route.stop_times - for stop in stops: - minutes_to_arrival = stop.arrival_minutes() - if stop.stop_id in station_ids: - if minutes_to_arrival > 0 and minutes_to_arrival < max_time: - return True - return False + def get_trains_for_routes(routes, trains): + return [train for train in trains if train.route in routes] - async def get_route_information(self): - # Filter routes - valid_routes = [route for route in await self.get_data() if - MTA.valid_route(self.train_lines, self.station_ids, route, self.max_arrival_time)] - return valid_routes + @staticmethod + def get_trains_for_route(route, trains): + return MTA.get_trains_for_routes([route], trains) + + + async def get_train_information(self): + # Might need to not filter these trains. + valid_trains = [train for train in await self.get_data() if True] + # MTA.trains_arriving_at_stations(self.train_lines, self.station_ids, train, self.max_arrival_time)] + return valid_trains async def _get_updates(self): self.is_running = True while (self.is_running): t = time() - data = self.get_route_information() + data = self.get_train_information() data = await data await self.process_callbacks(data) await asyncio.sleep(self.callback_frequency - (time() - t)) - # self.is_running = False async def process_callbacks(self, data): for callback in self.timing_callbacks: await callback(data) def add_train_line(self, train_line: str): - self.train_lines.append(train_line) + self.routes.append(train_line) self.set_valid_endpoints() def remove_train_line(self, train_line: str): - self.train_lines.remove(train_line) + self.routes.remove(train_line) self.set_valid_endpoints() def add_station_id(self, station_id: str): @@ -104,15 +104,14 @@ class MTA(object): def remove_callback(self, callback_func): self.timing_callbacks.remove(callback_func) - def convert_routes_to_station_first(self, routes): + def get_time_arriving_at_stations(self, trains): station_first = {} for station_id in self.station_ids: line_first = {} - for train_line in self.train_lines: - valid_routes = [route.get_arrival_at(station_id) for route in routes if - self.valid_route([train_line], [station_id], route, self.max_arrival_time)] - if valid_routes: - line_first[train_line] = valid_routes + for route in self.routes: + valid_trains = [train.get_arrival_at(station_id) for train in MTA.get_trains_for_route(route, trains) if train.arriving_at_station_in_time(station_id, self.max_arrival_time)] + if valid_trains: + line_first[route] = valid_trains if line_first: station_first[station_id] = line_first return station_first diff --git a/mta_manager/route.py b/mta_manager/route.py deleted file mode 100644 index 8e0cbba..0000000 --- a/mta_manager/route.py +++ /dev/null @@ -1,36 +0,0 @@ -from .stop import get_stop_from_dict - - -def get_route_from_dict(obj): - if "trip_update" in obj and "stop_time_update" in obj["trip_update"]: - # data we need is here create object - id = obj["id"] - route_id = obj["trip_update"]["trip"]["route_id"] - stop_times = [valid_stop for valid_stop in - [get_stop_from_dict(x) for x in obj["trip_update"]["stop_time_update"]] - if valid_stop is not None] - return Route(id, route_id, stop_times) - else: - return None - - -class Route(object): - def __init__(self, id, route_id, stop_times): - self.id = id - self.route_id = route_id - self.stop_times = stop_times - - def get_arrival_at(self, stop_id): - """ - returns the routes stop time at a given stop ID in minutes - if not found, returns None - :param stop_id: stop ID of arrival station - :return: arrival time in minutes - """ - for stop in self.stop_times: - if stop.stop_id == stop_id: - return stop.arrival_minutes() - return None - - def __str__(self): - return f"id:{self.id} | route_id:{self.route_id}| stop_times:{self.stop_times}" diff --git a/mta_manager/stop.py b/mta_manager/stop.py index ca15489..f8e6ac0 100644 --- a/mta_manager/stop.py +++ b/mta_manager/stop.py @@ -1,18 +1,18 @@ -from time import time + from datetime import datetime from math import trunc def get_stop_from_dict(obj): if "arrival" in obj and "departure" in obj and "stop_id" in obj: - return Stop(obj["arrival"]["time"], obj["departure"]["time"], obj["stop_id"]) + return Stop( obj["stop_id"], obj["arrival"]["time"], obj["departure"]["time"]) return None class Stop(object): - def __init__(self, arrival_time, departure_time, stop_id): + def __init__(self, id, arrival_time, departure_time, ): + self.id = id self.arrival_time = arrival_time self.departure_time = departure_time - self.stop_id = stop_id def arrival_minutes(self): return trunc(((datetime.fromtimestamp(self.arrival_time) - datetime.now()).total_seconds()) / 60) @@ -21,4 +21,4 @@ class Stop(object): now = datetime.now() time = datetime.fromtimestamp(self.arrival_time) time_minutes = trunc(((time - now).total_seconds()) / 60) - return f"arr:{time_minutes}|dep:{self.departure_time}|stop_id:{self.stop_id}" + return f"stop_id:{self.id}| arr:{time_minutes}| dep:{self.departure_time}" diff --git a/mta_manager/train.py b/mta_manager/train.py new file mode 100644 index 0000000..9693f42 --- /dev/null +++ b/mta_manager/train.py @@ -0,0 +1,43 @@ +from .stop import get_stop_from_dict + + +def get_train_from_dict(obj): + if "trip_update" in obj and "stop_time_update" in obj["trip_update"]: + # data we need is here create object + id = obj["id"] + route = obj["trip_update"]["trip"]["route_id"] + all_stops = [get_stop_from_dict(x) for x in obj["trip_update"]["stop_time_update"]] + valid_stops = [valid_stop for valid_stop in all_stops if valid_stop is not None] + return Train(id, route, valid_stops) + else: + return None + + +class Train(object): + def __init__(self, id, route, stops): + self.id = id + self.route = route + self.stops = stops + + def get_arrival_at(self, stop_id): + """ + returns the routes stop time at a given stop ID in minutes + if not found, returns None + :param stop_id: stop ID of arrival station + :return: arrival time in minutes + """ + for stop in self.stops: + if stop.id == stop_id: + return stop.arrival_minutes() + return None + + def arriving_at_station_in_time(self, station_id, max_time): + for stop in self.stops: + minutes_to_arrival = stop.arrival_minutes() + if stop.id == station_id: + if minutes_to_arrival > 0 and minutes_to_arrival < max_time: + return True + + def __str__(self): + formatted_stops = '\n'.join([str(stop) for stop in self.stops]) + return f"train_id:{self.id} | line_name:{self.route}| stops:\n {formatted_stops}" diff --git a/mta_manager/mta_test.py b/mta_test.py similarity index 84% rename from mta_manager/mta_test.py rename to mta_test.py index 2fb01d5..5417074 100644 --- a/mta_manager/mta_test.py +++ b/mta_test.py @@ -1,6 +1,6 @@ import os from dotenv import load_dotenv -from mta import MTA +from mta_manager import MTA import threading import time from time import sleep @@ -17,10 +17,11 @@ mtaController = MTA( -async def mta_callback(routes): +async def mta_callback(trains): print("We are inside of the call back now") - print(len(routes)) - pprint(mtaController.convert_routes_to_station_first(routes)) + print(len(trains)) + pprint([str(route) for route in trains]) + pprint(mtaController.get_time_arriving_at_stations(trains)) class threadWrapper(threading.Thread): def __init__(self, run): diff --git a/server.py b/server.py index 1fecfd9..38553a8 100644 --- a/server.py +++ b/server.py @@ -86,7 +86,7 @@ if __name__ == "__main__": async def mta_callback(routes): global subway_data, old_data, last_updated - subway_data = link_to_station(mtaController.convert_routes_to_station_first(routes)) + subway_data = link_to_station(mtaController.station_info_from_routes(routes)) subway_data["LastUpdated"] = last_updated if old_data is None: old_data = subway_data