%load_ext autoreload
%autoreload 2
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import matplotlib as mpl
import networkx as nx
import numpy as np

from flatland.core.graph.graph_rendering import get_positions, add_flatland_styling
from flatland.core.graph.graph_simplification import DecisionPointGraph
from flatland.core.graph.grid_to_graph import GraphTransitionMap
from flatland.envs.rail_env_utils import env_creator
from flatland.core.grid.rail_env_grid import RailEnvTransitions, RailEnvTransitionsEnum
from flatland.core.transition_map import GridTransitionMap
from flatland.utils.graphics_pil import PILSVG
mpl.rcParams['figure.max_open_warning'] = 0

Flatland Graph Demo#

This notebook illustrates the directed graph representation of Flatland.

A Flatland (microscopic) topology can be represented by different kinds of graphs. The topology must reflect the possible paths through the rail network - it must not be possible to traverse a switch in the acute angle. With the help of the graph it is very easy to calculate the shortest connection from node A to node B. The API makes it possible to solve such tasks very efficiently. Moreover, the graph can be simplified so that only decision-relevant nodes remain in the graph and all other nodes are merged. A decision node is a node or flatland cell (track) that reasonably allows the agent to stop, go, or branch off. For straight track edges within a route, it makes little sense to wait in many situations. This is because the agent would block many resources, i.e., if an agent does not drive to the decision point: a cell before a crossing, the agent blocks the area in between. This makes little sense from an optimization point of view.

Two (dual, equivalent) approaches are possible:

  • agents are positioned on the nodes

  • agents are positioned on the edges. The second approach makes it easier to visualize agents moving forward on edges. Hence, we choose the second approach.

Our directed graph consists of nodes and edges:

  • A node in the graph is defined by position and direction. The position corresponds to the position of the underlying cell in the original flatland topology, and the direction corresponds to the direction in which an agent reaches the cell. Thus, the node is defined by (r, c, d), where c (column) is the index of the horizontal cell grid position, r (row) is the index of the vertical cell grid position, and d (direction) is the direction of cell entry. In the Flatland (2d grid), not every of the eight neighbor cells can be reached from every direction. Therefore, the entry direction information is key.

  • An edge is defined by “from-node” u and “to-node” v such that for the edge e = (u, v). Edges reflect feasible transition from node u to node v.

The implementation uses networkX, so there are also many graph functions available.

References:

Create env#

env = env_creator()
/Users/che/workspaces/flatland-rl/flatland/envs/rail_generators.py:303: UserWarning: Could not set all required cities! Created 1/2
  warnings.warn(city_warning)
/Users/che/workspaces/flatland-rl/flatland/envs/rail_generators.py:215: UserWarning: [WARNING] Changing to Grid mode to place at least 2 cities.
  warnings.warn("[WARNING] Changing to Grid mode to place at least 2 cities.")

Transform to directed graph#

The directed graph is an equivalent representation of the grid transition map. It reflects the railway topology, ie. the paths a train can take.

GraphTransitionMap represents a Flatland 3 transition map by a directed graph.

The grid transition map contains for all cells a set of pairs (heading at cell entry, heading at cell exit). E.g. horizontal straight is {(E,E), (W,W)}. The directed graph’s nodes are entry pins (cell + plus heading at entry). Edges always go from entry pin at one cell to entry pin of a neighboring cell. The outgoing heading for the grid transition map is the incoming heading at a neighboring cell.

Incoming heading:

               S
               ⌄
               |
       E   >---+---< W
               |
               ^
               N

Outgoing heading (=incoming at neighbor cell):

               N (of cell-to-the-north)
               ^
               |
       E   <---+---> E (of cell-to-the-east)
(of cell-to-   |
 the-east)     ⌄
               S (of cell-to-the-south)
micro = GraphTransitionMap.grid_to_digraph(env.rail)
fig, ax = plt.subplots(1)
micro1 = nx.subgraph_view(micro, filter_edge=lambda u, v: len(list(micro.successors(v))) == 1)
nx.draw_networkx(micro1,
                 pos=get_positions(micro1),
                 ax=ax,
                 node_size=2,
                 with_labels=False,
                 arrows=False
                 )
