feat: rework to use next and host from single dockerfile
This commit is contained in:
@@ -1,7 +1,11 @@
|
||||
import csv
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
from fastapi import APIRouter
|
||||
from starlette.responses import JSONResponse
|
||||
|
||||
from mta_api_client import Route
|
||||
from mta_sign_server.config.schemas import Station, StationsResponse, LinesResponse
|
||||
|
||||
router = APIRouter(
|
||||
tags=["config"],
|
||||
@@ -9,7 +13,44 @@ router = APIRouter(
|
||||
|
||||
logger = logging.getLogger("config_router")
|
||||
|
||||
STOPS_FILE = Path(__file__).parent.parent.parent / "stops.txt"
|
||||
|
||||
|
||||
@router.get("/api/config")
|
||||
def get_all():
|
||||
return JSONResponse({"config": "goes here"})
|
||||
return {"config": "goes here"}
|
||||
|
||||
|
||||
@router.get("/api/stations", response_model=StationsResponse)
|
||||
def get_stations(search: str | None = None):
|
||||
"""Get list of all stations, optionally filtered by search term. Deduplicates by station name."""
|
||||
stations_dict = {}
|
||||
|
||||
if STOPS_FILE.exists():
|
||||
with open(STOPS_FILE, "r") as f:
|
||||
reader = csv.DictReader(f)
|
||||
for row in reader:
|
||||
# Only include parent stations (location_type == 1)
|
||||
if row.get("location_type") == "1":
|
||||
station_name = row["stop_name"]
|
||||
station_id = row["stop_id"]
|
||||
|
||||
# Only add first occurrence of each station name to deduplicate
|
||||
if station_name not in stations_dict:
|
||||
station = Station(id=station_id, name=station_name)
|
||||
|
||||
# Filter by search term if provided
|
||||
if search:
|
||||
if search.lower() in station.name.lower() or search in station.id:
|
||||
stations_dict[station_name] = station
|
||||
else:
|
||||
stations_dict[station_name] = station
|
||||
|
||||
return StationsResponse(stations=list(stations_dict.values()))
|
||||
|
||||
|
||||
@router.get("/api/lines", response_model=LinesResponse)
|
||||
def get_lines():
|
||||
"""Get list of all available train lines."""
|
||||
lines = [route.value for route in Route]
|
||||
return LinesResponse(lines=sorted(lines))
|
||||
|
||||
14
mta_sign_server/config/schemas.py
Normal file
14
mta_sign_server/config/schemas.py
Normal file
@@ -0,0 +1,14 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class Station(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
|
||||
|
||||
class StationsResponse(BaseModel):
|
||||
stations: list[Station]
|
||||
|
||||
|
||||
class LinesResponse(BaseModel):
|
||||
lines: list[str]
|
||||
@@ -1,11 +1,15 @@
|
||||
import csv
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
from collections import defaultdict
|
||||
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from fastapi_utils.tasks import repeat_every
|
||||
from starlette import status
|
||||
|
||||
from mta_api_client import Route, MTA, Feed
|
||||
from mta_api_client import Route, MTA
|
||||
from mta_api_client.feed import ALL_FEEDS
|
||||
from mta_sign_server.mta.schemas import StationResponse, RouteResponse, AllStationResponse
|
||||
|
||||
router = APIRouter(
|
||||
@@ -18,30 +22,91 @@ api_key = os.getenv('MTA_API_KEY', '')
|
||||
|
||||
mtaController = MTA(
|
||||
api_key,
|
||||
feeds=[Feed.ACE, Feed.N1234567]
|
||||
feeds=ALL_FEEDS
|
||||
)
|
||||
|
||||
ROUTES = [Route.A, Route.C, Route.E, Route.N1, Route.N2, Route.N3]
|
||||
ROUTES = [member for member in Route]
|
||||
STATION_STOP_IDs = ["127S", "127N", "A27N", "A27S"]
|
||||
STOPS_FILE = Path(__file__).parent.parent.parent / "stops.txt"
|
||||
|
||||
# Build mappings for station resolution
|
||||
_station_name_to_ids = None
|
||||
_station_id_to_name = None
|
||||
|
||||
|
||||
def _load_station_mapping():
|
||||
global _station_name_to_ids, _station_id_to_name
|
||||
if _station_name_to_ids is not None:
|
||||
return
|
||||
|
||||
_station_name_to_ids = defaultdict(set)
|
||||
_station_id_to_name = {}
|
||||
|
||||
if STOPS_FILE.exists():
|
||||
with open(STOPS_FILE, "r") as f:
|
||||
reader = csv.DictReader(f)
|
||||
for row in reader:
|
||||
if row.get("location_type") == "1": # Parent stations only
|
||||
station_name = row["stop_name"]
|
||||
station_id = row["stop_id"]
|
||||
_station_name_to_ids[station_name].add(station_id)
|
||||
_station_id_to_name[station_id] = station_name
|
||||
|
||||
|
||||
def _resolve_base_station_ids(stop_id: str) -> list[str]:
|
||||
"""Resolve a stop_id to a list of all base parent station IDs (without direction) that share the same station name."""
|
||||
_load_station_mapping()
|
||||
|
||||
# Remove direction suffix to get base ID
|
||||
base_id = stop_id.rstrip("NS")
|
||||
|
||||
# Look up the station name for this ID
|
||||
if base_id in _station_id_to_name:
|
||||
station_name = _station_id_to_name[base_id]
|
||||
# Return all base station IDs with this name
|
||||
return sorted(list(_station_name_to_ids[station_name]))
|
||||
|
||||
# If not found, just return the base ID
|
||||
return [base_id]
|
||||
|
||||
|
||||
@router.post("/api/mta/{stop_id}/{route}", response_model=RouteResponse, status_code=status.HTTP_200_OK)
|
||||
def get_route(stop_id: str, route: Route):
|
||||
arrival_times = mtaController.get_arrival_times(route, stop_id)
|
||||
if len(arrival_times) > 0:
|
||||
return RouteResponse(arrival_times=arrival_times)
|
||||
return RouteResponse(routeId=route, arrival_times=arrival_times)
|
||||
raise HTTPException(status_code=404, detail="no stops found for route and stop id")
|
||||
|
||||
|
||||
@router.post("/api/mta/{stop_id}", response_model=StationResponse, status_code=status.HTTP_200_OK)
|
||||
def get_station(stop_id: str):
|
||||
routes = {}
|
||||
for route in ROUTES:
|
||||
arrival_times = mtaController.get_arrival_times(route, stop_id)
|
||||
if len(arrival_times) > 0:
|
||||
routes[route] = RouteResponse(arrival_times=arrival_times)
|
||||
routes_dict = {} # Use dict to avoid duplicates
|
||||
base_station_ids = _resolve_base_station_ids(stop_id)
|
||||
|
||||
# Extract the requested direction
|
||||
requested_direction = ""
|
||||
if stop_id.endswith("N"):
|
||||
requested_direction = "N"
|
||||
elif stop_id.endswith("S"):
|
||||
requested_direction = "S"
|
||||
|
||||
# Only query the requested direction for each base station ID
|
||||
for base_id in base_station_ids:
|
||||
station_id_with_direction = f"{base_id}{requested_direction}"
|
||||
for route in ROUTES:
|
||||
arrival_times = mtaController.get_arrival_times(route, station_id_with_direction)
|
||||
if len(arrival_times) > 0:
|
||||
route_key = route.value
|
||||
if route_key not in routes_dict:
|
||||
routes_dict[route_key] = (route, arrival_times)
|
||||
else:
|
||||
# Combine arrival times from different station IDs
|
||||
routes_dict[route_key] = (route, sorted(list(set(routes_dict[route_key][1] + arrival_times))))
|
||||
|
||||
routes = [RouteResponse(routeId=route, arrival_times=times) for route, times in routes_dict.values()]
|
||||
|
||||
if routes:
|
||||
return StationResponse(routes=routes)
|
||||
return StationResponse(stationId=stop_id, routes=routes)
|
||||
raise HTTPException(status_code=404, detail="no trains or routes found for stop id")
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user