Cleaning up the code - renaming routes to trains as it makes more sense (each isntance is an instance of an individual train)
Reworking the MTA to be ready for use by people other than me.
This commit is contained in:
@@ -1,3 +1,3 @@
|
||||
from .mta import *
|
||||
from .route import *
|
||||
from .train import *
|
||||
from .stop import *
|
||||
|
||||
@@ -4,20 +4,23 @@ import json
|
||||
|
||||
from google.transit import gtfs_realtime_pb2
|
||||
from protobuf_to_dict import protobuf_to_dict
|
||||
from .route import get_route_from_dict
|
||||
from .train import get_train_from_dict
|
||||
from time import time
|
||||
|
||||
|
||||
class MTA(object):
|
||||
def __init__(self, api_key: str, train_lines, station_ids, timing_callbacks=None, alert_callbacks=None,
|
||||
# Create a data filter object.
|
||||
# Then be able to update that object on the fly.
|
||||
# This filter should return all possible trains and stations by default.
|
||||
# If anyhting is added it gets filtered out.
|
||||
def __init__(self, api_key: str, routes, station_ids, timing_callbacks=None, alert_callbacks=None,
|
||||
endpoints_file="./endpoints.json", callback_frequency=10, max_arrival_time=30):
|
||||
self.header = {
|
||||
"x-api-key": api_key
|
||||
}
|
||||
self.train_lines = train_lines
|
||||
self.routes = routes
|
||||
self.station_ids = station_ids
|
||||
self.timing_callbacks = timing_callbacks if timing_callbacks else []
|
||||
# self.alert_callbacks = alert_callbacks if alert_callbacks else []
|
||||
self.is_running = False
|
||||
self.callback_frequency = callback_frequency
|
||||
self.max_arrival_time = max_arrival_time
|
||||
@@ -28,9 +31,9 @@ class MTA(object):
|
||||
def set_valid_endpoints(self):
|
||||
self.valid_endpoints = {}
|
||||
for key, value in self.endpoints.items():
|
||||
valid_lines = [x for x in self.train_lines if x in key]
|
||||
if valid_lines:
|
||||
self.valid_endpoints[value] = valid_lines
|
||||
valid_routes = [x for x in self.routes if x in key]
|
||||
if valid_routes:
|
||||
self.valid_endpoints[value] = valid_routes
|
||||
print(self.valid_endpoints)
|
||||
|
||||
def start_updates(self):
|
||||
@@ -43,53 +46,50 @@ class MTA(object):
|
||||
self.is_running = False
|
||||
|
||||
async def get_data(self):
|
||||
routes = []
|
||||
trains = []
|
||||
for endpoint, valid_lines in self.valid_endpoints.items():
|
||||
r = requests.get(endpoint, headers=self.header)
|
||||
feed = gtfs_realtime_pb2.FeedMessage()
|
||||
feed.ParseFromString(r.content)
|
||||
subway_feed = protobuf_to_dict(feed)['entity']
|
||||
routes.extend([x for x in [get_route_from_dict(x) for x in subway_feed] if x is not None])
|
||||
return routes
|
||||
trains.extend([train for train in [get_train_from_dict(train_dict) for train_dict in subway_feed] if train is not None])
|
||||
return trains
|
||||
|
||||
|
||||
@staticmethod
|
||||
def valid_route(train_lines, station_ids, route, max_time):
|
||||
if route.route_id not in train_lines:
|
||||
return False
|
||||
stops = route.stop_times
|
||||
for stop in stops:
|
||||
minutes_to_arrival = stop.arrival_minutes()
|
||||
if stop.stop_id in station_ids:
|
||||
if minutes_to_arrival > 0 and minutes_to_arrival < max_time:
|
||||
return True
|
||||
return False
|
||||
def get_trains_for_routes(routes, trains):
|
||||
return [train for train in trains if train.route in routes]
|
||||
|
||||
async def get_route_information(self):
|
||||
# Filter routes
|
||||
valid_routes = [route for route in await self.get_data() if
|
||||
MTA.valid_route(self.train_lines, self.station_ids, route, self.max_arrival_time)]
|
||||
return valid_routes
|
||||
@staticmethod
|
||||
def get_trains_for_route(route, trains):
|
||||
return MTA.get_trains_for_routes([route], trains)
|
||||
|
||||
|
||||
async def get_train_information(self):
|
||||
# Might need to not filter these trains.
|
||||
valid_trains = [train for train in await self.get_data() if True]
|
||||
# MTA.trains_arriving_at_stations(self.train_lines, self.station_ids, train, self.max_arrival_time)]
|
||||
return valid_trains
|
||||
|
||||
async def _get_updates(self):
|
||||
self.is_running = True
|
||||
while (self.is_running):
|
||||
t = time()
|
||||
data = self.get_route_information()
|
||||
data = self.get_train_information()
|
||||
data = await data
|
||||
await self.process_callbacks(data)
|
||||
await asyncio.sleep(self.callback_frequency - (time() - t))
|
||||
# self.is_running = False
|
||||
|
||||
async def process_callbacks(self, data):
|
||||
for callback in self.timing_callbacks:
|
||||
await callback(data)
|
||||
|
||||
def add_train_line(self, train_line: str):
|
||||
self.train_lines.append(train_line)
|
||||
self.routes.append(train_line)
|
||||
self.set_valid_endpoints()
|
||||
|
||||
def remove_train_line(self, train_line: str):
|
||||
self.train_lines.remove(train_line)
|
||||
self.routes.remove(train_line)
|
||||
self.set_valid_endpoints()
|
||||
|
||||
def add_station_id(self, station_id: str):
|
||||
@@ -104,15 +104,14 @@ class MTA(object):
|
||||
def remove_callback(self, callback_func):
|
||||
self.timing_callbacks.remove(callback_func)
|
||||
|
||||
def convert_routes_to_station_first(self, routes):
|
||||
def get_time_arriving_at_stations(self, trains):
|
||||
station_first = {}
|
||||
for station_id in self.station_ids:
|
||||
line_first = {}
|
||||
for train_line in self.train_lines:
|
||||
valid_routes = [route.get_arrival_at(station_id) for route in routes if
|
||||
self.valid_route([train_line], [station_id], route, self.max_arrival_time)]
|
||||
if valid_routes:
|
||||
line_first[train_line] = valid_routes
|
||||
for route in self.routes:
|
||||
valid_trains = [train.get_arrival_at(station_id) for train in MTA.get_trains_for_route(route, trains) if train.arriving_at_station_in_time(station_id, self.max_arrival_time)]
|
||||
if valid_trains:
|
||||
line_first[route] = valid_trains
|
||||
if line_first:
|
||||
station_first[station_id] = line_first
|
||||
return station_first
|
||||
|
||||
@@ -1,36 +0,0 @@
|
||||
from .stop import get_stop_from_dict
|
||||
|
||||
|
||||
def get_route_from_dict(obj):
|
||||
if "trip_update" in obj and "stop_time_update" in obj["trip_update"]:
|
||||
# data we need is here create object
|
||||
id = obj["id"]
|
||||
route_id = obj["trip_update"]["trip"]["route_id"]
|
||||
stop_times = [valid_stop for valid_stop in
|
||||
[get_stop_from_dict(x) for x in obj["trip_update"]["stop_time_update"]]
|
||||
if valid_stop is not None]
|
||||
return Route(id, route_id, stop_times)
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
class Route(object):
|
||||
def __init__(self, id, route_id, stop_times):
|
||||
self.id = id
|
||||
self.route_id = route_id
|
||||
self.stop_times = stop_times
|
||||
|
||||
def get_arrival_at(self, stop_id):
|
||||
"""
|
||||
returns the routes stop time at a given stop ID in minutes
|
||||
if not found, returns None
|
||||
:param stop_id: stop ID of arrival station
|
||||
:return: arrival time in minutes
|
||||
"""
|
||||
for stop in self.stop_times:
|
||||
if stop.stop_id == stop_id:
|
||||
return stop.arrival_minutes()
|
||||
return None
|
||||
|
||||
def __str__(self):
|
||||
return f"id:{self.id} | route_id:{self.route_id}| stop_times:{self.stop_times}"
|
||||
@@ -1,18 +1,18 @@
|
||||
from time import time
|
||||
|
||||
from datetime import datetime
|
||||
from math import trunc
|
||||
|
||||
def get_stop_from_dict(obj):
|
||||
if "arrival" in obj and "departure" in obj and "stop_id" in obj:
|
||||
return Stop(obj["arrival"]["time"], obj["departure"]["time"], obj["stop_id"])
|
||||
return Stop( obj["stop_id"], obj["arrival"]["time"], obj["departure"]["time"])
|
||||
return None
|
||||
|
||||
|
||||
class Stop(object):
|
||||
def __init__(self, arrival_time, departure_time, stop_id):
|
||||
def __init__(self, id, arrival_time, departure_time, ):
|
||||
self.id = id
|
||||
self.arrival_time = arrival_time
|
||||
self.departure_time = departure_time
|
||||
self.stop_id = stop_id
|
||||
|
||||
def arrival_minutes(self):
|
||||
return trunc(((datetime.fromtimestamp(self.arrival_time) - datetime.now()).total_seconds()) / 60)
|
||||
@@ -21,4 +21,4 @@ class Stop(object):
|
||||
now = datetime.now()
|
||||
time = datetime.fromtimestamp(self.arrival_time)
|
||||
time_minutes = trunc(((time - now).total_seconds()) / 60)
|
||||
return f"arr:{time_minutes}|dep:{self.departure_time}|stop_id:{self.stop_id}"
|
||||
return f"stop_id:{self.id}| arr:{time_minutes}| dep:{self.departure_time}"
|
||||
|
||||
43
mta_manager/train.py
Normal file
43
mta_manager/train.py
Normal file
@@ -0,0 +1,43 @@
|
||||
from .stop import get_stop_from_dict
|
||||
|
||||
|
||||
def get_train_from_dict(obj):
|
||||
if "trip_update" in obj and "stop_time_update" in obj["trip_update"]:
|
||||
# data we need is here create object
|
||||
id = obj["id"]
|
||||
route = obj["trip_update"]["trip"]["route_id"]
|
||||
all_stops = [get_stop_from_dict(x) for x in obj["trip_update"]["stop_time_update"]]
|
||||
valid_stops = [valid_stop for valid_stop in all_stops if valid_stop is not None]
|
||||
return Train(id, route, valid_stops)
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
class Train(object):
|
||||
def __init__(self, id, route, stops):
|
||||
self.id = id
|
||||
self.route = route
|
||||
self.stops = stops
|
||||
|
||||
def get_arrival_at(self, stop_id):
|
||||
"""
|
||||
returns the routes stop time at a given stop ID in minutes
|
||||
if not found, returns None
|
||||
:param stop_id: stop ID of arrival station
|
||||
:return: arrival time in minutes
|
||||
"""
|
||||
for stop in self.stops:
|
||||
if stop.id == stop_id:
|
||||
return stop.arrival_minutes()
|
||||
return None
|
||||
|
||||
def arriving_at_station_in_time(self, station_id, max_time):
|
||||
for stop in self.stops:
|
||||
minutes_to_arrival = stop.arrival_minutes()
|
||||
if stop.id == station_id:
|
||||
if minutes_to_arrival > 0 and minutes_to_arrival < max_time:
|
||||
return True
|
||||
|
||||
def __str__(self):
|
||||
formatted_stops = '\n'.join([str(stop) for stop in self.stops])
|
||||
return f"train_id:{self.id} | line_name:{self.route}| stops:\n {formatted_stops}"
|
||||
@@ -1,6 +1,6 @@
|
||||
import os
|
||||
from dotenv import load_dotenv
|
||||
from mta import MTA
|
||||
from mta_manager import MTA
|
||||
import threading
|
||||
import time
|
||||
from time import sleep
|
||||
@@ -17,10 +17,11 @@ mtaController = MTA(
|
||||
|
||||
|
||||
|
||||
async def mta_callback(routes):
|
||||
async def mta_callback(trains):
|
||||
print("We are inside of the call back now")
|
||||
print(len(routes))
|
||||
pprint(mtaController.convert_routes_to_station_first(routes))
|
||||
print(len(trains))
|
||||
pprint([str(route) for route in trains])
|
||||
pprint(mtaController.get_time_arriving_at_stations(trains))
|
||||
|
||||
class threadWrapper(threading.Thread):
|
||||
def __init__(self, run):
|
||||
@@ -86,7 +86,7 @@ if __name__ == "__main__":
|
||||
|
||||
async def mta_callback(routes):
|
||||
global subway_data, old_data, last_updated
|
||||
subway_data = link_to_station(mtaController.convert_routes_to_station_first(routes))
|
||||
subway_data = link_to_station(mtaController.station_info_from_routes(routes))
|
||||
subway_data["LastUpdated"] = last_updated
|
||||
if old_data is None:
|
||||
old_data = subway_data
|
||||
|
||||
Reference in New Issue
Block a user