Cleaning up the code - renaming routes to trains as it makes more sense (each isntance is an instance of an individual train)

Reworking the MTA to be ready for use by people other than me.
This commit is contained in:
Lucas
2022-02-25 18:39:17 -05:00
parent 9919eed55b
commit e12079fbde
7 changed files with 89 additions and 82 deletions

View File

@@ -1,3 +1,3 @@
from .mta import * from .mta import *
from .route import * from .train import *
from .stop import * from .stop import *

View File

@@ -4,20 +4,23 @@ import json
from google.transit import gtfs_realtime_pb2 from google.transit import gtfs_realtime_pb2
from protobuf_to_dict import protobuf_to_dict from protobuf_to_dict import protobuf_to_dict
from .route import get_route_from_dict from .train import get_train_from_dict
from time import time from time import time
class MTA(object): class MTA(object):
def __init__(self, api_key: str, train_lines, station_ids, timing_callbacks=None, alert_callbacks=None, # Create a data filter object.
# Then be able to update that object on the fly.
# This filter should return all possible trains and stations by default.
# If anyhting is added it gets filtered out.
def __init__(self, api_key: str, routes, station_ids, timing_callbacks=None, alert_callbacks=None,
endpoints_file="./endpoints.json", callback_frequency=10, max_arrival_time=30): endpoints_file="./endpoints.json", callback_frequency=10, max_arrival_time=30):
self.header = { self.header = {
"x-api-key": api_key "x-api-key": api_key
} }
self.train_lines = train_lines self.routes = routes
self.station_ids = station_ids self.station_ids = station_ids
self.timing_callbacks = timing_callbacks if timing_callbacks else [] self.timing_callbacks = timing_callbacks if timing_callbacks else []
# self.alert_callbacks = alert_callbacks if alert_callbacks else []
self.is_running = False self.is_running = False
self.callback_frequency = callback_frequency self.callback_frequency = callback_frequency
self.max_arrival_time = max_arrival_time self.max_arrival_time = max_arrival_time
@@ -28,9 +31,9 @@ class MTA(object):
def set_valid_endpoints(self): def set_valid_endpoints(self):
self.valid_endpoints = {} self.valid_endpoints = {}
for key, value in self.endpoints.items(): for key, value in self.endpoints.items():
valid_lines = [x for x in self.train_lines if x in key] valid_routes = [x for x in self.routes if x in key]
if valid_lines: if valid_routes:
self.valid_endpoints[value] = valid_lines self.valid_endpoints[value] = valid_routes
print(self.valid_endpoints) print(self.valid_endpoints)
def start_updates(self): def start_updates(self):
@@ -43,53 +46,50 @@ class MTA(object):
self.is_running = False self.is_running = False
async def get_data(self): async def get_data(self):
routes = [] trains = []
for endpoint, valid_lines in self.valid_endpoints.items(): for endpoint, valid_lines in self.valid_endpoints.items():
r = requests.get(endpoint, headers=self.header) r = requests.get(endpoint, headers=self.header)
feed = gtfs_realtime_pb2.FeedMessage() feed = gtfs_realtime_pb2.FeedMessage()
feed.ParseFromString(r.content) feed.ParseFromString(r.content)
subway_feed = protobuf_to_dict(feed)['entity'] subway_feed = protobuf_to_dict(feed)['entity']
routes.extend([x for x in [get_route_from_dict(x) for x in subway_feed] if x is not None]) trains.extend([train for train in [get_train_from_dict(train_dict) for train_dict in subway_feed] if train is not None])
return routes return trains
@staticmethod @staticmethod
def valid_route(train_lines, station_ids, route, max_time): def get_trains_for_routes(routes, trains):
if route.route_id not in train_lines: return [train for train in trains if train.route in routes]
return False
stops = route.stop_times
for stop in stops:
minutes_to_arrival = stop.arrival_minutes()
if stop.stop_id in station_ids:
if minutes_to_arrival > 0 and minutes_to_arrival < max_time:
return True
return False
async def get_route_information(self): @staticmethod
# Filter routes def get_trains_for_route(route, trains):
valid_routes = [route for route in await self.get_data() if return MTA.get_trains_for_routes([route], trains)
MTA.valid_route(self.train_lines, self.station_ids, route, self.max_arrival_time)]
return valid_routes
async def get_train_information(self):
# Might need to not filter these trains.
valid_trains = [train for train in await self.get_data() if True]
# MTA.trains_arriving_at_stations(self.train_lines, self.station_ids, train, self.max_arrival_time)]
return valid_trains
async def _get_updates(self): async def _get_updates(self):
self.is_running = True self.is_running = True
while (self.is_running): while (self.is_running):
t = time() t = time()
data = self.get_route_information() data = self.get_train_information()
data = await data data = await data
await self.process_callbacks(data) await self.process_callbacks(data)
await asyncio.sleep(self.callback_frequency - (time() - t)) await asyncio.sleep(self.callback_frequency - (time() - t))
# self.is_running = False
async def process_callbacks(self, data): async def process_callbacks(self, data):
for callback in self.timing_callbacks: for callback in self.timing_callbacks:
await callback(data) await callback(data)
def add_train_line(self, train_line: str): def add_train_line(self, train_line: str):
self.train_lines.append(train_line) self.routes.append(train_line)
self.set_valid_endpoints() self.set_valid_endpoints()
def remove_train_line(self, train_line: str): def remove_train_line(self, train_line: str):
self.train_lines.remove(train_line) self.routes.remove(train_line)
self.set_valid_endpoints() self.set_valid_endpoints()
def add_station_id(self, station_id: str): def add_station_id(self, station_id: str):
@@ -104,15 +104,14 @@ class MTA(object):
def remove_callback(self, callback_func): def remove_callback(self, callback_func):
self.timing_callbacks.remove(callback_func) self.timing_callbacks.remove(callback_func)
def convert_routes_to_station_first(self, routes): def get_time_arriving_at_stations(self, trains):
station_first = {} station_first = {}
for station_id in self.station_ids: for station_id in self.station_ids:
line_first = {} line_first = {}
for train_line in self.train_lines: for route in self.routes:
valid_routes = [route.get_arrival_at(station_id) for route in routes if valid_trains = [train.get_arrival_at(station_id) for train in MTA.get_trains_for_route(route, trains) if train.arriving_at_station_in_time(station_id, self.max_arrival_time)]
self.valid_route([train_line], [station_id], route, self.max_arrival_time)] if valid_trains:
if valid_routes: line_first[route] = valid_trains
line_first[train_line] = valid_routes
if line_first: if line_first:
station_first[station_id] = line_first station_first[station_id] = line_first
return station_first return station_first

