Files
mta-sign/mta_sign_server/mta/router.py
2026-01-20 21:48:33 -05:00

135 lines
4.7 KiB
Python

import csv
import logging
import os
from pathlib import Path
from collections import defaultdict
from fastapi import APIRouter, HTTPException
from fastapi_utils.tasks import repeat_every
from starlette import status
from mta_api_client import Route, MTA
from mta_api_client.feed import ALL_FEEDS
from mta_sign_server.mta.schemas import StationResponse, RouteResponse, AllStationResponse
router = APIRouter(
tags=["mta-data"],
)
logger = logging.getLogger("mta")
api_key = os.getenv('MTA_API_KEY', '')
mtaController = MTA(
api_key,
feeds=ALL_FEEDS
)
ROUTES = [member for member in Route]
STATION_STOP_IDs = ["127S", "127N", "A27N", "A27S"]
STOPS_FILE = Path(__file__).parent.parent.parent / "stops.txt"
# Build mappings for station resolution
_station_name_to_ids = None
_station_id_to_name = None
def _load_station_mapping():
global _station_name_to_ids, _station_id_to_name
if _station_name_to_ids is not None:
return
_station_name_to_ids = defaultdict(set)
_station_id_to_name = {}
if STOPS_FILE.exists():
with open(STOPS_FILE, "r") as f:
reader = csv.DictReader(f)
for row in reader:
if row.get("location_type") == "1": # Parent stations only
station_name = row["stop_name"]
station_id = row["stop_id"]
_station_name_to_ids[station_name].add(station_id)
_station_id_to_name[station_id] = station_name
def _resolve_base_station_ids(stop_id: str) -> list[str]:
"""Resolve a stop_id to a list of all base parent station IDs (without direction) that share the same station name."""
_load_station_mapping()
# Remove direction suffix to get base ID
base_id = stop_id.rstrip("NS")
# Look up the station name for this ID
if base_id in _station_id_to_name:
station_name = _station_id_to_name[base_id]
# Return all base station IDs with this name
return sorted(list(_station_name_to_ids[station_name]))
# If not found, just return the base ID
return [base_id]
@router.post("/api/mta/{stop_id}/{route}", response_model=RouteResponse, status_code=status.HTTP_200_OK)
def get_route(stop_id: str, route: Route):
arrival_times = mtaController.get_arrival_times(route, stop_id)
if len(arrival_times) > 0:
return RouteResponse(routeId=route, arrival_times=arrival_times)
raise HTTPException(status_code=404, detail="no stops found for route and stop id")
@router.post("/api/mta/{stop_id}", response_model=StationResponse, status_code=status.HTTP_200_OK)
def get_station(stop_id: str):
routes_dict = {} # Use dict to avoid duplicates
base_station_ids = _resolve_base_station_ids(stop_id)
# Extract the requested direction
requested_direction = ""
if stop_id.endswith("N"):
requested_direction = "N"
elif stop_id.endswith("S"):
requested_direction = "S"
# Only query the requested direction for each base station ID
for base_id in base_station_ids:
station_id_with_direction = f"{base_id}{requested_direction}"
for route in ROUTES:
arrival_times = mtaController.get_arrival_times(route, station_id_with_direction)
if len(arrival_times) > 0:
route_key = route.value
if route_key not in routes_dict:
routes_dict[route_key] = (route, arrival_times)
else:
# Combine arrival times from different station IDs
routes_dict[route_key] = (route, sorted(list(set(routes_dict[route_key][1] + arrival_times))))
routes = [RouteResponse(routeId=route, arrival_times=times) for route, times in routes_dict.values()]
if routes:
return StationResponse(stationId=stop_id, routes=routes)
raise HTTPException(status_code=404, detail="no trains or routes found for stop id")
@router.post("/api/mta", response_model=AllStationResponse, status_code=status.HTTP_200_OK)
def get_all():
print("HELLO WORLD")
all_stations = []
for stop_id in STATION_STOP_IDs:
routes = []
for route in ROUTES:
arrival_times = mtaController.get_arrival_times(route, stop_id)
if len(arrival_times) > 0:
routes.append(RouteResponse(routeId=route, arrival_times=arrival_times))
all_stations.append(StationResponse(stationId=stop_id, routes=routes))
print(all_stations)
if all_stations:
return AllStationResponse(stations=all_stations)
raise HTTPException(status_code=404, detail="no arriving trains found for all configured routes")
@router.on_event("startup")
@repeat_every(seconds=10)
def update_trains():
logger.info("UPDATING TRAINS")
mtaController.update_trains()