Source code for flatland.envs.agent_chains

import networkx as nx
import numpy as np

from typing import List, Tuple, Set, Union
import graphviz as gv


[docs] class MotionCheck(object): """ Class to find chains of agents which are "colliding" with a stopped agent. This is to allow close-packed chains of agents, ie a train of agents travelling at the same speed with no gaps between them, """ def __init__(self): self.G = nx.DiGraph() self.Grev = nx.DiGraph() # reversed graph for finding predecessors self.nDeadlocks = 0 self.svDeadlocked = set() self._G_reversed: Union[nx.DiGraph, None] = None
[docs] def get_G_reversed(self): #if self._G_reversed is None: # self._G_reversed = self.G.reverse() #return self._G_reversed return self.Grev
[docs] def reset_G_reversed(self): self._G_reversed = None
[docs] def addAgent(self, iAg, rc1, rc2, xlabel=None): """ add an agent and its motion as row,col tuples of current and next position. The agent's current position is given an "agent" attribute recording the agent index. If an agent does not want to move this round (rc1 == rc2) then a self-loop edge is created. xlabel is used for test cases to give a label (see graphviz) """ # Agents which have not yet entered the env have position None. # Substitute this for the row = -1, column = agent index if rc1 is None: rc1 = (-1, iAg) if rc2 is None: rc2 = (-1, iAg) self.G.add_node(rc1, agent=iAg) if xlabel: self.G.nodes[rc1]["xlabel"] = xlabel self.G.add_edge(rc1, rc2) self.Grev.add_edge(rc2, rc1)
[docs] def find_stops(self): """ find all the stopped agents as a set of rc position nodes A stopped agent is a self-loop on a cell node. """ # get the (sparse) adjacency matrix spAdj = nx.linalg.adjacency_matrix(self.G) # the stopped agents appear as 1s on the diagonal # the where turns this into a list of indices of the 1s giStops = np.where(spAdj.diagonal())[0] # convert the cell/node indices into the node rc values lvAll = list(self.G.nodes()) # pick out the stops by their indices lvStops = [lvAll[i] for i in giStops] # make it into a set ready for a set intersection svStops = set(lvStops) return svStops
[docs] def find_stops2(self): """ alternative method to find stopped agents, using a networkx call to find selfloop edges """ svStops = {u for u, v in nx.classes.function.selfloop_edges(self.G)} return svStops
[docs] def find_stop_preds(self, svStops=None): """ Find the predecessors to a list of stopped agents (ie the nodes / vertices) Returns the set of predecessors. Includes "chained" predecessors. """ if svStops is None: svStops = self.find_stops2() # Get all the chains of agents - weakly connected components. # Weakly connected because it's a directed graph and you can traverse a chain of agents # in only one direction lWCC = list(nx.algorithms.components.weakly_connected_components(self.G)) svBlocked = set() reversed_G = None for oWCC in lWCC: if (len(oWCC) == 1): continue # print("Component:", len(oWCC), oWCC) # Get the node details for this WCC in a subgraph Gwcc = self.G.subgraph(oWCC) # Find all the stops in this chain or tree svCompStops = svStops.intersection(Gwcc) # print(svCompStops) if len(svCompStops) > 0: if reversed_G is None: reversed_G = self.get_G_reversed() # We need to traverse it in reverse - back up the movement edges Gwcc_rev = reversed_G.subgraph(oWCC) # Gwcc.reverse() for vStop in svCompStops: # Find all the agents stopped by vStop by following the (reversed) edges # This traverses a tree - dfs = depth first seearch iter_stops = nx.algorithms.traversal.dfs_postorder_nodes(Gwcc_rev, vStop) lStops = list(iter_stops) svBlocked.update(lStops) # the set of all the nodes/agents blocked by this set of stopped nodes return svBlocked
[docs] def find_swaps(self): """ find all the swap conflicts where two agents are trying to exchange places. These appear as simple cycles of length 2. These agents are necessarily deadlocked (since they can't change direction in flatland) - meaning they will now be stuck for the rest of the episode. """ # svStops = self.find_stops2() llvLoops = list(nx.algorithms.cycles.simple_cycles(self.G)) llvSwaps = [lvLoop for lvLoop in llvLoops if len(lvLoop) == 2] svSwaps = {v for lvSwap in llvSwaps for v in lvSwap} return svSwaps
[docs] def find_swaps2(self) -> Set[Tuple[int, int]]: svSwaps = set() sEdges = self.G.edges() for u, v in sEdges: if u == v: # print("self loop", u, v) pass else: if (v, u) in sEdges: # print("swap", uv) svSwaps.update([u, v]) return svSwaps
[docs] def find_same_dest(self): """ find groups of agents which are trying to land on the same cell. ie there is a gap of one cell between them and they are both landing on it. """ pass
[docs] def block_preds(self, svStops, color="red"): """ Take a list of stopped agents, and apply a stop color to any chains/trees of agents trying to head toward those cells. Count the number of agents blocked, ignoring those which are already marked. (Otherwise it can double count swaps) """ iCount = 0 svBlocked = set() if len(svStops) == 0: return svBlocked # The reversed graph allows us to follow directed edges to find affected agents. Grev = self.get_G_reversed() for v in svStops: # Use depth-first-search to find a tree of agents heading toward the blocked cell. lvPred = list(nx.traversal.dfs_postorder_nodes(Grev, source=v)) svBlocked |= set(lvPred) svBlocked.add(v) # print("node:", v, "set", svBlocked) # only count those not already marked for v2 in [v] + lvPred: if self.G.nodes[v2].get("color") != color: self.G.nodes[v2]["color"] = color iCount += 1 return svBlocked
[docs] def find_conflicts(self): self.reset_G_reversed() svStops = self.find_stops2() # voluntarily stopped agents - have self-loops # svSwaps = self.find_swaps() # deadlocks - adjacent head-on collisions svSwaps = self.find_swaps2() # faster version of find_swaps # Block all swaps and their tree of predessors self.svDeadlocked = self.block_preds(svSwaps, color="purple") # Take the union of the above, and find all the predecessors # svBlocked = self.find_stop_preds(svStops.union(svSwaps)) # Just look for the tree of preds for each voluntarily stopped agent svBlocked = self.find_stop_preds(svStops) # iterate the nodes v with their predecessors dPred (dict of nodes->{}) for (v, dPred) in self.G.pred.items(): # mark any swaps with purple - these are directly deadlocked # if v in svSwaps: # self.G.nodes[v]["color"] = "purple" # If they are not directly deadlocked, but are in the union of stopped + deadlocked # elif v in svBlocked: # if in blocked, it will not also be in a swap pred tree, so no need to worry about overwriting if v in svBlocked: self.G.nodes[v]["color"] = "red" # not blocked but has two or more predecessors, ie >=2 agents waiting to enter this node elif len(dPred) > 1: # if this agent is already red/blocked, ignore. CHECK: why? # certainly we want to ignore purple so we don't overwrite with red. if self.G.nodes[v].get("color") in ("red", "purple"): continue # if this node has no agent, and >=2 want to enter it. if self.G.nodes[v].get("agent") is None: self.G.nodes[v]["color"] = "blue" # this node has an agent and >=2 want to enter else: self.G.nodes[v]["color"] = "magenta" # predecessors of a contended cell: {agent index -> node} diAgCell = {self.G.nodes[vPred].get("agent"): vPred for vPred in dPred} # remove the agent with the lowest index, who wins iAgWinner = min(diAgCell) diAgCell.pop(iAgWinner) # Block all the remaining predessors, and their tree of preds # for iAg, v in diAgCell.items(): # self.G.nodes[v]["color"] = "red" # for vPred in nx.traversal.dfs_postorder_nodes(self.G.reverse(), source=v): # self.G.nodes[vPred]["color"] = "red" self.block_preds(diAgCell.values(), "red")
[docs] def check_motion(self, iAgent, rcPos): """ Returns tuple of boolean can the agent move, and the cell it will move into. If agent position is None, we use a dummy position of (-1, iAgent) """ if rcPos is None: rcPos = (-1, iAgent) dAttr = self.G.nodes.get(rcPos) # print("pos:", rcPos, "dAttr:", dAttr) if dAttr is None: dAttr = {} # If it's been marked red or purple then it can't move if "color" in dAttr: sColor = dAttr["color"] if sColor in ["red", "purple"]: return False dSucc = self.G.succ[rcPos] # This should never happen - only the next cell of an agent has no successor if len(dSucc) == 0: print(f"error condition - agent {iAgent} node {rcPos} has no successor") return False # This agent has a successor rcNext = self.G.successors(rcPos).__next__() if rcNext == rcPos: # the agent didn't want to move return False # The agent wanted to move, and it can return True
[docs] def render(omc: MotionCheck, horizontal=True): try: oAG = nx.drawing.nx_agraph.to_agraph(omc.G) oAG.layout("dot") sDot = oAG.to_string() if horizontal: sDot = sDot.replace('{', '{ rankdir="LR" ') # return oAG.draw(format="png") # This returns a graphviz object which implements __repr_svg return gv.Source(sDot) except ImportError as oError: print("Flatland agent_chains ignoring ImportError - install pygraphviz to render graphs") return None
[docs] class ChainTestEnv(object): """ Just for testing agent chains """ def __init__(self, omc: MotionCheck): self.iAgNext = 0 self.iRowNext = 1 self.omc = omc
[docs] def addAgent(self, rc1, rc2, xlabel=None): self.omc.addAgent(self.iAgNext, rc1, rc2, xlabel=xlabel) self.iAgNext += 1
[docs] def addAgentToRow(self, c1, c2, xlabel=None): self.addAgent((self.iRowNext, c1), (self.iRowNext, c2), xlabel=xlabel)
[docs] def create_test_chain(self, nAgents: int, rcVel: Tuple[int] = (0, 1), liStopped: List[int] = [], xlabel=None): """ create a chain of agents """ lrcAgPos = [(self.iRowNext, i * rcVel[1]) for i in range(nAgents)] for iAg, rcPos in zip(range(nAgents), lrcAgPos): if iAg in liStopped: rcVel1 = (0, 0) else: rcVel1 = rcVel self.omc.addAgent(iAg + self.iAgNext, rcPos, (rcPos[0] + rcVel1[0], rcPos[1] + rcVel1[1])) if xlabel: self.omc.G.nodes[lrcAgPos[0]]["xlabel"] = xlabel self.iAgNext += nAgents self.iRowNext += 1
[docs] def nextRow(self): self.iRowNext += 1
[docs] def create_test_agents(omc: MotionCheck): # blocked chain omc.addAgent(1, (1, 2), (1, 3)) omc.addAgent(2, (1, 3), (1, 4)) omc.addAgent(3, (1, 4), (1, 5)) omc.addAgent(31, (1, 5), (1, 5)) # unblocked chain omc.addAgent(4, (2, 1), (2, 2)) omc.addAgent(5, (2, 2), (2, 3)) # blocked short chain omc.addAgent(6, (3, 1), (3, 2)) omc.addAgent(7, (3, 2), (3, 2)) # solitary agent omc.addAgent(8, (4, 1), (4, 2)) # solitary stopped agent omc.addAgent(9, (5, 1), (5, 1)) # blocked short chain (opposite direction) omc.addAgent(10, (6, 4), (6, 3)) omc.addAgent(11, (6, 3), (6, 3)) # swap conflict omc.addAgent(12, (7, 1), (7, 2)) omc.addAgent(13, (7, 2), (7, 1))
[docs] def create_test_agents2(omc: MotionCheck): # blocked chain cte = ChainTestEnv(omc) cte.create_test_chain(4, liStopped=[3], xlabel="stopped\nchain") cte.create_test_chain(4, xlabel="running\nchain") cte.create_test_chain(2, liStopped=[1], xlabel="stopped \nshort\n chain") cte.addAgentToRow(1, 2, "swap") cte.addAgentToRow(2, 1) cte.nextRow() cte.addAgentToRow(1, 2, "chain\nswap") cte.addAgentToRow(2, 3) cte.addAgentToRow(3, 2) cte.nextRow() cte.addAgentToRow(1, 2, "midchain\nstop") cte.addAgentToRow(2, 3) cte.addAgentToRow(3, 4) cte.addAgentToRow(4, 4) cte.addAgentToRow(5, 6) cte.addAgentToRow(6, 7) cte.nextRow() cte.addAgentToRow(1, 2, "midchain\nswap") cte.addAgentToRow(2, 3) cte.addAgentToRow(3, 4) cte.addAgentToRow(4, 3) cte.addAgentToRow(5, 4) cte.addAgentToRow(6, 5) cte.nextRow() cte.addAgentToRow(1, 2, "Land on\nSame") cte.addAgentToRow(3, 2) cte.nextRow() cte.addAgentToRow(1, 2, "chains\nonto\nsame") cte.addAgentToRow(2, 3) cte.addAgentToRow(3, 4) cte.addAgentToRow(5, 4) cte.addAgentToRow(6, 5) cte.addAgentToRow(7, 6) cte.nextRow() cte.addAgentToRow(1, 2, "3-way\nsame") cte.addAgentToRow(3, 2) cte.addAgent((cte.iRowNext + 1, 2), (cte.iRowNext, 2)) cte.nextRow() if False: cte.nextRow() cte.nextRow() cte.addAgentToRow(1, 2, "4-way\nsame") cte.addAgentToRow(3, 2) cte.addAgent((cte.iRowNext + 1, 2), (cte.iRowNext, 2)) cte.addAgent((cte.iRowNext - 1, 2), (cte.iRowNext, 2)) cte.nextRow() cte.nextRow() cte.addAgentToRow(1, 2, "Tee") cte.addAgentToRow(2, 3) cte.addAgentToRow(3, 4) cte.addAgent((cte.iRowNext + 1, 3), (cte.iRowNext, 3)) cte.nextRow() cte.nextRow() cte.addAgentToRow(1, 2, "Tree") cte.addAgentToRow(2, 3) cte.addAgentToRow(3, 4) r1 = cte.iRowNext r2 = cte.iRowNext + 1 r3 = cte.iRowNext + 2 cte.addAgent((r2, 3), (r1, 3)) cte.addAgent((r2, 2), (r2, 3)) cte.addAgent((r3, 2), (r2, 3)) cte.nextRow()
[docs] def test_agent_following(): omc = MotionCheck() create_test_agents2(omc) svStops = omc.find_stops() svBlocked = omc.find_stop_preds() llvSwaps = omc.find_swaps() svSwaps = {v for lvSwap in llvSwaps for v in lvSwap} print(list(svBlocked)) lvCells = omc.G.nodes() lColours = ["magenta" if v in svStops else "red" if v in svBlocked else "purple" if v in svSwaps else "lightblue" for v in lvCells] dPos = dict(zip(lvCells, lvCells)) nx.draw(omc.G, with_labels=True, arrowsize=20, pos=dPos, node_color=lColours)
[docs] def main(): test_agent_following()
if __name__ == "__main__": main()