119 lines
4.3 KiB
Python
119 lines
4.3 KiB
Python
import asyncio
|
|
import requests
|
|
import json
|
|
|
|
from google.transit import gtfs_realtime_pb2
|
|
from protobuf_to_dict import protobuf_to_dict
|
|
from .route import get_route_from_dict
|
|
from time import time
|
|
|
|
|
|
class MTA(object):
|
|
def __init__(self, api_key: str, train_lines, station_ids, timing_callbacks=None, alert_callbacks=None,
|
|
endpoints_file="./endpoints.json", callback_frequency=5, max_arrival_time=30):
|
|
self.header = {
|
|
"x-api-key": api_key
|
|
}
|
|
self.train_lines = train_lines
|
|
self.station_ids = station_ids
|
|
self.timing_callbacks = timing_callbacks if timing_callbacks else []
|
|
# self.alert_callbacks = alert_callbacks if alert_callbacks else []
|
|
self.is_running = False
|
|
self.callback_frequency = callback_frequency
|
|
self.max_arrival_time = max_arrival_time
|
|
with open(endpoints_file, "r") as f:
|
|
self.endpoints = json.load(f)
|
|
self.set_valid_endpoints()
|
|
|
|
def set_valid_endpoints(self):
|
|
self.valid_endpoints = {}
|
|
for key, value in self.endpoints.items():
|
|
valid_lines = [x for x in self.train_lines if x in key]
|
|
if valid_lines:
|
|
self.valid_endpoints[value] = valid_lines
|
|
print(self.valid_endpoints)
|
|
|
|
def start_updates(self):
|
|
print("starting updates")
|
|
loop = asyncio.new_event_loop()
|
|
asyncio.set_event_loop(loop)
|
|
loop.run_until_complete(self._get_updates())
|
|
|
|
def stop_updates(self):
|
|
self.is_running = False
|
|
|
|
async def get_data(self):
|
|
routes = []
|
|
for endpoint, valid_lines in self.valid_endpoints.items():
|
|
r = requests.get(endpoint, headers=self.header)
|
|
feed = gtfs_realtime_pb2.FeedMessage()
|
|
feed.ParseFromString(r.content)
|
|
subway_feed = protobuf_to_dict(feed)['entity']
|
|
routes.extend([x for x in [get_route_from_dict(x) for x in subway_feed] if x is not None])
|
|
return routes
|
|
|
|
@staticmethod
|
|
def valid_route(train_lines, station_ids, route, max_time):
|
|
if route.route_id not in train_lines:
|
|
return False
|
|
stops = route.stop_times
|
|
for stop in stops:
|
|
minutes_to_arrival = stop.arrival_minutes()
|
|
if stop.stop_id in station_ids:
|
|
if minutes_to_arrival > 0 and minutes_to_arrival < max_time:
|
|
return True
|
|
return False
|
|
|
|
async def get_route_information(self):
|
|
# Filter routes
|
|
valid_routes = [route for route in await self.get_data() if
|
|
MTA.valid_route(self.train_lines, self.station_ids, route, self.max_arrival_time)]
|
|
return valid_routes
|
|
|
|
async def _get_updates(self):
|
|
self.is_running = True
|
|
while (self.is_running):
|
|
t = time()
|
|
data = self.get_route_information()
|
|
data = await data
|
|
await self.process_callbacks(data)
|
|
await asyncio.sleep(self.callback_frequency - (time() - t))
|
|
# self.is_running = False
|
|
|
|
async def process_callbacks(self, data):
|
|
for callback in self.timing_callbacks:
|
|
await callback(data)
|
|
|
|
def add_train_line(self, train_line: str):
|
|
self.train_lines.append(train_line)
|
|
self.set_valid_endpoints()
|
|
|
|
def remove_train_line(self, train_line: str):
|
|
self.train_lines.remove(train_line)
|
|
self.set_valid_endpoints()
|
|
|
|
def add_station_id(self, station_id: str):
|
|
self.station_ids.append(station_id)
|
|
|
|
def remove_station_id(self, station_id: str):
|
|
self.station_ids.remove(station_id)
|
|
|
|
def add_callback(self, callback_func):
|
|
self.timing_callbacks.append(callback_func)
|
|
|
|
def remove_callback(self, callback_func):
|
|
self.timing_callbacks.remove(callback_func)
|
|
|
|
def convert_routes_to_station_first(self, routes):
|
|
station_first = {}
|
|
for station_id in self.station_ids:
|
|
line_first = {}
|
|
for train_line in self.train_lines:
|
|
valid_routes = [route.get_arrival_at(station_id) for route in routes if
|
|
self.valid_route([train_line], [station_id], route, self.max_arrival_time)]
|
|
if valid_routes:
|
|
line_first[train_line] = valid_routes
|
|
if line_first:
|
|
station_first[station_id] = line_first
|
|
return station_first
|