Source code for flatland.integrations.interactiveai.interactiveai

"""
FlatlandCallbacks for InteractiveAI (https://github.com/AI4REALNET/InteractiveAI).

The callbacks create context and events during a scenario run.
If an InteractiveAI instance is up and running, the callbacks  send out HTTP POST requests to InteractiveAI contexts and events REST API endpoints.
In this notebook, we just log the contexts and events that would be sent out.

- The agent positions are sent as context, with geo-coordinates for display on a map.
- Agent malfunctions are sent as events.
"""
import datetime
import logging
import time
import warnings
from pathlib import Path
from typing import Optional, Dict, Tuple

import requests

from flatland.callbacks.callbacks import FlatlandCallbacks
from flatland.envs.rail_env import RailEnv
from flatland.integrations.interactiveai import event_api, context_api, historic_api
from flatland.integrations.interactiveai.context_api import ContextApiApi, ContextIn
from flatland.integrations.interactiveai.event_api import EventApiApi
from flatland.integrations.interactiveai.event_api import EventIn
from flatland.integrations.interactiveai.historic_api import HistoricApiApi

logger = logging.getLogger(__name__)


[docs] class FlatlandInteractiveAICallbacks(FlatlandCallbacks): def __init__(self, coordinate_map: Dict[Tuple[int, int], Tuple[float, float]], client_id: str = "opfab-client", username: str = "admin", password: str = "test", token_url: str = "http://frontend/auth/token", event_api_host="http://localhost:5001", context_api_host="http://localhost:5100", history_api_host="http://localhost:5200", step_to_millis: int = 1000, collect_only: bool = False, now: datetime.datetime = None, ): """ Parameters ---------- coordinate_map : Dict[Tuple[int, int], Tuple[float, float]] map (row,column) to (lat,lon) client_id: InteractiveAI client_id username : str InteractiveAI username password : str InteractiveAI password token_url : str InteractiveAI token_url event_api_host: str Defaults to http://localhost:5001", context_api_host : str Defaults to "http://localhost:5100", history_api_host : str Defaults to "http://localhost:5200", step_to_millis : int how many millis delay till next `env.step()` collect_only : bool Do not send out requests but collect the events and context dicts in the callbacks object. now : datetime.datetime Real-world time for `env._elapsed_steps == 0`. Defaults to `datetime.datetime.now()`. """ self.token_url = token_url self.client_id = client_id self.username = username self.password = password self.access_token = None self.event_api_host = event_api_host self.context_api_host = context_api_host self.history_api_host = history_api_host self.context_api: ContextApiApi = None self.historic_api: HistoricApiApi = None self.events_api: EventApiApi = None self.in_malfunction = {} self.collect_only = collect_only self.step_to_millis = step_to_millis self.events = [] self.contexts = [] self.coordinate_map = coordinate_map self.now = now if self.now is None: self.now = datetime.datetime.now()
[docs] def connect(self): access_token = self._get_access_token_password_grant(token_url=self.token_url, client_id=self.client_id, username=self.username, password=self.password) print(access_token) self.events_api = EventApiApi(event_api.ApiClient(event_api.Configuration(host=self.event_api_host))) self.events_api.api_client.set_default_header("Authorization", f"Bearer {access_token}") self.context_api = ContextApiApi(context_api.ApiClient(context_api.Configuration(host=self.context_api_host))) self.context_api.api_client.set_default_header("Authorization", f"Bearer {access_token}") self.historic_api = HistoricApiApi(historic_api.ApiClient(historic_api.Configuration(host=self.history_api_host))) self.historic_api.api_client.set_default_header("Authorization", f"Bearer {access_token}") self.access_token = access_token
def _get_access_token_password_grant(self, token_url, client_id, username, password) -> str: payload = f"username={username}&password={password}&grant_type=password&clientId={client_id}" headers = {"Content-Type": "application/x-www-form-urlencoded"} response = requests.request("POST", token_url, headers=headers, data=payload) response.raise_for_status() json_response = response.json() return json_response.get("access_token")
[docs] def on_episode_step( self, *, env: Optional[RailEnv] = None, data_dir: Path = None, **kwargs, ) -> None: """Called on each episode step (after the action(s) has/have been logged). This callback is also called after the final step of an episode, meaning when terminated/truncated are returned as True from the `env.step()` call. The exact time of the call of this callback is after `env.step([action])` and also after the results of this step (observation, reward, terminated, truncated, infos) have been logged to the given `episode` object. """ if self.access_token is None and not self.collect_only: self.connect() for agent in env.agents: prev = self.in_malfunction.get(agent.handle, False) cur = agent.malfunction_handler.in_malfunction self.in_malfunction[agent.handle] = cur if not prev and cur: event = { "use_case": "Railway", "criticality": "LOW", "title": f"Malfunction Train {agent.handle}", "description": f"Malfunction Train {agent.handle}", "start_date": self.now, "end_date": self.now + datetime.timedelta(milliseconds=agent.malfunction_handler.malfunction_down_counter * self.step_to_millis), "data": { "event_type": "Malfunction", "delay": 300, "agent_id": f"{agent.handle}", "id_train": f"{agent.handle}", } } self.events.append(event) if not self.collect_only: self.events_api.api_v1_events_post(EventIn.from_dict(event)) # TODO use non-blocking calls or queue? for agent in env.agents: if agent.position is not None and agent.position not in self.coordinate_map: warnings.warn(f"Missing mapping for {agent.position}") context = { "use_case": "Railway", "data": { "trains": [ { 'id_train': f'Train {agent.handle}', 'latitude': f"{self.coordinate_map[agent.position][0]}", 'longitude': f'{self.coordinate_map[agent.position][1]}', } for agent in env.agents if agent.position is not None # TODO and agent.position in self.coordinate_map] } } self.contexts.append(context) if not self.collect_only: self.context_api.api_v1_contexts_post(ContextIn.from_dict(context)) time.sleep(self.step_to_millis * 0.001)