feat: adding openapi spect generation to the frontend client aas well as to fast api. Broke API out into different routes
This commit is contained in:
5
mta_api_client/__init__.py
Normal file
5
mta_api_client/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
from .mta import MTA
|
||||
|
||||
from .train import Train
|
||||
from .feed import Feed
|
||||
from .route import Route
|
||||
23
mta_api_client/feed.py
Normal file
23
mta_api_client/feed.py
Normal file
@@ -0,0 +1,23 @@
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class Feed(Enum):
|
||||
ACE = "https://api-endpoint.mta.info/Dataservice/mtagtfsfeeds/nyct%2Fgtfs-ace"
|
||||
BDFM = "https://api-endpoint.mta.info/Dataservice/mtagtfsfeeds/nyct%2Fgtfs-bdfm"
|
||||
G = "https://api-endpoint.mta.info/Dataservice/mtagtfsfeeds/nyct%2Fgtfs-bdfm"
|
||||
JZ = "https://api-endpoint.mta.info/Dataservice/mtagtfsfeeds/nyct%2Fgtfs-jz"
|
||||
NQRW = "https://api-endpoint.mta.info/Dataservice/mtagtfsfeeds/nyct%2Fgtfs-nqrw"
|
||||
L = "https://api-endpoint.mta.info/Dataservice/mtagtfsfeeds/nyct%2Fgtfs-l"
|
||||
N1234567 = "https://api-endpoint.mta.info/Dataservice/mtagtfsfeeds/nyct%2Fgtfs"
|
||||
SIR = "https://api-endpoint.mta.info/Dataservice/mtagtfsfeeds/nyct%2Fgtfs-si"
|
||||
|
||||
|
||||
ALL_FEEDS = [
|
||||
Feed.ACE,
|
||||
Feed.BDFM,
|
||||
Feed.G,
|
||||
Feed.NQRW,
|
||||
Feed.L,
|
||||
Feed.N1234567,
|
||||
Feed.SIR
|
||||
]
|
||||
50
mta_api_client/mta.py
Normal file
50
mta_api_client/mta.py
Normal file
@@ -0,0 +1,50 @@
|
||||
import requests
|
||||
|
||||
from google.transit import gtfs_realtime_pb2
|
||||
from .train import Train
|
||||
from .feed import Feed, ALL_FEEDS
|
||||
from .route import Route
|
||||
|
||||
|
||||
class MTA(object):
|
||||
def __init__(self, api_key: str, feeds: [Feed] = ALL_FEEDS, stations: [str] = [],
|
||||
max_arrival_time: int = 30):
|
||||
self.header = {
|
||||
"x-api-key": api_key
|
||||
}
|
||||
self.feeds = feeds
|
||||
self.stations = stations
|
||||
self.max_arrival_time = max_arrival_time
|
||||
self.trains: [Train] = []
|
||||
|
||||
def stop_updates(self):
|
||||
self.is_running = False
|
||||
|
||||
def update_trains(self) -> [Train]:
|
||||
trains = []
|
||||
for feed in self.feeds:
|
||||
r = requests.get(feed.value, headers=self.header)
|
||||
feed = gtfs_realtime_pb2.FeedMessage()
|
||||
feed.ParseFromString(r.content)
|
||||
trains.extend([train for train in [Train(train) for train in feed.entity] if
|
||||
train.has_trips()])
|
||||
self.trains = trains
|
||||
return trains
|
||||
|
||||
def get_trains(self) -> [Train]:
|
||||
return self.trains
|
||||
|
||||
def get_arrival_times(self, route: Route, station: str) -> [int]:
|
||||
arrival_times = []
|
||||
for train in self.trains:
|
||||
if train.get_route() is route:
|
||||
arrival = train.get_arrival_at(station)
|
||||
if arrival is not None and arrival < self.max_arrival_time and arrival > 0:
|
||||
arrival_times.append(arrival)
|
||||
return sorted(arrival_times)
|
||||
|
||||
def add_station_id(self, station_id: str):
|
||||
self.stations.append(station_id)
|
||||
|
||||
def remove_station_id(self, station_id: str):
|
||||
self.stations.remove(station_id)
|
||||
40
mta_api_client/route.py
Normal file
40
mta_api_client/route.py
Normal file
@@ -0,0 +1,40 @@
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class Route(Enum):
|
||||
A = "A"
|
||||
C = "C"
|
||||
E = "E"
|
||||
|
||||
B = "B"
|
||||
D = "D"
|
||||
F = "F"
|
||||
M = "M"
|
||||
|
||||
G = "G"
|
||||
|
||||
J = "J"
|
||||
Z = "Z"
|
||||
|
||||
N = "N"
|
||||
Q = "Q"
|
||||
R = "R"
|
||||
W = "W"
|
||||
|
||||
N1 = "1"
|
||||
N2 = "2"
|
||||
N3 = "3"
|
||||
N4 = "4"
|
||||
N5 = "5"
|
||||
N6 = "6"
|
||||
N7 = "7"
|
||||
|
||||
L = "L"
|
||||
SIR = "SIR"
|
||||
|
||||
|
||||
_routes = set(item.value for item in Route)
|
||||
|
||||
|
||||
def is_valid_route(route: str) -> bool:
|
||||
return route in _routes
|
||||
7
mta_api_client/stop.py
Normal file
7
mta_api_client/stop.py
Normal file
@@ -0,0 +1,7 @@
|
||||
from datetime import datetime
|
||||
from google.transit import gtfs_realtime_pb2
|
||||
from math import trunc
|
||||
|
||||
|
||||
def trip_arrival_in_minutes(stop_time_update: gtfs_realtime_pb2.TripUpdate):
|
||||
return trunc(((datetime.fromtimestamp(stop_time_update.arrival.time) - datetime.now()).total_seconds()) / 60)
|
||||
33
mta_api_client/train.py
Normal file
33
mta_api_client/train.py
Normal file
@@ -0,0 +1,33 @@
|
||||
from google.transit import gtfs_realtime_pb2
|
||||
from .stop import trip_arrival_in_minutes
|
||||
from .route import Route, is_valid_route
|
||||
|
||||
|
||||
class Train(object):
|
||||
def __init__(self, train_proto: gtfs_realtime_pb2.FeedEntity):
|
||||
self.train_proto: gtfs_realtime_pb2.FeedEntity = train_proto
|
||||
|
||||
def get_arrival_at(self, stop_id) -> int | None:
|
||||
"""
|
||||
returns the routes stop time at a given stop ID in minutes
|
||||
if not found, returns None
|
||||
:param stop_id: stop ID of arrival station
|
||||
:return: arrival time in minutes
|
||||
"""
|
||||
for stop_time_update in self.train_proto.trip_update.stop_time_update:
|
||||
if stop_time_update.stop_id == stop_id:
|
||||
return trip_arrival_in_minutes(stop_time_update)
|
||||
return None
|
||||
|
||||
def _get_route(self) -> str:
|
||||
return self.train_proto.trip_update.trip.route_id
|
||||
|
||||
def get_route(self) -> Route:
|
||||
return Route(self.train_proto.trip_update.trip.route_id)
|
||||
|
||||
def has_trips(self) -> bool:
|
||||
return self.train_proto.trip_update is not None \
|
||||
and len(self.train_proto.trip_update.stop_time_update) > 0 and is_valid_route(self._get_route())
|
||||
|
||||
def __str__(self):
|
||||
return f"{self.train_proto}"
|
||||
Reference in New Issue
Block a user