135 lines
4.7 KiB
Python
135 lines
4.7 KiB
Python
import csv
|
|
import logging
|
|
import os
|
|
from pathlib import Path
|
|
from collections import defaultdict
|
|
|
|
from fastapi import APIRouter, HTTPException
|
|
from fastapi_utils.tasks import repeat_every
|
|
from starlette import status
|
|
|
|
from mta_api_client import Route, MTA
|
|
from mta_api_client.feed import ALL_FEEDS
|
|
from mta_sign_server.mta.schemas import StationResponse, RouteResponse, AllStationResponse
|
|
|
|
router = APIRouter(
|
|
tags=["mta-data"],
|
|
)
|
|
|
|
logger = logging.getLogger("mta")
|
|
|
|
api_key = os.getenv('MTA_API_KEY', '')
|
|
|
|
mtaController = MTA(
|
|
api_key,
|
|
feeds=ALL_FEEDS
|
|
)
|
|
|
|
ROUTES = [member for member in Route]
|
|
STATION_STOP_IDs = ["127S", "127N", "A27N", "A27S"]
|
|
STOPS_FILE = Path(__file__).parent.parent.parent / "stops.txt"
|
|
|
|
# Build mappings for station resolution
|
|
_station_name_to_ids = None
|
|
_station_id_to_name = None
|
|
|
|
|
|
def _load_station_mapping():
|
|
global _station_name_to_ids, _station_id_to_name
|
|
if _station_name_to_ids is not None:
|
|
return
|
|
|
|
_station_name_to_ids = defaultdict(set)
|
|
_station_id_to_name = {}
|
|
|
|
if STOPS_FILE.exists():
|
|
with open(STOPS_FILE, "r") as f:
|
|
reader = csv.DictReader(f)
|
|
for row in reader:
|
|
if row.get("location_type") == "1": # Parent stations only
|
|
station_name = row["stop_name"]
|
|
station_id = row["stop_id"]
|
|
_station_name_to_ids[station_name].add(station_id)
|
|
_station_id_to_name[station_id] = station_name
|
|
|
|
|
|
def _resolve_base_station_ids(stop_id: str) -> list[str]:
|
|
"""Resolve a stop_id to a list of all base parent station IDs (without direction) that share the same station name."""
|
|
_load_station_mapping()
|
|
|
|
# Remove direction suffix to get base ID
|
|
base_id = stop_id.rstrip("NS")
|
|
|
|
# Look up the station name for this ID
|
|
if base_id in _station_id_to_name:
|
|
station_name = _station_id_to_name[base_id]
|
|
# Return all base station IDs with this name
|
|
return sorted(list(_station_name_to_ids[station_name]))
|
|
|
|
# If not found, just return the base ID
|
|
return [base_id]
|
|
|
|
|
|
@router.post("/api/mta/{stop_id}/{route}", response_model=RouteResponse, status_code=status.HTTP_200_OK)
|
|
def get_route(stop_id: str, route: Route):
|
|
arrival_times = mtaController.get_arrival_times(route, stop_id)
|
|
if len(arrival_times) > 0:
|
|
return RouteResponse(routeId=route, arrival_times=arrival_times)
|
|
raise HTTPException(status_code=404, detail="no stops found for route and stop id")
|
|
|
|
|
|
@router.post("/api/mta/{stop_id}", response_model=StationResponse, status_code=status.HTTP_200_OK)
|
|
def get_station(stop_id: str):
|
|
routes_dict = {} # Use dict to avoid duplicates
|
|
base_station_ids = _resolve_base_station_ids(stop_id)
|
|
|
|
# Extract the requested direction
|
|
requested_direction = ""
|
|
if stop_id.endswith("N"):
|
|
requested_direction = "N"
|
|
elif stop_id.endswith("S"):
|
|
requested_direction = "S"
|
|
|
|
# Only query the requested direction for each base station ID
|
|
for base_id in base_station_ids:
|
|
station_id_with_direction = f"{base_id}{requested_direction}"
|
|
for route in ROUTES:
|
|
arrival_times = mtaController.get_arrival_times(route, station_id_with_direction)
|
|
if len(arrival_times) > 0:
|
|
route_key = route.value
|
|
if route_key not in routes_dict:
|
|
routes_dict[route_key] = (route, arrival_times)
|
|
else:
|
|
# Combine arrival times from different station IDs
|
|
routes_dict[route_key] = (route, sorted(list(set(routes_dict[route_key][1] + arrival_times))))
|
|
|
|
routes = [RouteResponse(routeId=route, arrival_times=times) for route, times in routes_dict.values()]
|
|
|
|
if routes:
|
|
return StationResponse(stationId=stop_id, routes=routes)
|
|
raise HTTPException(status_code=404, detail="no trains or routes found for stop id")
|
|
|
|
|
|
@router.post("/api/mta", response_model=AllStationResponse, status_code=status.HTTP_200_OK)
|
|
def get_all():
|
|
print("HELLO WORLD")
|
|
all_stations = []
|
|
for stop_id in STATION_STOP_IDs:
|
|
routes = []
|
|
for route in ROUTES:
|
|
arrival_times = mtaController.get_arrival_times(route, stop_id)
|
|
if len(arrival_times) > 0:
|
|
routes.append(RouteResponse(routeId=route, arrival_times=arrival_times))
|
|
all_stations.append(StationResponse(stationId=stop_id, routes=routes))
|
|
print(all_stations)
|
|
if all_stations:
|
|
return AllStationResponse(stations=all_stations)
|
|
raise HTTPException(status_code=404, detail="no arriving trains found for all configured routes")
|
|
|
|
|
|
@router.on_event("startup")
|
|
@repeat_every(seconds=10)
|
|
def update_trains():
|
|
logger.info("UPDATING TRAINS")
|
|
mtaController.update_trains()
|