Cleaning up the code - renaming routes to trains as it makes more sense (each isntance is an instance of an individual train)

Reworking the MTA to be ready for use by people other than me.
This commit is contained in:
Lucas
2022-02-25 18:39:17 -05:00
parent 9919eed55b
commit e12079fbde
7 changed files with 89 additions and 82 deletions

View File

@@ -4,20 +4,23 @@ import json
from google.transit import gtfs_realtime_pb2
from protobuf_to_dict import protobuf_to_dict
from .route import get_route_from_dict
from .train import get_train_from_dict
from time import time
class MTA(object):
def __init__(self, api_key: str, train_lines, station_ids, timing_callbacks=None, alert_callbacks=None,
# Create a data filter object.
# Then be able to update that object on the fly.
# This filter should return all possible trains and stations by default.
# If anyhting is added it gets filtered out.
def __init__(self, api_key: str, routes, station_ids, timing_callbacks=None, alert_callbacks=None,
endpoints_file="./endpoints.json", callback_frequency=10, max_arrival_time=30):
self.header = {
"x-api-key": api_key
}
self.train_lines = train_lines
self.routes = routes
self.station_ids = station_ids
self.timing_callbacks = timing_callbacks if timing_callbacks else []
# self.alert_callbacks = alert_callbacks if alert_callbacks else []
self.is_running = False
self.callback_frequency = callback_frequency
self.max_arrival_time = max_arrival_time
@@ -28,9 +31,9 @@ class MTA(object):
def set_valid_endpoints(self):
self.valid_endpoints = {}
for key, value in self.endpoints.items():
valid_lines = [x for x in self.train_lines if x in key]
if valid_lines:
self.valid_endpoints[value] = valid_lines
valid_routes = [x for x in self.routes if x in key]
if valid_routes:
self.valid_endpoints[value] = valid_routes
print(self.valid_endpoints)
def start_updates(self):
@@ -43,53 +46,50 @@ class MTA(object):
self.is_running = False
async def get_data(self):
routes = []
trains = []
for endpoint, valid_lines in self.valid_endpoints.items():
r = requests.get(endpoint, headers=self.header)
feed = gtfs_realtime_pb2.FeedMessage()
feed.ParseFromString(r.content)
subway_feed = protobuf_to_dict(feed)['entity']
routes.extend([x for x in [get_route_from_dict(x) for x in subway_feed] if x is not None])
return routes
trains.extend([train for train in [get_train_from_dict(train_dict) for train_dict in subway_feed] if train is not None])
return trains
@staticmethod
def valid_route(train_lines, station_ids, route, max_time):
if route.route_id not in train_lines:
return False
stops = route.stop_times
for stop in stops:
minutes_to_arrival = stop.arrival_minutes()
if stop.stop_id in station_ids:
if minutes_to_arrival > 0 and minutes_to_arrival < max_time:
return True
return False
def get_trains_for_routes(routes, trains):
return [train for train in trains if train.route in routes]
async def get_route_information(self):
# Filter routes
valid_routes = [route for route in await self.get_data() if
MTA.valid_route(self.train_lines, self.station_ids, route, self.max_arrival_time)]
return valid_routes
@staticmethod
def get_trains_for_route(route, trains):
return MTA.get_trains_for_routes([route], trains)
async def get_train_information(self):
# Might need to not filter these trains.
valid_trains = [train for train in await self.get_data() if True]
# MTA.trains_arriving_at_stations(self.train_lines, self.station_ids, train, self.max_arrival_time)]
return valid_trains
async def _get_updates(self):
self.is_running = True
while (self.is_running):
t = time()
data = self.get_route_information()
data = self.get_train_information()
data = await data
await self.process_callbacks(data)
await asyncio.sleep(self.callback_frequency - (time() - t))
# self.is_running = False
async def process_callbacks(self, data):
for callback in self.timing_callbacks:
await callback(data)
def add_train_line(self, train_line: str):
self.train_lines.append(train_line)
self.routes.append(train_line)
self.set_valid_endpoints()
def remove_train_line(self, train_line: str):
self.train_lines.remove(train_line)
self.routes.remove(train_line)
self.set_valid_endpoints()
def add_station_id(self, station_id: str):
@@ -104,15 +104,14 @@ class MTA(object):
def remove_callback(self, callback_func):
self.timing_callbacks.remove(callback_func)
def convert_routes_to_station_first(self, routes):
def get_time_arriving_at_stations(self, trains):
station_first = {}
for station_id in self.station_ids:
line_first = {}
for train_line in self.train_lines:
valid_routes = [route.get_arrival_at(station_id) for route in routes if
self.valid_route([train_line], [station_id], route, self.max_arrival_time)]
if valid_routes:
line_first[train_line] = valid_routes
for route in self.routes:
valid_trains = [train.get_arrival_at(station_id) for train in MTA.get_trains_for_route(route, trains) if train.arriving_at_station_in_time(station_id, self.max_arrival_time)]
if valid_trains:
line_first[route] = valid_trains
if line_first:
station_first[station_id] = line_first
return station_first