Source code for flatland.utils.jupyter_utils
from typing import List, NamedTuple
import numpy as np
from IPython import display
from ipycanvas import canvas
from flatland.envs.rail_env import RailEnvActions
from flatland.utils.rendertools import RenderTool
[docs]
class Behaviour():
def __init__(self, env):
self.env = env
self.nAg = len(env.agents)
[docs]
def getActions(self):
return {}
[docs]
class AlwaysForward(Behaviour):
[docs]
def getActions(self):
return {i: RailEnvActions.MOVE_FORWARD for i in range(self.nAg)}
[docs]
class DelayedStartForward(AlwaysForward):
def __init__(self, env, nStartDelay=2):
self.nStartDelay = nStartDelay
super().__init__(env)
[docs]
def getActions(self):
iStep = self.env._elapsed_steps + 1
nAgentsMoving = min(self.nAg, iStep // self.nStartDelay)
return {i: RailEnvActions.MOVE_FORWARD for i in range(nAgentsMoving)}
AgentPause = NamedTuple("AgentPause",
[
("iAg", int),
("iPauseAt", int),
("iPauseFor", int)
])
[docs]
class ForwardWithPause(Behaviour):
def __init__(self, env, lPauses: List[AgentPause]):
self.env = env
self.nAg = len(env.agents)
self.lPauses = lPauses
self.dAgPaused = {}
[docs]
def getActions(self):
iStep = self.env._elapsed_steps + 1 # add one because this is called before step()
# new pauses starting this step
lNewPauses = [tPause for tPause in self.lPauses if tPause.iPauseAt == iStep]
# copy across the agent index and pause length
for pause in lNewPauses:
self.dAgPaused[pause.iAg] = pause.iPauseFor
# default action is move forward
dAction = {i: RailEnvActions.MOVE_FORWARD for i in range(self.nAg)}
# overwrite paused agents with stop
for iAg in self.dAgPaused:
dAction[iAg] = RailEnvActions.STOP_MOVING
# decrement the counters for each pause, and remove any expired pauses.
lFinished = []
for iAg in self.dAgPaused:
self.dAgPaused[iAg] -= 1
if self.dAgPaused[iAg] <= 0:
lFinished.append(iAg)
for iAg in lFinished:
self.dAgPaused.pop(iAg, None)
return dAction
[docs]
class Deterministic(Behaviour):
def __init__(self, env, dAg_lActions):
super().__init__(env)
self.dAg_lActions = dAg_lActions
[docs]
def getActions(self):
iStep = self.env._elapsed_steps
dAg_Action = {}
for iAg, lActions in self.dAg_lActions.items():
if iStep < len(lActions):
iAct = lActions[iStep]
else:
iAct = RailEnvActions.DO_NOTHING
dAg_Action[iAg] = iAct
# print(iStep, dAg_Action[0])
return dAg_Action
[docs]
class EnvCanvas():
def __init__(self, env, behaviour: Behaviour = None):
self.env = env
self.iStep = 0
if behaviour is None:
behaviour = AlwaysForward(env)
self.behaviour = behaviour
self.oRT = RenderTool(env, show_debug=True)
self.oCan = canvas.Canvas(size=(600, 300))
self.render()
[docs]
def render(self):
self.oRT.render_env(show_rowcols=True, show_inactive_agents=False, show_observations=False)
gIm = self.oRT.get_image()
red_channel = gIm[:, :, 0]
blue_channel = gIm[:, :, 1]
green_channel = gIm[:, :, 2]
image_data = np.stack((red_channel, blue_channel, green_channel), axis=2)
self.oCan.put_image_data(image_data)
[docs]
def step(self):
dAction = self.behaviour.getActions()
self.env.step(dAction)
[docs]
def show(self):
self.render()
display.display(self.oCan)