diff --git a/endpoints.json b/endpoints.json deleted file mode 100644 index 0c351d1..0000000 --- a/endpoints.json +++ /dev/null @@ -1,10 +0,0 @@ -{ - "ACE": "https://api-endpoint.mta.info/Dataservice/mtagtfsfeeds/nyct%2Fgtfs-ace", - "BDFM": "https://api-endpoint.mta.info/Dataservice/mtagtfsfeeds/nyct%2Fgtfs-bdfm", - "G": "https://api-endpoint.mta.info/Dataservice/mtagtfsfeeds/nyct%2Fgtfs-g", - "JZ": "https://api-endpoint.mta.info/Dataservice/mtagtfsfeeds/nyct%2Fgtfs-jz", - "NQRW": "https://api-endpoint.mta.info/Dataservice/mtagtfsfeeds/nyct%2Fgtfs-nqrw", - "L": "https://api-endpoint.mta.info/Dataservice/mtagtfsfeeds/nyct%2Fgtfs-l", - "SIR": "https://api-endpoint.mta.info/Dataservice/mtagtfsfeeds/nyct%2Fgtfs-si", - "1234567": "https://api-endpoint.mta.info/Dataservice/mtagtfsfeeds/nyct%2Fgtfs" -} \ No newline at end of file diff --git a/mta_manager/__init__.py b/mta_manager/__init__.py index e69de29..1236195 100644 --- a/mta_manager/__init__.py +++ b/mta_manager/__init__.py @@ -0,0 +1,5 @@ +from .mta import MTA + +from .train import Train +from .feed import Feed +from .route import Route diff --git a/mta_manager/feed.py b/mta_manager/feed.py new file mode 100644 index 0000000..29fc43f --- /dev/null +++ b/mta_manager/feed.py @@ -0,0 +1,21 @@ +from enum import Enum + +class Feed(Enum): + ACE = "https://api-endpoint.mta.info/Dataservice/mtagtfsfeeds/nyct%2Fgtfs-ace" + BDFM = "https://api-endpoint.mta.info/Dataservice/mtagtfsfeeds/nyct%2Fgtfs-bdfm" + G = "https://api-endpoint.mta.info/Dataservice/mtagtfsfeeds/nyct%2Fgtfs-bdfm" + JZ = "https://api-endpoint.mta.info/Dataservice/mtagtfsfeeds/nyct%2Fgtfs-jz" + NQRW = "https://api-endpoint.mta.info/Dataservice/mtagtfsfeeds/nyct%2Fgtfs-nqrw" + L = "https://api-endpoint.mta.info/Dataservice/mtagtfsfeeds/nyct%2Fgtfs-l" + N1234567 = "https://api-endpoint.mta.info/Dataservice/mtagtfsfeeds/nyct%2Fgtfs" + SIR = "https://api-endpoint.mta.info/Dataservice/mtagtfsfeeds/nyct%2Fgtfs-si" + +ALL_FEEDS = [ + Feed.ACE, + Feed.BDFM, + Feed.G, + Feed.NQRW, + Feed.L, + Feed.N1234567, + Feed.SIR +] \ No newline at end of file diff --git a/mta_manager/mta.py b/mta_manager/mta.py index 0b4c6f2..816a4bf 100644 --- a/mta_manager/mta.py +++ b/mta_manager/mta.py @@ -1,115 +1,50 @@ -import asyncio import requests -import json from google.transit import gtfs_realtime_pb2 -from protobuf_to_dict import protobuf_to_dict -from time import time from .train import Train +from .feed import Feed, ALL_FEEDS +from .route import Route class MTA(object): - # Create a data filter object. - # Then be able to update that object on the fly. - # This filter should return all possible trains and stations by default. - # If anything is added it gets filtered out. - def __init__(self, api_key: str, routes, station_ids, timing_callbacks=None, alert_callbacks=None, - endpoints_file="./endpoints.json", callback_frequency=10, max_arrival_time=30): + def __init__(self, api_key: str, feeds: [Feed] = ALL_FEEDS, stations: [str] = [], + max_arrival_time: int = 30): self.header = { "x-api-key": api_key } - self.routes = routes - self.station_ids = station_ids - self.timing_callbacks = timing_callbacks if timing_callbacks else [] - self.is_running = False - self.callback_frequency = callback_frequency + self.feeds = feeds + self.stations = stations self.max_arrival_time = max_arrival_time - with open(endpoints_file, "r") as f: - self.endpoints = json.load(f) - self.set_valid_endpoints() - - def set_valid_endpoints(self): - self.valid_endpoints = {} - for key, value in self.endpoints.items(): - valid_routes = [x for x in self.routes if x in key] - if valid_routes: - self.valid_endpoints[value] = valid_routes - print(self.valid_endpoints) - - def start_updates(self): - print("starting updates") - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - loop.run_until_complete(self._get_updates()) + self.trains: [Train] = [] def stop_updates(self): self.is_running = False - async def get_data(self): + def get_incoming_trains(self) -> [Train]: trains = [] - for endpoint, valid_lines in self.valid_endpoints.items(): - r = requests.get(endpoint, headers=self.header) + for feed in self.feeds: + r = requests.get(feed.value, headers=self.header) feed = gtfs_realtime_pb2.FeedMessage() feed.ParseFromString(r.content) - subway_feed = protobuf_to_dict(feed)['entity'] - trains.extend([train for train in [Train.get_train_from_dict(train_dict) for train_dict in subway_feed] if - train is not None]) + trains.extend([train for train in [Train(train) for train in feed.entity] if + train.has_trips()]) + self.trains = trains return trains - @staticmethod - def get_trains_for_routes(routes, trains): - return [train for train in trains if train.route in routes] + def get_trains(self) -> [Train]: + return self.trains - @staticmethod - def get_trains_for_route(route, trains): - return MTA.get_trains_for_routes([route], trains) - - async def get_train_information(self): - valid_trains = [train for train in await self.get_data() if True] - return valid_trains - - async def _get_updates(self): - self.is_running = True - while (self.is_running): - t = time() - data = self.get_train_information() - data = await data - await self.process_callbacks(data) - await asyncio.sleep(self.callback_frequency - (time() - t)) - - async def process_callbacks(self, data): - for callback in self.timing_callbacks: - await callback(data) - - def add_train_line(self, train_line: str): - self.routes.append(train_line) - self.set_valid_endpoints() - - def remove_train_line(self, train_line: str): - self.routes.remove(train_line) - self.set_valid_endpoints() + def get_arrival_times(self, route: Route, station: str) -> [int]: + arrival_times = [] + for train in self.trains: + if train.get_route() is route: + arrival = train.get_arrival_at(station) + if arrival is not None and arrival < self.max_arrival_time: + arrival_times.append(arrival) + return sorted(arrival_times) def add_station_id(self, station_id: str): - self.station_ids.append(station_id) + self.stations.append(station_id) def remove_station_id(self, station_id: str): - self.station_ids.remove(station_id) - - def add_callback(self, callback_func): - self.timing_callbacks.append(callback_func) - - def remove_callback(self, callback_func): - self.timing_callbacks.remove(callback_func) - - def get_time_arriving_at_stations(self, trains): - station_first = {} - for station_id in self.station_ids: - line_first = {} - for route in self.routes: - valid_trains = [train.get_arrival_at(station_id) for train in MTA.get_trains_for_route(route, trains) if - train.arriving_at_station_in_time(station_id, self.max_arrival_time)] - if valid_trains: - line_first[route] = valid_trains - if line_first: - station_first[station_id] = line_first - return station_first + self.stations.remove(station_id) diff --git a/mta_manager/route.py b/mta_manager/route.py new file mode 100644 index 0000000..71d69c9 --- /dev/null +++ b/mta_manager/route.py @@ -0,0 +1,36 @@ +from enum import Enum + +class Route(Enum): + A = "A" + C = "C" + E = "E" + + B = "B" + D = "D" + F = "F" + M = "M" + + G = "G" + + J = "J" + Z = "Z" + + N = "N" + Q = "Q" + R = "R" + W = "W" + + N1 = "1" + N2 = "2" + N3 = "3" + N4 = "4" + N5 = "5" + N6 = "6" + N7 = "7" + + L = "L" + SIR = "SIR" + +_routes = set(item.value for item in Route) +def is_valid_route(route: str) -> bool: + return route in _routes \ No newline at end of file diff --git a/mta_manager/stop.py b/mta_manager/stop.py index 634e5a3..d5273b2 100644 --- a/mta_manager/stop.py +++ b/mta_manager/stop.py @@ -1,24 +1,28 @@ from datetime import datetime +from google.transit import gtfs_realtime_pb2 from math import trunc -class Stop(object): - def __init__(self, id, arrival_time, departure_time, ): - self.id = id - self.arrival_time = arrival_time - self.departure_time = departure_time +def trip_arrival_in_minutes(stop_time_update: gtfs_realtime_pb2.TripUpdate): + return trunc(((datetime.fromtimestamp(stop_time_update.arrival.time) - datetime.now()).total_seconds()) / 60) - def arrival_minutes(self): - return trunc(((datetime.fromtimestamp(self.arrival_time) - datetime.now()).total_seconds()) / 60) - - def __str__(self): - now = datetime.now() - time = datetime.fromtimestamp(self.arrival_time) - time_minutes = trunc(((time - now).total_seconds()) / 60) - return f"stop_id:{self.id}| arr:{time_minutes}| dep:{self.departure_time}" - - @staticmethod - def get_stop_from_dict(obj): - if "arrival" in obj and "departure" in obj and "stop_id" in obj: - return Stop(obj["stop_id"], obj["arrival"]["time"], obj["departure"]["time"]) - return None +# class Stop(object): +# def __init__(self, id, arrival_time, departure_time, ): +# self.id = id +# self.arrival_time = arrival_time +# self.departure_time = departure_time +# +# def arrival_minutes(self): +# return trunc(((datetime.fromtimestamp(self.arrival_time) - datetime.now()).total_seconds()) / 60) +# +# def __str__(self): +# now = datetime.now() +# time = datetime.fromtimestamp(self.arrival_time) +# time_minutes = trunc(((time - now).total_seconds()) / 60) +# return f"stop_id:{self.id}| arr:{time_minutes}| dep:{self.departure_time}" +# +# @staticmethod +# def get_stop_from_dict(obj): +# if "arrival" in obj and "departure" in obj and "stop_id" in obj: +# return Stop(obj["stop_id"], obj["arrival"]["time"], obj["departure"]["time"]) +# return None diff --git a/mta_manager/train.py b/mta_manager/train.py index a53dfd2..4b5fb42 100644 --- a/mta_manager/train.py +++ b/mta_manager/train.py @@ -1,42 +1,36 @@ -from .stop import Stop + +from google.transit import gtfs_realtime_pb2 +from .stop import trip_arrival_in_minutes +from .route import Route, is_valid_route + class Train(object): - def __init__(self, id, route, stops): - self.id = id - self.route = route - self.stops = stops + def __init__(self, train_proto: gtfs_realtime_pb2.FeedEntity): + self.train_proto: gtfs_realtime_pb2.FeedEntity = train_proto - def get_arrival_at(self, stop_id): + def get_arrival_at(self, stop_id) -> int | None: """ returns the routes stop time at a given stop ID in minutes if not found, returns None :param stop_id: stop ID of arrival station :return: arrival time in minutes """ - for stop in self.stops: - if stop.id == stop_id: - return stop.arrival_minutes() + for stop_time_update in self.train_proto.trip_update.stop_time_update: + if stop_time_update.stop_id == stop_id: + return trip_arrival_in_minutes(stop_time_update) return None - def arriving_at_station_in_time(self, station_id, max_time): - for stop in self.stops: - minutes_to_arrival = stop.arrival_minutes() - if stop.id == station_id: - if minutes_to_arrival > 0 and minutes_to_arrival < max_time: - return True + + def _get_route(self) -> str: + return self.train_proto.trip_update.trip.route_id + def get_route(self) -> Route: + return Route(self.train_proto.trip_update.trip.route_id) + + def has_trips(self) -> bool: + return self.train_proto.trip_update is not None \ + and len(self.train_proto.trip_update.stop_time_update) > 0 and is_valid_route(self._get_route()) def __str__(self): - formatted_stops = '\n'.join([str(stop) for stop in self.stops]) + formatted_stops = '\n'.join([str(stop) for stop in self.stops]) return f"train_id:{self.id} | line_name:{self.route}| stops:\n {formatted_stops}" - @staticmethod - def get_train_from_dict(obj): - if "trip_update" in obj and "stop_time_update" in obj["trip_update"]: - # data we need is here create object - id = obj["id"] - route = obj["trip_update"]["trip"]["route_id"] - all_stops = [Stop.get_stop_from_dict(x) for x in obj["trip_update"]["stop_time_update"]] - valid_stops = [valid_stop for valid_stop in all_stops if valid_stop is not None] - return Train(id, route, valid_stops) - else: - return None \ No newline at end of file diff --git a/mta_test.py b/mta_test.py deleted file mode 100644 index a0cf048..0000000 --- a/mta_test.py +++ /dev/null @@ -1,62 +0,0 @@ -import os -from dotenv import load_dotenv -from mta_manager import MTA -import threading - -from time import sleep -from pprint import pprint - -load_dotenv() - -api_key = os.getenv('MTA_API_KEY', '') -mtaController = MTA( - api_key, - ["A", "C", "E", "1", "2", "3"], - ["127S", "127N", "A27N", "A27S"] -) - - -async def mta_callback(trains): - print("We are inside of the call back now") - print(len(trains)) - pprint([str(route) for route in trains]) - pprint(mtaController.get_time_arriving_at_stations(trains)) - - -class Threadwrapper(threading.Thread): - def __init__(self, run): - threading.Thread.__init__(self) - self.run = run - - def run(self): - self.run() - - -def start_mta(): - mtaController.add_callback(mta_callback) - mtaController.start_updates() - - -def stop_mta(): - sleep(10) - mtaController.stop_updates() - - -threadLock = threading.Lock() -threads = [] - -# Create new threads -thread1 = Threadwrapper(start_mta) -thread2 = Threadwrapper(stop_mta) - -thread1.start() -thread2.start() - -# Add threads to thread list -threads.append(thread1) -threads.append(thread2) - -# Wait for all threads to complete -for t in threads: - t.join() -print("Exiting Main Thread") diff --git a/ruff.toml b/ruff.toml index 9471f8a..5e5f544 100644 --- a/ruff.toml +++ b/ruff.toml @@ -31,8 +31,12 @@ exclude = [ "venv", ] + line-length = 120 dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$" target-version = "py311" + +[per-file-ignores] +"__init__.py" = ["F401"] \ No newline at end of file diff --git a/server.py b/server.py index 775e408..7926cdf 100644 --- a/server.py +++ b/server.py @@ -1,12 +1,12 @@ import os -import threading -import pandas as pd - -from deepdiff import DeepDiff from datetime import datetime + +import pandas as pd +from flask_apscheduler import APScheduler from dotenv import load_dotenv -from flask import Flask, jsonify, render_template, request, abort -from mta_manager import MTA +from flask import Flask, jsonify, render_template, request + +from mta_manager import MTA, Feed, Route load_dotenv() @@ -14,11 +14,16 @@ app = Flask(__name__) app.secret_key = "SuperSecretDontEvenTryToGuessMeGGEZNoRe" app._static_folder = os.path.abspath("templates/static/") -stops = pd.read_csv("stops.txt") -stop_ids = ["127S", "127N", "A27N", "A27S"] +scheduler = APScheduler() +scheduler.init_app(app) + +stops = pd.read_csv("stops.txt") start_time = datetime.now().strftime("%d/%m/%Y %H:%M:%S") +ROUTES = [Route.A, Route.C, Route.E, Route.N1, Route.N2, Route.N3] +STATION_STOP_IDs = ["127S", "127N", "A27N", "A27S"] + def link_to_station(data) -> {}: linked_data = {} @@ -52,22 +57,17 @@ def get_start_time(): @app.route("/mta_data", methods=["POST"]) -def get_mta_data(): - global subway_data - station = request.json["station"] - if station in subway_data: - mta_data = subway_data[station] - mta_data["LastUpdated"] = subway_data["LastUpdated"] - return jsonify( - mta_data - ) - else: - abort(404) - - -@app.route("/stops", methods=["GET"]) -def get_routes(): - return jsonify() +async def get_mta_data(): + if len(mtaController.trains) == 0: + _ = update_trains() + arrival_by_station_and_route = {} + for stop_id in STATION_STOP_IDs: + arrival_by_station_and_route[stop_id] = {} + for route in ROUTES: + arrival_tiems = mtaController.get_arrival_times(route, stop_id) + if len(arrival_tiems) > 0: + arrival_by_station_and_route[stop_id][route.value] = arrival_tiems + return arrival_by_station_and_route @app.route("/get_stop_id", methods=["POST"]) @@ -77,60 +77,25 @@ def get_stop_id(): return jsonify({"station_changed": True}) + if __name__ == "__main__": api_key = os.getenv('MTA_API_KEY', '') old_data = None last_updated = datetime.now().strftime("%d/%m/%Y %H:%M:%S") - - async def mta_callback(trains): - global subway_data, old_data, last_updated - subway_data = link_to_station(mtaController.get_time_arriving_at_stations(trains)) - subway_data["LastUpdated"] = last_updated - if old_data is None: - old_data = subway_data - data_diff = DeepDiff(old_data, subway_data, ignore_order=True) - if data_diff != {}: - old_data = subway_data - last_updated = datetime.now().strftime("%d/%m/%Y %H:%M:%S") - app.logger.info(f"Updated Subway Data - {subway_data}") - - - class threadWrapper(threading.Thread): - def __init__(self, run): - threading.Thread.__init__(self) - self.run = run - - def run(self): - self.run() - - mtaController = MTA( api_key, - ["A", "C", "E", "1", "2", "3"], - ["127S", "127N", "A27N", "A27S"] + feeds=[Feed.ACE, Feed.N1234567] ) - mtaController.add_callback(mta_callback) + def update_trains(): + app.logger.debug("UPDATING TRAINS") + mtaController.get_incoming_trains() - - def start_mta(): - while True: - try: - mtaController.start_updates() - except Exception as e: - app.logger.info(f"Exception found in update function - {e}") - - - threadLock = threading.Lock() - threads = [threadWrapper(start_mta)] - - for t in threads: - t.start() + scheduler.add_job("train_updater", func=update_trains, trigger="interval", seconds=10) + scheduler.start() debug = os.getenv("DEBUG", 'False').lower() in ('true', '1', 't') - app.run(host="localhost", debug=True, port=5000) + app.run(host="localhost", debug=True, port=5000, use_reloader=False) - for t in threads: - t.join() print("Exiting Main Thread")