feat: refactoring code to use protos directly instead of translating them.
This commit is contained in:
@@ -1,42 +1,36 @@
|
||||
from .stop import Stop
|
||||
|
||||
from google.transit import gtfs_realtime_pb2
|
||||
from .stop import trip_arrival_in_minutes
|
||||
from .route import Route, is_valid_route
|
||||
|
||||
|
||||
class Train(object):
|
||||
def __init__(self, id, route, stops):
|
||||
self.id = id
|
||||
self.route = route
|
||||
self.stops = stops
|
||||
def __init__(self, train_proto: gtfs_realtime_pb2.FeedEntity):
|
||||
self.train_proto: gtfs_realtime_pb2.FeedEntity = train_proto
|
||||
|
||||
def get_arrival_at(self, stop_id):
|
||||
def get_arrival_at(self, stop_id) -> int | None:
|
||||
"""
|
||||
returns the routes stop time at a given stop ID in minutes
|
||||
if not found, returns None
|
||||
:param stop_id: stop ID of arrival station
|
||||
:return: arrival time in minutes
|
||||
"""
|
||||
for stop in self.stops:
|
||||
if stop.id == stop_id:
|
||||
return stop.arrival_minutes()
|
||||
for stop_time_update in self.train_proto.trip_update.stop_time_update:
|
||||
if stop_time_update.stop_id == stop_id:
|
||||
return trip_arrival_in_minutes(stop_time_update)
|
||||
return None
|
||||
|
||||
def arriving_at_station_in_time(self, station_id, max_time):
|
||||
for stop in self.stops:
|
||||
minutes_to_arrival = stop.arrival_minutes()
|
||||
if stop.id == station_id:
|
||||
if minutes_to_arrival > 0 and minutes_to_arrival < max_time:
|
||||
return True
|
||||
|
||||
def _get_route(self) -> str:
|
||||
return self.train_proto.trip_update.trip.route_id
|
||||
def get_route(self) -> Route:
|
||||
return Route(self.train_proto.trip_update.trip.route_id)
|
||||
|
||||
def has_trips(self) -> bool:
|
||||
return self.train_proto.trip_update is not None \
|
||||
and len(self.train_proto.trip_update.stop_time_update) > 0 and is_valid_route(self._get_route())
|
||||
|
||||
def __str__(self):
|
||||
formatted_stops = '\n'.join([str(stop) for stop in self.stops])
|
||||
formatted_stops = '\n'.join([str(stop) for stop in self.stops])
|
||||
return f"train_id:{self.id} | line_name:{self.route}| stops:\n {formatted_stops}"
|
||||
|
||||
@staticmethod
|
||||
def get_train_from_dict(obj):
|
||||
if "trip_update" in obj and "stop_time_update" in obj["trip_update"]:
|
||||
# data we need is here create object
|
||||
id = obj["id"]
|
||||
route = obj["trip_update"]["trip"]["route_id"]
|
||||
all_stops = [Stop.get_stop_from_dict(x) for x in obj["trip_update"]["stop_time_update"]]
|
||||
valid_stops = [valid_stop for valid_stop in all_stops if valid_stop is not None]
|
||||
return Train(id, route, valid_stops)
|
||||
else:
|
||||
return None
|
||||
Reference in New Issue
Block a user