View File

@@ -1,36 +0,0 @@
from .stop import get_stop_from_dict
def get_route_from_dict(obj):
if "trip_update" in obj and "stop_time_update" in obj["trip_update"]:
# data we need is here create object
id = obj["id"]
route_id = obj["trip_update"]["trip"]["route_id"]
stop_times = [valid_stop for valid_stop in
[get_stop_from_dict(x) for x in obj["trip_update"]["stop_time_update"]]
if valid_stop is not None]
return Route(id, route_id, stop_times)
else:
return None
class Route(object):
def __init__(self, id, route_id, stop_times):
self.id = id
self.route_id = route_id
self.stop_times = stop_times
def get_arrival_at(self, stop_id):
"""
returns the routes stop time at a given stop ID in minutes
if not found, returns None
:param stop_id: stop ID of arrival station
:return: arrival time in minutes
"""
for stop in self.stop_times:
if stop.stop_id == stop_id:
return stop.arrival_minutes()
return None
def __str__(self):
return f"id:{self.id} | route_id:{self.route_id}| stop_times:{self.stop_times}"

View File

@@ -1,18 +1,18 @@
from time import time
from datetime import datetime from datetime import datetime
from math import trunc from math import trunc
def get_stop_from_dict(obj): def get_stop_from_dict(obj):
if "arrival" in obj and "departure" in obj and "stop_id" in obj: if "arrival" in obj and "departure" in obj and "stop_id" in obj:
return Stop(obj["arrival"]["time"], obj["departure"]["time"], obj["stop_id"]) return Stop( obj["stop_id"], obj["arrival"]["time"], obj["departure"]["time"])
return None return None
class Stop(object): class Stop(object):
def __init__(self, arrival_time, departure_time, stop_id): def __init__(self, id, arrival_time, departure_time, ):
self.id = id
self.arrival_time = arrival_time self.arrival_time = arrival_time
self.departure_time = departure_time self.departure_time = departure_time
self.stop_id = stop_id
def arrival_minutes(self): def arrival_minutes(self):
return trunc(((datetime.fromtimestamp(self.arrival_time) - datetime.now()).total_seconds()) / 60) return trunc(((datetime.fromtimestamp(self.arrival_time) - datetime.now()).total_seconds()) / 60)
@@ -21,4 +21,4 @@ class Stop(object):
now = datetime.now() now = datetime.now()
time = datetime.fromtimestamp(self.arrival_time) time = datetime.fromtimestamp(self.arrival_time)
time_minutes = trunc(((time - now).total_seconds()) / 60) time_minutes = trunc(((time - now).total_seconds()) / 60)
return f"arr:{time_minutes}|dep:{self.departure_time}|stop_id:{self.stop_id}" return f"stop_id:{self.id}| arr:{time_minutes}| dep:{self.departure_time}"

