Source code for flatland.envs.graph.graph_simplification

from typing import List, Tuple

import networkx as nx
from attr import attrs, attrib

from flatland.envs.graph.rail_graph_transition_map import GraphTransitionMap

GridNode = Tuple[int, int, int]  # row, column, heading (at cell entry)


[docs] @attrs class DecisionPointGraphEdgeData: """ The edge data of a decision point overlay graph. Attributes ---------- path: List[GridNode] The list of collapsed cells, starting on a facing switch (e.g. where a decision has been taken upon entering). len: int Number of collapsed cells. """ path = attrib(type=List[GridNode]) len = attrib(type=int)
[docs] class DecisionPointGraph: """ Overlay on top of Flatland 3 grid where consecutive cells where agents cannot choose between alternative paths are collapsed into a single edge. A reference to the underlying grid nodes is maintained. The edge length is the number of cells "collapsed" into this edge. See `DecisionPointGraphEdgeData`. Attributes ---------- g: nx.MultiDiGraph The decision point graph. Nodes have type `GridNode` and edge data has an attributed `d` of type `DecisionPointGraphEdgeData`. """ def __init__(self, g: nx.MultiDiGraph): self.g = g @staticmethod def _explore_branch(g: nx.DiGraph, u: GridNode, v: GridNode) -> List[GridNode]: branch = [u, v] successors = list(g.successors(v)) assert len(successors) > 0 while len(successors) == 1: successor = successors[0] branch.append(successor) if successor == u: # loop break successors = list(g.successors(successor)) assert len(successors) > 0 return branch
[docs] @staticmethod def fromGraphTransitionMap(gtm: GraphTransitionMap) -> "DecisionPointGraph": """ Factory method to derive `DecisionPointGraph` from `GraphTransitionMap`. Parameters ---------- gtm: GraphTransitionMap The Flatland 3 graph transition map. Returns ------- The overlaid decision point graph. """ # multiple paths can be between consecutive decision points g = nx.MultiDiGraph() # find decision points (ie. nodes with more than one successor (=neighbor in the directed graph)) micro = gtm.g decision_nodes = {s for s in micro.nodes if len(list(micro.successors(s))) > 1} # add edge starting at each decision point closed = set() for dp in decision_nodes: for n in micro.successors(dp): branch = DecisionPointGraph._explore_branch(micro, dp, n) g.add_edge(branch[0], branch[-1], d=DecisionPointGraphEdgeData(path=branch, len=len(branch))) for u, v, in zip(branch, branch[1:]): closed.add((u, v)) # special cases closed loops open = set(micro.edges) - closed while not len(open) == 0: u, v = next(iter(open)) branch = DecisionPointGraph._explore_branch(micro, u, v) g.add_edge(branch[0], branch[-1], d=DecisionPointGraphEdgeData(path=branch, len=len(branch))) for u_, v_, in zip(branch, branch[1:]): open.discard((u_, v_)) return DecisionPointGraph(g)