51 lines
1.7 KiB
Python
51 lines
1.7 KiB
Python
import requests
|
|
|
|
from google.transit import gtfs_realtime_pb2
|
|
from .train import Train
|
|
from .feed import Feed, ALL_FEEDS
|
|
from .route import Route
|
|
|
|
|
|
class MTA(object):
|
|
def __init__(self, api_key: str, feeds: [Feed] = ALL_FEEDS, stations: [str] = [],
|
|
max_arrival_time: int = 30):
|
|
self.header = {
|
|
"x-api-key": api_key
|
|
}
|
|
self.feeds = feeds
|
|
self.stations = stations
|
|
self.max_arrival_time = max_arrival_time
|
|
self.trains: [Train] = []
|
|
|
|
def stop_updates(self):
|
|
self.is_running = False
|
|
|
|
def get_incoming_trains(self) -> [Train]:
|
|
trains = []
|
|
for feed in self.feeds:
|
|
r = requests.get(feed.value, headers=self.header)
|
|
feed = gtfs_realtime_pb2.FeedMessage()
|
|
feed.ParseFromString(r.content)
|
|
trains.extend([train for train in [Train(train) for train in feed.entity] if
|
|
train.has_trips()])
|
|
self.trains = trains
|
|
return trains
|
|
|
|
def get_trains(self) -> [Train]:
|
|
return self.trains
|
|
|
|
def get_arrival_times(self, route: Route, station: str) -> [int]:
|
|
arrival_times = []
|
|
for train in self.trains:
|
|
if train.get_route() is route:
|
|
arrival = train.get_arrival_at(station)
|
|
if arrival is not None and arrival < self.max_arrival_time:
|
|
arrival_times.append(arrival)
|
|
return sorted(arrival_times)
|
|
|
|
def add_station_id(self, station_id: str):
|
|
self.stations.append(station_id)
|
|
|
|
def remove_station_id(self, station_id: str):
|
|
self.stations.remove(station_id)
|