43
mta_manager/train.py Normal file
View File

@@ -0,0 +1,43 @@
from .stop import get_stop_from_dict
def get_train_from_dict(obj):
if "trip_update" in obj and "stop_time_update" in obj["trip_update"]:
# data we need is here create object
id = obj["id"]
route = obj["trip_update"]["trip"]["route_id"]
all_stops = [get_stop_from_dict(x) for x in obj["trip_update"]["stop_time_update"]]
valid_stops = [valid_stop for valid_stop in all_stops if valid_stop is not None]
return Train(id, route, valid_stops)
else:
return None
class Train(object):
def __init__(self, id, route, stops):
self.id = id
self.route = route
self.stops = stops
def get_arrival_at(self, stop_id):
"""
returns the routes stop time at a given stop ID in minutes
if not found, returns None
:param stop_id: stop ID of arrival station
:return: arrival time in minutes
"""
for stop in self.stops:
if stop.id == stop_id:
return stop.arrival_minutes()
return None
def arriving_at_station_in_time(self, station_id, max_time):
for stop in self.stops:
minutes_to_arrival = stop.arrival_minutes()
if stop.id == station_id:
if minutes_to_arrival > 0 and minutes_to_arrival < max_time:
return True
def __str__(self):
formatted_stops = '\n'.join([str(stop) for stop in self.stops])
return f"train_id:{self.id} | line_name:{self.route}| stops:\n {formatted_stops}"

View File

@@ -1,6 +1,6 @@
import os import os
from dotenv import load_dotenv from dotenv import load_dotenv
from mta import MTA from mta_manager import MTA
import threading import threading
import time import time
from time import sleep from time import sleep
@@ -17,10 +17,11 @@ mtaController = MTA(
async def mta_callback(routes): async def mta_callback(trains):
print("We are inside of the call back now") print("We are inside of the call back now")
print(len(routes)) print(len(trains))
pprint(mtaController.convert_routes_to_station_first(routes)) pprint([str(route) for route in trains])
pprint(mtaController.get_time_arriving_at_stations(trains))
class threadWrapper(threading.Thread): class threadWrapper(threading.Thread):
def __init__(self, run): def __init__(self, run):

View File

@@ -86,7 +86,7 @@ if __name__ == "__main__":
async def mta_callback(routes): async def mta_callback(routes):
global subway_data, old_data, last_updated global subway_data, old_data, last_updated
subway_data = link_to_station(mtaController.convert_routes_to_station_first(routes)) subway_data = link_to_station(mtaController.station_info_from_routes(routes))
subway_data["LastUpdated"] = last_updated subway_data["LastUpdated"] = last_updated
if old_data is None: if old_data is None:
old_data = subway_data old_data = subway_data