micro2 = nx.subgraph_view(micro, filter_node=lambda v: len(list(micro.successors(v))) == 2)
nx.draw_networkx(micro2,
                 pos=get_positions(micro2),
                 ax=ax,
                 node_size=8,
                 node_color="red",
                 with_labels=False,
                 )
micro3 = nx.subgraph_view(micro, filter_edge=lambda u, v: len(list(micro.successors(v))) == 2)
nx.draw_networkx(micro3,
                 pos=get_positions(micro3),
                 ax=ax,
                 arrows=True,
                 node_size=3,
                 with_labels=False
                 )

add_flatland_styling(env, ax)

ax.set_title('Railway Topology as Directed Graph')

fig.set_size_inches(15,15)
# fig.savefig('graph_demo.png', dpi=100)
../_images/68b5b005ba291cd085540eb28140d8113da1f4ca13d4d78a3592ba82ffbfae7a.png

Deep Dive Basic Railway Elements#

Here are the 10 basic railway elements with their graph equivalent.

title

pil = PILSVG(1,1)
assert len(set([e.value for e in RailEnvTransitionsEnum])) == 30
for i, e in enumerate(RailEnvTransitionsEnum):
    transition = e.value

    fig, axs = plt.subplots(1)
    # use 3 x 3 not to go -1
    rail_map = np.array(
        [[RailEnvTransitionsEnum.empty] * 3] +
        [[RailEnvTransitionsEnum.empty, transition, RailEnvTransitionsEnum.empty]] +
        [[RailEnvTransitionsEnum.empty] * 3], dtype=np.uint16)

    gtm = GridTransitionMap(width=rail_map.shape[1], height=rail_map.shape[0], transitions=RailEnvTransitions())
    gtm.grid = rail_map
    ax = axs #[i]
    ax.set_ylim(3 - 0.5, -0.5)
    ax.tick_params(left=True, bottom=True, labelleft=True, labelbottom=True)
    ax.set_xticks(np.arange(0, 3, 1))
    ax.set_yticks(np.arange(0, 3, 1))
    img = pil.pil_rail[transition]
    img = np.fliplr(np.rot90(np.rot90(img)))
    ax.imshow(img, extent=[0.5, 1.5, 0.5, 1.5])
    ax.set_xticks(np.arange(-0.5, 2.5, 1), minor=True)
    ax.set_yticks(np.arange(-0.5, 2.5, 1), minor=True)
    ax.set_xlim(-0.5,2.5)
    ax.set_ylim(-0.5,2.5)
    ax.tick_params(left=True, bottom=True, labelleft=True, labelbottom=True)
    ax.grid(which="minor")
    g = GraphTransitionMap.grid_to_digraph(gtm)

    nx.draw_networkx(
         g,
         pos=get_positions(g, delta=0.05),
         ax=ax,
         node_size=10,
         with_labels=False,
         #font_size=5,
         arrows=True
    )
    ax.set_title(e)
../_images/5682b957d6ab8e2f8e12699d8758155452018acc13d17df6232f6512763ebe83.png ../_images/2bb7e2c8daf10a63a0e28497ce04161c13a710110ffa8a69acfc324ba88037d4.png ../_images/d346c19379b23aa9705d99842f6bdb18b27c83ada2d232df8d619d562f422174.png ../_images/f8e9a81d1610854c8c436ac9df748372df1e475ee4ba8da3fcf15a69ba15dc3e.png ../_images/bb6187666193d073908cafce3f18231bfd1f8e8d594f0d933ee55d342887c4ea.png ../_images/cdd2bf721ce43b1d40f32cbb508d2b10a9d68e730be4f98d220218a9e0ddc654.png ../_images/53bf178e88f92aba157610986ef8a7b12db587f690c39a78f2f47b4e77483c1a.png ../_images/5fbbbcd38029809dfe009bd6d9d7b19d5e6193bccaa61866cb327e7e2cc198f8.png ../_images/3ff2c54960c1d48a3dfd7eb41d14b80c5caef1fd25f4c15784e7023f797b506f.png ../_images/be01267bac7c04ad1215191e071b1fb2c72d566258ac775b9ad659119ddd86bc.png ../_images/dc85c576898367d4308bdc9acc28774386ee827873e1f83887a25845354f9d2f.png ../_images/b7d699995fa63de25fb9864ea3483f154a02807238112f904df0b2bc54dcf58e.png ../_images/7c161e6bb8d52fe03fcf062cf370b44a38806b49818cf048e94e6b9fe49c1fc9.png ../_images/a5fe5702ed5999669d92286f6fb2730a9d1284edff5a53120ad839dae1dd640c.png ../_images/8f545b44e2ba5a11ce909fd8a09e5e3a6e3ad3f996e8464821899f35fece4b9b.png ../_images/f60d9464f7f4dfa1c052d5712bce5ab38f9aacda017fef557532964e7c9f886b.png ../_images/578712071dce13cfc55e98412a57ee90f9b73c03491d504b819b02c8245bd8a0.png ../_images/24a3cd7c14a0cd8f79df50a5a1e2fcf5ec73a5fcf8baf1cff4c07ee6843eedf2.png ../_images/72f0b3d11c8b595331ee9ed081c404227684db977d42229fef45eef0eb0c3ba9.png ../_images/cda792aa3864ce17939faee7b5f0f82f4d2a0857e4bd6fa1b7d019c29be84a63.png ../_images/04136d8911d2594d2c02bc696eb4603a8fe95676b9d2802f0081e4cf8cb1ae6f.png ../_images/d7f09c548b95b276a599b22babbc958e6842af6ed9a39b9d30c3b0002fde59c3.png ../_images/6ef2d49593213ce65560ab4ef586ec73db7035051a2fa3b5cd8120447e678efa.png ../_images/171710525bed4a3e9cb9f94535aaa65c5dd1bf644edb025d8421dd0177ef2b80.png ../_images/96191fadb2dce0a2b2b8e240def521546f0093efa1243cfaba58988cc2283ed8.png ../_images/e8b20006decd0a52c322c1a0b362b33eeeccd727111029779913bc6c1a7a3169.png ../_images/53b7993073b27b63f817e613294c953e8122b7f47859bf11e348825ba6a7b7c2.png ../_images/7d57e5bfc0e5171861e032d4244ced310bfae01107752fb33f9f8cd98bb83904.png ../_images/6d10355d582808ca2d0521ca2eb12d289c9db53fb181b13be49c2ff9f2852269.png ../_images/03df46cd5ad2ca143ef5901a202063e10d2abbc7291aa1bab02dcfcd3e4541bc.png

Simplify Graph to Decision Point Graph#

DecisionPointGraph is an overlay on top of Flatland 3 grid where consecutive cells where agents cannot choose between alternative paths are collapsed into a single edge. A reference to the underlying grid nodes is maintained. The edge length is the number of cells “collapsed” into this edge.

gtm = GraphTransitionMap(micro)
decision_point_graph = DecisionPointGraph.fromGraphTransitionMap(gtm)
collapsed = decision_point_graph.g

Comparison#

fig, axs = plt.subplots(1, 2)
micro1 = nx.subgraph_view(micro, filter_edge=lambda u, v: len(list(micro.successors(v))) == 1)
nx.draw_networkx(micro1,
                 pos=get_positions(micro1),
                 ax=axs[0],
                 node_size=2,
                 with_labels=False,
                 arrows=False
                 )
micro2 = nx.subgraph_view(micro, filter_node=lambda v: len(list(micro.successors(v))) == 2)
nx.draw_networkx(micro2,
                 pos=get_positions(micro2),
                 ax=axs[0],
                 node_size=8,
                 node_color="red",
                 with_labels=False,
                 )
micro3 = nx.subgraph_view(micro, filter_edge=lambda u, v: len(list(micro.successors(v))) == 2)
nx.draw_networkx(micro3,
                 pos=get_positions(micro3),
                 ax=axs[0],
                 arrows=True,
                 node_size=1,
                 with_labels=False
                 )

nx.draw_networkx(collapsed,
                 pos=get_positions(collapsed),
                 ax=axs[1],
                 node_size=2,
                 with_labels=False
                 )
add_flatland_styling(env, axs[1])
add_flatland_styling(env, axs[0])

axs[0].set_title('micro')
axs[1].set_title('collapsed')

fig.set_size_inches(30,15)
# fig.savefig('graph_demo.png', dpi=100)
../_images/f70949e61577450e2cde36cff6f0fa0ce3c15cb6b251312480c8bd898656b1a2.png